TensorFlow.js:打造您自己的“可学习的机器”通过 TensorFlow.js 使用迁移学习

1. 准备工作

在过去几年中,TensorFlow.js 模型的使用量呈指数级增长,许多 JavaScript 开发者现在都希望使用现有的先进模型,并对其进行重新训练,以处理其行业特有的自定义数据。采用现有模型(通常称为基础模型)并将其用于类似但不同的领域,这种做法称为迁移学习。

与从完全空白的模型开始训练相比,迁移学习具有许多优势。您可以重复使用之前从已训练的模型中学习到的知识,并且只需要较少的要分类的新商品示例。此外,由于只需重新训练模型架构的最后几层,而无需重新训练整个网络,因此训练速度通常会显著提高。因此,迁移学习非常适合 Web 浏览器环境,因为该环境中的资源可能会因执行设备而异,但也可以直接访问传感器,从而轻松获取数据。

本 Codelab 将向您展示如何从零开始构建一个 Web 应用,重现 Google 广受欢迎的“会学习的机器”网站。该网站可让您创建一个功能性 Web 应用,任何用户都可以使用该应用通过网络摄像头拍摄的几个示例图片来识别自定义对象。为了让您专注于此 Codelab 的机器学习方面,我们特意将网站保持在最低限度。不过,与原始的会学习的机器网站一样,您完全可以利用现有的 Web 开发者经验来改善用户体验。

前提条件

本 Codelab 专为熟悉 TensorFlow.js 预建模型和基本 API 用法,并且希望开始使用 TensorFlow.js 中的迁移学习的 Web 开发者而设计。

  • 学习本实验的前提是要对 TensorFlow.js、HTML5、CSS 和 JavaScript 有基本的了解。

如果您是 TensorFlow.js 新手,不妨先学习这门免费的“从新手到高手”课程。该课程不要求您具备机器学习或 TensorFlow.js 方面的背景知识,会以小步骤的方式教您需要了解的所有内容。

学习内容

  • 什么是 TensorFlow.js,以及为什么您应该在下一个 Web 应用中使用它。
  • 如何构建一个简化的 HTML/CSS /JS 网页,以复制 会学习的机器 用户体验。
  • 如何使用 TensorFlow.js 加载预训练的基础模型(具体而言是 MobileNet),以生成可用于迁移学习的图片特征。
  • 如何从用户的摄像头收集数据,以识别多种类型的数据。
  • 如何创建和定义一个多层感知机,该感知机可获取图像特征并学习使用这些特征对新对象进行分类。

我们开始行动吧...

所需条件

  • 构建过程中最好使用 Glitch.com 账号,或者您也可以使用您能轻松进行修改和运行操作的 Web 服务环境。

2. 什么是 TensorFlow.js?

54e81d02971f53e8.png

TensorFlow.js 是一个开源机器学习库,可在任何能够运行 JavaScript 的环境中运行。它以用 Python 编写的原始 TensorFlow 库为基础,旨在面向 JavaScript 生态系统打造全新的开发者体验以及 API 体验。

它的应用场合如何?

鉴于 JavaScript 的可移植性,您现在可以使用一种语言编写机器学习代码,然后轻松地在下列所有平台上运行此代码:

  • 使用 Vanilla JavaScript 的客户端网络浏览器
  • 使用 Node.js 的服务器端甚至是 Raspberry Pi 等 IoT 设备
  • 使用 Electron 的桌面应用
  • 使用 React Native 的原生移动应用

TensorFlow.js 还支持在这些环境中使用多个后端(这些环境是它可以在其中运行的实际硬件环境,例如 CPU 或 WebGL 等。此处的“后端”并不是指服务器端环境,例如,用于执行的后端可能是 WebGL 中的客户端),以确保兼容性,同时保持快速运行。TensorFlow.js 目前支持:

  • 在设备显卡 (GPU) 上运行 WebGL - 这是使用 GPU 加速运行较大模型(大小超过 3MB)的最快方式。
  • 在 CPU 上执行 Web Assembly (WASM) - 旨在提升包括老款手机在内的设备的 CPU 性能。这种方式更适合较小的模型(大小在 3MB 以下),由于将内容上传到图形处理器存在开销,因此在 CPU 上运行 WASM 可能实际上比运行 WebGL 要快。
  • CPU 执行 - 如果其他环境都不可用,则回退至此方式。这是三种方式中运行最慢的一种,但始终可以使用。

注意:如果您知道要在哪个设备上运行,则可以选择强制调用这些后端之一。或者,如果您没有指定,可以直接让 TensorFlow.js 为您做决定。

客户端的强大功能

在客户端计算机的网络浏览器中运行 TensorFlow.js 具有诸多值得关注的优势。

隐私权

您可以在客户端计算机上对数据进行训练和分类,而不将数据发送到第三方网络服务器。有时候,这可能是一项强制要求,例如 GDPR 等地方法律要求这么做,或者对于您所处理的数据,用户可能想要将其保留在自己的计算机上,而不发送给第三方。

速度

由于您无需向远程服务器发送数据,因此推理(一种对数据进行分类的操作)可能会更快。更棒的是,如果用户向您授予权限,您可以直接访问设备的传感器,比如相机、麦克风、GPS、加速度计等。

覆盖面和规模

只要点击一下您发送的链接,世界上的任何人就可以在他们的浏览器中打开相应网页,并利用您创作的内容。无需使用 CUDA 驱动程序等进行复杂的服务器端 Linux 设置,即可使用该机器学习系统。

费用

没有服务器意味着您只需为托管 HTML、CSS、JS 和模型文件的内容分发网络 (CDN) 付费。CDN 的费用要比全天候运行服务器(可能搭载了显卡)便宜得多。

服务器端功能

利用 TensorFlow.js 的 Node.js 实现可获得以下特性。

全面的 CUDA 支持

在服务器端,要实现显卡加速,必须安装 NVIDIA CUDA 驱动程序,以便 TensorFlow 能够使用显卡(与使用 WebGL 的浏览器不同,这种情况下无需安装该驱动程序)。不过,借助全面的 CUDA 支持,您可以充分利用显卡更精细的功能,从而缩短训练和推理时间。性能与 Python TensorFlow 实现相当,因为它们具有相同的 C++ 后端。

模型大小

对于一些研究中的先进模型,您可能需要使用非常大的模型(大小可能达到 GB 级)。由于每个浏览器标签页的内存用量有限,这些模型目前无法在网络浏览器中运行。如需运行这些较大的模型,您可以使用 Node.js 在自己的服务器上高效运行此类模型,并且服务器的硬件达到自己的需求。

IoT

Raspberry Pi 等热门单板计算机均支持 Node.js,这意味着您也可以在此类设备上执行 TensorFlow.js 模型。

速度

Node.js 采用 JavaScript 语言编写,这意味着您可以从即时编译中受益。也就是说,使用 Node.js 时,您可能会经常发现性能有所提升,因为系统会在运行时对之进行优化,特别是对于您可能正在执行的任何预处理操作。此案例研究中就有一个很好的例子,展示了 Hugging Face 如何使用 Node.js 使其自然语言处理模型的性能提升了两倍。

现在,您已了解 TensorFlow.js 的基础知识,它可以运行的环境,以及它具备的一些优势。接下来,我们来运用一些实用功能。

3. 迁移学习

什么是迁移学习?

迁移学习是指利用已学到的知识来帮助学习其他类似的事物。

我们人类一直在这样做。您的大脑中存储着一生的经验,可帮助您识别以前从未见过的新事物。以这棵柳树为例:

e28070392cd4afb9.png

根据您所处的地理位置,您可能以前从未见过这种树。

不过,如果我让您告诉我下面的新图片中是否有柳树,您可能很快就能找到它们,即使它们与我之前向您展示的原始图片中的柳树角度不同,也略有不同。

d9073a0d5df27222.png

您的大脑中已经有许多神经元知道如何识别树状物体,还有一些神经元擅长寻找长直线。您可以利用这些知识快速对柳树进行分类,柳树是一种具有许多长而直的垂直树枝的树状物体。

同样,如果您有一个已在某个领域(例如图片识别)中训练过的机器学习模型,您可以重复使用该模型来执行另一项相关任务。

您也可以使用 MobileNet 等高级模型执行相同的操作。MobileNet 是一种非常热门的研究模型,可对 1000 种不同的对象类型执行图像识别。从狗到汽车,它基于一个名为 ImageNet 的庞大数据集进行训练,该数据集包含数百万张带标签的图片。

在此动画中,您可以看到此 MobileNet V1 模型中包含的层数量非常庞大:

7d4e1e35c1a89715.gif

在训练过程中,该模型学会了如何提取对所有这 1,000 个对象都至关重要的常见特征,并且它用于识别这些对象的许多较低级别特征也可用于检测之前从未见过的新对象。毕竟,一切最终都只是线条、纹理和形状的组合。

我们先来看看传统的卷积神经网络 (CNN) 架构(类似于 MobileNet),然后了解迁移学习如何利用以下已训练的网络来学习新知识。下图展示了 CNN 的典型模型架构,在本例中,该模型经过训练,可识别 0 到 9 的手写数字:

baf4e3d434576106.png

如果您能将现有已训练模型的预训练低级层(如左侧所示)与模型末尾附近的分类层(如右侧所示,有时称为模型的分类头)分开,则可以使用低级层根据其训练所用的原始数据为任何给定图片生成输出特征。以下是移除了分类头的同一网络:

369a8a9041c6917d.png

假设您尝试识别的新事物也可以利用之前模型学到的此类输出特征,那么这些特征很有可能可以重新用于新用途。

在上图中,这个假设的模型是基于数字训练的,因此可能学到的数字知识也可以应用于字母(例如 a、b 和 c)。

因此,现在您可以添加一个新的分类头,尝试预测 a、b 或 c,如下所示:

db97e5e60ae73bbd.png

在此示例中,较低级别的层被冻结且未经过训练,只有新的分类头会自行更新,以便从左侧预训练的截断模型提供的特征中学习。

这种做法称为迁移学习,也是 Teachable Machine 在幕后所做的事情。

您还可以看到,由于只需在网络末端训练多层感知机,因此训练速度比从头开始训练整个网络快得多。

但如何获取模型的部分子部分呢?请前往下一部分了解详情。

4. TensorFlow Hub - 基础模型

找到合适的基础模型以供使用

对于 MobileNet 等更高级且更热门的研究模型,您可以前往 TensorFlow Hub,然后过滤出适合使用 MobileNet v3 架构的 TensorFlow.js 模型,以找到如下所示的结果:

c5dc1420c6238c14.png

请注意,其中一些结果属于“图片分类”类型(每张模型卡片结果的左上角会详细说明),另一些结果属于“图片特征向量”类型。

这些图片特征向量结果实际上是 MobileNet 的预切分版本,您可以使用它们来获取图片特征向量,而不是最终的分类结果。

此类模型通常称为“基础模型”,您可以像上一部分中所示那样,通过添加新的分类头并使用自己的数据对其进行训练,来使用这些模型执行迁移学习。

接下来要检查的是,对于给定的感兴趣的基础模型,该模型以何种 TensorFlow.js 格式发布。如果您打开其中一个特征向量 MobileNet v3 模型的网页,则可以从 JS 文档中看到,它采用的是基于文档中使用了 tf.loadGraphModel() 的示例代码段的图模型形式。

f97d903d2e46924b.png

另请注意,如果您找到的是层格式而非图格式的模型,您可以选择冻结哪些层以及解冻哪些层以进行训练。在为新任务创建模型时,这非常有用,此类模型通常称为“迁移模型”。不过,在本教程中,您将使用默认的图模型类型,大多数 TF Hub 模型都是以这种类型部署的。如需详细了解如何使用 Layers 模型,请参阅 TensorFlow.js 从入门到精通课程。

迁移学习的优势

与从头开始训练整个模型架构相比,使用迁移学习有哪些优势?

首先,训练时间是使用迁移学习方法的一项关键优势,因为您已经有一个经过训练的基础模型可以作为基础。

其次,由于已经进行了训练,因此您只需展示极少的要分类的新事物示例即可。

如果您在收集要分类的事物的示例数据方面的时间和资源有限,并且需要在收集更多训练数据以使其更可靠之前快速制作原型,那么这种方法非常有用。

由于需要的数据更少,并且训练较小网络的速度更快,因此迁移学习的资源密集度较低。这使得它非常适合浏览器环境,在现代机器上只需几十秒即可完成,而完整模型训练则需要数小时、数天或数周。

好的!现在,您已经了解迁移学习的本质,接下来可以创建自己的会学习的机器版本了。让我们开始吧!

5. 设置代码环境

所需条件

  • 新式网络浏览器。
  • 具备 HTML、CSS、JavaScript 和 Chrome 开发者工具(查看控制台输出)的基础知识。

开始编码

我们已为 Glitch.comCodepen.io 创建了可供您开始使用的样板模板。您只需点击一下即可克隆任一模板,作为本 Codelab 的基态。

在 Glitch 上,点击“remix this”按钮即可创建分支,并创建一组可修改的新文件。

或者,在 Codepen 上,点击屏幕右下角的“fork”

这个框架非常简单,为您提供了以下文件:

  • HTML 网页 (index.html)
  • 样式表 (style.css)
  • 用于编写 JavaScript 代码的文件 (script.js)

为方便起见,我们还在 HTML 文件中添加了对 TensorFlow.js 库的导入。如下所示:

index.html

<!-- Import TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js" type="text/javascript"></script>

替代方案:使用您偏好的网页编辑器或在本地工作

如果您想下载代码并在本地或在其他在线编辑器中操作,只需在同一目录中创建上述 3 个文件,然后将 Glitch 样板中的代码复制并粘贴到每个文件中即可。

6. 应用 HTML 样板

从何处着手?

所有原型都需要一些基本的 HTML 基架,您可以基于这些基架来呈现您的发现结果。立即对此进行设置。您需要添加以下内容:

  • 网页的标题。
  • 一些说明性文字。
  • 状态段落。
  • 一个视频,用于在准备就绪后显示摄像头画面。
  • 用于启动相机、收集数据或重置体验的多个按钮。
  • 导入 TensorFlow.js 和您稍后将编写的 JS 文件。

打开 index.html,将现有代码替换为以下代码,以设置上述功能:

index.html

<!DOCTYPE html>
<html lang="en">
  <head>
    <title>Transfer Learning - TensorFlow.js</title>
    <meta charset="utf-8">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <!-- Import the webpage's stylesheet -->
    <link rel="stylesheet" href="/style.css">
  </head>  
  <body>
    <h1>Make your own "Teachable Machine" using Transfer Learning with MobileNet v3 in TensorFlow.js using saved graph model from TFHub.</h1>
    
    <p id="status">Awaiting TF.js load</p>
    
    <video id="webcam" autoplay muted></video>
    
    <button id="enableCam">Enable Webcam</button>
    <button class="dataCollector" data-1hot="0" data-name="Class 1">Gather Class 1 Data</button>
    <button class="dataCollector" data-1hot="1" data-name="Class 2">Gather Class 2 Data</button>
    <button id="train">Train &amp; Predict!</button>
    <button id="reset">Reset</button>

    <!-- Import TensorFlow.js library -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.11.0/dist/tf.min.js" type="text/javascript"></script>

    <!-- Import the page's JavaScript to do some stuff -->
    <script type="module" src="/script.js"></script>
  </body>
</html>

分解

下面,我们来详细介绍上面的一些 HTML 代码,重点说明您添加的一些关键内容。

  • 您为网页标题添加了 <h1> 标记,并添加了 ID 为“status”的 <p> 标记,您将在此处打印信息,因为您会使用系统的不同部分来查看输出。
  • 您添加了一个 ID 为“webcam”的 <video> 元素,稍后您将在此元素中呈现摄像头直播。
  • 您添加了 5 个 <button> 元素。第一个(ID 为“enableCam”)用于启用摄像头。接下来的两个按钮的类为“dataCollector”,可用于收集您要识别的对象的示例图片。您稍后编写的代码将经过精心设计,以便您可以添加任意数量的此类按钮,并且这些按钮会自动按预期运行。

请注意,这些按钮还具有一个名为 data-1hot 的特殊用户定义属性,其整数值从第一个类的 0 开始。这是您将用于表示特定类别数据的数值索引。该索引将用于使用数值表示法(而非字符串)正确编码输出类,因为机器学习模型只能处理数字。

还有一个 data-name 属性,其中包含您要为此类使用的人类可读名称,这样您就可以向用户提供更有意义的名称,而不是 1-hot 编码中的数值索引值。

最后,您会看到一个“训练”按钮和一个“重置”按钮,前者用于在收集到数据后启动训练流程,后者用于重置应用。

  • 您还添加了 2 个 <script> 导入。一个用于 TensorFlow.js,另一个用于您稍后将定义的 script.js。

7. 添加样式

元素默认值

为您刚刚添加的 HTML 元素添加样式,以确保它们能够正确呈现。以下是一些用于正确定位和调整元素大小的样式。没什么特别的。您当然可以在稍后添加更多内容,以打造更出色的用户体验,就像您在 Teachable Machine 视频中看到的那样。

style.css

body {
  font-family: helvetica, arial, sans-serif;
  margin: 2em;
}

h1 {
  font-style: italic;
  color: #FF6F00;
}


video {
  clear: both;
  display: block;
  margin: 10px;
  background: #000000;
  width: 640px;
  height: 480px;
}

button {
  padding: 10px;
  float: left;
  margin: 5px 3px 5px 10px;
}

.removed {
  display: none;
}

#status {
  font-size:150%;
}

太棒了!这就是您所需的一切。如果您现在预览输出,应如下所示:

81909685d7566dcb.png

8. JavaScript:键常量和监听器

定义关键常量

首先,添加一些您将在整个应用中使用的关键常量。首先,将 script.js 的内容替换为以下常量:

script.js

const STATUS = document.getElementById('status');
const VIDEO = document.getElementById('webcam');
const ENABLE_CAM_BUTTON = document.getElementById('enableCam');
const RESET_BUTTON = document.getElementById('reset');
const TRAIN_BUTTON = document.getElementById('train');
const MOBILE_NET_INPUT_WIDTH = 224;
const MOBILE_NET_INPUT_HEIGHT = 224;
const STOP_DATA_GATHER = -1;
const CLASS_NAMES = [];

下面我们来了解一下这些参数的用途:

  • STATUS 只是保存对您将写入状态更新的段落标记的引用。
  • VIDEO 存储着对将呈现网络摄像头画面的 HTML 视频元素的引用。
  • ENABLE_CAM_BUTTONRESET_BUTTONTRAIN_BUTTON 从 HTML 网页中获取所有关键按钮的 DOM 引用。
  • MOBILE_NET_INPUT_WIDTHMOBILE_NET_INPUT_HEIGHT 分别定义了 MobileNet 模型的预期输入宽度和高度。通过将此信息存储在文件顶部附近的常量中,如果您日后决定使用其他版本,只需更新一次值,而不必在许多不同的位置进行替换。
  • STOP_DATA_GATHER 设置为 -1。此变量用于存储状态值,以便您了解用户何时停止点击按钮来从网络摄像头 Feed 中收集数据。通过为该数字指定一个更有意义的名称,可以提高代码的后续可读性。
  • CLASS_NAMES 用作查找表,用于存储可能类别预测的直观易懂的名称。此数组将在稍后填充。

好了,现在您已经获得了对关键元素的引用,接下来需要将一些事件监听器与这些元素相关联。

添加关键事件监听器

首先,为关键按钮添加点击事件处理脚本,如下所示:

script.js

ENABLE_CAM_BUTTON.addEventListener('click', enableCam);
TRAIN_BUTTON.addEventListener('click', trainAndPredict);
RESET_BUTTON.addEventListener('click', reset);


function enableCam() {
  // TODO: Fill this out later in the codelab!
}


function trainAndPredict() {
  // TODO: Fill this out later in the codelab!
}


function reset() {
  // TODO: Fill this out later in the codelab!
}

ENABLE_CAM_BUTTON - 点击时调用 enableCam 函数。

TRAIN_BUTTON - 点击时调用 trainAndPredict。

RESET_BUTTON - 点击时调用 reset。

最后,在此部分中,您可以使用 document.querySelectorAll() 找到所有具有“dataCollector”类的按钮。此方法会返回从文档中找到的匹配元素数组:

script.js

let dataCollectorButtons = document.querySelectorAll('button.dataCollector');
for (let i = 0; i < dataCollectorButtons.length; i++) {
  dataCollectorButtons[i].addEventListener('mousedown', gatherDataForClass);
  dataCollectorButtons[i].addEventListener('mouseup', gatherDataForClass);
  // Populate the human readable names for classes.
  CLASS_NAMES.push(dataCollectorButtons[i].getAttribute('data-name'));
}


function gatherDataForClass() {
  // TODO: Fill this out later in the codelab!
}

代码说明

然后,您遍历找到的按钮,并为每个按钮关联 2 个事件监听器。一个用于“mousedown”,一个用于“mouseup”。这样一来,您就可以在按住按钮时一直录制样本,这对于数据收集非常有用。

这两个事件都会调用您稍后将定义的 gatherDataForClass 函数。

此时,您还可以将从 HTML 按钮属性 data-name 中找到的直观易懂的类名称推送到 CLASS_NAMES 数组。

接下来,添加一些变量来存储稍后将使用的关键信息。

script.js

let mobilenet = undefined;
let gatherDataState = STOP_DATA_GATHER;
let videoPlaying = false;
let trainingDataInputs = [];
let trainingDataOutputs = [];
let examplesCount = [];
let predict = false;

下面我们来了解一下这些功能。

首先,您有一个变量 mobilenet 来存储已加载的 mobilenet 模型。最初将其设置为未定义。

接下来,您会看到一个名为 gatherDataState 的变量。如果按下“dataCollector”按钮,此值会更改为该按钮的 1 热 ID(如 HTML 中所定义),以便您了解当时正在收集哪类数据。最初,此变量设置为 STOP_DATA_GATHER,这样一来,您稍后编写的数据收集循环就不会在未按下任何按钮时收集任何数据。

videoPlaying 用于跟踪网络摄像头视频流是否已成功加载并播放,以及是否可供使用。最初,此值设置为 false,因为在您按下 ENABLE_CAM_BUTTON. 之前,网络摄像头处于关闭状态

接下来,定义两个数组 trainingDataInputstrainingDataOutputs。这些变量用于存储收集的训练数据值,因为您会点击“dataCollector”按钮,分别收集由 MobileNet 基本模型生成的输入特征和抽样的输出类。

然后,定义一个最终数组 examplesCount,,用于跟踪开始添加示例后每个类包含的示例数量。

最后,您有一个名为 predict 的变量,用于控制预测循环。初始设置为 false。在将此值设置为 true 之前,无法进行任何预测。

现在,所有关键变量都已定义完毕,接下来我们加载预先切分的 MobileNet v3 基本模型,该模型可提供图片特征向量,而不是分类。

9. 加载 MobileNet 基本模型

首先,定义一个名为 loadMobileNetFeatureModel 的新函数,如下所示。这必须是一个异步函数,因为加载模型的行为是异步的:

script.js

/**
 * Loads the MobileNet model and warms it up so ready for use.
 **/
async function loadMobileNetFeatureModel() {
  const URL = 
    'https://tfhub.dev/google/tfjs-model/imagenet/mobilenet_v3_small_100_224/feature_vector/5/default/1';
  
  mobilenet = await tf.loadGraphModel(URL, {fromTFHub: true});
  STATUS.innerText = 'MobileNet v3 loaded successfully!';
  
  // Warm up the model by passing zeros through it once.
  tf.tidy(function () {
    let answer = mobilenet.predict(tf.zeros([1, MOBILE_NET_INPUT_HEIGHT, MOBILE_NET_INPUT_WIDTH, 3]));
    console.log(answer.shape);
  });
}

// Call the function immediately to start loading.
loadMobileNetFeatureModel();

在此代码中,您定义了 URL,用于指定要加载的模型在 TFHub 文档中的位置。

然后,您可以使用 await tf.loadGraphModel() 加载模型,但请注意,由于您是从此 Google 网站加载模型,因此需要将特殊属性 fromTFHub 设置为 true。这是一种特殊情况,仅适用于使用托管在 TF Hub 上的模型,在这种情况下必须设置此额外属性。

加载完成后,您可以为 STATUS 元素的 innerText 设置一条消息,以便直观地看到该元素已正确加载,并且您已准备好开始收集数据。

现在,唯一要做的就是预热模型。对于此类大型模型,首次使用时可能需要片刻时间来设置所有内容。因此,通过模型传递零有助于避免将来在时间可能更关键的情况下出现任何等待。

您可以使用封装在 tf.tidy() 中的 tf.zeros() 来确保正确处置张量,并使用在开头常量中定义的批次大小 1 以及正确的高度和宽度。最后,您还需要指定颜色通道,在本例中为 3,因为模型需要 RGB 图片。

接下来,使用 answer.shape() 记录返回的张量的结果形状,以帮助您了解此模型生成的图片特征的大小。

定义此函数后,您可以立即调用它,以便在网页加载时启动模型下载。

如果您现在查看实时预览,过一会儿,您会看到状态文本从“等待加载 TF.js”变为“MobileNet v3 已成功加载!”(如下所示)。请确保此功能正常运行,然后再继续。

a28b734e190afff.png

您还可以查看控制台输出,了解此模型生成的输出特征的打印大小。在 MobileNet 模型中运行零后,您会看到打印出的形状为 [1, 1024]。第一个项的批次大小为 1,您可以看到它实际上返回了 1024 个特征,这些特征可用于帮助您对新对象进行分类。

10. 定义新模型头

现在,您可以定义模型头了,它本质上是一个非常简单的多层感知机。

script.js

let model = tf.sequential();
model.add(tf.layers.dense({inputShape: [1024], units: 128, activation: 'relu'}));
model.add(tf.layers.dense({units: CLASS_NAMES.length, activation: 'softmax'}));

model.summary();

// Compile the model with the defined optimizer and specify a loss function to use.
model.compile({
  // Adam changes the learning rate over time which is useful.
  optimizer: 'adam',
  // Use the correct loss function. If 2 classes of data, must use binaryCrossentropy.
  // Else categoricalCrossentropy is used if more than 2 classes.
  loss: (CLASS_NAMES.length === 2) ? 'binaryCrossentropy': 'categoricalCrossentropy', 
  // As this is a classification problem you can record accuracy in the logs too!
  metrics: ['accuracy']  
});

我们来了解一下此代码。首先,定义一个 tf.sequential 模型,然后向其中添加模型层。

接下来,添加一个密集层作为此模型的输入层。它的输入形状为 1024,因为 MobileNet v3 特征的输出就是这种大小。您在上一步中通过模型传递 1 后发现了这一点。此层有 128 个使用 ReLU 激活函数的神经元。

如果您不熟悉激活函数和模型层,不妨先学习本讲座开头详细介绍的课程,了解这些属性在幕后发挥的作用。

接下来要添加的是输出层。神经元数量应等于您要预测的类别数量。为此,您可以使用 CLASS_NAMES.length 查找您计划分类的类别数量,该数量等于界面中找到的数据收集按钮的数量。由于这是一个分类问题,因此您在此输出层上使用 softmax 激活函数,在尝试创建模型来解决分类问题(而非回归问题)时,必须使用此激活函数。

现在,打印 model.summary() 以将新定义模型的概览打印到控制台。

最后,编译模型,使其准备好接受训练。此处优化器设置为 adam,如果 CLASS_NAMES.length 等于 2,损失将为 binaryCrossentropy;如果有 3 个或更多个要分类的类别,则损失将为 categoricalCrossentropy。系统还会请求准确率指标,以便稍后在日志中监控这些指标,从而进行调试。

您应该会在控制台中看到类似以下的内容:

22eaf32286fea4bb.png

请注意,该模型有超过 13 万个可训练的参数。不过,由于这是一个由常规神经元组成的简单密集层,因此训练速度会非常快。

作为项目完成后要执行的活动,您可以尝试更改第一层的神经元数量,看看在仍能获得不错性能的情况下,可以将神经元数量降到多低。通常,在机器学习中,需要进行一定程度的试错,才能找到最佳参数值,从而在资源用量和速度之间实现最佳平衡。

11. 启用摄像头

现在,您可以完善之前定义的 enableCam() 函数了。添加一个名为 hasGetUserMedia() 的新函数(如下所示),然后将之前定义的 enableCam() 函数的内容替换为下面的相应代码。

script.js

function hasGetUserMedia() {
  return !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
}

function enableCam() {
  if (hasGetUserMedia()) {
    // getUsermedia parameters.
    const constraints = {
      video: true,
      width: 640, 
      height: 480 
    };

    // Activate the webcam stream.
    navigator.mediaDevices.getUserMedia(constraints).then(function(stream) {
      VIDEO.srcObject = stream;
      VIDEO.addEventListener('loadeddata', function() {
        videoPlaying = true;
        ENABLE_CAM_BUTTON.classList.add('removed');
      });
    });
  } else {
    console.warn('getUserMedia() is not supported by your browser');
  }
}

首先,创建一个名为 hasGetUserMedia() 的函数,通过检查关键浏览器 API 属性是否存在来检查浏览器是否支持 getUserMedia()

enableCam() 函数中,使用您刚刚在上面定义的 hasGetUserMedia() 函数来检查是否支持。如果不是,则向控制台输出警告。

如果支持,请为 getUserMedia() 调用定义一些限制条件,例如您只想要视频流,并且希望视频的 width640 像素,height480 像素。为什么呢?不过,获取大于此尺寸的视频意义不大,因为需要将视频调整为 224x224 像素才能将其输入到 MobileNet 模型中。您还可以通过请求较低的分辨率来节省一些计算资源。大多数相机都支持此分辨率。

接下来,使用上述 constraints 详细信息调用 navigator.mediaDevices.getUserMedia(),然后等待返回 stream。返回 stream 后,您可以获取 VIDEO 元素,并通过将其设置为 srcObject 值来播放 stream

您还应在 VIDEO 元素上添加一个 eventListener,以了解 stream 何时已加载并成功播放。

加载完 steam 后,您可以将 videoPlaying 设置为 true 并移除 ENABLE_CAM_BUTTON,方法是将 ENABLE_CAM_BUTTON 的类设置为“removed”,以防止用户再次点击它。

现在,运行您的代码,点击“启用摄像头”按钮,然后允许访问网络摄像头。如果您是首次执行此操作,则应看到自己渲染到网页上的视频元素中,如下所示:

b378eb1affa9b883.png

好了,现在该添加一个函数来处理 dataCollector 按钮点击事件了。

12. 数据收集按钮事件处理脚本

现在,您可以开始填写当前为空的函数 gatherDataForClass(). 了。这是您在 Codelab 开始时为 dataCollector 按钮分配的事件处理脚本。

script.js

/**
 * Handle Data Gather for button mouseup/mousedown.
 **/
function gatherDataForClass() {
  let classNumber = parseInt(this.getAttribute('data-1hot'));
  gatherDataState = (gatherDataState === STOP_DATA_GATHER) ? classNumber : STOP_DATA_GATHER;
  dataGatherLoop();
}

首先,通过调用 this.getAttribute() 并将属性名称(在本例中为 data-1hot)作为参数传递给该函数,检查当前点击的按钮上的 data-1hot 属性。由于这是一个字符串,因此您可以使用 parseInt() 将其强制转换为整数,并将此结果分配给名为 classNumber. 的变量

接下来,相应地设置 gatherDataState 变量。如果当前 gatherDataState 等于 STOP_DATA_GATHER(您将其设置为 -1),则表示您目前未收集任何数据,并且触发了 mousedown 事件。将 gatherDataState 设置为刚刚找到的 classNumber

否则,这意味着您目前正在收集数据,触发的事件是 mouseup 事件,而您现在想要停止收集该类的数据。只需将其重新设置为 STOP_DATA_GATHER 状态,即可结束您稍后将定义的数据收集循环。

最后,启动对 dataGatherLoop(), 的调用,该调用实际上会执行类数据的记录。

13. 数据收集

现在,定义 dataGatherLoop() 函数。此函数负责从网络摄像头视频中抽样图片,将其传递给 MobileNet 模型,并捕获该模型的输出(1024 个特征向量)。

然后,它会将这些数据与当前按下的按钮的 gatherDataState ID 一起存储,以便您了解这些数据代表哪个类。

我们详细了解一下:

script.js

function dataGatherLoop() {
  if (videoPlaying && gatherDataState !== STOP_DATA_GATHER) {
    let imageFeatures = tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [MOBILE_NET_INPUT_HEIGHT, 
          MOBILE_NET_INPUT_WIDTH], true);
      let normalizedTensorFrame = resizedTensorFrame.div(255);
      return mobilenet.predict(normalizedTensorFrame.expandDims()).squeeze();
    });

    trainingDataInputs.push(imageFeatures);
    trainingDataOutputs.push(gatherDataState);
    
    // Intialize array index element if currently undefined.
    if (examplesCount[gatherDataState] === undefined) {
      examplesCount[gatherDataState] = 0;
    }
    examplesCount[gatherDataState]++;

    STATUS.innerText = '';
    for (let n< = 0; n  CLASS_NAMES.length; n++) {
      STATUS.innerText += CLASS_NAMES[n] + ' data count: ' + examplesCount[n] + '. ';
    }
    window.requestAnimationFrame(dataGatherLoop);
  }
}

只有当 videoPlaying 为 true(表示摄像头处于活动状态)且 gatherDataState 不等于 STOP_DATA_GATHER 并且当前正在按下用于收集课堂数据的按钮时,您才会继续执行此函数。

接下来,将您的代码封装在 tf.tidy() 中,以处置后续代码中创建的所有张量。此 tf.tidy() 代码执行的结果存储在名为 imageFeatures 的变量中。

您现在可以使用 tf.browser.fromPixels() 抓取网络摄像头 VIDEO 的帧。包含图片数据的所得张量存储在名为 videoFrameAsTensor 的变量中。

接下来,调整 videoFrameAsTensor 变量的大小,使其成为适合 MobileNet 模型输入的正确形状。使用 tf.image.resizeBilinear() 调用,将要重塑的张量作为第一个参数,然后使用一个形状来定义新的高度和宽度,该形状由您之前创建的常量定义。最后,通过传递第三个参数将 align corners 设置为 true,以避免在调整大小时出现任何对齐问题。调整大小后的结果存储在名为 resizedTensorFrame 的变量中。

请注意,此原始调整大小操作会拉伸图片,因为您的摄像头图片大小为 640x480 像素,而模型需要 224x224 像素的方形图片。

对于本次演示,这应该没问题。不过,完成本 Codelab 后,您不妨尝试从这张图片中裁剪出一个正方形,以便在日后创建的任何生产系统中获得更好的效果。

接下来,对图片数据进行归一化处理。使用 tf.browser.frompixels() 时,图片数据始终在 0 到 255 的范围内,因此您只需将 resizedTensorFrame 除以 255,即可确保所有值都在 0 到 1 之间,而这正是 MobileNet 模型期望的输入。

最后,在代码的 tf.tidy() 部分,通过调用 mobilenet.predict() 将此归一化张量推送到已加载的模型中,并使用 expandDims()normalizedTensorFrame 的扩展版本传递给 mobilenet.predict(),使其成为大小为 1 的批次,因为模型需要一批输入进行处理。

结果返回后,您可以立即对该返回的结果调用 squeeze(),将其压缩回一维张量,然后返回该张量并将其分配给捕获 tf.tidy() 结果的 imageFeatures 变量。

现在,您已经获得了 MobileNet 模型的 imageFeatures,可以通过将这些数据推送到之前定义的 trainingDataInputs 数组来记录它们。

您还可以通过将当前 gatherDataState 推送到 trainingDataOutputs 数组来记录此输入所代表的内容。

请注意,在之前定义的 gatherDataForClass() 函数中点击按钮时,gatherDataState 变量已设置为您正在记录数据的当前类的数值 ID。

此时,您还可以增加特定类别的示例数量。为此,请先检查 examplesCount 数组中的索引是否已初始化。如果未定义,请将其设置为 0 以初始化给定类的数值 ID 的计数器,然后您可以针对当前 gatherDataState 递增 examplesCount

现在,更新网页上的 STATUS 元素文本,以显示捕获到的每个类的当前数量。为此,请遍历 CLASS_NAMES 数组,并输出人类可读的名称,同时输出 examplesCount 中相同索引处的数据计数。

最后,调用 window.requestAnimationFrame() 并将 dataGatherLoop 作为参数传递,以递归方式再次调用此函数。系统将继续从视频中抽样帧,直到检测到按钮的 mouseup,并将 gatherDataState 设置为 STOP_DATA_GATHER,,此时数据收集循环将结束。

如果您现在运行代码,应该能够点击“启用摄像头”按钮,等待网络摄像头加载,然后点击并按住每个数据收集按钮,以收集每个数据类别的示例。在这里,您可以看到我分别为手机和手收集数据。

541051644a45131f.gif

您应该会看到状态文本已更新,因为系统会将所有张量存储在内存中,如上面的屏幕截图所示。

14. 训练和预测

下一步是为当前为空的 trainAndPredict() 函数实现代码,这是进行迁移学习的地方。我们来看一下代码:

script.js

async function trainAndPredict() {
  predict = false;
  tf.util.shuffleCombo(trainingDataInputs, trainingDataOutputs);
  let outputsAsTensor = tf.tensor1d(trainingDataOutputs, 'int32');
  let oneHotOutputs = tf.oneHot(outputsAsTensor, CLASS_NAMES.length);
  let inputsAsTensor = tf.stack(trainingDataInputs);
  
  let results = await model.fit(inputsAsTensor, oneHotOutputs, {shuffle: true, batchSize: 5, epochs: 10, 
      callbacks: {onEpochEnd: logProgress} });
  
  outputsAsTensor.dispose();
  oneHotOutputs.dispose();
  inputsAsTensor.dispose();
  predict = true;
  predictLoop();
}

function logProgress(epoch, logs) {
  console.log('Data for epoch ' + epoch, logs);
}

首先,请将 predict 设置为 false,确保停止当前正在进行的所有预测。

接下来,使用 tf.util.shuffleCombo() 对输入和输出数组进行混排,以确保顺序不会导致训练出现问题。

将输出数组 trainingDataOutputs, 转换为 int32 类型的 tensor1d,以便在 one hot 编码中使用。此值存储在名为 outputsAsTensor 的变量中。

tf.oneHot() 函数与此 outputsAsTensor 变量以及要编码的最大类别数(即 CLASS_NAMES.length)一起使用。经过 one-hot 编码的输出现在存储在一个名为 oneHotOutputs 的新张量中。

请注意,目前 trainingDataInputs 是记录的张量数组。为了将这些张量用于训练,您需要将张量数组转换为常规的二维张量。

为此,TensorFlow.js 库中提供了一个名为 tf.stack() 的出色函数,

该函数接受一个张量数组,并将它们堆叠在一起,以生成更高维度的张量作为输出。在这种情况下,系统会返回一个 2D 张量,这是一个包含所记录特征的 1 维输入批次,每个输入的长度为 1024,这正是训练所需的。

接下来,await model.fit() 训练自定义模型头。在此处,您需要传递 inputsAsTensor 变量以及 oneHotOutputs,分别表示要用作示例输入和目标输出的训练数据。在第 3 个参数的配置对象中,将 shuffle 设置为 true,使用 5batchSize,并将 epochs 设置为 10,然后为 onEpochEnd 指定一个 callback 给您稍后将定义的 logProgress 函数。

最后,由于模型现在已训练完毕,您可以处置创建的张量。然后,您可以将 predict 重新设置为 true,以允许再次进行预测,然后调用 predictLoop() 函数开始预测实时摄像头图像。

您还可以定义 logProcess() 函数来记录训练状态,该函数在上面的 model.fit() 中使用,并在每轮训练后将结果输出到控制台。

即将大功告成!现在可以添加 predictLoop() 函数来进行预测了。

核心预测循环

在此处,您将实现主预测循环,该循环会从摄像头中抽样帧,并持续预测每个帧中的内容,同时在浏览器中实时显示结果。

我们来检查一下代码:

script.js

function predictLoop() {
  if (predict) {
    tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO).div(255);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor,[MOBILE_NET_INPUT_HEIGHT, 
          MOBILE_NET_INPUT_WIDTH], true);

      let imageFeatures = mobilenet.predict(resizedTensorFrame.expandDims());
      let prediction = model.predict(imageFeatures).squeeze();
      let highestIndex = prediction.argMax().arraySync();
      let predictionArray = prediction.arraySync();

      STATUS.innerText = 'Prediction: ' + CLASS_NAMES[highestIndex] + ' with ' + Math.floor(predictionArray[highestIndex] * 100) + '% confidence';
    });

    window.requestAnimationFrame(predictLoop);
  }
}

首先,检查 predict 是否为 true,以便仅在模型经过训练并可供使用后才进行预测。

接下来,您可以像在 dataGatherLoop() 函数中一样,获取当前图片的图片特征。从本质上讲,您可以使用 tf.browser.from pixels() 从网络摄像头中抓取一帧,对其进行归一化处理,将其大小调整为 224x224 像素,然后将该数据传递给 MobileNet 模型,以获得生成的图片特征。

不过,现在您可以使用新训练的模型头,通过训练后模型的 predict() 函数传递刚刚找到的 imageFeatures 结果,来实际执行预测。然后,您可以压缩生成的张量,使其再次变为 1 维,并将其分配给名为 prediction 的变量。

借助此 prediction,您可以使用 argMax() 找到具有最高值的索引,然后使用 arraySync() 将生成的张量转换为数组,以获取 JavaScript 中的底层数据,从而发现具有最高值的元素的位置。此值存储在名为 highestIndex 的变量中。

您还可以通过直接对 prediction 张量调用 arraySync(),以相同的方式获取实际的预测置信度得分。

现在,您已拥有使用 prediction 数据更新 STATUS 文本所需的一切。如需获取类的直观易懂的字符串,您只需在 CLASS_NAMES 数组中查找 highestIndex,然后从 predictionArray 中获取置信度值。为了让结果更易于理解(以百分比形式显示),只需将结果乘以 100 并math.floor()

最后,您可以在准备就绪后再次使用 window.requestAnimationFrame() 调用 predictionLoop(),以获取视频流的实时分类结果。如果您选择使用新数据训练新模型,此过程会一直持续到 predict 设置为 false

这样,您就完成了拼图的最后一块。实现重置按钮。

15. 实现重置按钮

即将完成!最后一块拼图是实现一个重置按钮,以便重新开始。您目前为空的 reset() 函数的代码如下所示。请继续按如下方式更新:

script.js

/**
 * Purge data and start over. Note this does not dispose of the loaded 
 * MobileNet model and MLP head tensors as you will need to reuse 
 * them to train a new model.
 **/
function reset() {
  predict = false;
  examplesCount.length = 0;
  for (let i = 0; i < trainingDataInputs.length; i++) {
    trainingDataInputs[i].dispose();
  }
  trainingDataInputs.length = 0;
  trainingDataOutputs.length = 0;
  STATUS.innerText = 'No data collected';
  
  console.log('Tensors in memory: ' + tf.memory().numTensors);
}

首先,将 predict 设置为 false,以停止所有正在运行的预测循环。接下来,将 examplesCount 数组的长度设置为 0,以删除其中的所有内容。这是一种从数组中清除所有内容的便捷方法。

现在,遍历所有当前记录的 trainingDataInputs,并确保您对其中包含的每个张量进行 dispose(),以再次释放内存,因为 JavaScript 垃圾回收器不会清理张量。

完成此操作后,您现在可以安全地将 trainingDataInputstrainingDataOutputs 数组的数组长度设置为 0,以清除这些数组。

最后,将 STATUS 文本设置为有意义的内容,并输出内存中剩余的张量以进行健全性检查。

请注意,由于 MobileNet 模型和您定义的多层感知器均未处置,因此内存中仍会存在数百个张量。如果您决定在重置后重新训练,则需要将这些模型与新的训练数据一起重新使用。

16. 我们来试用一下

现在,您可以测试自己的 Teachable Machine 版本了!

前往实时预览,启用网络摄像头,为房间中的某个对象收集至少 30 个类别 1 的样本,然后为另一个对象收集相同数量的类别 2 的样本,点击“训练”,然后查看控制台日志以了解进度。训练速度应该很快:

bf1ac3cc5b15740.gif

训练完成后,将物体对准摄像头,即可获得实时预测结果,这些结果将显示在网页顶部附近的状态文本区域中。如果您遇到问题,请查看我已完成的有效代码,看看您是否遗漏了任何复制内容。

17. 恭喜

恭喜!您刚刚在浏览器中实时完成了第一个使用 TensorFlow.js 的迁移学习示例。

试一试各种各样的对象,您可能会注意到,有些对象比其他对象更难识别,尤其是当它们与某些其他对象相似时。您可能需要添加更多类别或训练数据,才能区分它们。

回顾

在此 Codelab 中,您学习了以下内容:

  1. 什么是迁移学习,以及它相对于训练完整模型的优势。
  2. 如何从 TensorFlow Hub 获取可重复使用的模型。
  3. 如何设置适合迁移学习的 Web 应用。
  4. 如何加载和使用基础模型来生成图片特征。
  5. 如何训练新的预测头,使其能够识别网络摄像头图像中的自定义对象。
  6. 如何使用生成的模型实时对数据进行分类。

后续操作

您已经有了一个基准工作版本,那么,您可以想出哪些创意来扩展此机器学习模型的样板,以处理现实中的用例?也许您可以彻底改变您目前所在的行业,帮助您公司的员工训练模型来对日常工作中重要的事物进行分类?这里将拥有无限可能

如需进一步了解相关知识,不妨免费学习这门完整课程,该课程将向您展示如何将您目前在本 Codelab 中使用的 2 个模型合并为一个模型,以提高效率。

如果您想详细了解原始的可学习的机器应用背后的理论,请查看此教程

与我们分享您的成果

您也可以轻松将今天的成果扩展到其他创意用例中,建议您跳出思维定式,做到持续改进。

请记得使用 #MadeWithTFJS 标签在社交媒体上标记我们,这样就有机会在我们的 TensorFlow 博客甚至未来的活动中展示您的项目。我们很期待看到您的成果。

参考网站