1. 准备工作
在过去的几年里,TensorFlow.js 模型的使用率呈指数级增长,许多 JavaScript 开发者现在希望利用现有的先进模型,重新训练它们处理其行业独有的自定义数据。采用现有模型(通常称为“基本模型”)并将其用于相似但不同的领域这一行为称为“迁移学习”。
与从完全空白的模型开始学习相比,迁移学习具有诸多优势。您可以重复使用从先前训练的模型中学到的知识,并且需要较少的待分类新项示例。此外,由于只需重新训练模型架构的最后几层而不是整个网络,训练速度通常要快得多。因此,迁移学习非常适合网络浏览器环境,在这种环境中,资源可能会因执行设备而异,但也可以直接访问传感器,以便轻松获取数据。
此 Codelab 将向您展示如何在空白画布上构建 Web 应用,重现 Google 热门“Teachable Machine”网站。通过该网站,您可以创建一个可正常运行的 Web 应用,任何用户只需使用来自他们的网络摄像头的一些示例图片,即可使用该应用来识别自定义对象。我们特意将网站设计得非常精简,以便您专注于此 Codelab 的机器学习方面的内容。不过,与原来的 Teachable Machine 网站一样,利用您现有的 Web 开发者经验来改善用户体验,仍有很大的余地。
前提条件
此 Codelab 面向以下 Web 开发者:在一定程度上熟悉 TensorFlow.js 预制模型和基本 API 用法,并希望开始在 TensorFlow.js 中开始使用迁移学习的 Web 开发者。
- 本实验假设您对 TensorFlow.js、HTML5、CSS 和 JavaScript 有基本的了解。
如果您刚开始接触 Tensorflow.js,不妨先学习这个免费的零入门课程,该课程假定您没有机器学习或 TensorFlow.js 的背景,会以较小的步骤教您您需要了解的所有内容。
学习内容
- 什么是 TensorFlow.js,以及为什么您应该在下一个 Web 应用中使用 TensorFlow.js。
- 如何构建简化的 HTML/CSS /JS 网页,以重现 Teachable Machine 用户体验。
- 如何使用 TensorFlow.js 加载预训练的基本模型(特别是 MobileNet),以生成可用于迁移学习的图像特征。
- 如何通过用户的摄像头收集您想要识别的多类数据的数据。
- 如何创建和定义多层感知器,以接受图像特征并学习如何使用特征对新对象进行分类。
开始入侵...
所需条件
- 您最好使用 Glitch.com 账号,或者使用您方便自行修改和运行的 Web 服务环境。
2. 什么是 TensorFlow.js?
TensorFlow.js 是一个开源机器学习库,可在任何 JavaScript 环境下运行。它基于以 Python 编写的原始 TensorFlow 库,旨在为 JavaScript 生态系统重新创建这种开发者体验和一组 API。
它的使用范围如何?
鉴于 JavaScript 的可移植性,您现在可以使用 1 种语言编写,并轻松地在以下所有平台上执行机器学习:
- 网络浏览器中的客户端(使用原始 JavaScript)
- 服务器端,甚至是 Raspberry Pi 等 IoT 设备,使用 Node.js
- 使用 Electron 的桌面应用
- 使用 React Native 的原生移动应用
TensorFlow.js 还支持在这些环境(例如,它可以在 CPU 或 WebGL 等实际基于硬件的环境)中执行多个后端。“后端”这里的环境并不意味着服务器端环境(例如,在 WebGL 中用于执行的后端可以是客户端),以便确保兼容性并保持快速运行。TensorFlow.js 目前支持:
- 在设备的显卡 (GPU) 上执行 WebGL - 这是通过 GPU 加速执行较大模型(大小超过 3MB)的最快方式。
- 在 CPU 上执行 Web Assembly (WASM) - 提高各种设备的 CPU 性能,包括老一代手机。这更适合较小的模型(小于 3MB),因为将内容上传到图形处理器的开销实际上比使用 WebGL 时在 CPU 上执行得更快。
- CPU 执行 - 如果其他环境均不可用,则备用环境。这是三个速度中速度最慢的,但是您可以一直保持下去。
注意:如果您知道要在哪种设备上执行,则可以选择强制使用其中一个后端;如果您不指定,也可以直接让 TensorFlow.js 为您决定。
客户端超能力
在客户端计算机的网络浏览器中运行 TensorFlow.js 可带来诸多好处,值得考虑。
隐私权
您可以在客户端计算机上训练和对数据进行分类,而无需将数据发送到第三方网络服务器。有时,您可能需要遵守当地法律(例如 GDPR),或者在处理用户可能希望保留在其计算机上而不发送给第三方的任何数据时。
速度
由于您不必将数据发送到远程服务器,因此推断(数据分类)可以更快完成。更棒的是,您还可以直接访问设备的传感器,例如摄像头、麦克风、GPS、加速度计等(如果用户授予您访问权限)。
覆盖面和规模
世界上的任何人都可以点击您发送的链接,在浏览器中打开网页,然后利用您的成果。无需使用 CUDA 驱动程序进行复杂的服务器端 Linux 设置,只需使用机器学习系统即可。
费用
没有服务器意味着您只需为托管 HTML、CSS、JS 和模型文件的 CDN 付费。CDN 的成本比保持服务器(可能连接了显卡)全天候运行更便宜。
服务器端功能
利用 TensorFlow.js 的 Node.js 实现可启用以下功能。
全面支持 CUDA
在服务器端,对于显卡加速,您必须安装 NVIDIA CUDA 驱动程序,以使 TensorFlow 能够与显卡搭配使用(这与使用 WebGL 的浏览器不同,无需安装)。但是,有了全面的 CUDA 支持,您可以充分利用显卡较低级别的功能,从而缩短训练和推理时间。性能与 Python TensorFlow 实现相当,因为它们共享同一个 C++ 后端。
模型大小
对于研究中的尖端模型,您使用的模型可能非常大,也可能高达 GB。由于每个浏览器标签页的内存用量限制,这些模型目前无法在网络浏览器中运行。要运行这些较大的模型,您可以在自己的服务器上使用 Node.js,并满足高效运行此类模型所需的硬件规格。
IOT
Raspberry Pi 等热门单板计算机支持 Node.js,这也意味着您也可以在此类设备上执行 TensorFlow.js 模型。
速度
Node.js 是使用 JavaScript 编写的,这意味着它受益于即时编译。这意味着,在使用 Node.js 时,您通常可能会看到性能提升,因为它将在运行时进行优化,尤其是对于您可能正在进行的任何预处理。此案例研究展示了 Hugging Face 如何使用 Node.js 将自然语言处理模型的性能提升至原来的 2 倍。
现在,您已了解 TensorFlow.js 的基础知识、其运行位置以及一些优势,下面我们开始使用 TensorFlow.js 做一些有用的事吧!
3. 迁移学习
迁移学习究竟是什么?
迁移学习涉及利用已学到的知识,帮助学习不同但相似的东西。
我们人类一直都在这样做。您的大脑中蕴含了终生难忘的经历,您可以利用这些经历来帮助识别以前从未见过的新事物。以下面的柳树为例:
根据您在世界上的哪个地方,您以前可能没见过这种树。
但是,如果我请您告诉我以下新图片中是否有柳树,那么您可能很快就能发现它们,即使它们的角度与我向您展示的原始图片略有不同。
你的大脑中已经有许多神经元知道如何识别树状物体,以及其他擅长寻找长直线的神经元。您可以重复使用这些知识来快速对柳树进行分类,柳树是一种具有许多长直垂直分支的树状对象。
同样,如果您的机器学习模型已在某个领域中训练(例如识别图片),则可以重复使用该模型执行其他但相关的任务。
您可以使用 MobileNet 这样的高级模型执行相同的操作,这是一个非常受欢迎的研究模型,可以对 1000 种不同的对象类型执行图像识别。从狗到汽车,它都基于一个名为 ImageNet 的大型数据集进行训练,该数据集包含数百万张已加标签的图片。
在此动画中,您可以看到它在此 MobileNet V1 模型中的大量层:
在训练期间,该模型学会了如何提取对这 1,000 个对象很重要的共同特征,并且它用于识别此类对象的许多较低级别的特征在检测之前从未见过的新对象方面也很有用。毕竟,一切最终只是线条、纹理和形状的组合。
我们来看一个传统的卷积神经网络 (CNN) 架构(类似于 MobileNet),并了解迁移学习如何利用以下经过训练的网络学习新知识。下图显示了 CNN 的典型模型架构,在本例中,该模型架构经过训练,能够识别 0 到 9 的手写数字:
如果您可以将现有训练模型的预训练较低级别层(如左图所示)与右侧所示模型末尾附近的分类层(有时称为模型的分类头)分开,则可以使用较低级别层,根据训练所依据的原始数据为任何指定图片生成输出特征。以下是移除了分类头的同一网络:
假设您尝试识别的新事物也可以利用之前模型学到的此类输出特征,那么很有可能将它们重复用于新目的。
在上图中,这个假设模型是用数字进行训练的,所以我们学到的关于数字的知识可能也适用于字母 a、b 和 c。
现在,您可以添加一个尝试预测 a、b 或 c 的新分类标头,如下所示:
在这里,较低级别的层被冻结且不进行训练,只有新的分类头会自行更新,以从左侧预训练的截断模型提供的特征中学习。
这种行为称为迁移学习,是 Teachable Machine 在后台执行的操作。
您还可以看到,由于只需在网络的最后训练多层感知机,它的训练速度比从头开始训练整个网络快得多。
但是,如何掌握模型的子部分呢?请参阅下一部分了解详情。
4. TensorFlow Hub - 基础模型
查找要使用的合适基本模型
如需像 MobileNet 这样更高级和受欢迎的研究模型,您可以访问 TensorFlow Hub,然后过滤出适合 TensorFlow.js 且使用 MobileNet v3 架构的模型,以找到如下所示的结果:
请注意,其中一些结果属于“图片分类”类型(详细说明见每个模型卡片结果的左上角),其他模型则属于“图片特征向量”。
这些图像特征向量结果本质上是预切掉的 MobileNet 版本,可用于获取图像特征向量而不是最终分类。
此类模型通常称为“基本模型”,然后,你可以使用它添加新的分类头并使用自己的数据训练分类头,采用与上一部分相同的方式执行迁移学习。
接下来要检查的是给定的基本模型,即该模型发布时采用的 TensorFlow.js 格式。如果您打开其中一个特征向量 MobileNet v3 模型的页面,则可以从 JS 文档中看到,它采用图表模型的形式,基于文档中使用 tf.loadGraphModel()
的示例代码段。
还需要注意的是,如果您发现模型采用的是层格式而不是图表格式,则可以选择要冻结哪些层,以及将哪些层取消冻结以进行训练。这在为新任务创建模型(通常称为“转移模型”)时非常有用。不过,目前您将使用本教程的默认图模型类型,大多数 TF Hub 模型在部署时就是采用此类型。如需详细了解如何使用 Layers 模型,请查看从新手到高手的 TensorFlow.js 课程。
迁移学习的优势
相比从头开始训练整个模型架构,使用迁移学习有哪些优势?
首先,训练时间是使用迁移学习方法的关键优势,因为您已经拥有经过训练的基础模型。
其次,由于已经进行过训练,您可以减少显示尝试分类的新事物的样本。
如果您的时间和资源有限,无法收集待分类对象的示例数据,并且需要在收集更多训练数据之前快速构建原型以提高其稳健性,那么这样做真的非常棒。
鉴于需要更少的数据和训练较小网络的速度,迁移学习的资源密集型更少。这使得它非常适合浏览器环境,在现代机器上只需几十秒,而不是数小时、数天或数周进行完整的模型训练。
好的!现在,您已了解迁移学习的精髓,是时候创建您自己的 Teachable Machine 版本了。让我们开始吧!
5. 开始编码
所需条件
- 一款现代网络浏览器。
- 具备 HTML、CSS、JavaScript 和 Chrome 开发者工具的基础知识(查看控制台输出)。
开始编码
我们已为 Glitch.com 或 Codepen.io 创建了样板模板,以供您参考。您只需点击一下,即可将任一模板克隆为此 Codelab 的基础状态。
在 Glitch 上,点击 remix this 按钮,将该文件创建分支,然后创建一组可供编辑的新文件。
或者,在 Codepen 上,点击fork"。
这个非常简单的框架可为您提供以下文件:
- HTML 网页 (index.html)
- 样式表 (style.css)
- 用于编写 JavaScript 代码 (script.js) 的文件
为方便起见,我们在 HTML 文件中为 TensorFlow.js 库添加了一项导入。如下所示:
index.html
<!-- Import TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js" type="text/javascript"></script>
替代方案:使用您偏好的网页编辑器或在本地处理文件
如果您想下载代码并在本地运行,或使用其他在线编辑器运行,只需在同一目录中创建上面命名的 3 个文件,然后将 Glitch 样板中的代码复制并粘贴到每个文件中即可。
6. 应用 HTML 样板
如何入手?
所有原型都需要一些基本的 HTML 基架,您可在上面呈现自己的发现结果。请立即设置。您将添加:
- 网页的标题。
- 一些描述性文字。
- 状态段落。
- 视频就绪后,用于保存摄像头画面的视频。
- 多个用于启动相机、收集数据或重置体验的按钮。
- 适用于 TensorFlow.js 和 JS 文件的导入,您稍后将编写代码。
打开 index.html
,然后粘贴包含以下内容的现有代码以设置上述功能:
index.html
<!DOCTYPE html>
<html lang="en">
<head>
<title>Transfer Learning - TensorFlow.js</title>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1">
<!-- Import the webpage's stylesheet -->
<link rel="stylesheet" href="/style.css">
</head>
<body>
<h1>Make your own "Teachable Machine" using Transfer Learning with MobileNet v3 in TensorFlow.js using saved graph model from TFHub.</h1>
<p id="status">Awaiting TF.js load</p>
<video id="webcam" autoplay muted></video>
<button id="enableCam">Enable Webcam</button>
<button class="dataCollector" data-1hot="0" data-name="Class 1">Gather Class 1 Data</button>
<button class="dataCollector" data-1hot="1" data-name="Class 2">Gather Class 2 Data</button>
<button id="train">Train & Predict!</button>
<button id="reset">Reset</button>
<!-- Import TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.11.0/dist/tf.min.js" type="text/javascript"></script>
<!-- Import the page's JavaScript to do some stuff -->
<script type="module" src="/script.js"></script>
</body>
</html>
分解
让我们对上述一些 HTML 代码进行分解,突出显示您添加的一些关键内容。
- 您为网页标题添加了
<h1>
标记以及 ID 为“status”的<p>
标记。您可以使用系统的不同部分来查看输出,将信息输出到该位置。 - 您添加了一个 ID 为“摄像头”的
<video>
元素,你稍后将在其中呈现你的摄像头视频流 - 您添加了 5 个
<button>
元素。第一个是“enableCam”ID启用相机。接下来的两个按钮具有“dataCollector”类这让您可以收集想要识别的对象的示例图像。您稍后编写的代码会经过精心设计,让您可以添加任意数量的此类按钮,并且它们会自动按预期运行。
请注意,这些按钮还有一个名为 data-1hot 的特殊用户定义属性,其第一个类的整数值从 0 开始。这是您将用来表示特定类的数据的数字索引。索引将用于使用数字表示法(而非字符串)对输出类进行正确编码,因为机器学习模型只能处理数字。
此外,还有一个 data-name 属性,其中包含您要用于该类的人类可读名称,可让您为用户提供更有意义的名称,而不是 1 热编码中的数字索引值。
最后,您会获得一个训练按钮和重置按钮,用于在收集数据后启动训练过程,或分别重置应用。
- 您还添加了 2 个
<script>
导入项。一个用于 TensorFlow.js,另一个用于脚本.js(稍后会定义)。
7. 添加样式
元素默认值
为刚刚添加的 HTML 元素添加样式,以确保其正确呈现。以下是正确添加到位置和大小元素的一些样式。没什么特别的招数。您稍后当然可以对此进行补充,以提供更好的用户体验,就像您在教学机器视频中看到的那样。
style.css
body {
font-family: helvetica, arial, sans-serif;
margin: 2em;
}
h1 {
font-style: italic;
color: #FF6F00;
}
video {
clear: both;
display: block;
margin: 10px;
background: #000000;
width: 640px;
height: 480px;
}
button {
padding: 10px;
float: left;
margin: 5px 3px 5px 10px;
}
.removed {
display: none;
}
#status {
font-size:150%;
}
太棒了!这就是您所需的一切。如果您现在预览输出,它应如下所示:
8. JavaScript:关键常量和监听器
定义键常量
首先,添加您将在整个应用中使用的一些关键常量。首先,将 script.js
的内容替换为以下常量:
script.js
const STATUS = document.getElementById('status');
const VIDEO = document.getElementById('webcam');
const ENABLE_CAM_BUTTON = document.getElementById('enableCam');
const RESET_BUTTON = document.getElementById('reset');
const TRAIN_BUTTON = document.getElementById('train');
const MOBILE_NET_INPUT_WIDTH = 224;
const MOBILE_NET_INPUT_HEIGHT = 224;
const STOP_DATA_GATHER = -1;
const CLASS_NAMES = [];
我们来详细了解一下它们的用途:
STATUS
只是保存对您要向其写入状态更新的段落标记的引用。VIDEO
存储对将呈现摄像头画面的 HTML 视频元素的引用。ENABLE_CAM_BUTTON
、RESET_BUTTON
和TRAIN_BUTTON
用于从 HTML 页面中获取所有按键按钮的 DOM 引用。MOBILE_NET_INPUT_WIDTH
和MOBILE_NET_INPUT_HEIGHT
分别定义了 MobileNet 模型的预期输入宽度和高度。将此值存储在靠近文件顶部的常量中,如果您稍后决定使用其他版本,就可以更轻松地更新一次值,而不必在许多不同的地方进行替换。STOP_DATA_GATHER
设置为 - 1。这会存储一个状态值,以便您了解用户何时停止点击用于从摄像头 Feed 收集数据的按钮。为此编号指定更有意义的名称,可以提高代码的可读性。CLASS_NAMES
用作查找,并保存可能的类别预测的人类可读名称。稍后将填充此数组。
现在,您已经有了对关键元素的引用,是时候将一些事件监听器与这些引用相关联了。
添加关键事件监听器
首先,向按键按钮添加点击事件处理脚本,如下所示:
script.js
ENABLE_CAM_BUTTON.addEventListener('click', enableCam);
TRAIN_BUTTON.addEventListener('click', trainAndPredict);
RESET_BUTTON.addEventListener('click', reset);
function enableCam() {
// TODO: Fill this out later in the codelab!
}
function trainAndPredict() {
// TODO: Fill this out later in the codelab!
}
function reset() {
// TODO: Fill this out later in the codelab!
}
ENABLE_CAM_BUTTON
- 点击后会调用 enableCam 函数。
TRAIN_BUTTON
- 点击后会调用 trainAndPredict。
RESET_BUTTON
- 点击后通话会被重置。
最后,在此部分中,您可以找到具有“dataCollector”类的所有按钮使用 document.querySelectorAll()
。这将返回一个从文档中找到的匹配的元素数组:
script.js
let dataCollectorButtons = document.querySelectorAll('button.dataCollector');
for (let i = 0; i < dataCollectorButtons.length; i++) {
dataCollectorButtons[i].addEventListener('mousedown', gatherDataForClass);
dataCollectorButtons[i].addEventListener('mouseup', gatherDataForClass);
// Populate the human readable names for classes.
CLASS_NAMES.push(dataCollectorButtons[i].getAttribute('data-name'));
}
function gatherDataForClass() {
// TODO: Fill this out later in the codelab!
}
代码说明:
然后,您可以遍历找到的按钮,并将 2 个事件监听器与每个按钮相关联。一个用于“mousedown”,一个用于“mouseup”。这样一来,只要按下按钮,您就可以一直录制样本,这对于数据收集非常有用。
这两个事件都会调用您稍后定义的 gatherDataForClass
函数。
此时,您还可以将找到的人类可读类名称从 HTML 按钮属性 data-name 推送到 CLASS_NAMES
数组。
接下来,添加一些变量,用于存储稍后将会用到的关键内容。
script.js
let mobilenet = undefined;
let gatherDataState = STOP_DATA_GATHER;
let videoPlaying = false;
let trainingDataInputs = [];
let trainingDataOutputs = [];
let examplesCount = [];
let predict = false;
下面我们将逐一进行介绍。
首先,您需要一个变量 mobilenet
来存储加载的 Mobilenet 模型。最初将其设为“undefined”。
接下来,您有一个名为 gatherDataState
的变量。如果“dataCollector”按钮时,该 ID 会变为该按钮的常用 ID(如 HTML 中所定义),这样您就知道此时您正在收集的数据类。最初,此属性设置为 STOP_DATA_GATHER
,以便您稍后编写的数据收集循环不会在用户未按下任何按钮时收集任何数据。
videoPlaying
会跟踪摄像头直播是否已成功加载、播放以及是否可供使用。初始状态下,此政策设为 false
,因为摄像头在您按 ENABLE_CAM_BUTTON.
之前不会开启
接下来,定义 2 个数组:trainingDataInputs
和 trainingDataOutputs
。当您点击“dataCollector”时按钮,分别对输出类进行采样。
然后定义最后一个数组 examplesCount,
,以跟踪您开始添加每个类后包含的样本数。
最后,您有一个名为 predict
的变量来控制预测循环。此属性最初设置为 false
。除非稍后将此值设置为 true
,否则无法进行预测。
现在,所有关键变量均已定义,接下来我们加载预先切好的 MobileNet v3 基本模型,该模型提供图像特征向量,而不是分类。
9. 加载 MobileNet 基本模型
首先,定义一个名为 loadMobileNetFeatureModel
的新函数,如下所示。这必须是一个异步函数,因为加载模型的操作是异步的:
script.js
/**
* Loads the MobileNet model and warms it up so ready for use.
**/
async function loadMobileNetFeatureModel() {
const URL =
'https://tfhub.dev/google/tfjs-model/imagenet/mobilenet_v3_small_100_224/feature_vector/5/default/1';
mobilenet = await tf.loadGraphModel(URL, {fromTFHub: true});
STATUS.innerText = 'MobileNet v3 loaded successfully!';
// Warm up the model by passing zeros through it once.
tf.tidy(function () {
let answer = mobilenet.predict(tf.zeros([1, MOBILE_NET_INPUT_HEIGHT, MOBILE_NET_INPUT_WIDTH, 3]));
console.log(answer.shape);
});
}
// Call the function immediately to start loading.
loadMobileNetFeatureModel();
在此代码中,您将定义 URL
,其中要加载的模型位于 TFHub 文档中。
然后,您可以使用 await tf.loadGraphModel()
加载模型。从这个 Google 网站加载模型时,请记得将特殊属性 fromTFHub
设置为 true
。这是一种特殊情况,仅适用于使用托管在 TF Hub 上的模型,在这种情况下,必须设置此额外属性。
加载完成后,您可以使用消息设置 STATUS
元素的 innerText
,以便直观地看到它已正确加载,并可以开始收集数据。
现在唯一要做的就是预热模型。对于像这样的较大模型,首次使用时,可能需要一些时间完成所有设置。因此,它有助于在模型中传递零,以避免将来时间可能更为重要的等待。
您可以使用封装在 tf.tidy()
中的 tf.zeros()
来确保正确处置张量,批次大小为 1,且高度和宽度是您一开始就在常量中定义的正确宽度和宽度。最后,您还需要指定颜色通道,在本例中为 3,因为模型需要 RGB 图片。
接下来,记录使用 answer.shape()
返回的张量的生成形状,以帮助您了解此模型生成的图像特征的大小。
定义此函数后,您可以立即调用该函数,以在页面加载时启动模型下载。
如果您现在查看实时预览,片刻之后,您会看到状态文本从“正在等待 TF.js 加载”变为如下所示。请先确保此操作可正常运行,然后再继续。
您还可以查看控制台输出,了解此模型生成的输出特征的打印大小。通过 MobileNet 模型运行零之后,您会看到输出的形状为 [1, 1024]
。第一项的批次大小为 1,您可以看到它实际返回了 1024 个特征,这些特征随后可用于帮助您对新对象进行分类。
10. 定义新模型头部
现在,需要定义模型头部,它本质上是一个非常简约的多层感知机。
script.js
let model = tf.sequential();
model.add(tf.layers.dense({inputShape: [1024], units: 128, activation: 'relu'}));
model.add(tf.layers.dense({units: CLASS_NAMES.length, activation: 'softmax'}));
model.summary();
// Compile the model with the defined optimizer and specify a loss function to use.
model.compile({
// Adam changes the learning rate over time which is useful.
optimizer: 'adam',
// Use the correct loss function. If 2 classes of data, must use binaryCrossentropy.
// Else categoricalCrossentropy is used if more than 2 classes.
loss: (CLASS_NAMES.length === 2) ? 'binaryCrossentropy': 'categoricalCrossentropy',
// As this is a classification problem you can record accuracy in the logs too!
metrics: ['accuracy']
});
我们来了解一下此代码。首先,定义一个要添加模型层的 tf.series 模型。
接下来,添加一个密集层作为此模型的输入层。它的输入形状为 1024
,因为 MobileNet v3 功能的输出大小达到此大小。您通过模型传递数据后,在上一步中发现了这一点。该层有 128 个神经元,这些神经元使用 ReLU 激活函数。
如果您刚开始接触激活函数和模型层,请考虑学习本研讨会开头详述的课程,了解这些属性的幕后用途。
要添加的下一层是输出层。神经元的数量应与您尝试预测的类别的数量相等。为此,您可以使用 CLASS_NAMES.length
查找要分类的类别数量,这等于界面中的数据收集按钮数量。由于这是一个分类问题,因此您可以在此输出层上使用 softmax
激活,而在尝试创建模型(而不是回归)时,必须使用该激活。
现在,输出 model.summary()
,以将新定义的模型的概览输出到控制台。
最后,编译模型,使其准备好接受训练。这里的优化器设置为 adam
,如果 CLASS_NAMES.length
等于 2
,则损失将为 binaryCrossentropy
,或者如果有 3 个或更多要分类的类别,则损失将使用 categoricalCrossentropy
。系统会请求准确性指标,以便稍后可以在日志中监控这些指标以进行调试。
在控制台中,您应该会看到如下内容:
请注意,该模型有超过 13 万个可训练参数。但是,由于这是一个由常规神经元组成的简单密集层,因此它会训练非常快。
作为项目完成后要做的活动,您可以尝试更改第一层中的神经元数量,看看可以达到多低,同时仍能获得良好的性能。使用机器学习时,通常需要进行一定程度的试错,才能找到最优参数值,以便在资源用量和速度之间实现最佳平衡。
11. 启用摄像头
现在是时候充实您之前定义的 enableCam()
函数了。添加一个名为 hasGetUserMedia()
的新函数(如下所示),然后将之前定义的 enableCam()
函数的内容替换为下面的相应代码。
script.js
function hasGetUserMedia() {
return !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
}
function enableCam() {
if (hasGetUserMedia()) {
// getUsermedia parameters.
const constraints = {
video: true,
width: 640,
height: 480
};
// Activate the webcam stream.
navigator.mediaDevices.getUserMedia(constraints).then(function(stream) {
VIDEO.srcObject = stream;
VIDEO.addEventListener('loadeddata', function() {
videoPlaying = true;
ENABLE_CAM_BUTTON.classList.add('removed');
});
});
} else {
console.warn('getUserMedia() is not supported by your browser');
}
}
首先,创建一个名为 hasGetUserMedia()
的函数,通过检查是否存在关键的浏览器 API 属性来检查浏览器是否支持 getUserMedia()
。
在 enableCam()
函数中,使用您刚刚定义的 hasGetUserMedia()
函数来检查其是否受支持。如果不是,请向控制台输出警告。
如果它支持,请为 getUserMedia()
调用定义一些约束条件,例如,您只想使用视频流,并且希望视频的 width
大小为 640
像素,height
为 480
像素。为什么呢?让视频变大没有什么意义,因为需要将它调整到 224 x 224 像素才能馈送到 MobileNet 模型。不妨请求较低的分辨率,以节省一些计算资源。大多数相机都支持此大小的分辨率。
接下来,使用上面详述的 constraints
调用 navigator.mediaDevices.getUserMedia()
,然后等待 stream
返回。返回 stream
后,您可以将 VIDEO
元素设置为其 srcObject
值,从而播放 stream
。
您还应在 VIDEO
元素上添加 eventListener,以了解 stream
何时已加载并成功播放。
steam 加载后,您可以将 videoPlaying
设置为 true 并删除 ENABLE_CAM_BUTTON
,通过将其类设置为 "removed
来防止其再次点击。
现在,运行您的代码,点击“启用摄像头”按钮,然后授予摄像头使用权限。如果这是您第一次执行此操作,您应该会看到自己呈现给页面上的视频元素,如下所示:
好的,现在该添加一个函数来处理 dataCollector
按钮点击了。
12. 数据收集按钮事件处理脚本
现在,需要填充当前名为 gatherDataForClass().
的空函数。这是您在本 Codelab 开始时为 dataCollector
按钮指定的事件处理函数函数。
script.js
/**
* Handle Data Gather for button mouseup/mousedown.
**/
function gatherDataForClass() {
let classNumber = parseInt(this.getAttribute('data-1hot'));
gatherDataState = (gatherDataState === STOP_DATA_GATHER) ? classNumber : STOP_DATA_GATHER;
dataGatherLoop();
}
首先,使用属性名称(在本例中为 data-1hot
作为参数)调用 this.getAttribute()
,以检查当前点击的按钮上的 data-1hot
属性。由于这是一个字符串,因此您可以使用 parseInt()
将其转换为整数,并将此结果赋值给名为 classNumber.
的变量
接下来,相应地设置 gatherDataState
变量。如果当前的 gatherDataState
等于 STOP_DATA_GATHER
(您将其设为 -1),则表示您目前未收集任何数据,并且是触发的 mousedown
事件。将 gatherDataState
设置为您刚刚找到的 classNumber
。
否则,这意味着您目前在收集数据,而触发的事件是 mouseup
事件,而您现在想要停止收集该类的数据。只需将其设置回 STOP_DATA_GATHER
状态,即可结束您稍后要定义的数据收集循环。
最后,启动对 dataGatherLoop(),
的调用,该调用实际执行类数据的记录。
13. 数据收集
现在,定义 dataGatherLoop()
函数。此函数负责对网络摄像头视频中的图像进行采样,通过 MobileNet 模型传递图像,并捕获该模型的输出(1024 个特征向量)。
然后,它会存储这些信息以及当前按下按钮的 gatherDataState
ID,以便您了解这些数据表示什么类。
我们详细了解一下:
script.js
function dataGatherLoop() {
if (videoPlaying && gatherDataState !== STOP_DATA_GATHER) {
let imageFeatures = tf.tidy(function() {
let videoFrameAsTensor = tf.browser.fromPixels(VIDEO);
let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [MOBILE_NET_INPUT_HEIGHT,
MOBILE_NET_INPUT_WIDTH], true);
let normalizedTensorFrame = resizedTensorFrame.div(255);
return mobilenet.predict(normalizedTensorFrame.expandDims()).squeeze();
});
trainingDataInputs.push(imageFeatures);
trainingDataOutputs.push(gatherDataState);
// Intialize array index element if currently undefined.
if (examplesCount[gatherDataState] === undefined) {
examplesCount[gatherDataState] = 0;
}
examplesCount[gatherDataState]++;
STATUS.innerText = '';
for (let n = 0; n < CLASS_NAMES.length; n++) {
STATUS.innerText += CLASS_NAMES[n] + ' data count: ' + examplesCount[n] + '. ';
}
window.requestAnimationFrame(dataGatherLoop);
}
}
仅当 videoPlaying
为 true(即,摄像头处于活动状态,且 gatherDataState
不等于 STOP_DATA_GATHER
)且当前已按下用于收集课程数据的按钮时,您才会继续执行此函数。
接下来,将您的代码封装在 tf.tidy()
中,以处置后续代码中已创建的任何张量。此 tf.tidy()
代码执行的结果存储在名为 imageFeatures
的变量中。
您现在可以使用 tf.browser.fromPixels()
抓取摄像头 VIDEO
的帧。包含图片数据的结果张量存储在名为 videoFrameAsTensor
的变量中。
接下来,调整 videoFrameAsTensor
变量的大小,使其具有适合 MobileNet 模型输入的正确形状。使用 tf.image.resizeBilinear()
调用,将要调整形状的张量作为第一个参数,然后使用一个形状来定义新的高度和宽度(如之前创建的常量所定义)。最后,通过传递第三个参数将对齐角设置为 true,以避免在调整大小时出现任何对齐问题。此大小调整的结果存储在名为 resizedTensorFrame
的变量中。
请注意,此基元调整大小会拉伸图片,因为摄像头图片的大小为 640 x 480 像素,而模型需要 224 x 224 像素的方形图片。
就本演示而言,这应当能够正常工作。不过,完成此 Codelab 后,您可能需要尝试从此图片中剪裁一个方形,以便之后创建的任何生产系统都能获得更好的效果。
接下来,标准化图片数据。使用 tf.browser.frompixels()
时,图片数据始终位于 0 到 255 的范围内,因此只需将尺寸调整大小 TensorFrame 除以 255,即可确保所有值都介于 0 和 1 之间,这正是 MobileNet 模型的输入值。
最后,在代码的 tf.tidy()
部分中,通过调用 mobilenet.predict()
向加载的模型推送此归一化张量,您将使用 expandDims()
向该模型传递 normalizedTensorFrame
的扩展版本,使其是一个 1 的批次,因为该模型需要一批输入进行处理。
结果返回后,您可以立即对返回的结果调用 squeeze()
,以将其压缩为一维张量,然后返回该张量并将其赋值给从 tf.tidy()
捕获结果的 imageFeatures
变量。
现在,您已获得 MobileNet 模型中的 imageFeatures
,可以将它们推送到您之前定义的 trainingDataInputs
数组上,以记录这些输入。
您还可以通过将当前的 gatherDataState
也推送到 trainingDataOutputs
数组来记录此输入表示的内容。
请注意,当您在先前定义的 gatherDataForClass()
函数中点击相应按钮时,gatherDataState
变量会设为您要记录数据的当前类的数字 ID。
此时,您还可以增加给定类的样本数量。为此,请先检查 examplesCount
数组中的索引之前是否已初始化。如果未定义,请将其设置为 0 以初始化给定类的数字 ID 的计数器,然后为当前 gatherDataState
递增 examplesCount
。
现在,更新网页上 STATUS
元素的文本,以便在捕获每个类时显示每个类的当前计数。为此,请循环遍历 CLASS_NAMES
数组,然后输出直观易懂的名称以及 examplesCount
中相同索引下的数据量。
最后,使用 dataGatherLoop
作为参数传递来调用 window.requestAnimationFrame()
,以递归方式再次调用此函数。这将继续对视频中的帧进行采样,直到检测到按钮的 mouseup
且 gatherDataState
设置为 STOP_DATA_GATHER,
,此时数据收集循环将结束。
如果您现在运行代码,您应该能够点击“启用相机”按钮,等待网络摄像头加载,然后点击并按住每个数据收集按钮,以收集每类数据的样本。在这里,你可以看到我分别收集了我的手机和我的手的数据。
您应该会看到状态文本已更新,因为它将所有张量存储在内存中,如上面的屏幕截图所示。
14. 训练和预测
下一步是为当前为空的 trainAndPredict()
函数实现代码,这就是迁移学习的发生位置。我们来看一下代码:
script.js
async function trainAndPredict() {
predict = false;
tf.util.shuffleCombo(trainingDataInputs, trainingDataOutputs);
let outputsAsTensor = tf.tensor1d(trainingDataOutputs, 'int32');
let oneHotOutputs = tf.oneHot(outputsAsTensor, CLASS_NAMES.length);
let inputsAsTensor = tf.stack(trainingDataInputs);
let results = await model.fit(inputsAsTensor, oneHotOutputs, {shuffle: true, batchSize: 5, epochs: 10,
callbacks: {onEpochEnd: logProgress} });
outputsAsTensor.dispose();
oneHotOutputs.dispose();
inputsAsTensor.dispose();
predict = true;
predictLoop();
}
function logProgress(epoch, logs) {
console.log('Data for epoch ' + epoch, logs);
}
首先,请务必将 predict
设置为 false
,以确保停止执行任何当前预测。
接下来,使用 tf.util.shuffleCombo()
重排输入和输出数组,以确保顺序不会导致训练出现问题。
将输出数组 trainingDataOutputs,
转换为 int32 类型的 tensor1d,以便随时在独热编码中使用它。它存储在名为 outputsAsTensor
的变量中。
将 tf.oneHot()
函数与此 outputsAsTensor
变量以及要编码的类数量上限(即 CLASS_NAMES.length
)搭配使用。您的独热编码输出现在存储在名为 oneHotOutputs
的新张量中。
请注意,目前 trainingDataInputs
是记录的张量数组。为了将这些张量用于训练,您需要将张量数组转换为常规的 2D 张量。
为此,TensorFlow.js 库中有一个名为 tf.stack()
的出色函数,
该方法接受一个张量数组,然后将它们堆叠在一起,以生成更高维度的张量作为输出。在本例中,系统返回了一个 2D 张量,这是一批一维输入,每个输入的长度为 1024 个,包含记录的特征,这就是训练所需的内容。
接下来,await model.fit()
以训练自定义模型头部。在这里,您将传递 inputsAsTensor
变量和 oneHotOutputs
,以分别表示要用于示例输入和目标输出的训练数据。在第 3 个参数的配置对象中,将 shuffle
设为 true
,使用 5
的 batchSize
,将 epochs
设为 10
,然后为稍后定义的 logProgress
函数指定 onEpochEnd
的 callback
。
最后,在模型训练完成时,您可以处置创建的张量。然后,您可以将 predict
设置回 true
以允许再次进行预测,然后调用 predictLoop()
函数以开始预测实时摄像头图片。
您还可以定义 logProcess()
函数来记录训练状态,该函数在上面的 model.fit()
中使用,会在每轮训练后将结果输出到控制台。
即将大功告成!现在该添加 predictLoop()
函数以进行预测了。
核心预测循环
在此示例中,您将实现主预测循环,该循环会对摄像头的帧进行采样,并在浏览器中通过实时结果持续预测每一帧中的内容。
我们来检查一下代码:
script.js
function predictLoop() {
if (predict) {
tf.tidy(function() {
let videoFrameAsTensor = tf.browser.fromPixels(VIDEO).div(255);
let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor,[MOBILE_NET_INPUT_HEIGHT,
MOBILE_NET_INPUT_WIDTH], true);
let imageFeatures = mobilenet.predict(resizedTensorFrame.expandDims());
let prediction = model.predict(imageFeatures).squeeze();
let highestIndex = prediction.argMax().arraySync();
let predictionArray = prediction.arraySync();
STATUS.innerText = 'Prediction: ' + CLASS_NAMES[highestIndex] + ' with ' + Math.floor(predictionArray[highestIndex] * 100) + '% confidence';
});
window.requestAnimationFrame(predictLoop);
}
}
首先,检查 predict
是否为 true,以便仅在模型训练完成后进行预测,并可供使用。
接下来,您可以获取当前图片的图片特征,就像在 dataGatherLoop()
函数中进行操作一样。从本质上讲,您需要使用 tf.browser.from pixels()
从摄像头抓取一帧,对其进行归一化,将其大小调整为 224 x 224 像素,然后通过 MobileNet 模型传递该数据,以获取生成的图片特征。
不过,您现在可以使用新训练的模型头部来实际执行预测,方法是通过经过训练的模型的 predict()
函数传递刚刚找到的 imageFeatures
。然后,您可以压缩生成的张量,使其再次变为一维状态,并将其赋值给一个名为 prediction
的变量。
借助此 prediction
,您可以使用 argMax()
查找具有最高值的索引,然后使用 arraySync()
将生成的张量转换为数组,以获取 JavaScript 中的底层数据,从而发现最高值元素的位置。此值存储在名为 highestIndex
的变量中。
您也可以直接对 prediction
张量调用 arraySync()
,以同样的方式获得实际预测置信度分数。
您现在已经有了使用 prediction
数据更新 STATUS
文本所需的一切。如需获取类的直观易懂的字符串,您可以在 CLASS_NAMES
数组中查找 highestIndex
,然后从 predictionArray
获取置信度值。若要使其以百分比的形式更易读懂,只需乘以 100 并对结果使用 math.floor()
即可。
最后,在准备就绪后,您可以使用 window.requestAnimationFrame()
再次调用 predictionLoop()
,从而对视频流进行实时分类。如果您选择使用新数据训练新模型,此过程会一直持续到 predict
设为 false
。
这就是谜题的最后一块。实现重置按钮。
15. 实现重置按钮
即将完成!最后一点是,实现重置按钮以重新开始。您当前为空的 reset()
函数的代码如下所示。接下来,按如下所示进行更新:
script.js
/**
* Purge data and start over. Note this does not dispose of the loaded
* MobileNet model and MLP head tensors as you will need to reuse
* them to train a new model.
**/
function reset() {
predict = false;
examplesCount.length = 0;
for (let i = 0; i < trainingDataInputs.length; i++) {
trainingDataInputs[i].dispose();
}
trainingDataInputs.length = 0;
trainingDataOutputs.length = 0;
STATUS.innerText = 'No data collected';
console.log('Tensors in memory: ' + tf.memory().numTensors);
}
首先,将 predict
设置为 false
,以停止任何正在运行的预测循环。接下来,将 examplesCount
数组的长度设置为 0 以删除其中的所有内容,这是一种清除数组中所有内容的便捷方式。
现在,请查看当前记录的所有 trainingDataInputs
,并确保对其中包含的每个张量执行 dispose()
操作,以再次释放内存,因为 JavaScript 垃圾回收器不会清理张量。
完成此操作后,您现在可以安全地将 trainingDataInputs
和 trainingDataOutputs
数组上的数组长度设置为 0,以清除这些内容。
最后,将 STATUS
文本设置为合理的内容,并输出留在内存中的张量作为健全性检查。
请注意,由于您定义的 MobileNet 模型和多层感知机都未被处理,因此内存中仍有数百个张量。如果您决定在重置后重新训练,则需要将新训练集重复用于新的训练数据。
16. 我们来试试看
是时候测试你自己的 Teachable Machine 版本了!
前往实时预览页面,启用摄像头,针对房间中的某个类别收集至少 30 个类别 1 的样本,然后对类别 2 中的其他对象执行相同的操作,点击“训练”,并检查控制台日志以查看进度。它的训练应该非常快:
训练完成后,将对象显示给相机以获取实时预测,这些预测将输出到网页上顶部附近的状态文本区域。如果您遇到问题,请检查我已完成的工作代码,看看是否遗漏了任何复制内容。
17. 恭喜
恭喜!您刚刚在浏览器中使用 TensorFlow.js 完成了第一个迁移学习示例。
试试看,用各种对象测试一下,您可能会注意到有些元素比其他对象更难识别,尤其是当它们与其他事物相似时。您可能需要添加更多类别或训练数据,以便进行区分。
回顾
在此 Codelab 中,您学习了以下内容:
- 什么是迁移学习,以及它与训练完整模型相比有哪些优势。
- 如何从 TensorFlow Hub 获取可重复使用的模型。
- 如何设置适合迁移学习的 Web 应用。
- 如何加载和使用基本模型来生成图像特征。
- 如何训练能够识别摄像头图像中的自定义对象的新预测头。
- 如何使用生成的模型实时对数据进行分类。
后续操作
现在,您已经有了工作基础,接下来您能想出哪些创意来扩展此机器学习模型样板文件,以将其用于您可能正在处理的实际用例中?或许你可以彻底改变目前所在的行业,帮助公司员工训练出模型,对日常工作中的重要事项进行分类。有无限可能。
要想更进一步,不妨考虑免费学习此完整课程,其中介绍了如何将此 Codelab 中现有的 2 个模型组合成 1 个模型,以提高效率。
此外,如果您想进一步了解原始可训练机器应用背后的理论,请参阅此教程。
与我们分享您的成果
你还可以轻松地将你当前的应用扩展到其他创意应用场景,我们建议你跳出常规,不断创新。
请记得在社交媒体上使用 #MadeWithTFJS # 标签给我们加上标签,这样您的项目就有机会在我们的 TensorFlow 博客甚至将来的活动上获得特别推介。我们很期待看到你的成果。
可以结账的网站
- TensorFlow.js 官方网站
- TensorFlow.js 预制模型
- TensorFlow.js API
- TensorFlow.js 展示与讲述 — 汲取灵感,看看别人做了什么。