TensorFlow.js: Make your own "Teachable Machine" using transfer learning with TensorFlow.js

1. Before you begin

TensorFlow.js model usage has grown exponentially over the past few years and many JavaScript developers are now looking to take existing state-of-the-art models and retrain them to work with custom data that is unique to their industry. The act of taking an existing model (often referred to as a base model), and using it on a similar but different domain is known as transfer learning.

Transfer learning has many advantages over starting from a completely blank model. You can reuse knowledge already learned from a prior trained model, and you require fewer examples of the new item you wish to classify. Also, training is often significantly faster due to only having to retrain the final few layers of the model architecture instead of the whole network. For this reason, transfer learning is very well suited for the web browser environment where resources may vary based on the device of execution, but also has direct access to the sensors for easy data acquisition.

This codelab shows you how to build a web app from a blank canvas, recreating Google's popular " Teachable Machine" website. The website lets you create a functional web app that any user can use to recognize a custom object with just a few example images from their webcam. The website is purposely kept minimal so that you can focus on the Machine Learning aspects of this codelab. As with the original Teachable Machine website, however, there is plenty of scope to apply your existing web developer experience to improve the UX.

Prerequisites

This codelab is written for web developers who are somewhat familiar with TensorFlow.js pre-made models and basic API usage, and who want to get started with transfer learning inTensorFlow.js.

  • Basic familiarity with TensorFlow.js, HTML5, CSS, and JavaScript is assumed for this lab.

If you are new to Tensorflow.js, consider taking this free zero to hero course first, which assumes no background with Machine Learning or TensorFlow.js, and teaches you everything you need to know in smaller steps.

What you'll learn

  • What TensorFlow.js is and why you should use it in your next web app.
  • How to build a simplified HTML/CSS /JS webpage that replicates the Teachable Machine user experience.
  • How to use TensorFlow.js to load a pre-trained base model, specifically MobileNet, to generate image features that can be used in transfer learning.
  • How to gather data from a user's webcam for multiple classes of data that you want to recognize.
  • How to create and define a multi-layer perceptron that takes the image features and learns to classify new objects using them.

Let's get hacking...

What you'll need

  • A Glitch.com account is preferred to follow along, or you can use a web serving environment you are comfortable editing and running yourself.

2. What is TensorFlow.js?

54e81d02971f53e8.png

TensorFlow.js is an open source machine learning library that can run anywhere JavaScript can. It's based upon the original TensorFlow library written in Python and aims to re-create this developer experience and set of APIs for the JavaScript ecosystem.

Where can it be used?

Given the portability of JavaScript, you can now write in 1 language and perform machine learning across all of the following platforms with ease:

  • Client side in the web browser using vanilla JavaScript
  • Server side and even IoT devices like Raspberry Pi using Node.js
  • Desktop apps using Electron
  • Native mobile apps using React Native

TensorFlow.js also supports multiple backends within each of these environments (the actual hardware based environments it can execute within such as the CPU or WebGL for example. A "backend" in this context does not mean a server side environment - the backend for execution could be client side in WebGL for example) to ensure compatibility and also keep things running fast. Currently TensorFlow.js supports:

  • WebGL execution on the device's graphics card (GPU) - this is the fastest way to execute larger models (over 3MB in size) with GPU acceleration.
  • Web Assembly (WASM) execution on CPU - to improve CPU performance across devices including older generation mobile phones for example. This is better suited to smaller models (less than 3MB in size) which can actually execute faster on CPU with WASM than with WebGL due to the overhead of uploading content to a graphics processor.
  • CPU execution - the fallback should none of the other environments be available. This is the slowest of the three but is always there for you.

Note: You can choose to force one of these backends if you know what device you will be executing on, or you can simply let TensorFlow.js decide for you if you do not specify this.

Client side super powers

Running TensorFlow.js in the web browser on the client machine can lead to several benefits that are worth considering.

Privacy

You can both train and classify data on the client machine without ever sending data to a 3rd party web server. There may be times where this may be a requirement to comply with local laws, such as GDPR for example, or when processing any data that the user may want to keep on their machine and not sent to a 3rd party.

Speed

As you are not having to send data to a remote server, inference (the act of classifying the data) can be faster. Even better, you have direct access to the device's sensors such as the camera, microphone, GPS, accelerometer and more should the user grant you access.

Reach and scale

With one click anyone in the world can click a link you send them, open the web page in their browser, and utilise what you have made. No need for a complex server side Linux setup with CUDA drivers and much more just to use the machine learning system.

Cost

No servers means the only thing you need to pay for is a CDN to host your HTML, CSS, JS, and model files. The cost of a CDN is much cheaper than keeping a server (potentially with a graphics card attached) running 24/7.

Server side features

Leveraging the Node.js implementation of TensorFlow.js enables the following features.

Full CUDA support

On the server side, for graphics card acceleration, you must install the NVIDIA CUDA drivers to enable TensorFlow to work with the graphics card (unlike in the browser which uses WebGL - no install needed). However with full CUDA support you can fully leverage the graphics card's lower level abilities, leading to faster training and inference times. Performance is on parity with the Python TensorFlow implementation as they both share the same C++ backend.

Model Size

For cutting edge models from research, you may be working with very large models, maybe gigabytes in size. These models can not currently be run in the web browser due to the limitations of memory usage per browser tab. To run these larger models you can use Node.js on your own server with the hardware specifications you require to run such a model efficiently.

IOT

Node.js is supported on popular single board computers like the Raspberry Pi, which in turn means you can execute TensorFlow.js models on such devices too.

Speed

Node.js is written in JavaScript which means that it benefits from just in time compilation. This means that you may often see performance gains when using Node.js as it will be optimized at runtime, especially for any preprocessing you may be doing. A great example of this can be seen in this case study which shows how Hugging Face used Node.js to get a 2x performance boost for their natural language processing model.

Now you understand the basics of TensorFlow.js, where it can run, and some of the benefits, let's start doing useful things with it!

3. Transfer learning

What exactly is transfer learning?

Transfer learning involves taking knowledge that has already been learned to help learn a different but similar thing.

We humans do this all the time. You have a lifetime of experiences contained in your brain that you can use to help recognize new things you have never seen before. Take this willow tree for example:

e28070392cd4afb9.png

Depending on where you are in the world there is a chance you may not have seen this type of tree before.

Yet if I ask you to tell me if there are any willow trees in the new image below, you can probably spot them pretty fast, even though they are at a different angle, and slightly different to the original one I showed you.

d9073a0d5df27222.png

You already have a bunch of neurons in your brain that know how to identify tree-like objects, and other neurons that are good at finding long straight lines. You can reuse that knowledge to quickly classify a willow tree, which is a tree-like object that has lots of long straight vertical branches.

Similarly, if you have a machine learning model that is already trained on a domain, such as recognizing images, you can re-use it to perform a different but related task.

You can do the same with an advanced model like MobileNet, which is a very popular research model that can perform image recognition on 1000 different object types. From dogs to cars, it was trained on a huge dataset known as ImageNet that has millions of labelled images.

In this animation, you can see the huge number of layers it has in this MobileNet V1 model:

7d4e1e35c1a89715.gif

During its training, this model learned how to extract common features that matter for all of those 1000 objects, and many of the lower level features it uses to identify such objects can be useful to detect new objects it has never seen before too. After all, everything is ultimately just a combination of lines, textures, and shapes.

Let's take a look at a traditional Convolutional Neural Network (CNN) architecture (similar to MobileNet) and see how transfer learning can leverage this trained network to learn something new. The image below shows the typical model architecture of a CNN that in this case was trained to recognize handwritten digits from 0 to 9:

baf4e3d434576106.png

If you could separate the pre-trained lower level layers of an existing trained model like this shown on the left, from the classification layers near the end of the model shown on the right (sometimes referred to as the classification head of the model), you could use the lower level layers to produce output features for any given image based on the original data it was trained on. Here is the same network with the classification head removed:

369a8a9041c6917d.png

Assuming the new thing you are trying to recognize can also make use of such output features the prior model has learned, then there is a good chance they can be reused for a new purpose.

In the diagram above, this hypothetical model was trained on digits, so maybe what was learned about digits can also be applied to letters like a, b, and c.

So now you could add a new classification head that tries to predict a, b, or c instead, as shown:

db97e5e60ae73bbd.png

Here the lower level layers are frozen and are not trained, only the new classification head will update itself to learn from the features provided from the pre-trained chopped up model on the left.

The act of doing this is known as transfer learning and is what Teachable Machine does behind the scenes.

You can also see that by only having to train the multi-layer perceptron at the very end of the network, it trains much faster than if you had to train the whole network from scratch.

But how can you get your hands on sub-parts of a model? Head to the next section to find out.

4. TensorFlow Hub - base models

Find a suitable base model to use

For more advanced and popular research models like MobileNet, you can go to TensorFlow hub, and then filter for models suitable for TensorFlow.js that use the MobileNet v3 architecture to find results like the ones shown here:

c5dc1420c6238c14.png

Note that some of these results are of type "image classification" (detailed at the top left of each model card result), and others are of type "image feature vector."

These Image Feature Vector results are essentially the pre-chopped up versions of MobileNet that you can use to get the image feature vectors instead of the final classification.

Models like this are often called "base models," which you can then use to perform transfer learning in the same manner as shown in the prior section by adding a new classification head and training it with your own data.

The next thing to check is for a given base model of interest what TensorFlow.js format the model is released as. If you open the page for one of these feature vector MobileNet v3 models you can see from the JS documentation that it's in the form of a graph model based on the example code snippet in the documentation which uses tf.loadGraphModel().

f97d903d2e46924b.png

It should also be noted that if you find a model in the layers format instead of the graph format you can choose which layers to freeze and which to unfreeze for training. This can be very powerful when creating a model for a new task, which is often referred to as the "transfer model." For now, however, you'll use the default graph model type for this tutorial, which most TF Hub models are deployed as. To learn more about working with Layers models, check out the zero to hero TensorFlow.js course.

Advantages of transfer learning

What are the advantages of using transfer learning instead of training the whole model architecture from scratch?

First, the training time is a key advantage to using a transfer learning approach as you already have a trained base model to build upon.

Secondly, you can get away with showing far fewer examples of the new thing you are trying to classify due to the training that has already taken place.

This is really great if you have limited time and resources to gather example data of the thing you want to classify, and need to make a prototype quickly before gathering more training data to make it more robust.

Given the need for less data and the speed of training a smaller network, transfer learning is less resource intensive. That makes it very suitable for the browser environment, taking just tens of seconds on a modern machine instead of hours, days, or weeks for full model training.

Alright! Now you know the essence of what Transfer Learning is, it's time to create your very own version of Teachable Machine. Let's get started!

5. Get set up to code

What you'll need

  • A modern web browser.
  • Basic knowledge of HTML, CSS, JavaScript, and Chrome DevTools (viewing the console output).

Let's get coding

Boilerplate templates to start from have been created for Glitch.com or Codepen.io. You can simply clone either template as your base state for this code lab, in just one click.

On Glitch, click the "remix this" button to fork it and make a new set of files you can edit.

Alternatively, on Codepen, click" fork" in the lower bottom right of the screen.

This very simple skeleton provides you with the following files:

  • HTML page (index.html)
  • Stylesheet (style.css)
  • File to write our JavaScript code (script.js)

For your convenience, there is an added import in the HTML file for the TensorFlow.js library. It looks like this:

index.html

<!-- Import TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js" type="text/javascript"></script>

Alternative: Use your preferred webeditor or work locally

If you want to download the code and work locally, or on a different online editor, simply create the 3 files named above in the same directory and copy and paste the code from our Glitch boilerplate into each of them.

6. App HTML boilerplate

Where do I start?

All prototypes require some basic HTML scaffolding you can render your findings on. Set that up now. You are going to add:

  • A title for the page.
  • Some descriptive text.
  • A status paragraph.
  • A video to hold the webcam feed once ready.
  • Several buttons to start the camera, collect data, or reset the experience.
  • Imports for TensorFlow.js and JS files you will code later.

Open index.html and paste over the existing code with the following to set up the above features:

index.html

<!DOCTYPE html>
<html lang="en">
  <head>
    <title>Transfer Learning - TensorFlow.js</title>
    <meta charset="utf-8">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <!-- Import the webpage's stylesheet -->
    <link rel="stylesheet" href="/style.css">
  </head>  
  <body>
    <h1>Make your own "Teachable Machine" using Transfer Learning with MobileNet v3 in TensorFlow.js using saved graph model from TFHub.</h1>
    
    <p id="status">Awaiting TF.js load</p>
    
    <video id="webcam" autoplay muted></video>
    
    <button id="enableCam">Enable Webcam</button>
    <button class="dataCollector" data-1hot="0" data-name="Class 1">Gather Class 1 Data</button>
    <button class="dataCollector" data-1hot="1" data-name="Class 2">Gather Class 2 Data</button>
    <button id="train">Train &amp; Predict!</button>
    <button id="reset">Reset</button>

    <!-- Import TensorFlow.js library -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.11.0/dist/tf.min.js" type="text/javascript"></script>

    <!-- Import the page's JavaScript to do some stuff -->
    <script type="module" src="/script.js"></script>
  </body>
</html>

Break it down

Let's break some of the above HTML code down to highlight some key things you added.

  • You added an <h1> tag for the page title along with a <p> tag with an ID of ‘status,' which is where you will print information to, as you use different parts of the system to view outputs.
  • You added a <video> element with an ID of ‘webcam,' to which you will render your webcam stream later.
  • You added 5 <button> elements. The first, with the ID of ‘enableCam,' enables the camera. The next two buttons have a class of ‘dataCollector,' which lets you gather example images for the objects you want to recognize. The code you write later will be designed so that you can add any number of these buttons and they will work as intended automatically.

Note that these buttons also have a special user-defined attribute called data-1hot, with an integer value starting from 0 for the first class. This is the numerical index you'll use to represent a certain class's data. The index will be used to encode the output classes correctly with a numerical representation instead of a string, as ML models can only work with numbers.

There is also a data-name attribute that contains the human readable name you want to use for this class, which lets you provide a more meaningful name to the user instead of a numerical index value from the 1 hot encoding.

Finally, you have a train and reset button to kick off the training process once data has been collected, or to reset the app respectively.

  • You also added 2 <script> imports. One for TensorFlow.js, and the other for script.js that you will define shortly.

7. Add style

Element defaults

Add styles for the HTML elements you just added to ensure they render correctly. Here are some styles that are added to position and size elements correctly. Nothing too special. You could certainly add to this later to make an even better UX, like you saw in the teachable machine video.

style.css

body {
  font-family: helvetica, arial, sans-serif;
  margin: 2em;
}

h1 {
  font-style: italic;
  color: #FF6F00;
}


video {
  clear: both;
  display: block;
  margin: 10px;
  background: #000000;
  width: 640px;
  height: 480px;
}

button {
  padding: 10px;
  float: left;
  margin: 5px 3px 5px 10px;
}

.removed {
  display: none;
}

#status {
  font-size:150%;
}

Great! That's all you need. If you preview the output right now, it should look something like this:

81909685d7566dcb.png

8. JavaScript: Key constants and listeners

Define key constants

First, add some key constants you'll use throughout the app. Start by replacing the contents of script.js with these constants:

script.js

const STATUS = document.getElementById('status');
const VIDEO = document.getElementById('webcam');
const ENABLE_CAM_BUTTON = document.getElementById('enableCam');
const RESET_BUTTON = document.getElementById('reset');
const TRAIN_BUTTON = document.getElementById('train');
const MOBILE_NET_INPUT_WIDTH = 224;
const MOBILE_NET_INPUT_HEIGHT = 224;
const STOP_DATA_GATHER = -1;
const CLASS_NAMES = [];

Let's break down what these are for:

  • STATUS simply holds a reference to the paragraph tag you will write status updates to.
  • VIDEO holds a reference to the HTML video element that will render the webcam feed.
  • ENABLE_CAM_BUTTON, RESET_BUTTON, and TRAIN_BUTTON grab DOM references to all the key buttons from the HTML page.
  • MOBILE_NET_INPUT_WIDTH and MOBILE_NET_INPUT_HEIGHT define the expected input width and height of the MobileNet model respectively. By storing this in a constant near the top of the file like this, if you decide to use a different version later, it makes it easier to update the values once instead of having to replace it in many different places.
  • STOP_DATA_GATHER is set to - 1. This stores a state value so you know when the user has stopped clicking a button to gather data from the webcam feed. By giving this number a more meaningful name, it makes the code more readable later.
  • CLASS_NAMES acts as a lookup and holds the human readable names for the possible class predictions. This array will be populated later.

OK, now that you have references to key elements, it is time to associate some event listeners to them.

Add key event listeners

Start by adding click event handlers to key buttons as shown:

script.js

ENABLE_CAM_BUTTON.addEventListener('click', enableCam);
TRAIN_BUTTON.addEventListener('click', trainAndPredict);
RESET_BUTTON.addEventListener('click', reset);


function enableCam() {
  // TODO: Fill this out later in the codelab!
}


function trainAndPredict() {
  // TODO: Fill this out later in the codelab!
}


function reset() {
  // TODO: Fill this out later in the codelab!
}

ENABLE_CAM_BUTTON - calls the enableCam function when clicked.

TRAIN_BUTTON - calls trainAndPredict when clicked.

RESET_BUTTON - calls reset when clicked.

Finally in this section you can find all buttons that have a class of ‘dataCollector' using document.querySelectorAll(). This returns an array of elements found from the document that match:

script.js

let dataCollectorButtons = document.querySelectorAll('button.dataCollector');
for (let i = 0; i < dataCollectorButtons.length; i++) {
  dataCollectorButtons[i].addEventListener('mousedown', gatherDataForClass);
  dataCollectorButtons[i].addEventListener('mouseup', gatherDataForClass);
  // Populate the human readable names for classes.
  CLASS_NAMES.push(dataCollectorButtons[i].getAttribute('data-name'));
}


function gatherDataForClass() {
  // TODO: Fill this out later in the codelab!
}

Code explanation:

You then iterate through the found buttons and associate 2 event listeners to each. One for ‘mousedown', and one for ‘mouseup'. This lets you keep recording samples as long as the button is pressed, which is useful for data collection.

Both events call a gatherDataForClass function that you will define later.

At this point, you can also push the found human readable class names from the HTML button attribute data-name to the CLASS_NAMES array.

Next, add some variables to store key things that will be used later.

script.js

let mobilenet = undefined;
let gatherDataState = STOP_DATA_GATHER;
let videoPlaying = false;
let trainingDataInputs = [];
let trainingDataOutputs = [];
let examplesCount = [];
let predict = false;

Let's walk through those.

First, you have a variable mobilenet to store the loaded mobilenet model. Initially set this to undefined.

Next, you have a variable called gatherDataState. If a ‘dataCollector' button is pressed, this changes to be the 1 hot ID of that button instead, as defined in the HTML, so you know what class of data you are collecting at that moment. Initially, this is set to STOP_DATA_GATHER so that the data gather loop you write later will not gather any data when no buttons are being pressed.

videoPlaying keeps track of whether the webcam stream is successfully loaded and playing and is available to use. Initially, this is set to false as the webcam is not on until you press the ENABLE_CAM_BUTTON.

Next, define 2 arrays, trainingDataInputs and trainingDataOutputs. These store the gathered training data values, as you click the ‘dataCollector' buttons for the input features generated by MobileNet base model, and the output class sampled respectively.

One final array, examplesCount, is then defined to keep track of how many examples are contained for each class once you start adding them.

Finally, you have a variable called predict that controls your prediction loop. This is set to false initially. No predictions can take place until this is set to true later.

Now that all the key variables have been defined, let's go and load the pre-chopped up MobileNet v3 base model that provides image feature vectors instead of classifications.

9. Load the MobileNet base model

First, define a new function called loadMobileNetFeatureModel as shown below. This must be an async function as the act of loading a model is asynchronous:

script.js

/**
 * Loads the MobileNet model and warms it up so ready for use.
 **/
async function loadMobileNetFeatureModel() {
  const URL = 
    'https://tfhub.dev/google/tfjs-model/imagenet/mobilenet_v3_small_100_224/feature_vector/5/default/1';
  
  mobilenet = await tf.loadGraphModel(URL, {fromTFHub: true});
  STATUS.innerText = 'MobileNet v3 loaded successfully!';
  
  // Warm up the model by passing zeros through it once.
  tf.tidy(function () {
    let answer = mobilenet.predict(tf.zeros([1, MOBILE_NET_INPUT_HEIGHT, MOBILE_NET_INPUT_WIDTH, 3]));
    console.log(answer.shape);
  });
}

// Call the function immediately to start loading.
loadMobileNetFeatureModel();

In this code you define the URL where the model to load is located from the TFHub documentation.

You can then load the model using await tf.loadGraphModel(), remembering to set the special property fromTFHub to true as you are loading a model from this Google website. This is a special case only for using models hosted on TF Hub where this extra property has to be set.

Once loading is complete you can set the STATUS element's innerText with a message so you can visually see it has loaded correctly and you are ready to start gathering data.

The only thing left to do now is to warm up the model. With larger models like this, the first time you use the model, it can take a moment to set everything up. Therefore it helps to pass zeros through the model to avoid any waiting in the future where timing may be more critical.

You can use tf.zeros() wrapped in a tf.tidy() to ensure tensors are disposed of correctly, with a batch size of 1, and the correct height and width that you defined in your constants at the start. Finally, you also specify the color channels, which in this case is 3 as the model expects RGB images.

Next, log the resulting shape of the tensor returned using answer.shape() to help you understand the size of the image features this model produces.

After defining this function, you can then call it immediately to initiate the model download on the page load.

If you view your live preview right now, after a few moments you will see the status text change from "Awaiting TF.js load" to become "MobileNet v3 loaded successfully!" as shown below. Ensure this works before continuing.

a28b734e190afff.png

You can also check the console output to see the printed size of the output features that this model produces. After running zeros through the MobileNet model, you'll see a shape of [1, 1024] printed. The first item is just the batch size of 1, and you can see it actually returns 1024 features that can then be used to help you classify new objects.

10. Define the new model head

Now it's time to define your model head, which is essentially a very minimal multi-layer perceptron.

script.js

let model = tf.sequential();
model.add(tf.layers.dense({inputShape: [1024], units: 128, activation: 'relu'}));
model.add(tf.layers.dense({units: CLASS_NAMES.length, activation: 'softmax'}));

model.summary();

// Compile the model with the defined optimizer and specify a loss function to use.
model.compile({
  // Adam changes the learning rate over time which is useful.
  optimizer: 'adam',
  // Use the correct loss function. If 2 classes of data, must use binaryCrossentropy.
  // Else categoricalCrossentropy is used if more than 2 classes.
  loss: (CLASS_NAMES.length === 2) ? 'binaryCrossentropy': 'categoricalCrossentropy', 
  // As this is a classification problem you can record accuracy in the logs too!
  metrics: ['accuracy']  
});

Let's walk through this code. You start by defining a tf.sequential model to which you will add model layers to.

Next, add a dense layer as the input layer to this model. This has an input shape of 1024 as the outputs from the MobileNet v3 features are of this size. You discovered this in the previous step after passing ones through the model. This layer has 128 neurons that use the ReLU activation function.

If you are new to activation functions and model layers, consider taking the course detailed at the start of this workshop to understand what these properties do behind the scenes.

The next layer to add is the output layer. The number of neurons should equal the number of classes you are trying to predict. To do that you can use CLASS_NAMES.length to find how many classes you are planning to classify, which is equal to the number of data gather buttons found in the user interface. As this is a classification problem, you use the softmax activation on this output layer, which must be used when trying to create a model to solve classification problems instead of regression.

Now print a model.summary() to print the overview of the newly defined model to the console.

Finally, compile the model so it is ready to be trained. Here the optimizer is set to adam, and the loss will either be binaryCrossentropy if CLASS_NAMES.length is equal to 2, or it will use categoricalCrossentropy if there are 3 or more classes to classify. Accuracy metrics are also requested so they can be monitored in the logs later for debugging purposes.

In the console you should see something like this:

22eaf32286fea4bb.png

Note that this has over 130 thousand trainable parameters. But as this is a simple dense layer of regular neurons it will train pretty fast.

As an activity to do once the project is complete, you could try changing the number of neurons in the first layer to see how low you can make it while still getting decent performance. Often with machine learning there is some level of trial and error to find optimal parameter values to give you the best trade off between resource usage and speed.

11. Enable the webcam

It's now time to flesh out the enableCam() function you defined earlier. Add a new function named hasGetUserMedia() as shown below and then replace the contents of the previously defined enableCam() function with the corresponding code below.

script.js

function hasGetUserMedia() {
  return !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
}

function enableCam() {
  if (hasGetUserMedia()) {
    // getUsermedia parameters.
    const constraints = {
      video: true,
      width: 640, 
      height: 480 
    };

    // Activate the webcam stream.
    navigator.mediaDevices.getUserMedia(constraints).then(function(stream) {
      VIDEO.srcObject = stream;
      VIDEO.addEventListener('loadeddata', function() {
        videoPlaying = true;
        ENABLE_CAM_BUTTON.classList.add('removed');
      });
    });
  } else {
    console.warn('getUserMedia() is not supported by your browser');
  }
}

First, create a function named hasGetUserMedia() to check if the browser supports getUserMedia() by checking for the existence of key browser APIs properties.

In the enableCam() function use the hasGetUserMedia() function you just defined above to check if it is supported. If it isn't, print a warning to the console.

If it supports it, define some constraints for your getUserMedia() call, such as you want the video stream only, and that you prefer the width of the video to be 640 pixels in size, and the height to be 480 pixels. Why? Well, there is not much point getting a video larger than this as it would need to be resized to 224 by 224 pixels to be fed into the MobileNet model. You may as well save some computing resources by requesting a smaller resolution. Most cameras support a resolution of this size.

Next, call navigator.mediaDevices.getUserMedia() with the constraints detailed above, and then wait for the stream to be returned. Once the stream is returned you can get your VIDEO element to play the stream by setting it as its srcObject value.

You should also add an eventListener on the VIDEO element to know when the stream has loaded and is playing successfully.

Once the steam loads, you can set videoPlaying to true and remove the ENABLE_CAM_BUTTON to prevent it from being clicked again by setting its class to "removed".

Now run your code, click the enable camera button, and allow access to the webcam. If it is your first time doing this, you should see yourself rendered to the video element on the page as shown:

b378eb1affa9b883.png

Ok, now it is time to add a function to deal with the dataCollector button clicks.

12. Data collection button event handler

Now it's time to fill out your currently empty function called gatherDataForClass(). This is what you assigned as your event handler function for dataCollector buttons at the start of the codelab.

script.js

/**
 * Handle Data Gather for button mouseup/mousedown.
 **/
function gatherDataForClass() {
  let classNumber = parseInt(this.getAttribute('data-1hot'));
  gatherDataState = (gatherDataState === STOP_DATA_GATHER) ? classNumber : STOP_DATA_GATHER;
  dataGatherLoop();
}

First, check the data-1hot attribute on the currently clicked button by calling this.getAttribute() with the attribute's name, in this case data-1hot as the parameter. As this is a string, you can then use parseInt() to cast it to an integer and assign this result to a variable named classNumber.

Next, set the gatherDataState variable accordingly. If the current gatherDataState is equal to STOP_DATA_GATHER (which you set to be -1), then that means you are not currently gathering any data and it was a mousedown event that fired. Set the gatherDataState to be the classNumber you just found.

Otherwise, it means that you are currently gathering data and the event that fired was a mouseup event, and you now want to stop gathering data for that class. Just set it back to the STOP_DATA_GATHER state to end the data gathering loop you will define shortly.

Finally, kick off the call to dataGatherLoop(), which actually performs the recording of class data.

13. Data collection

Now, define the dataGatherLoop() function. This function is responsible for sampling images from the webcam video, passing them through the MobileNet model, and capturing the outputs of that model (the 1024 feature vectors).

It then stores them along with the gatherDataState ID of the button that is currently being pressed so you know what class this data represents.

Let's walk through it:

script.js

function dataGatherLoop() {
  if (videoPlaying && gatherDataState !== STOP_DATA_GATHER) {
    let imageFeatures = tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [MOBILE_NET_INPUT_HEIGHT, 
          MOBILE_NET_INPUT_WIDTH], true);
      let normalizedTensorFrame = resizedTensorFrame.div(255);
      return mobilenet.predict(normalizedTensorFrame.expandDims()).squeeze();
    });

    trainingDataInputs.push(imageFeatures);
    trainingDataOutputs.push(gatherDataState);
    
    // Intialize array index element if currently undefined.
    if (examplesCount[gatherDataState] === undefined) {
      examplesCount[gatherDataState] = 0;
    }
    examplesCount[gatherDataState]++;

    STATUS.innerText = '';
    for (let n = 0; n < CLASS_NAMES.length; n++) {
      STATUS.innerText += CLASS_NAMES[n] + ' data count: ' + examplesCount[n] + '. ';
    }
    window.requestAnimationFrame(dataGatherLoop);
  }
}

You are only going to continue this function's execution if videoPlaying is true, meaning that the webcam is active, and gatherDataState is not equal to STOP_DATA_GATHER and a button for class data gathering is currently being pressed.

Next, wrap your code in a tf.tidy() to dispose of any created tensors in the code that follows. The result of this tf.tidy() code execution is stored in a variable called imageFeatures.

You can now grab a frame of the webcam VIDEO using tf.browser.fromPixels(). The resulting tensor containing the image data is stored in a variable called videoFrameAsTensor.

Next, resize the videoFrameAsTensor variable to be of the correct shape for the MobileNet model's input. Use a tf.image.resizeBilinear() call with the tensor you want to reshape as the first parameter, and then a shape that defines the new height and width as defined by the constants you already created earlier. Finally, set align corners to true by passing the third parameter to avoid any alignment issues when resizing. The result of this resize is stored in a variable called resizedTensorFrame.

Note that this primitive resize stretches the image, as your webcam image is 640 by 480 pixels in size, and the model needs a square image of 224 by 224 pixels.

For the purposes of this demo this should work fine. However, once you complete this codelab, you may want to try and crop a square from this image instead for even better results for any production system you may create later.

Next, normalize the image data. Image data is always in the range of 0 to 255 when using tf.browser.frompixels(), so you can simply divide resizedTensorFrame by 255 to ensure all values are between 0 and 1 instead, which is what the MobileNet model expects as inputs.

Finally, in the tf.tidy() section of the code, push this normalized tensor through the loaded model by calling mobilenet.predict(), to which you pass the expanded version of the normalizedTensorFrame using expandDims() so that it is a batch of 1, as the model expects a batch of inputs for processing.

Once the result comes back, you can then immediately call squeeze() on that returned result to squash it back down to a 1D tensor, which you then return and assign to the imageFeatures variable that captures the result from tf.tidy().

Now that you have the imageFeatures from the MobileNet model, you can record those by pushing them onto the trainingDataInputs array that you defined previously.

You can also record what this input represents by pushing the current gatherDataState to the trainingDataOutputs array too.

Note that the gatherDataState variable would have been set to the current class's numerical ID you are recording data for when the button was clicked in the previously defined gatherDataForClass() function.

At this point you can also increment the number of examples you have for a given class. To do this, first check if the index within the examplesCount array has been initialized before or not. If it is undefined, set it to 0 to initialize the counter for a given class's numerical ID, and then you can increment the examplesCount for the current gatherDataState.

Now update the STATUS element's text on the web page to show the current counts for each class as they're captured. To do this, loop through the CLASS_NAMES array, and print the human readable name combined with the data count at the same index in examplesCount.

Finally, call window.requestAnimationFrame() with dataGatherLoop passed as a parameter, to recursively call this function again. This will continue to sample frames from the video until the button's mouseup is detected, and gatherDataState is set to STOP_DATA_GATHER, at which point the data gather loop will end.

If you run your code now, you should be able to click the enable camera button, await the webcam to load, and then click and hold each of the data gather buttons to gather examples for each class of data. Here you see me gather data for my mobile phone and my hand respectively.

541051644a45131f.gif

You should see the status text updated as it stores all the tensors in memory as shown in the screen capture above.

14. Train and predict

The next step is to implement code for your currently empty trainAndPredict() function, which is where the transfer learning takes place. Let's take a look at the code:

script.js

async function trainAndPredict() {
  predict = false;
  tf.util.shuffleCombo(trainingDataInputs, trainingDataOutputs);
  let outputsAsTensor = tf.tensor1d(trainingDataOutputs, 'int32');
  let oneHotOutputs = tf.oneHot(outputsAsTensor, CLASS_NAMES.length);
  let inputsAsTensor = tf.stack(trainingDataInputs);
  
  let results = await model.fit(inputsAsTensor, oneHotOutputs, {shuffle: true, batchSize: 5, epochs: 10, 
      callbacks: {onEpochEnd: logProgress} });
  
  outputsAsTensor.dispose();
  oneHotOutputs.dispose();
  inputsAsTensor.dispose();
  predict = true;
  predictLoop();
}

function logProgress(epoch, logs) {
  console.log('Data for epoch ' + epoch, logs);
}

First, ensure you stop any current predictions from taking place by setting predict to false.

Next, shuffle your input and output arrays using tf.util.shuffleCombo() to ensure the order does not cause issues in training.

Convert your output array, trainingDataOutputs, to be a tensor1d of type int32 so it is ready to be used in a one hot encoding. This is stored in a variable named outputsAsTensor.

Use the tf.oneHot() function with this outputsAsTensor variable along with the max number of classes to encode, which is just the CLASS_NAMES.length. Your one hot encoded outputs are now stored in a new tensor called oneHotOutputs.

Note that currently trainingDataInputs is an array of recorded tensors. In order to use these for training you will need to convert the array of tensors to become a regular 2D tensor.

To do that there is a great function within the TensorFlow.js library called tf.stack(),

which takes an array of tensors and stacks them together to produce a higher dimensional tensor as an output. In this case a tensor 2D is returned, that's a batch of 1 dimensional inputs that are each 1024 in length containing the features recorded, which is what you need for training.

Next, await model.fit() to train the custom model head. Here you pass your inputsAsTensor variable along with the oneHotOutputs to represent the training data to use for example inputs and target outputs respectively. In the configuration object for the 3rd parameter, set shuffle to true, use batchSize of 5, with epochs set to 10, and then specify a callback for onEpochEnd to the logProgress function that you will define shortly.

Finally, you can dispose of the created tensors as the model is now trained. You can then set predict back to true to allow predictions to take place again, and then call the predictLoop() function to start predicting live webcam images.

You can also define the logProcess() function to log the state of training, which is used in model.fit() above and that prints results to console after each round of training.

You're almost there! Time to add the predictLoop() function to make predictions.

Core prediction loop

Here you implement the main prediction loop that samples frames from a webcam and continuously predicts what is in each frame with real time results in the browser.

Let's check the code:

script.js

function predictLoop() {
  if (predict) {
    tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO).div(255);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor,[MOBILE_NET_INPUT_HEIGHT, 
          MOBILE_NET_INPUT_WIDTH], true);

      let imageFeatures = mobilenet.predict(resizedTensorFrame.expandDims());
      let prediction = model.predict(imageFeatures).squeeze();
      let highestIndex = prediction.argMax().arraySync();
      let predictionArray = prediction.arraySync();

      STATUS.innerText = 'Prediction: ' + CLASS_NAMES[highestIndex] + ' with ' + Math.floor(predictionArray[highestIndex] * 100) + '% confidence';
    });

    window.requestAnimationFrame(predictLoop);
  }
}

First, check that predict is true, so that predictions are only made after a model is trained and is available to use.

Next, you can get the image features for the current image just like you did in the dataGatherLoop() function. Essentially, you grab a frame from the webcam using tf.browser.from pixels(), normalise it, resize it to be 224 by 224 pixels in size, and then pass that data through the MobileNet model to get the resulting image features.

Now, however, you can use your newly trained model head to actually perform a prediction by passing the resulting imageFeatures just found through the trained model's predict()function. You can then squeeze the resulting tensor to make it 1 dimensional again and assign it to a variable called prediction.

With this prediction you can find the index that has the highest value using argMax(), and then convert this resulting tensor to an array using arraySync() to get at the underlying data in JavaScript to discover the position of the highest valued element. This value is stored in the variable called highestIndex.

You can also get the actual prediction confidence scores in the same way by calling arraySync() on the prediction tensor directly.

You now have everything you need to update the STATUS text with the prediction data. To get the human readable string for the class you can just look up the highestIndex in the CLASS_NAMES array, and then grab the confidence value from the predictionArray. To make it more readable as a percentage, just multiply by 100 and math.floor() the result.

Finally, you can use window.requestAnimationFrame() to call predictionLoop() all over again once ready, to get real time classification on your video stream. This continues until predict is set to false if you choose to train a new model with new data.

Which brings you to the final piece of the puzzle. Implementing the reset button.

15. Implement the reset button

Almost complete! The final piece of the puzzle is to implement a reset button to start over. The code for your currently empty reset() function is below. Go ahead and update it as follows:

script.js

/**
 * Purge data and start over. Note this does not dispose of the loaded 
 * MobileNet model and MLP head tensors as you will need to reuse 
 * them to train a new model.
 **/
function reset() {
  predict = false;
  examplesCount.length = 0;
  for (let i = 0; i < trainingDataInputs.length; i++) {
    trainingDataInputs[i].dispose();
  }
  trainingDataInputs.length = 0;
  trainingDataOutputs.length = 0;
  STATUS.innerText = 'No data collected';
  
  console.log('Tensors in memory: ' + tf.memory().numTensors);
}

First, stop any running prediction loops by setting predict to false. Next, delete all contents in the examplesCount array by setting its length to 0, which is a handy way to clear all contents from an array.

Now go through all the current recorded trainingDataInputs and ensure you dispose() of each tensor contained within it to free up memory again, as Tensors are not cleaned up by the JavaScript garbage collector.

Once that is done you can now safely set the array length to 0 on both the trainingDataInputs and trainingDataOutputs arrays to clear those too.

Finally set the STATUS text to something sensible, and print out the tensors left in memory as a sanity check.

Note that there will be a few hundred tensors still in memory as both the MobileNet model and the multi-layer perceptron you defined are not disposed of. You will need to reuse them with new training data if you decide to train again after this reset.

16. Let's try it out

It's time to test out your very own version of Teachable Machine!

Head to the live preview, enable the webcam, gather at least 30 samples for class 1 for some object in your room, and then do the same for class 2 for a different object, click train, and check the console log to see progress. It should train pretty fast:

bf1ac3cc5b15740.gif

Once trained, show the objects to the camera to get live predictions that will be printed to the status text area on the web page near the top. If you are having trouble, check my completed working code to see if you missed copying over anything.

17. Congratulations

Congratulations! You have just completed your very first transfer learning example using TensorFlow.js live in the browser.

Try it out, test it on a variety of objects, you may notice some things are harder to recognize than others, especially if they are similar to something else. You may need to add more classes or training data to be able to tell them apart.

Recap

In this codelab you learned:

  1. What transfer learning is, and its advantages over training a full model.
  2. How to get models for re-use from TensorFlow Hub.
  3. How to set up a web app suitable for transfer learning.
  4. How to load and use a base model to generate image features.
  5. How to train a new prediction head that can recognize custom objects from webcam imagery.
  6. How to use the resulting models to classify data in real time.

What's next?

Now that you have a working base to start from, what creative ideas can you come up with to extend this machine learning model boilerplate for a real world use case you may be working on? Maybe you could revolutionize the industry that you currently work in to help folk at your company train models to classify things that are important in their day-to-day work? The possibilities are endless.

To go further, consider taking this full course for free, which shows you how to combine the 2 models you currently have in this codelab into 1 single model for efficiency.

Also if you are curious more around the theory behind the original teachable machine application check out this tutorial.

Share what you make with us

You can easily extend what you made today for other creative use cases too and we encourage you to think outside the box and keep hacking.

Remember to tag us on social media using the #MadeWithTFJS hashtag for a chance for your project to be featured on our TensorFlow blog or even future events. We would love to see what you make.

Websites to check out