TensorFlow.js 迁移学习图片分类器

在此 Codelab 中,您将学习如何构建一个简单的“会学习的机器”,即一个可使用 TensorFlow.js(一款适用于 JavaScript 的功能强大且灵活的机器学习库)在浏览器中进行实时训练的分类器。首先,您将加载并运行一个名为 MobileNet 的常用预训练模型,以用于在浏览器中进行图片分类。然后,您将使用一项名为“迁移学习”的技术,该技术使用预训练的 MobileNet 模型对我们的训练进行 Bootstrap 处理,并自定义该模型以对您的应用进行训练。

此 Codelab 涵盖超出会学习的机器应用以外的理论。如果您想了解相关信息,请查看此教程

要学习的内容

  • 如何加载预训练的 MobileNet 模型并利用新数据进行预测
  • 如何通过网络摄像头进行预测
  • 如何使用 MobileNet 的中间激活功能,使用网络摄像头即时为您定义的一组新类别执行迁移学习

下面我们开始步入正题!

要完成本 Codelab,您需要:

  1. 最新版本的 Chrome 或其他现代浏览器。
  2. 文本编辑器(可以是在机器中本地运行的,或者通过 CodepenGlitch 等在网络上运行)。
  3. 了解 HTML、CSS、JavaScript 和 Chrome 开发者工具(或您的首选浏览器开发者工具)。
  4. 大致了解神经网络的概念。如果您需要了解简介或回顾相关内容,请观看这部由 3blue1brown 制作的视频,或这部由 Ashi Krishnan 制作的有关基于 JavaScript 的深度学习的视频

在编辑器中打开 index.html,然后添加以下内容:

<html>
  <head>
    <!-- Load the latest version of TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script>
  </head>
  <body>
    <div id="console"></div>
    <!-- Add an image that we will use to test -->
    <img id="img" crossorigin src="https://i.imgur.com/JlUvsxa.jpg" width="227" height="227"/>
    <!-- Load index.js after the content of the page -->
    <script src="index.js"></script>
  </body>
</html>

接下来,在代码编辑器中打开/创建文件 index.js,并添加以下代码:

let net;

async function app() {
  console.log('Loading mobilenet..');

  // Load the model.
  net = await mobilenet.load();
  console.log('Successfully loaded model');

  // Make a prediction through the model on our image.
  const imgEl = document.getElementById('img');
  const result = await net.classify(imgEl);
  console.log(result);
}

app();

要运行网页,只需在网络浏览器中打开 index.html 即可。如果您使用的是 Cloud Console,只需刷新预览页面。

您应该会在开发者工具的 JavaScript 控制台中看到一张狗狗的图片,这是 MobileNet 预测的最有可能的内容。请注意,下载模型可能需要一些时间,请耐心等待。

图片的分类是否正确?

另外值得注意的是,此模型在手机上也可以使用。

现在,我们来将这一过程变得互动性和实时性更强一点。设置摄像头,对通过网络摄像头拍摄的图片进行预测。

首先,设置网络摄像头视频元素。打开 index.html 文件,将以下代码行添加到 <body> 部分内,然后删除我们用于加载狗狗图片的 <img> 标记:

<video autoplay playsinline muted id="webcam" width="224" height="224"></video>

打开 index.js 文件,然后将 camwareElement 添加到文件最顶部

const webcamElement = document.getElementById('webcam');

现在,在之前添加的 app() 函数中,您可以移除通过图片获得的预测结果,改为创建一个通过网络摄像头元素进行预测的无限循环。

async function app() {
  console.log('Loading mobilenet..');

  // Load the model.
  net = await mobilenet.load();
  console.log('Successfully loaded model');

  // Create an object from Tensorflow.js data API which could capture image
  // from the web camera as Tensor.
  const webcam = await tf.data.webcam(webcamElement);
  while (true) {
    const img = await webcam.capture();
    const result = await net.classify(img);

    document.getElementById('console').innerText = `
      prediction: ${result[0].className}\n
      probability: ${result[0].probability}
    `;
    // Dispose the tensor to release the memory.
    img.dispose();

    // Give some breathing room by waiting for the next animation frame to
    // fire.
    await tf.nextFrame();
  }
}

如果您使用网页打开了控制台,现在应该会看到关于摄像头采集到的每个帧是什么的概率的 MobileNet 预测。

这些预测可能是无意义的,因为 ImageNet 数据集中的内容与网络摄像头通常会捕捉的图片并不相似。要进行测试,一种方法是在手机上显示一张狗狗图片,置于笔记本电脑摄像头前方。

现在,让我们这个模型变得更有用。我们来制作一个即时使用网络摄像头的自定义 3 个类别的对象分类器。我们将通过 MobileNet 进行分类,但这次我们将对特定摄像头图片的模型进行内部表示(激活),并用其进行分类。

我们将使用一个名为“K-最近邻 (KNN) 分类器”的模块,它能有效地将网络摄像头图片(实际上是他们的 MobileNet 激活)归到不同的类别,当用户要求进行预测时,只需选择拥有要与为其进行预测的图片最相似的激活的类别。

index.html 的 <head> 标记导入项末尾处添加 KNN 分类器的导入项(您仍需需要 MobileNet,因此请勿移除该导入项):

...
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script>
...

为视频元素下的 index.html 中的每个按钮添加 3 个按钮。这些按钮将用于向模型中添加训练图片。

...
<button id="class-a">Add A</button>
<button id="class-b">Add B</button>
<button id="class-c">Add C</button>
...

index.js 的顶部创建分类器:

const classifier = knnClassifier.create();

更新应用函数:

async function app() {
  console.log('Loading mobilenet..');

  // Load the model.
  net = await mobilenet.load();
  console.log('Successfully loaded model');

  // Create an object from Tensorflow.js data API which could capture image
  // from the web camera as Tensor.
  const webcam = await tf.data.webcam(webcamElement);

  // Reads an image from the webcam and associates it with a specific class
  // index.
  const addExample = async classId => {
    // Capture an image from the web camera.
    const img = await webcam.capture();

    // Get the intermediate activation of MobileNet 'conv_preds' and pass that
    // to the KNN classifier.
    const activation = net.infer(img, true);

    // Pass the intermediate activation to the classifier.
    classifier.addExample(activation, classId);

    // Dispose the tensor to release the memory.
    img.dispose();
  };

  // When clicking a button, add an example for that class.
  document.getElementById('class-a').addEventListener('click', () => addExample(0));
  document.getElementById('class-b').addEventListener('click', () => addExample(1));
  document.getElementById('class-c').addEventListener('click', () => addExample(2));

  while (true) {
    if (classifier.getNumClasses() > 0) {
      const img = await webcam.capture();

      // Get the activation from mobilenet from the webcam.
      const activation = net.infer(img, 'conv_preds');
      // Get the most likely class and confidence from the classifier module.
      const result = await classifier.predictClass(activation);

      const classes = ['A', 'B', 'C'];
      document.getElementById('console').innerText = `
        prediction: ${classes[result.label]}\n
        probability: ${result.confidences[result.label]}
      `;

      // Dispose the tensor to release the memory.
      img.dispose();
    }

    await tf.nextFrame();
  }
}

现在,当您加载 index.html 页面时,可以使用常见物体或面部/身体手势为这三个类别捕获图片。您每次点击某一个“添加”按钮时,都会将一张图片作为一个示例添加到该类别中。执行此操作时,模型会针对传入的摄像头图片持续进行预测,并实时显示结果。

现在,试试添加另一个不代表任何操作的类别。

在此 Codelab 中,您使用 TensorFlow.js 实现了一个简单的机器学习 Web 应用。您加载并使用了一个预训练的 MobileNet 模型,用于对来自网络摄像头的图片进行分类。然后您对此模型进行了自定义,以将图片分类到三个自定义类别。

请一定记得访问 js.tensorflow.org,以查看更多附有代码的示例和演示,了解您可以如何在自己的应用中使用 TensorFlow.js。