1. Before you begin
TensorFlow.js model usage has grown exponentially over the past few years and many JavaScript developers are now looking to take existing state-of-the-art models and retrain them to work with custom data that is unique to their industry. The act of taking an existing model (often referred to as a base model), and using it on a similar but different domain is known as transfer learning.
Transfer learning has many advantages over starting from a completely blank model. You can reuse knowledge already learned from a prior trained model, and you require fewer examples of the new item you wish to classify. Also, training is often significantly faster due to only having to retrain the final few layers of the model architecture instead of the whole network. For this reason, transfer learning is very well suited for the web browser environment where resources may vary based on the device of execution, but also has direct access to the sensors for easy data acquisition.
This codelab shows you how to build a web app from a blank canvas, recreating Google's popular " Teachable Machine" website. The website lets you create a functional web app that any user can use to recognize a custom object with just a few example images from their webcam. The website is purposely kept minimal so that you can focus on the Machine Learning aspects of this codelab. As with the original Teachable Machine website, however, there is plenty of scope to apply your existing web developer experience to improve the UX.
Prerequisites
This codelab is written for web developers who are somewhat familiar with TensorFlow.js pre-made models and basic API usage, and who want to get started with transfer learning inTensorFlow.js.
- Basic familiarity with TensorFlow.js, HTML5, CSS, and JavaScript is assumed for this lab.
If you are new to Tensorflow.js, consider taking this free zero to hero course first, which assumes no background with Machine Learning or TensorFlow.js, and teaches you everything you need to know in smaller steps.
What you'll learn
- What TensorFlow.js is and why you should use it in your next web app.
- How to build a simplified HTML/CSS /JS webpage that replicates the Teachable Machine user experience.
- How to use TensorFlow.js to load a pre-trained base model, specifically MobileNet, to generate image features that can be used in transfer learning.
- How to gather data from a user's webcam for multiple classes of data that you want to recognize.
- How to create and define a multi-layer perceptron that takes the image features and learns to classify new objects using them.
Let's get hacking...
What you'll need
- A Glitch.com account is preferred to follow along, or you can use a web serving environment you are comfortable editing and running yourself.
2. What is TensorFlow.js?
TensorFlow.js is an open source machine learning library that can run anywhere JavaScript can. It's based upon the original TensorFlow library written in Python and aims to re-create this developer experience and set of APIs for the JavaScript ecosystem.
Where can it be used?
Given the portability of JavaScript, you can now write in 1 language and perform machine learning across all of the following platforms with ease:
- Client side in the web browser using vanilla JavaScript
- Server side and even IoT devices like Raspberry Pi using Node.js
- Desktop apps using Electron
- Native mobile apps using React Native
TensorFlow.js also supports multiple backends within each of these environments (the actual hardware based environments it can execute within such as the CPU or WebGL for example. A "backend" in this context does not mean a server side environment - the backend for execution could be client side in WebGL for example) to ensure compatibility and also keep things running fast. Currently TensorFlow.js supports:
- WebGL execution on the device's graphics card (GPU) - this is the fastest way to execute larger models (over 3MB in size) with GPU acceleration.
- Web Assembly (WASM) execution on CPU - to improve CPU performance across devices including older generation mobile phones for example. This is better suited to smaller models (less than 3MB in size) which can actually execute faster on CPU with WASM than with WebGL due to the overhead of uploading content to a graphics processor.
- CPU execution - the fallback should none of the other environments be available. This is the slowest of the three but is always there for you.
Note: You can choose to force one of these backends if you know what device you will be executing on, or you can simply let TensorFlow.js decide for you if you do not specify this.
Client side super powers
Running TensorFlow.js in the web browser on the client machine can lead to several benefits that are worth considering.
Privacy
You can both train and classify data on the client machine without ever sending data to a 3rd party web server. There may be times where this may be a requirement to comply with local laws, such as GDPR for example, or when processing any data that the user may want to keep on their machine and not sent to a 3rd party.
Speed
As you are not having to send data to a remote server, inference (the act of classifying the data) can be faster. Even better, you have direct access to the device's sensors such as the camera, microphone, GPS, accelerometer and more should the user grant you access.
Reach and scale
With one click anyone in the world can click a link you send them, open the web page in their browser, and utilise what you have made. No need for a complex server side Linux setup with CUDA drivers and much more just to use the machine learning system.
Cost
No servers means the only thing you need to pay for is a CDN to host your HTML, CSS, JS, and model files. The cost of a CDN is much cheaper than keeping a server (potentially with a graphics card attached) running 24/7.
Server side features
Leveraging the Node.js implementation of TensorFlow.js enables the following features.
Full CUDA support
On the server side, for graphics card acceleration, you must install the NVIDIA CUDA drivers to enable TensorFlow to work with the graphics card (unlike in the browser which uses WebGL - no install needed). However with full CUDA support you can fully leverage the graphics card's lower level abilities, leading to faster training and inference times. Performance is on parity with the Python TensorFlow implementation as they both share the same C++ backend.
Model Size
For cutting edge models from research, you may be working with very large models, maybe gigabytes in size. These models can not currently be run in the web browser due to the limitations of memory usage per browser tab. To run these larger models you can use Node.js on your own server with the hardware specifications you require to run such a model efficiently.
IOT
Node.js is supported on popular single board computers like the Raspberry Pi, which in turn means you can execute TensorFlow.js models on such devices too.
Speed
Node.js is written in JavaScript which means that it benefits from just in time compilation. This means that you may often see performance gains when using Node.js as it will be optimized at runtime, especially for any preprocessing you may be doing. A great example of this can be seen in this case study which shows how Hugging Face used Node.js to get a 2x performance boost for their natural language processing model.
Now you understand the basics of TensorFlow.js, where it can run, and some of the benefits, let's start doing useful things with it!
3. Transfer learning
What exactly is transfer learning?
Transfer learning involves taking knowledge that has already been learned to help learn a different but similar thing.
We humans do this all the time. You have a lifetime of experiences contained in your brain that you can use to help recognize new things you have never seen before. Take this willow tree for example:
Depending on where you are in the world there is a chance you may not have seen this type of tree before.
Yet if I ask you to tell me if there are any willow trees in the new image below, you can probably spot them pretty fast, even though they are at a different angle, and slightly different to the original one I showed you.
You already have a bunch of neurons in your brain that know how to identify tree-like objects, and other neurons that are good at finding long straight lines. You can reuse that knowledge to quickly classify a willow tree, which is a tree-like object that has lots of long straight vertical branches.
Similarly, if you have a machine learning model that is already trained on a domain, such as recognizing images, you can re-use it to perform a different but related task.
You can do the same with an advanced model like MobileNet, which is a very popular research model that can perform image recognition on 1000 different object types. From dogs to cars, it was trained on a huge dataset known as ImageNet that has millions of labelled images.
In this animation, you can see the huge number of layers it has in this MobileNet V1 model:
During its training, this model learned how to extract common features that matter for all of those 1000 objects, and many of the lower level features it uses to identify such objects can be useful to detect new objects it has never seen before too. After all, everything is ultimately just a combination of lines, textures, and shapes.
Let's take a look at a traditional Convolutional Neural Network (CNN) architecture (similar to MobileNet) and see how transfer learning can leverage this trained network to learn something new. The image below shows the typical model architecture of a CNN that in this case was trained to recognize handwritten digits from 0 to 9:
If you could separate the pre-trained lower level layers of an existing trained model like this shown on the left, from the classification layers near the end of the model shown on the right (sometimes referred to as the classification head of the model), you could use the lower level layers to produce output features for any given image based on the original data it was trained on. Here is the same network with the classification head removed:
Assuming the new thing you are trying to recognize can also make use of such output features the prior model has learned, then there is a good chance they can be reused for a new purpose.
In the diagram above, this hypothetical model was trained on digits, so maybe what was learned about digits can also be applied to letters like a, b, and c.
So now you could add a new classification head that tries to predict a, b, or c instead, as shown:
Here the lower level layers are frozen and are not trained, only the new classification head will update itself to learn from the features provided from the pre-trained chopped up model on the left.
The act of doing this is known as transfer learning and is what Teachable Machine does behind the scenes.
You can also see that by only having to train the multi-layer perceptron at the very end of the network, it trains much faster than if you had to train the whole network from scratch.
But how can you get your hands on sub-parts of a model? Head to the next section to find out.
4. TensorFlow Hub - base models
Find a suitable base model to use
For more advanced and popular research models like MobileNet, you can go to TensorFlow hub, and then filter for models suitable for TensorFlow.js that use the MobileNet v3 architecture to find results like the ones shown here:
Note that some of these results are of type "image classification" (detailed at the top left of each model card result), and others are of type "image feature vector."
These Image Feature Vector results are essentially the pre-chopped up versions of MobileNet that you can use to get the image feature vectors instead of the final classification.
Models like this are often called "base models," which you can then use to perform transfer learning in the same manner as shown in the prior section by adding a new classification head and training it with your own data.
The next thing to check is for a given base model of interest what TensorFlow.js format the model is released as. If you open the page for one of these feature vector MobileNet v3 models you can see from the JS documentation that it's in the form of a graph model based on the example code snippet in the documentation which uses tf.loadGraphModel()
.
It should also be noted that if you find a model in the layers format instead of the graph format you can choose which layers to freeze and which to unfreeze for training. This can be very powerful when creating a model for a new task, which is often referred to as the "transfer model." For now, however, you'll use the default graph model type for this tutorial, which most TF Hub models are deployed as. To learn more about working with Layers models, check out the zero to hero TensorFlow.js course.
Advantages of transfer learning
What are the advantages of using transfer learning instead of training the whole model architecture from scratch?
First, the training time is a key advantage to using a transfer learning approach as you already have a trained base model to build upon.
Secondly, you can get away with showing far fewer examples of the new thing you are trying to classify due to the training that has already taken place.
This is really great if you have limited time and resources to gather example data of the thing you want to classify, and need to make a prototype quickly before gathering more training data to make it more robust.
Given the need for less data and the speed of training a smaller network, transfer learning is less resource intensive. That makes it very suitable for the browser environment, taking just tens of seconds on a modern machine instead of hours, days, or weeks for full model training.
Alright! Now you know the essence of what Transfer Learning is, it's time to create your very own version of Teachable Machine. Let's get started!
5. Get set up to code
What you'll need
- A modern web browser.
- Basic knowledge of HTML, CSS, JavaScript, and Chrome DevTools (viewing the console output).
Let's get coding
Boilerplate templates to start from have been created for Glitch.com or Codepen.io. You can simply clone either template as your base state for this code lab, in just one click.
On Glitch, click the "remix this" button to fork it and make a new set of files you can edit.
Alternatively, on Codepen, click" fork" in the lower bottom right of the screen.
This very simple skeleton provides you with the following files:
- HTML page (index.html)
- Stylesheet (style.css)
- File to write our JavaScript code (script.js)
For your convenience, there is an added import in the HTML file for the TensorFlow.js library. It looks like this:
index.html
<!-- Import TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js" type="text/javascript"></script>
Alternative: Use your preferred webeditor or work locally
If you want to download the code and work locally, or on a different online editor, simply create the 3 files named above in the same directory and copy and paste the code from our Glitch boilerplate into each of them.
6. App HTML boilerplate
Where do I start?
All prototypes require some basic HTML scaffolding you can render your findings on. Set that up now. You are going to add:
- A title for the page.
- Some descriptive text.
- A status paragraph.
- A video to hold the webcam feed once ready.
- Several buttons to start the camera, collect data, or reset the experience.
- Imports for TensorFlow.js and JS files you will code later.
Open index.html
and paste over the existing code with the following to set up the above features:
index.html
<!DOCTYPE html>
<html lang="en">
<head>
<title>Transfer Learning - TensorFlow.js</title>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1">
<!-- Import the webpage's stylesheet -->
<link rel="stylesheet" href="/style.css">
</head>
<body>
<h1>Make your own "Teachable Machine" using Transfer Learning with MobileNet v3 in TensorFlow.js using saved graph model from TFHub.</h1>
<p id="status">Awaiting TF.js load</p>
<video id="webcam" autoplay muted></video>
<button id="enableCam">Enable Webcam</button>
<button class="dataCollector" data-1hot="0" data-name="Class 1">Gather Class 1 Data</button>
<button class="dataCollector" data-1hot="1" data-name="Class 2">Gather Class 2 Data</button>
<button id="train">Train & Predict!</button>
<button id="reset">Reset</button>
<!-- Import TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.11.0/dist/tf.min.js" type="text/javascript"></script>
<!-- Import the page's JavaScript to do some stuff -->
<script type="module" src="/script.js"></script>
</body>
</html>
Break it down
Let's break some of the above HTML code down to highlight some key things you added.
- You added an
<h1>
tag for the page title along with a<p>
tag with an ID of ‘status,' which is where you will print information to, as you use different parts of the system to view outputs. - You added a
<video>
element with an ID of ‘webcam,' to which you will render your webcam stream later. - You added 5
<button>
elements. The first, with the ID of ‘enableCam,' enables the camera. The next two buttons have a class of ‘dataCollector,' which lets you gather example images for the objects you want to recognize. The code you write later will be designed so that you can add any number of these buttons and they will work as intended automatically.
Note that these buttons also have a special user-defined attribute called data-1hot, with an integer value starting from 0 for the first class. This is the numerical index you'll use to represent a certain class's data. The index will be used to encode the output classes correctly with a numerical representation instead of a string, as ML models can only work with numbers.
There is also a data-name attribute that contains the human readable name you want to use for this class, which lets you provide a more meaningful name to the user instead of a numerical index value from the 1 hot encoding.
Finally, you have a train and reset button to kick off the training process once data has been collected, or to reset the app respectively.
- You also added 2
<script>
imports. One for TensorFlow.js, and the other for script.js that you will define shortly.
7. Add style
Element defaults
Add styles for the HTML elements you just added to ensure they render correctly. Here are some styles that are added to position and size elements correctly. Nothing too special. You could certainly add to this later to make an even better UX, like you saw in the teachable machine video.
style.css
body {
font-family: helvetica, arial, sans-serif;
margin: 2em;
}
h1 {
font-style: italic;
color: #FF6F00;
}
video {
clear: both;
display: block;
margin: 10px;
background: #000000;
width: 640px;
height: 480px;
}
button {
padding: 10px;
float: left;
margin: 5px 3px 5px 10px;
}
.removed {
display: none;
}
#status {
font-size:150%;
}
Great! That's all you need. If you preview the output right now, it should look something like this:
8. JavaScript: Key constants and listeners
Define key constants
First, add some key constants you'll use throughout the app. Start by replacing the contents of script.js
with these constants:
script.js
const STATUS = document.getElementById('status');
const VIDEO = document.getElementById('webcam');
const ENABLE_CAM_BUTTON = document.getElementById('enableCam');
const RESET_BUTTON = document.getElementById('reset');
const TRAIN_BUTTON = document.getElementById('train');
const MOBILE_NET_INPUT_WIDTH = 224;
const MOBILE_NET_INPUT_HEIGHT = 224;
const STOP_DATA_GATHER = -1;
const CLASS_NAMES = [];
Let's break down what these are for:
STATUS
simply holds a reference to the paragraph tag you will write status updates to.VIDEO
holds a reference to the HTML video element that will render the webcam feed.ENABLE_CAM_BUTTON
,RESET_BUTTON
, andTRAIN_BUTTON
grab DOM references to all the key buttons from the HTML page.MOBILE_NET_INPUT_WIDTH
andMOBILE_NET_INPUT_HEIGHT
define the expected input width and height of the MobileNet model respectively. By storing this in a constant near the top of the file like this, if you decide to use a different version later, it makes it easier to update the values once instead of having to replace it in many different places.STOP_DATA_GATHER
is set to - 1. This stores a state value so you know when the user has stopped clicking a button to gather data from the webcam feed. By giving this number a more meaningful name, it makes the code more readable later.CLASS_NAMES
acts as a lookup and holds the human readable names for the possible class predictions. This array will be populated later.
OK, now that you have references to key elements, it is time to associate some event listeners to them.
Add key event listeners
Start by adding click event handlers to key buttons as shown:
script.js
ENABLE_CAM_BUTTON.addEventListener('click', enableCam);
TRAIN_BUTTON.addEventListener('click', trainAndPredict);
RESET_BUTTON.addEventListener('click', reset);
function enableCam() {
// TODO: Fill this out later in the codelab!
}
function trainAndPredict() {
// TODO: Fill this out later in the codelab!
}
function reset() {
// TODO: Fill this out later in the codelab!
}
ENABLE_CAM_BUTTON
- calls the enableCam function when clicked.
TRAIN_BUTTON
- calls trainAndPredict when clicked.
RESET_BUTTON
- calls reset when clicked.
Finally in this section you can find all buttons that have a class of ‘dataCollector' using document.querySelectorAll()
. This returns an array of elements found from the document that match:
script.js
let dataCollectorButtons = document.querySelectorAll('button.dataCollector');
for (let i = 0; i < dataCollectorButtons.length; i++) {
dataCollectorButtons[i].addEventListener('mousedown', gatherDataForClass);
dataCollectorButtons[i].addEventListener('mouseup', gatherDataForClass);
// Populate the human readable names for classes.
CLASS_NAMES.push(dataCollectorButtons[i].getAttribute('data-name'));
}
function gatherDataForClass() {
// TODO: Fill this out later in the codelab!
}
Code explanation:
You then iterate through the found buttons and associate 2 event listeners to each. One for ‘mousedown', and one for ‘mouseup'. This lets you keep recording samples as long as the button is pressed, which is useful for data collection.
Both events call a gatherDataForClass
function that you will define later.
At this point, you can also push the found human readable class names from the HTML button attribute data-name to the CLASS_NAMES
array.
Next, add some variables to store key things that will be used later.
script.js
let mobilenet = undefined;
let gatherDataState = STOP_DATA_GATHER;
let videoPlaying = false;
let trainingDataInputs = [];
let trainingDataOutputs = [];
let examplesCount = [];
let predict = false;
Let's walk through those.
First, you have a variable mobilenet
to store the loaded mobilenet model. Initially set this to undefined.
Next, you have a variable called gatherDataState
. If a ‘dataCollector' button is pressed, this changes to be the 1 hot ID of that button instead, as defined in the HTML, so you know what class of data you are collecting at that moment. Initially, this is set to STOP_DATA_GATHER
so that the data gather loop you write later will not gather any data when no buttons are being pressed.
videoPlaying
keeps track of whether the webcam stream is successfully loaded and playing and is available to use. Initially, this is set to false
as the webcam is not on until you press the ENABLE_CAM_BUTTON.
Next, define 2 arrays, trainingDataInputs
and trainingDataOutputs
. These store the gathered training data values, as you click the ‘dataCollector' buttons for the input features generated by MobileNet base model, and the output class sampled respectively.
One final array, examplesCount,
is then defined to keep track of how many examples are contained for each class once you start adding them.
Finally, you have a variable called predict
that controls your prediction loop. This is set to false
initially. No predictions can take place until this is set to true
later.
Now that all the key variables have been defined, let's go and load the pre-chopped up MobileNet v3 base model that provides image feature vectors instead of classifications.
9. Load the MobileNet base model
First, define a new function called loadMobileNetFeatureModel
as shown below. This must be an async function as the act of loading a model is asynchronous:
script.js
/**
* Loads the MobileNet model and warms it up so ready for use.
**/
async function loadMobileNetFeatureModel() {
const URL =
'https://tfhub.dev/google/tfjs-model/imagenet/mobilenet_v3_small_100_224/feature_vector/5/default/1';
mobilenet = await tf.loadGraphModel(URL, {fromTFHub: true});
STATUS.innerText = 'MobileNet v3 loaded successfully!';
// Warm up the model by passing zeros through it once.
tf.tidy(function () {
let answer = mobilenet.predict(tf.zeros([1, MOBILE_NET_INPUT_HEIGHT, MOBILE_NET_INPUT_WIDTH, 3]));
console.log(answer.shape);
});
}
// Call the function immediately to start loading.
loadMobileNetFeatureModel();
In this code you define the URL
where the model to load is located from the TFHub documentation.
You can then load the model using await tf.loadGraphModel()
, remembering to set the special property fromTFHub
to true
as you are loading a model from this Google website. This is a special case only for using models hosted on TF Hub where this extra property has to be set.
Once loading is complete you can set the STATUS
element's innerText
with a message so you can visually see it has loaded correctly and you are ready to start gathering data.
The only thing left to do now is to warm up the model. With larger models like this, the first time you use the model, it can take a moment to set everything up. Therefore it helps to pass zeros through the model to avoid any waiting in the future where timing may be more critical.
You can use tf.zeros()
wrapped in a tf.tidy()
to ensure tensors are disposed of correctly, with a batch size of 1, and the correct height and width that you defined in your constants at the start. Finally, you also specify the color channels, which in this case is 3 as the model expects RGB images.
Next, log the resulting shape of the tensor returned using answer.shape()
to help you understand the size of the image features this model produces.
After defining this function, you can then call it immediately to initiate the model download on the page load.
If you view your live preview right now, after a few moments you will see the status text change from "Awaiting TF.js load" to become "MobileNet v3 loaded successfully!" as shown below. Ensure this works before continuing.
You can also check the console output to see the printed size of the output features that this model produces. After running zeros through the MobileNet model, you'll see a shape of [1, 1024]
printed. The first item is just the batch size of 1, and you can see it actually returns 1024 features that can then be used to help you classify new objects.
10. Define the new model head
Now it's time to define your model head, which is essentially a very minimal multi-layer perceptron.
script.js
let model = tf.sequential();
model.add(tf.layers.dense({inputShape: [1024], units: 128, activation: 'relu'}));
model.add(tf.layers.dense({units: CLASS_NAMES.length, activation: 'softmax'}));
model.summary();
// Compile the model with the defined optimizer and specify a loss function to use.
model.compile({
// Adam changes the learning rate over time which is useful.
optimizer: 'adam',
// Use the correct loss function. If 2 classes of data, must use binaryCrossentropy.
// Else categoricalCrossentropy is used if more than 2 classes.
loss: (CLASS_NAMES.length === 2) ? 'binaryCrossentropy': 'categoricalCrossentropy',
// As this is a classification problem you can record accuracy in the logs too!
metrics: ['accuracy']
});
Let's walk through this code. You start by defining a tf.sequential model to which you will add model layers to.
Next, add a dense layer as the input layer to this model. This has an input shape of 1024
as the outputs from the MobileNet v3 features are of this size. You discovered this in the previous step after passing ones through the model. This layer has 128 neurons that use the ReLU activation function.
If you are new to activation functions and model layers, consider taking the course detailed at the start of this workshop to understand what these properties do behind the scenes.
The next layer to add is the output layer. The number of neurons should equal the number of classes you are trying to predict. To do that you can use CLASS_NAMES.length
to find how many classes you are planning to classify, which is equal to the number of data gather buttons found in the user interface. As this is a classification problem, you use the softmax
activation on this output layer, which must be used when trying to create a model to solve classification problems instead of regression.
Now print a model.summary()
to print the overview of the newly defined model to the console.
Finally, compile the model so it is ready to be trained. Here the optimizer is set to adam
, and the loss will either be binaryCrossentropy
if CLASS_NAMES.length
is equal to 2
, or it will use categoricalCrossentropy
if there are 3 or more classes to classify. Accuracy metrics are also requested so they can be monitored in the logs later for debugging purposes.
In the console you should see something like this:
Note that this has over 130 thousand trainable parameters. But as this is a simple dense layer of regular neurons it will train pretty fast.
As an activity to do once the project is complete, you could try changing the number of neurons in the first layer to see how low you can make it while still getting decent performance. Often with machine learning there is some level of trial and error to find optimal parameter values to give you the best trade off between resource usage and speed.
11. Enable the webcam
It's now time to flesh out the enableCam()
function you defined earlier. Add a new function named hasGetUserMedia()
as shown below and then replace the contents of the previously defined enableCam()
function with the corresponding code below.
script.js
function hasGetUserMedia() {
return !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
}
function enableCam() {
if (hasGetUserMedia()) {
// getUsermedia parameters.
const constraints = {
video: true,
width: 640,
height: 480
};
// Activate the webcam stream.
navigator.mediaDevices.getUserMedia(constraints).then(function(stream) {
VIDEO.srcObject = stream;
VIDEO.addEventListener('loadeddata', function() {
videoPlaying = true;
ENABLE_CAM_BUTTON.classList.add('removed');
});
});
} else {
console.warn('getUserMedia() is not supported by your browser');
}
}
First, create a function named hasGetUserMedia()
to check if the browser supports getUserMedia()
by checking for the existence of key browser APIs properties.
In the enableCam()
function use the hasGetUserMedia()
function you just defined above to check if it is supported. If it isn't, print a warning to the console.
If it supports it, define some constraints for your getUserMedia()
call, such as you want the video stream only, and that you prefer the width
of the video to be 640
pixels in size, and the height
to be 480
pixels. Why? Well, there is not much point getting a video larger than this as it would need to be resized to 224 by 224 pixels to be fed into the MobileNet model. You may as well save some computing resources by requesting a smaller resolution. Most cameras support a resolution of this size.
Next, call navigator.mediaDevices.getUserMedia()
with the constraints
detailed above, and then wait for the stream
to be returned. Once the stream
is returned you can get your VIDEO
element to play the stream
by setting it as its srcObject
value.
You should also add an eventListener on the VIDEO
element to know when the stream
has loaded and is playing successfully.
Once the steam loads, you can set videoPlaying
to true and remove the ENABLE_CAM_BUTTON
to prevent it from being clicked again by setting its class to "removed
".
Now run your code, click the enable camera button, and allow access to the webcam. If it is your first time doing this, you should see yourself rendered to the video element on the page as shown:
Ok, now it is time to add a function to deal with the dataCollector
button clicks.
12. Data collection button event handler
Now it's time to fill out your currently empty function called gatherDataForClass().
This is what you assigned as your event handler function for dataCollector
buttons at the start of the codelab.
script.js
/**
* Handle Data Gather for button mouseup/mousedown.
**/
function gatherDataForClass() {
let classNumber = parseInt(this.getAttribute('data-1hot'));
gatherDataState = (gatherDataState === STOP_DATA_GATHER) ? classNumber : STOP_DATA_GATHER;
dataGatherLoop();
}
First, check the data-1hot
attribute on the currently clicked button by calling this.getAttribute()
with the attribute's name, in this case data-1hot
as the parameter. As this is a string, you can then use parseInt()
to cast it to an integer and assign this result to a variable named classNumber.
Next, set the gatherDataState
variable accordingly. If the current gatherDataState
is equal to STOP_DATA_GATHER
(which you set to be -1), then that means you are not currently gathering any data and it was a mousedown
event that fired. Set the gatherDataState
to be the classNumber
you just found.
Otherwise, it means that you are currently gathering data and the event that fired was a mouseup
event, and you now want to stop gathering data for that class. Just set it back to the STOP_DATA_GATHER
state to end the data gathering loop you will define shortly.
Finally, kick off the call to dataGatherLoop(),
which actually performs the recording of class data.
13. Data collection
Now, define the dataGatherLoop()
function. This function is responsible for sampling images from the webcam video, passing them through the MobileNet model, and capturing the outputs of that model (the 1024 feature vectors).
It then stores them along with the gatherDataState
ID of the button that is currently being pressed so you know what class this data represents.
Let's walk through it:
script.js
function dataGatherLoop() {
if (videoPlaying && gatherDataState !== STOP_DATA_GATHER) {
let imageFeatures = tf.tidy(function() {
let videoFrameAsTensor = tf.browser.fromPixels(VIDEO);
let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [MOBILE_NET_INPUT_HEIGHT,
MOBILE_NET_INPUT_WIDTH], true);
let normalizedTensorFrame = resizedTensorFrame.div(255);
return mobilenet.predict(normalizedTensorFrame.expandDims()).squeeze();
});
trainingDataInputs.push(imageFeatures);
trainingDataOutputs.push(gatherDataState);
// Intialize array index element if currently undefined.
if (examplesCount[gatherDataState] === undefined) {
examplesCount[gatherDataState] = 0;
}
examplesCount[gatherDataState]++;
STATUS.innerText = '';
for (let n = 0; n < CLASS_NAMES.length; n++) {
STATUS.innerText += CLASS_NAMES[n] + ' data count: ' + examplesCount[n] + '. ';
}
window.requestAnimationFrame(dataGatherLoop);
}
}
You are only going to continue this function's execution if videoPlaying
is true, meaning that the webcam is active, and gatherDataState
is not equal to STOP_DATA_GATHER
and a button for class data gathering is currently being pressed.
Next, wrap your code in a tf.tidy()
to dispose of any created tensors in the code that follows. The result of this tf.tidy()
code execution is stored in a variable called imageFeatures
.
You can now grab a frame of the webcam VIDEO
using tf.browser.fromPixels()
. The resulting tensor containing the image data is stored in a variable called videoFrameAsTensor
.
Next, resize the videoFrameAsTensor
variable to be of the correct shape for the MobileNet model's input. Use a tf.image.resizeBilinear()
call with the tensor you want to reshape as the first parameter, and then a shape that defines the new height and width as defined by the constants you already created earlier. Finally, set align corners to true by passing the third parameter to avoid any alignment issues when resizing. The result of this resize is stored in a variable called resizedTensorFrame
.
Note that this primitive resize stretches the image, as your webcam image is 640 by 480 pixels in size, and the model needs a square image of 224 by 224 pixels.
For the purposes of this demo this should work fine. However, once you complete this codelab, you may want to try and crop a square from this image instead for even better results for any production system you may create later.
Next, normalize the image data. Image data is always in the range of 0 to 255 when using tf.browser.frompixels()
, so you can simply divide resizedTensorFrame by 255 to ensure all values are between 0 and 1 instead, which is what the MobileNet model expects as inputs.
Finally, in the tf.tidy()
section of the code, push this normalized tensor through the loaded model by calling mobilenet.predict()
, to which you pass the expanded version of the normalizedTensorFrame
using expandDims()
so that it is a batch of 1, as the model expects a batch of inputs for processing.
Once the result comes back, you can then immediately call squeeze()
on that returned result to squash it back down to a 1D tensor, which you then return and assign to the imageFeatures
variable that captures the result from tf.tidy()
.
Now that you have the imageFeatures
from the MobileNet model, you can record those by pushing them onto the trainingDataInputs
array that you defined previously.
You can also record what this input represents by pushing the current gatherDataState
to the trainingDataOutputs
array too.
Note that the gatherDataState
variable would have been set to the current class's numerical ID you are recording data for when the button was clicked in the previously defined gatherDataForClass()
function.
At this point you can also increment the number of examples you have for a given class. To do this, first check if the index within the examplesCount
array has been initialized before or not. If it is undefined, set it to 0 to initialize the counter for a given class's numerical ID, and then you can increment the examplesCount
for the current gatherDataState
.
Now update the STATUS
element's text on the web page to show the current counts for each class as they're captured. To do this, loop through the CLASS_NAMES
array, and print the human readable name combined with the data count at the same index in examplesCount
.
Finally, call window.requestAnimationFrame()
with dataGatherLoop
passed as a parameter, to recursively call this function again. This will continue to sample frames from the video until the button's mouseup
is detected, and gatherDataState
is set to STOP_DATA_GATHER,
at which point the data gather loop will end.
If you run your code now, you should be able to click the enable camera button, await the webcam to load, and then click and hold each of the data gather buttons to gather examples for each class of data. Here you see me gather data for my mobile phone and my hand respectively.
You should see the status text updated as it stores all the tensors in memory as shown in the screen capture above.
14. Train and predict
The next step is to implement code for your currently empty trainAndPredict()
function, which is where the transfer learning takes place. Let's take a look at the code:
script.js
async function trainAndPredict() {
predict = false;
tf.util.shuffleCombo(trainingDataInputs, trainingDataOutputs);
let outputsAsTensor = tf.tensor1d(trainingDataOutputs, 'int32');
let oneHotOutputs = tf.oneHot(outputsAsTensor, CLASS_NAMES.length);
let inputsAsTensor = tf.stack(trainingDataInputs);
let results = await model.fit(inputsAsTensor, oneHotOutputs, {shuffle: true, batchSize: 5, epochs: 10,
callbacks: {onEpochEnd: logProgress} });
outputsAsTensor.dispose();
oneHotOutputs.dispose();
inputsAsTensor.dispose();
predict = true;
predictLoop();
}
function logProgress(epoch, logs) {
console.log('Data for epoch ' + epoch, logs);
}
First, ensure you stop any current predictions from taking place by setting predict
to false
.
Next, shuffle your input and output arrays using tf.util.shuffleCombo()
to ensure the order does not cause issues in training.
Convert your output array, trainingDataOutputs,
to be a tensor1d of type int32 so it is ready to be used in a one hot encoding. This is stored in a variable named outputsAsTensor
.
Use the tf.oneHot()
function with this outputsAsTensor
variable along with the max number of classes to encode, which is just the CLASS_NAMES.length
. Your one hot encoded outputs are now stored in a new tensor called oneHotOutputs
.
Note that currently trainingDataInputs
is an array of recorded tensors. In order to use these for training you will need to convert the array of tensors to become a regular 2D tensor.
To do that there is a great function within the TensorFlow.js library called tf.stack()
,
which takes an array of tensors and stacks them together to produce a higher dimensional tensor as an output. In this case a tensor 2D is returned, that's a batch of 1 dimensional inputs that are each 1024 in length containing the features recorded, which is what you need for training.
Next, await model.fit()
to train the custom model head. Here you pass your inputsAsTensor
variable along with the oneHotOutputs
to represent the training data to use for example inputs and target outputs respectively. In the configuration object for the 3rd parameter, set shuffle
to true
, use batchSize
of 5
, with epochs
set to 10
, and then specify a callback
for onEpochEnd
to the logProgress
function that you will define shortly.
Finally, you can dispose of the created tensors as the model is now trained. You can then set predict
back to true
to allow predictions to take place again, and then call the predictLoop()
function to start predicting live webcam images.
You can also define the logProcess()
function to log the state of training, which is used in model.fit()
above and that prints results to console after each round of training.
You're almost there! Time to add the predictLoop()
function to make predictions.
Core prediction loop
Here you implement the main prediction loop that samples frames from a webcam and continuously predicts what is in each frame with real time results in the browser.
Let's check the code:
script.js
function predictLoop() {
if (predict) {
tf.tidy(function() {
let videoFrameAsTensor = tf.browser.fromPixels(VIDEO).div(255);
let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor,[MOBILE_NET_INPUT_HEIGHT,
MOBILE_NET_INPUT_WIDTH], true);
let imageFeatures = mobilenet.predict(resizedTensorFrame.expandDims());
let prediction = model.predict(imageFeatures).squeeze();
let highestIndex = prediction.argMax().arraySync();
let predictionArray = prediction.arraySync();
STATUS.innerText = 'Prediction: ' + CLASS_NAMES[highestIndex] + ' with ' + Math.floor(predictionArray[highestIndex] * 100) + '% confidence';
});
window.requestAnimationFrame(predictLoop);
}
}
First, check that predict
is true, so that predictions are only made after a model is trained and is available to use.
Next, you can get the image features for the current image just like you did in the dataGatherLoop()
function. Essentially, you grab a frame from the webcam using tf.browser.from pixels()
, normalise it, resize it to be 224 by 224 pixels in size, and then pass that data through the MobileNet model to get the resulting image features.
Now, however, you can use your newly trained model head to actually perform a prediction by passing the resulting imageFeatures
just found through the trained model's predict()
function. You can then squeeze the resulting tensor to make it 1 dimensional again and assign it to a variable called prediction
.
With this prediction
you can find the index that has the highest value using argMax()
, and then convert this resulting tensor to an array using arraySync()
to get at the underlying data in JavaScript to discover the position of the highest valued element. This value is stored in the variable called highestIndex
.
You can also get the actual prediction confidence scores in the same way by calling arraySync()
on the prediction
tensor directly.
You now have everything you need to update the STATUS
text with the prediction
data. To get the human readable string for the class you can just look up the highestIndex
in the CLASS_NAMES
array, and then grab the confidence value from the predictionArray
. To make it more readable as a percentage, just multiply by 100 and math.floor()
the result.
Finally, you can use window.requestAnimationFrame()
to call predictionLoop()
all over again once ready, to get real time classification on your video stream. This continues until predict
is set to false
if you choose to train a new model with new data.
Which brings you to the final piece of the puzzle. Implementing the reset button.
15. Implement the reset button
Almost complete! The final piece of the puzzle is to implement a reset button to start over. The code for your currently empty reset()
function is below. Go ahead and update it as follows:
script.js
/**
* Purge data and start over. Note this does not dispose of the loaded
* MobileNet model and MLP head tensors as you will need to reuse
* them to train a new model.
**/
function reset() {
predict = false;
examplesCount.length = 0;
for (let i = 0; i < trainingDataInputs.length; i++) {
trainingDataInputs[i].dispose();
}
trainingDataInputs.length = 0;
trainingDataOutputs.length = 0;
STATUS.innerText = 'No data collected';
console.log('Tensors in memory: ' + tf.memory().numTensors);
}
First, stop any running prediction loops by setting predict
to false
. Next, delete all contents in the examplesCount
array by setting its length to 0, which is a handy way to clear all contents from an array.
Now go through all the current recorded trainingDataInputs
and ensure you dispose()
of each tensor contained within it to free up memory again, as Tensors are not cleaned up by the JavaScript garbage collector.
Once that is done you can now safely set the array length to 0 on both the trainingDataInputs
and trainingDataOutputs
arrays to clear those too.
Finally set the STATUS
text to something sensible, and print out the tensors left in memory as a sanity check.
Note that there will be a few hundred tensors still in memory as both the MobileNet model and the multi-layer perceptron you defined are not disposed of. You will need to reuse them with new training data if you decide to train again after this reset.
16. Let's try it out
It's time to test out your very own version of Teachable Machine!
Head to the live preview, enable the webcam, gather at least 30 samples for class 1 for some object in your room, and then do the same for class 2 for a different object, click train, and check the console log to see progress. It should train pretty fast:
Once trained, show the objects to the camera to get live predictions that will be printed to the status text area on the web page near the top. If you are having trouble, check my completed working code to see if you missed copying over anything.
17. Congratulations
Congratulations! You have just completed your very first transfer learning example using TensorFlow.js live in the browser.
Try it out, test it on a variety of objects, you may notice some things are harder to recognize than others, especially if they are similar to something else. You may need to add more classes or training data to be able to tell them apart.
Recap
In this codelab you learned:
- What transfer learning is, and its advantages over training a full model.
- How to get models for re-use from TensorFlow Hub.
- How to set up a web app suitable for transfer learning.
- How to load and use a base model to generate image features.
- How to train a new prediction head that can recognize custom objects from webcam imagery.
- How to use the resulting models to classify data in real time.
What's next?
Now that you have a working base to start from, what creative ideas can you come up with to extend this machine learning model boilerplate for a real world use case you may be working on? Maybe you could revolutionize the industry that you currently work in to help folk at your company train models to classify things that are important in their day-to-day work? The possibilities are endless.
To go further, consider taking this full course for free, which shows you how to combine the 2 models you currently have in this codelab into 1 single model for efficiency.
Also if you are curious more around the theory behind the original teachable machine application check out this tutorial.
Share what you make with us
You can easily extend what you made today for other creative use cases too and we encourage you to think outside the box and keep hacking.
Remember to tag us on social media using the #MadeWithTFJS hashtag for a chance for your project to be featured on our TensorFlow blog or even future events. We would love to see what you make.
Websites to check out
- TensorFlow.js official website
- TensorFlow.js pre-made models
- TensorFlow.js API
- TensorFlow.js Show & Tell — get inspired and see what others have made.