TensorFlow.js:使用 TensorFlow.js 通过迁移学习制作您自己的“会学习的机器”

1. 准备工作

TensorFlow.js 模型使用量在过去几年内呈指数级增长,许多 JavaScript 开发者现在都希望采用现有的先进模型,重新训练这些模型来处理行业。取现有模型(通常称为基本模型)并将其用于类似但不同的领域称为迁移学习。

与从完全空白的模型开始相比,迁移学习有许多优势。您可以重复使用从之前训练的模型中获得的知识,并且需要减少要分类的新项的示例。此外,由于只需重新训练模型架构的最后几层而不是整个网络,因此训练速度通常要快得多。因此,迁移学习非常适合网络浏览器环境,在环境中,资源可能会因执行设备而异,并且可以直接访问传感器以轻松获取数据。

此 Codelab 将向您介绍如何在空白画布上构建 Web 应用,并重新创建 Google 的热门“会学习的机器”网站。通过该网站,您可以创建可正常运行的 Web 应用,任何用户只需使用来自摄像头的几个示例图片即可识别某个自定义对象。我们对网站进行了有针对性的精简,以便您可以专注于此 Codelab 的机器学习方面。不过,与最初的 Teachable Machine 网站一样,您可以通过很多方式运用现有的 Web 开发者体验来提升用户体验。

前提条件

此 Codelab 的编写者是对 TensorFlow.js 预制模型和基本 API 用法比较熟悉并且希望开始使用 TensorFlow.js 的迁移学习的 Web 开发者。

  • 本实验假定您已基本熟悉 TensorFlow.js、HTML5、CSS 和 JavaScript。

如果您刚开始接触 Tensorflow.js,请先考虑学习这款免费的主工具课程,该课程假定您未掌握机器学习或 TensorFlow.js 的背景信息,并通过较小的步骤向您介绍需要了解的所有内容。

学习内容

  • 什么是 TensorFlow.js,以及为什么您可以在下一个 Web 应用中使用 TensorFlow.js。
  • 如何构建复制了 Teachable Machine 用户体验的简化 HTML/CSS /JS 网页。
  • 如何使用 TensorFlow.js 加载预训练的基本模型(特别是 MobileNet),以生成可用于迁移学习的图像特征。
  • 如何从用户的摄像头为要识别的多类数据收集数据。
  • 如何创建和定义多层感知,利用图像特征并学习使用新特征对其进行分类。

让我们开始吧...

所需条件

  • 您最好关注 Glitch.com 帐号,您也可以使用自己能轻松编辑和运行的网络服务环境。

2. 什么是 TensorFlow.js?

54e81d02971f53e8.png

TensorFlow.js 是一个开源机器学习库,可运行在 JavaScript 可以运行的任何位置。它基于使用 Python 编写的原始 TensorFlow 库,旨在为 JavaScript 生态系统打造这种开发者体验和一系列 API。

可以在哪里使用?

鉴于 JavaScript 的可移植性,您现在可以使用 1 种语言编写,并在以下所有平台上轻松执行机器学习:

  • 使用原版 JavaScript 的网络浏览器中的客户端
  • 使用 Node.js 进行服务器端,甚至是 Raspberry Pi 等 IoT 设备
  • 使用 Electron 的桌面应用
  • 使用 React Native 的原生移动应用

TensorFlow.js 还支持在每个环境中使用多个后端(例如,可以在 CPU 或 WebGL 中执行的实际基于硬件的环境)。此处的“后端”并不意味着服务器端环境(例如,执行后端可以是 WebGL 中的客户端),以确保兼容性并保持运行速度。目前,TensorFlow.js 支持:

  • 在设备的显卡上执行 WebGL - 这是使用 GPU 加速执行大型模型(大小超过 3MB)的最快方法。
  • Web Assembly (WASM) 在 CPU 上执行 - 旨在改进各种设备的 CPU 性能,例如旧款手机。这更适合较小模型(小于 3MB),在具有 WASM 的 CPU 上,由于将内容上传到图形处理器的开销,其实际运行速度比使用 WebGL 更快。
  • CPU 执行 - 在其他环境均不可用的情况下进行回退。这是三者中最慢的一个,但将始终存在。

注意:如果您知道要在哪个设备上执行,可以选择强制执行这些后端之一,也可以直接让 TensorFlow.js 来决定未指定。

客户端超能力

在客户端计算机上的网络浏览器中运行 TensorFlow.js 有几个值得考虑的优势。

隐私权

您可以在客户端计算机上训练和分类数据,而无需将数据发送到第三方网络服务器。有时,根据这些规定,您可能需要遵守 GDPR 等当地法律,或者在处理用户可能想要保留但不希望将其发送给第三方的任何数据时。

速度

由于您不必向远程服务器发送数据,因此推断速度(数据分类操作)速度更快。更棒的是,如果用户授予访问权限,您可以直接访问设备的传感器,例如相机、麦克风、GPS、加速度计等。

扩大覆盖面并扩大覆盖面

只需点击一下,世界上的任何用户都可以点击您向其发送的链接,在其浏览器中打开网页,并利用您创建的内容。您无需使用 CUDA 驱动程序完成复杂的服务器端 Linux 设置,只需要使用机器学习系统即可。

费用

没有服务器意味着您只需支付 CDN 即可托管 HTML、CSS、JS 和模型文件。与保持服务器全天候运行(可能安装有显卡)相比,CDN 的费用要低得多。

服务器端功能

利用 TensorFlow.js 的 Node.js 实现可以启用以下功能。

全面的 CUDA 支持

对于服务器端的图形卡加速,您必须安装 NVIDIA CUDA 驱动程序,才能让 TensorFlow 与显卡配合使用(这与使用 WebGL 的浏览器不同 - 无需安装)。不过,借助全面的 CUDA 支持,您可以充分利用显卡的较低级别功能,从而加快训练和推断速度。性能与 Python TensorFlow 实现等效,因为它们都共用同一个 C++ 后端。

模型大小

对于研究中的先进模型,您使用的可能是超大模型,可能是 GB。由于每个浏览器标签页的内存用量限制,这些模型目前无法在网络浏览器中运行。如需运行这些更大的模型,您可以在自己的服务器上使用 Node.js,并高效地运行此类模型所需的硬件规格。

广告订单

Raspberry Pi 等常用的单板计算机也支持 Node.js,这也意味着您也可以在此类设备上执行 TensorFlow.js 模型。

速度

Node.js 是使用 JavaScript 编写的,这意味着,它仅在即时编译时就能受益。这意味着,使用 Node.js 时,其性能通常会提高,因为运行时会进行优化,尤其是在执行任何预处理时。一个很好的例子是这篇案例研究,其中展示了 Hugging Face 如何利用 Node.js 将自然语言处理模型的性能提升了 2 倍。

现在,您已经了解了 TensorFlow.js 的基础知识、运行位置以及它的一些优势,下面让我们开始对它进行实用的操作吧!

3.迁移学习

到底什么是迁移学习?

迁移学习包括获取已学过的知识,以帮助学习不同但相似的事物。

我们一直都在这样做。您的大脑中有一生经验,可用于帮助识别之前未见过的新内容。以柳树为例:

e28070392cd4afb9.png

根据您之前身处的世界,您可能以前从未见过这种树。

不过,如果你们想告诉我新图片中是否有柳树,你就可以快速确定它们的角度,即使它们的角度不同,并且与我向你显示的原始树略有不同。

d9073a0d5df27222.png

您的大脑中已经有了多个能够识别树状物体的神经元,以及那些擅长寻找长直线的神经元。您可以重复使用这些知识来快速对柳树进行分类,这种树是一种具有许多长直直分支的树状对象。

同样,如果您有已经在某个网域上训练过的机器学习模型(例如识别图片),则可以重复使用该模型以执行其他但相关的任务。

您还可以使用 MobileNet 之类的高级模型来达到同样的目的。MobileNet 是一种非常流行的研究模型,可以对 1000 种不同的对象类型执行图片识别。从狗狗到汽车,它在名为 ImageNet 的大型数据集上训练,该数据集包含数百万张已加标签的图片。

在此动画中,您可以看到此 MobileNet V1 模型中的大量层:

7d4e1e35c1a89715.gif

在训练期间,此模型学习了如何提取对所有这 1000 个对象都很重要的常用特征,并且它用于识别此类对象的许多较低级别的特征也可用于检测之前从未见过的新对象。毕竟,一切都只是线条、纹理和形状的组合。

我们来看看传统的卷积神经网络 (CNN) 架构(类似于 MobileNet)以及迁移学习如何利用此训练的网络学习新知识。下图显示了 CNN 的典型模型架构,在本例中,CNN 已经过训练可识别 0 到 9 的手写数字:

baf4e3d434576106.png

如果您可以将如下所示的现有训练模型的预训练低层层与右边所示模型末端附近的分类层(有时称为模型的分类头部)分开,您可以使用较低级别的层,根据训练时所用的原始数据为任何给定图像生成输出特征。以下是移除了分类头部的同一网络:

369a8a9041c6917d.png

假设您尝试识别的新事物也可以利用之前模型学到的此类输出特征,那么很有可能重复这些特征以用于新的目的。

在上图中,此假设模型是用数字训练的,因此也许有关数字的信息也可以应用于 a、b 和 c 等字母。

因此,您现在可以添加一个尝试预测 a、b 或 c 的新分类头,如下所示:

db97e5e60ae73bbd.png

在这里,较低级别的层被冻结且未训练,只有新的分类头会自行更新,以从左侧的预训练断层模型提供的特征中学习。

这种操作称为“迁移学习”,Teachable Machine 在后台进行。

也可以看到,只需要在网络一端训练多层感知器,其训练速度要比必须从头开始训练整个网络快得多。

但是,如何获得模型的子集?请转到下一部分了解详情。

4.TensorFlow Hub - 基础模型

寻找合适的基础模型

对于更高级、更热门的研究模型(如 MobileNet),您可以转到 TensorFlow Hub,然后过滤出适合使用 MobileNet v3 架构的 TensorFlow.js 模型查找结果如下所示:

c5dc1420c6238c14.png

请注意,这些结果中有些是“图片分类”(每种模型卡片结果的左上角详细显示),有些则是“图片特征向量”。

这些图片特征向量结果本质上是预裁剪的 MobileNet 版本,可用于获取图片特征向量(而不是最终分类)。

此类模型通常称为“基本模型”,您随后可以按照与上一部分相同的方式,通过添加新的分类头并使用自己的数据对其进行训练,来执行迁移学习。

接下来,我们要检查的是,对于给定感兴趣的基本模型,应以哪种模型格式发布 TensorFlow.js。如果您打开其中一个特征向量 MobileNet v3 模型的页面,在 JS 文档中,您会看到它根据使用 tf.loadGraphModel() 的文档中的示例代码段的图表模型表示。

f97d903d2e46924b.png

另请注意,如果您找到的是层格式(而不是图表格式)的模型,则可以选择要冻结的层以及要冻结的层以用于训练。在为新任务(通常称为“转移模型”)创建模型时,此功能非常有用。 但是,您暂时在本教程中使用默认的图模型类型,大多数 TF Hub 模型都将作为模型类型进行部署。如需详细了解如何使用层模型,请参阅从零到核心的 TensorFlow.js 课程。

迁移学习的优势

使用迁移学习而不是从头开始训练整个模型架构有什么优势?

首先,训练时间是使用迁移学习方法的一个主要优势,因为您已经有了训练基于的基础模型。

其次,由于已经进行了训练,您正尝试进行的新分类的示例会少很多,这能够省去很多麻烦。

如果您没有足够的时间和资源来收集要分类的内容的示例数据,并且需要快速制作原型,然后再收集更多训练数据以使其更强大,那么这将非常有用。

由于需要更少的数据,并且训练较小的网络的速度较快,迁移学习对资源占用较少。这使得它非常适合浏览器环境,在现代计算机上只需数十秒即可完成,而无需花费数小时、数天或数周便可进行完整的模型训练。

好的!现在,您已了解迁移学习的本质,接下来我们应该创建您自己的 Teachable Machine 版本了。立即开始吧!

5. 开始设置代码

所需条件

  • 现代网络浏览器。
  • 了解 HTML、CSS、JavaScript 和 Chrome 开发者工具(查看控制台输出)。

让我们开始编码吧

我们已为 Glitch.comCodepen.io 创建了一些样板模板。只需点击一下,即可克隆任一模板作为此 Codelab 的基础状态。

在 Glitch 上,点击 remix this(重新混音)按钮,创建分支并生成一组新的文件供您编辑。

或者,在 Codepen 上,点击屏幕右下角的“分支”(fork)

这个非常简单的框架为您提供了以下文件:

  • HTML 网页 (index.html)
  • 样式表 (style.css)
  • 用于编写 JavaScript 代码 (script.js) 的文件

为方便起见,TensorFlow.js 库的 HTML 文件中添加了额外的导入操作。该架构如下所示:

index.html

<!-- Import TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js" type="text/javascript"></script>

替代方案:使用您首选的网络编辑器或在本地工作

如果要下载代码并在本地工作,或者使用其他在线编辑器工作,只需在同一目录中创建上述 3 个文件,然后将 Glitch 样板代码复制并粘贴到每个文件中即可。

6.应用 HTML 样板

从何处着手?

所有原型都需要一些基本的 HTML 架构,以便呈现您的发现结果。立即进行设置。您将添加:

  • 网页的标题。
  • 一些描述性文字。
  • 状态段落。
  • 准备好网络摄像头信息流的视频。
  • 多个用于启动摄像头、收集数据或重置体验的按钮。
  • 可导入稍后会编码的 TensorFlow.js 和 JS 文件。

打开 index.html 并粘贴以下代码,以设置上述功能:

index.html

<!DOCTYPE html>
<html lang="en">
  <head>
    <title>Transfer Learning - TensorFlow.js</title>
    <meta charset="utf-8">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <!-- Import the webpage's stylesheet -->
    <link rel="stylesheet" href="/style.css">
  </head>
  <body>
    <h1>Make your own "Teachable Machine" using Transfer Learning with MobileNet v3 in TensorFlow.js using saved graph model from TFHub.</h1>

    <p id="status">Awaiting TF.js load</p>

    <video id="webcam" autoplay muted></video>

    <button id="enableCam">Enable Webcam</button>
    <button class="dataCollector" data-1hot="0" data-name="Class 1">Gather Class 1 Data</button>
    <button class="dataCollector" data-1hot="1" data-name="Class 2">Gather Class 2 Data</button>
    <button id="train">Train &amp; Predict!</button>
    <button id="reset">Reset</button>

    <!-- Import TensorFlow.js library -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.11.0/dist/tf.min.js" type="text/javascript"></script>

    <!-- Import the page's JavaScript to do some stuff -->
    <script type="module" src="/script.js"></script>
  </body>
</html>

细分维度

我们来分析一下上面的一些 HTML 代码,以突出您添加的一些关键内容。

  • 您为页面标题添加了一个 <h1> 标记以及一个 ID 为“status”的 <p> 标记,因为在使用系统的不同部分查看输出时,您将通过该标记输出信息。
  • 您添加了一个 ID 为“网络摄像头”的 <video> 元素,您稍后将在其中渲染您的摄像头直播。
  • 您添加了 5 个 <button> 元素。第一项 ID 为“enableCam”,用于启用相机。接下来的两个按钮具有“dataCollector”类,您可以使用该类来为要识别的对象收集示例图片。您稍后编写的代码已经过设计,因此可以添加任意数量的此类按钮,它们将按预期自动工作。

请注意,这些按钮还有一个特殊的用户指定属性,名为 data-1hot,第一个类的整数值从 0 开始。这是用于表示特定类的数据的数字索引。该索引将用数字表示法代替字符串来对输出类进行正确编码,因为机器学习模型只能用于数字。

还有一个 data-name 属性,其中包含您要用于此类的直观易懂的名称,可让您向用户提供一个更有意义的名称,而不是来自 1 个热编码的数值索引值。

最后,您可以使用训练和重置按钮,在收集数据后启动训练过程,或分别重置应用。

  • 您还添加了 2 个 <script> 导入作业。一个用于 TensorFlow.js,另一个用于您很快定义的 script.js。

7. 添加样式

元素默认值

为刚刚添加的 HTML 元素添加样式,以确保这些元素能够正确呈现。下面是正确添加到位置和尺寸元素中的一些样式。没什么特别的。您稍后可以再添加以下代码,以打造更出色的用户体验,就像您在可教机器视频中看到的那样。

style.css

body {
  font-family: helvetica, arial, sans-serif;
  margin: 2em;
}

h1 {
  font-style: italic;
  color: #FF6F00;
}

video {
  clear: both;
  display: block;
  margin: 10px;
  background: #000000;
  width: 640px;
  height: 480px;
}

button {
  padding: 10px;
  float: left;
  margin: 5px 3px 5px 10px;
}

.removed {
  display: none;
}

#status {
  font-size:150%;
}

太好了!这就是您所需的一切。如果立即预览输出,结果应如下所示:

81909685d7566dcb.png

8. JavaScript:键常量和监听器

定义键常量

首先,添加将在整个应用中使用的一些关键常量。首先,将 script.js 的内容替换为以下常量:

script.js

const STATUS = document.getElementById('status');
const VIDEO = document.getElementById('webcam');
const ENABLE_CAM_BUTTON = document.getElementById('enableCam');
const RESET_BUTTON = document.getElementById('reset');
const TRAIN_BUTTON = document.getElementById('train');
const MOBILE_NET_INPUT_WIDTH = 224;
const MOBILE_NET_INPUT_HEIGHT = 224;
const STOP_DATA_GATHER = -1;
const CLASS_NAMES = [];

我们来深入了解一下这些数据:

  • STATUS 仅存储您将写入状态更新的段落标记的引用。
  • VIDEO 用于存储对将呈现网络摄像头 Feed 的 HTML 视频元素的引用。
  • ENABLE_CAM_BUTTONRESET_BUTTONTRAIN_BUTTON 可从 HTML 页面获取对所有按键按钮的 DOM 引用。
  • MOBILE_NET_INPUT_WIDTHMOBILE_NET_INPUT_HEIGHT 分别定义了 MobileNet 模型的预期输入宽度和高度。通过像这样将文件存储在靠近文件顶部的常量中,这样在日后您决定使用其他版本时,更新一次值会更加容易,而无需在许多不同的位置替换。
  • STOP_DATA_GATHER 设置为 - 1。这会存储一个状态值,以便您知道用户何时停止点击某个按钮,以便从摄像头 Feed 收集数据。通过为这个数字指定更有意义的名称,可提高代码的可读性。
  • CLASS_NAMES 可充当查询,并保留可能的类别预测的易懂名称。此数组将在稍后填充。

现在,您已经有了对关键元素的引用,现在可以将某些事件监听器与其关联了。

添加按键事件监听器

首先,向按键按钮添加点击事件处理脚本,如下所示:

script.js

ENABLE_CAM_BUTTON.addEventListener('click', enableCam);
TRAIN_BUTTON.addEventListener('click', trainAndPredict);
RESET_BUTTON.addEventListener('click', reset);

function enableCam() {
  // TODO: Fill this out later in the codelab!
}

function trainAndPredict() {
  // TODO: Fill this out later in the codelab!
}

function reset() {
  // TODO: Fill this out later in the codelab!
}

ENABLE_CAM_BUTTON - 点击后会调用 enableCam 函数。

TRAIN_BUTTON - 点击后会调用 TrainAndPredict。

RESET_BUTTON - 点击时重置通话。

最后,在本节中,您可以使用 document.querySelectorAll() 找到类“dataCollector”的所有按钮。这将返回一个从文档中找到的与以下元素匹配的元素数组:

script.js

let dataCollectorButtons = document.querySelectorAll('button.dataCollector');
for (let i = 0; i < dataCollectorButtons.length; i++) {
  dataCollectorButtons[i].addEventListener('mousedown', gatherDataForClass);
  dataCollectorButtons[i].addEventListener('mouseup', gatherDataForClass);
  // Populate the human readable names for classes.
  CLASS_NAMES.push(dataCollectorButtons[i].getAttribute('data-name'));
}

function gatherDataForClass() {
  // TODO: Fill this out later in the codelab!
}

代码说明

然后,您可以遍历找到的按钮,并为每个按钮关联 2 个事件监听器。一个用于“mousedown”,另一个用于“mouseup”。这样,只要按下按钮,您就可以继续录制样本,这对于数据收集非常有用。

这两个事件都会调用 gatherDataForClass 函数,您稍后会定义该函数。

此时,您还可以将找到的人类可读类名称从 HTML 按钮属性名称 data-name 推送到 CLASS_NAMES 数组。

接下来,添加一些变量以存储稍后使用的关键元素。

script.js

let mobilenet = undefined;
let gatherDataState = STOP_DATA_GATHER;
let videoPlaying = false;
let trainingDataInputs = [];
let trainingDataOutputs = [];
let examplesCount = [];
let predict = false;

我们来演示一下。

首先,您有一个变量 mobilenet 来存储加载的 mobilenet 模型。在初始状态下,此属性为“未定义”。

接下来,您有一个名为 gatherDataState 的变量。如果按下“dataCollector”按钮,此按钮便会改为该按钮的 1 个常用 ID(如 HTML 中所定义),以便您了解您当前收集的是哪类数据。此值最初设为 STOP_DATA_GATHER,当您稍后未按下任何按钮时,您稍后写入的数据收集循环将不会收集任何数据。

videoPlaying 可跟踪网络摄像头信息流是否已成功加载和播放且可供使用。初始设置是“false”,因为在您按下 ENABLE_CAM_BUTTON. 后,摄像头不会开启

接下来,定义 2 个数组:trainingDataInputstrainingDataOutputs。当您点击 MobileData 基本模型生成的输入特征的“dataCollector”按钮以及分别对输出类别进行采样时,这些值会存储收集的训练数据值。

然后,再定义一个最终数组 examplesCount,,用于跟踪开始添加各个类时包含的样本数量。

最后,您有一个名为 predict 的变量来控制预测循环。此属性最初设置为 false。在此之后将它设置为 true 之前,无法进行预测。

现在,所有关键变量都已定义完毕,接下来我们加载加载预先经过过滤的 MobileNet v3 基础模型,该模型提供图片特征向量(而非分类)。

9. 加载 MobileNet 基础模型

首先,定义一个名为 loadMobileNetFeatureModel 的新函数,如下所示。这必须是异步函数,因为加载模型的行为是异步的:

script.js

/**
 * Loads the MobileNet model and warms it up so ready for use.
 **/
async function loadMobileNetFeatureModel() {
  const URL =
    'https://tfhub.dev/google/tfjs-model/imagenet/mobilenet_v3_small_100_224/feature_vector/5/default/1';

  mobilenet = await tf.loadGraphModel(URL, {fromTFHub: true});
  STATUS.innerText = 'MobileNet v3 loaded successfully!';

  // Warm up the model by passing zeros through it once.
  tf.tidy(function () {
    let answer = mobilenet.predict(tf.zeros([1, MOBILE_NET_INPUT_HEIGHT, MOBILE_NET_INPUT_WIDTH, 3]));
    console.log(answer.shape);
  });
}

// Call the function immediately to start loading.
loadMobileNetFeatureModel();

在此代码中,您将定义 URL,其中要加载的模型位于 TFHub 文档中。

然后,您可以使用 await tf.loadGraphModel() 加载该模型,并在从此 Google 网站加载模型时将特殊属性 fromTFHub 设置为 true。此特殊情况仅适用于使用 TF Hub 托管的模型(在该模型中,必须设置此额外的属性)。

加载完成后,您可以设置包含消息的 STATUS 元素的 innerText,以便直观地看到它已正确加载,并且您可以开始收集数据。

现在要做的就是将模型预热。对于此类较大的模型,首次使用该模型时,可能需要一些时间才能完成所有设置。因此,通过模型传递 0 有助于避免将来可能更关键的计时问题。

您可以使用封装在 tf.tidy() 中的 tf.zeros() 来确保正确处理张量,批次大小为 1,以及在开始时为常量定义的正确高度和宽度。最后,您还需要指定颜色通道,在本例中为 3,因为模型需要 RGB 图像。

接下来,使用 answer.shape() 记录所返回张量的形状,以帮助您了解此模型生成的图片特征的大小。

定义此函数后,您可以立即调用它,以在网页加载时启动模型下载。

如果您现在查看实时预览,片刻之后就会发现状态文字从“正在等待 TF.js 加载”变为“已成功加载 MobileNet v3!”如下所示。请先确认这一点,然后再继续操作。

a28b734e190afff.png

您还可以检查控制台输出,以查看此模型生成的输出特征的输出大小。通过 MobileNet 模型运行零之后,您会看到形状为 [1, 1024]。第一项只是批次大小为 1,您可以看到它实际返回了 1024 个特征,这些特征可用于帮助您对新对象进行分类。

10. 定义新模型头部

现在该定义模型头了,它本质上是一种非常最小的多层感知器。

script.js

let model = tf.sequential();
model.add(tf.layers.dense({inputShape: [1024], units: 128, activation: 'relu'}));
model.add(tf.layers.dense({units: CLASS_NAMES.length, activation: 'softmax'}));

model.summary();

// Compile the model with the defined optimizer and specify a loss function to use.
model.compile({
  // Adam changes the learning rate over time which is useful.
  optimizer: 'adam',
  // Use the correct loss function. If 2 classes of data, must use binaryCrossentropy.
  // Else categoricalCrossentropy is used if more than 2 classes.
  loss: (CLASS_NAMES.length === 2) ? 'binaryCrossentropy': 'categoricalCrossentropy',
  // As this is a classification problem you can record accuracy in the logs too!
  metrics: ['accuracy']
});

我们来了解一下此代码。首先,定义一个 tf.Sequence 模型以添加模型层。

接下来,添加一个密集层作为此模型的输入层。此输入形状为 1024,因为 MobileNet v3 特征的输出大小为 。在通过模型传递模型后,您在上一步中发现了这种情况。该层有 128 个使用 ReLU 激活函数的神经元。

如果您不熟悉激活函数和模型层,请考虑参加本研讨会开始时详细介绍的课程,了解这些属性在后台的工作原理。

要添加的下一个层是输出层。神经元的数量应该等于您要预测的类别数量。为此,您可以使用 CLASS_NAMES.length 查看您计划分类的类别数量,数量等于界面中发现的数据收集按钮的数量。由于这是一个分类问题,因此您可以在此输出层使用 softmax 激活,在尝试创建模型解决分类问题而非回归时,必须使用该激活层。

现在,输出一个 model.summary(),以将新定义的模型的概览输出到控制台。

最后,编译模型,使其可以训练模型。在此示例中,优化器设置为 adam;如果 CLASS_NAMES.length 等于 2,损失将是 binaryCrossentropy;如果有多个类,则会使用 categoricalCrossentropy分类。系统还会请求准确率指标,以便稍后在日志中对其进行监控以进行调试。

在控制台中,您应该会看到类似下图的内容:

22eaf32286fea4bb.png

请注意,这包含超过 13 万个可训练参数。但由于这是一个简单的常规神经元层,因此训练速度会非常快。

在项目完成后,作为一项活动,你可以尝试更改第一层中的神经元数量,看看在数量足够低的情况下,仍然可以取得良好性能。通常,机器学习会进行一定程度的试验和错误,以找到最佳参数值,以便在资源用量和速度之间进行最佳权衡。

11. 启用网络摄像头

现在,您可以对您之前定义的 enableCam() 函数进行充实。 如下所示,添加一个名为 hasGetUserMedia() 的新函数,然后将之前定义的 enableCam() 函数的内容替换为以下代码。

script.js

function hasGetUserMedia() {
  return !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
}

function enableCam() {
  if (hasGetUserMedia()) {
    // getUsermedia parameters.
    const constraints = {
      video: true,
      width: 640,
      height: 480
    };

    // Activate the webcam stream.
    navigator.mediaDevices.getUserMedia(constraints).then(function(stream) {
      VIDEO.srcObject = stream;
      VIDEO.addEventListener('loadeddata', function() {
        videoPlaying = true;
        ENABLE_CAM_BUTTON.classList.add('removed');
      });
    });
  } else {
    console.warn('getUserMedia() is not supported by your browser');
  }
}

首先,创建一个名为 hasGetUserMedia() 的函数,以通过检查是否存在关键浏览器 API 属性来检查浏览器是否支持 getUserMedia()

enableCam() 函数中,使用您在上面定义的 hasGetUserMedia() 函数来检查是否支持该函数。否则,向控制台输出警告。

如果支持,请为您的 getUserMedia() 调用定义一些限制条件,例如您希望只使用视频流,并且希望视频的 width 尺寸为 640 像素,并且 height480 像素。为什么呢?大于此值的视频并没有多大的意义,因为需要将其调整为 224 x 224 像素才能馈送到 MobileNet 模型中。您也可以通过请求较小的分辨率来节省一些计算资源。大多数相机都支持此大小的分辨率。

接下来,使用上文详述的 constraints 调用 navigator.mediaDevices.getUserMedia(),然后等待 stream 返回。返回 stream 后,您可以将 VIDEO 元素设置为其 srcObject 值,从而使其播放 stream

您还应在 VIDEO 元素上添加一个 eventListener,以了解 stream 何时加载完毕并成功播放。

蒸汽加载后,您可以将 videoPlaying 设置为 true 并移除 ENABLE_CAM_BUTTON,以防止将其类别设置为“removed”。

现在运行您的代码,点击启用摄像头按钮,然后授予摄像头的使用权限。如果您是首次执行此操作,您应该会看到自己在网页上呈现给视频元素,如下所示:

b378eb1affa9b883.png

现在可以添加函数来处理 dataCollector 按钮点击了。

12. 数据收集按钮事件处理脚本

现在,您可以填写当前名为 gatherDataForClass(). 的空函数。这是您在 Codelab 开始时为 dataCollector 按钮分配的事件处理程序函数。

script.js

/**
 * Handle Data Gather for button mouseup/mousedown.
 **/
function gatherDataForClass() {
  let classNumber = parseInt(this.getAttribute('data-1hot'));
  gatherDataState = (gatherDataState === STOP_DATA_GATHER) ? classNumber : STOP_DATA_GATHER;
  dataGatherLoop();
}

首先,请检查data-1hot调用当前点击的按钮的this.getAttribute()替换为属性名称(在此示例中为data-1hot作为参数。由于这是一个字符串,因此您可以使用 parseInt() 将其转换为整数,并将此结果分配给名为 classNumber. 的变量

接下来,相应地设置 gatherDataState 变量。如果当前的 gatherDataState 等于 STOP_DATA_GATHER(您设置为 -1),这意味着您当前不会收集任何数据,并且已触发 mousedown 事件。将 gatherDataState 设为您刚找到的 classNumber

否则,这表明您当前正在收集数据,且触发的事件是 mouseup 事件,而您现在希望停止收集该类的数据。只需将它重新设置为 STOP_DATA_GATHER 状态,即可结束您稍后定义的数据收集循环。

最后,启动对实际执行类数据的录制的 dataGatherLoop(), 调用。

13. 数据收集

现在,定义 dataGatherLoop() 函数。该函数负责从网络摄像头视频中节录图像,通过 MobileNet 模型传递图像,并捕获该模型的输出(1024 特征向量)。

然后,代码会连同当前按下的按钮的 gatherDataState ID 一起存储这些按钮,以便您了解此数据代表的类。

我们详细了解一下:

script.js

function dataGatherLoop() {
  if (videoPlaying && gatherDataState !== STOP_DATA_GATHER) {
    let imageFeatures = tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [MOBILE_NET_INPUT_HEIGHT,
          MOBILE_NET_INPUT_WIDTH], true);
      let normalizedTensorFrame = resizedTensorFrame.div(255);
      return mobilenet.predict(normalizedTensorFrame.expandDims()).squeeze();
    });

    trainingDataInputs.push(imageFeatures);
    trainingDataOutputs.push(gatherDataState);

    // Intialize array index element if currently undefined.
    if (examplesCount[gatherDataState] === undefined) {
      examplesCount[gatherDataState] = 0;
    }
    examplesCount[gatherDataState]++;

    STATUS.innerText = '';
    for (let n = 0; n < CLASS_NAMES.length; n++) {
      STATUS.innerText += CLASS_NAMES[n] + ' data count: ' + examplesCount[n] + '. ';
    }
    window.requestAnimationFrame(dataGatherLoop);
  }
}

仅当 videoPlaying 为 true 时(即网络摄像头处于活动状态,且 gatherDataState 不等于 STOP_DATA_GATHER,且当前正在按类数据收集按钮),才会继续执行此函数。

接下来,将代码封装在 tf.tidy() 中,以处理后续代码中已创建的张量。此 tf.tidy() 代码执行的结果会存储在名为 imageFeatures 的变量中。

您现在可以使用 tf.browser.fromPixels() 获取网络摄像头 VIDEO 的一帧。包含图片数据的张量存储在名为 videoFrameAsTensor 的变量中。

接下来,将 videoFrameAsTensor 变量的大小调整为适合 MobileNet 模型输入的形状。使用 tf.image.resizeBilinear() 调用,将要重塑的张量用作第一个参数,然后使用定义新高度和宽度(由您之前创建的常量定义)的形状。最后,通过传递第三个参数将对齐角设置为 true,以避免在调整大小时出现任何对齐问题。调整后的结果存储在名为 resizedTensorFrame 的变量中。

请注意,此初始大小调整会拉伸图片,因为网络摄像头图片的大小为 640 x 480 像素,而模型需要 224 x 224 像素的方形图片。

在本演示中,代码应该会正常运行。但是,在完成此 Codelab 后,您可能想尝试从此图片中剪裁出一个方形,以便日后创建的任何生产系统都有更好的结果。

接下来,归一化图片数据。使用 tf.browser.frompixels() 时,图片数据始终在 0 到 255 的范围内,因此只需将调整过大小的 TensorFrame 除以 255,即可确保所有值都介于 0 和 1 之间,而这是 MobileNet 模型所预期的值。

最后,在代码的 tf.tidy() 部分中,通过调用 mobilenet.predict() 将此标准化张量推送到加载的模型,您可以使用 expandDims() 向其传递扩展版本的 normalizedTensorFrame,以便其一批输入 1,因为模型需要一批输入进行处理。

结果出来后,您可以立即调用squeeze()将其缩小为一维张量,然后返回并分配给imageFeatures变量,用于从tf.tidy()

现在,您已获得了 MobileNet 模型的 imageFeatures,接下来可以通过将其推送到您之前定义的 trainingDataInputs 数组来记录这些记录。

您还可以通过将当前 gatherDataState 推送到 trainingDataOutputs 数组来记录此输入表示的内容。

请注意,在之前定义的 gatherDataForClass() 函数中点击按钮时,gatherDataState 变量会设置为当前类的数字 ID,您将记录相关数据。

此时,您还可以增加给定类的样本数。为此,请先检查 examplesCount 数组中的索引是否之前已初始化。如果未定义,则将其设置为 0 以初始化指定类的数字 ID 的计数器,然后递增当前 gatherDataStateexamplesCount

现在,在网页上更新 STATUS 元素的文本,以便捕获每个类的当前计数。为此,请循环遍历 CLASS_NAMES 数组,并输出简单易懂的名称以及 examplesCount 中同一索引处的数据计数。

最后,调用 window.requestAnimationFrame() 并将 dataGatherLoop 作为参数传递,以便再次以递归方式调用此函数。这会继续从视频中对帧进行采样,直到检测到按钮的 mouseup 并且 gatherDataState 设置为 STOP_DATA_GATHER,,此时数据收集循环将结束。

如果您现在运行代码,应该可以点击启用相机按钮,等待网络摄像头加载,然后点击并按住各个数据收集按钮来收集各类数据的示例。在这里,我分别看到有关手机和我的手机的数据。

541051644a45131f.gif

您将看到状态文本已更新,因为它会将所有张量存储在内存中,如上面的屏幕截图所示。

14. 训练和预测

下一步是为当前为空的 trainAndPredict() 函数实现代码,转移学习正是在此处进行。我们来看一下代码:

script.js

async function trainAndPredict() {
  predict = false;
  tf.util.shuffleCombo(trainingDataInputs, trainingDataOutputs);
  let outputsAsTensor = tf.tensor1d(trainingDataOutputs, 'int32');
  let oneHotOutputs = tf.oneHot(outputsAsTensor, CLASS_NAMES.length);
  let inputsAsTensor = tf.stack(trainingDataInputs);

  let results = await model.fit(inputsAsTensor, oneHotOutputs, {shuffle: true, batchSize: 5, epochs: 10,
      callbacks: {onEpochEnd: logProgress} });

  outputsAsTensor.dispose();
  oneHotOutputs.dispose();
  inputsAsTensor.dispose();
  predict = true;
  predictLoop();
}

function logProgress(epoch, logs) {
  console.log('Data for epoch ' + epoch, logs);
}

首先,请确保将 predict 设置为 false,以停止任何当前预测。

接下来,使用 tf.util.shuffleCombo() 打乱输入和输出数组的顺序,以确保顺序不会在训练中引发问题。

将输出数组 trainingDataOutputs, 转换为 int32 类型的 Tensor1d,这样它就可以在独热编码中使用。此文件存储在名为 outputsAsTensor 的变量中。

tf.oneHot() 函数与此 outputsAsTensor 变量搭配使用,以及要编码的类数上限(即 CLASS_NAMES.length)。您的一个热编码输出现在存储在名为 oneHotOutputs 的新张量中。

请注意,当前 trainingDataInputs 是记录的张量的数组。为了将其用于训练,您需要将张量数组转换为常规的 2D 张量。

为此,您可以在 TensorFlow.js 库中使用一个名为 tf.stack() 的强大函数。

该函数接受一系列张量,并将它们堆叠在一起以生成更高维的张量作为输出。在这种情况下,系统会返回张量 2D,即一批 1 维输入,每组长度为 1024 个字符,其中包含记录的特征,这正是训练所需的内容。

接下来,await model.fit() 来训练自定义模型头部。在这里,您将 inputsAsTensor 变量与 oneHotOutputs 一起传递,以分别表示要用于示例输入和目标输出的训练数据。在第三个参数的配置对象中,将shuffletrue,请使用batchSize/5 ,以及epochs设为10,然后指定一个callback适用于onEpochEnd添加到logProgress函数。

最后,您可以在训练模型时处理创建的张量。然后,您可以将 predict 重新设置为 true 以允许再次执行预测,然后调用 predictLoop() 函数以开始预测实时网络摄像头图片。

您还可以定义 logProcess() 函数以记录训练状态,训练状态在上面的 model.fit() 中使用,会在每轮训练后将结果输出到控制台。

即将大功告成!是时候添加 predictLoop() 函数进行预测了。

核心预测循环

在这里,您将实现主预测循环,通过浏览器进行实时结果,从网络摄像头采样帧并持续预测每个帧中的内容。

我们检查一下代码:

script.js

function predictLoop() {
  if (predict) {
    tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO).div(255);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor,[MOBILE_NET_INPUT_HEIGHT,
          MOBILE_NET_INPUT_WIDTH], true);

      let imageFeatures = mobilenet.predict(resizedTensorFrame.expandDims());
      let prediction = model.predict(imageFeatures).squeeze();
      let highestIndex = prediction.argMax().arraySync();
      let predictionArray = prediction.arraySync();

      STATUS.innerText = 'Prediction: ' + CLASS_NAMES[highestIndex] + ' with ' + Math.floor(predictionArray[highestIndex] * 100) + '% confidence';
    });

    window.requestAnimationFrame(predictLoop);
  }
}

首先,检查 predict 是否为 true,以便仅在模型训练完毕且可供使用后才进行预测。

接下来,您可以获取当前图片的图片特征,就像在 dataGatherLoop() 函数中一样。实际上,您应使用 tf.browser.from pixels() 从网络摄像头获取帧,对其进行归一化,将其大小调整为 224 x 224 像素,然后通过 MobileNet 模型传递这些数据,以获取最终的图像特征。

不过,您现在可以使用新训练的模型头部实际执行预测,方法是传递刚刚通过训练模型的 predict() 函数找到的 imageFeatures。然后,您可以挤压生成的张量,使其再次变为一维,然后分配给名为 prediction 的变量。

通过此 prediction,您可以使用 argMax() 找到具有最高值的索引,然后使用 arraySync() 将此生成的张量转换为一个数组,以获取 JavaScript 中的基础数据,以发现值最高的元素。此值存储在名为 highestIndex 的变量中。

您也可以通过直接对 prediction 张量调用 arraySync() 来采用相同的方式获取实际预测置信度分数。

现在,您已具备使用 prediction 数据更新 STATUS 文本所需的一切。要获取该类的直观易懂的字符串,您可以在 CLASS_NAMES 数组中查找 highestIndex,然后从 predictionArray 获取置信度值。如需提高可读性,以百分比的形式表示,只需用 100 乘以 math.floor() 得到结果即可。

最后,您可以在准备好后使用 window.requestAnimationFrame() 再次调用 predictionLoop(),对视频流进行实时分类。如果您选择使用新数据训练新模型,此过程将持续到 predict 设置为 false

这就涉及到最后一块拼图。实现重置按钮。

15. 实现重置按钮

即将完成!最后的解决方法就是实现重置按钮以重新开始。当前空 reset() 函数的代码如下。接下来,按如下所示进行更新:

script.js

/**
 * Purge data and start over. Note this does not dispose of the loaded
 * MobileNet model and MLP head tensors as you will need to reuse
 * them to train a new model.
 **/
function reset() {
  predict = false;
  examplesCount.length = 0;
  for (let i = 0; i < trainingDataInputs.length; i++) {
    trainingDataInputs[i].dispose();
  }
  trainingDataInputs.length = 0;
  trainingDataOutputs.length = 0;
  STATUS.innerText = 'No data collected';

  console.log('Tensors in memory: ' + tf.memory().numTensors);
}

首先,将 predict 设置为 false 以停止任何正在运行的预测循环。接下来,删除 examplesCount 数组中的所有内容,具体方法是将其长度设为 0,这是清除数组中的所有内容的便捷方式。

现在请查看当前记录的所有 trainingDataInputs,并确保其中所含的每个张量 dispose() 再次释放内存,因为 JavaScript 垃圾回收器不会清理张量。

完成后,您现在可以在 trainingDataInputstrainingDataOutputs 数组上将数组长度安全地设为 0,以清除这些数组。

最后,将 STATUS 文本设置为合理的内容,并将健全性检查输出的张量输出为健全性检查。

请注意,内存中仍然有数百个张量,因为系统不会处理 MobileNet 模型和您定义的多层感知器。如果在重置之后再次进行训练,您需要重新使用训练数据。

16. 我们来试试吧

现在,您可以测试一下自己的 Teachable Machine 版本了!

转到实时预览,启用网络摄像头,为会议室中某个对象的至少 1 个类别收集 30 个样本,然后为第 2 类为其他对象执行相同的操作,点击“训练”,然后查看控制台日志以查看进度。训练速度应该很快:

bf1ac3cc5b15740.gif

训练完成后,将对象显示给相机,以获取将要输出到网页顶部附近状态文本区域中的实时预测结果。如果遇到问题,请查看我已完成的代码,看看是否遗漏了任何复制的内容。

17. 恭喜

恭喜!您刚刚在浏览器中完成了使用 TensorFlow.js 的首个迁移学习示例。

不妨试试,在各种对象上进行测试,您可能会发现有些东西比其他对象更难识别,特别是当它们与其他对象相似时。您可能需要添加更多类或训练数据来区分它们。

回顾

在此 Codelab 中,您学习了以下内容:

  1. 什么是迁移学习,以及相对于训练完整模型有何优势。
  2. 如何从 TensorFlow Hub 获取可重复使用的模型。
  3. 如何设置适合迁移学习的 Web 应用。
  4. 如何加载和使用基础模型以生成图片特征。
  5. 如何训练新的预测头部,使其能够从摄像头图像中识别自定义对象。
  6. 如何使用生成的模型实时对数据进行分类。

接下来做什么?

现在,您已拥有一个良好的起点,您可以想出什么创意来扩展此机器学习模型样板?也许您可以彻底改变目前所在的行业,帮助公司员工训练模型,从而对他们的日常工作中重要的事务进行分类?无限可能。

要进一步学习,请考虑免费学习此完整课程,其中介绍了如何将此 Codelab 中现有的 2 个模型合并为 1 个模型以提高效率。

此外,如果您想详细了解原始 Teachable Machine 应用背后的理论,请参阅本教程

与我们分享您的成果

您也可以轻松地将今天创作的内容用于其他创意用例,并鼓励您跳出思维定式并继续进行创作。

记得使用 #MadeWithTFJS 标签在社交媒体上关注我们,争取让您的项目有机会在我们的 TensorFlow 博客甚至 将来举办的活动。我们期待看到您的作品。

要结算的网站