TensorFlow.js:打造自己的個人風格 "訓練學習機器";透過 TensorFlow.js 使用遷移學習

1. 事前準備

過去幾年來,TensorFlow.js 的模型用量大幅增加,許多 JavaScript 開發人員現在尋找現有的最先進模型,並重新訓練這些模型,以處理所屬產業特有的自訂資料。將現有模型 (通常稱為基礎模型) 並用於相似但不同領域,稱為「遷移學習」。

遷移學習比從完全空白的模型開始,有許多優點。您可以重複使用透過先前訓練模型學到的知識,也可以減少想要分類的新項目樣本。此外,由於只需要重新訓練模型架構的最後幾個層,而非整個網路,因此訓練速度通常會大幅加快。因此,如果網路瀏覽器環境的資源可能會因執行裝置而異,但也能直接存取感應器,方便取得資料,因此非常適合遷移學習。

本程式碼研究室將說明如何從空白的畫佈建構網頁應用程式,進而重現 Google 的熱門Teachable Machine網站。您可以透過網站建立功能齊全的網頁應用程式,使用者只需從網路攝影機拍攝幾張圖片,就能辨識自訂物件。這個網站刻意最小化,讓您專注於本程式碼研究室的機器學習層面。不過,與原始 Teachable Machine 網站一樣,您現有的網頁程式開發人員服務仍有很大的應用範圍,以提升使用者體驗。

必要條件

本程式碼研究室是為已十分熟悉 TensorFlow.js 預建模型和基本 API 使用方式的網頁開發人員所編寫,且是想開始使用 TensorFlow.js 中的遷移學習功能。

  • 本研究室假設您對 TensorFlow.js、HTML5、CSS 和 JavaScript 有基本瞭解。

如果您是 Tensorflow.js 新手,建議先參加這堂免費的零為主要課程,課程內容假定機器學習或 TensorFlow.js 沒有背景運作,並且會透過較小的步驟說明所有必學知識。

課程內容

  • 什麼是 TensorFlow.js,以及為何應在下一個網頁應用程式中使用 TensorFlow.js。
  • 如何建構簡化的 HTML/CSS /JS 網頁,重現 Teachable Machine 的使用者體驗。
  • 如何使用 TensorFlow.js 載入預先訓練的基本模型 (尤其是 MobileNet),產生可用於遷移學習的圖片特徵。
  • 如何從使用者的網路攝影機,針對您想辨識的多個資料類別收集資料。
  • 如何建立及定義多層感知,以便取得圖片特徵,並學習使用這些圖片特徵來分類新物件。

開始入侵吧...

軟硬體需求

  • 我們建議使用 Glitch.com 帳戶跟上腳步,或者使用您熟悉的網站放送環境自行編輯與執行。

2. 什麼是 TensorFlow.js?

54e81d02971f53e8.png

TensorFlow.js開放原始碼的機器學習程式庫,可以在 JavaScript 中的任何位置執行。這個程式庫是以以 Python 編寫的原始 TensorFlow 程式庫為基礎,旨在重新建立 JavaScript 生態系統適用的一組開發人員體驗和 API。

哪些地方可以使用?

由於 JavaScript 具備可攜性,現在您可以使用 1 種語言來編寫程式碼,並在下列所有平台上輕鬆執行機器學習:

  • 運用基本 JavaScript,在網路瀏覽器的用戶端
  • 使用 Node.js 的伺服器端,甚至是 Raspberry Pi 等 IoT 裝置
  • 使用 Electron 的電腦版應用程式
  • 使用 React Native 的原生行動應用程式

TensorFlow.js 也支援這些環境中的多個後端,也就是可在 CPU 或 WebGL 等內部執行的實際硬體型環境。「後端」在這個情況下,並不表示伺服器端環境。舉例來說,執行作業的後端可能是 WebGL 中的用戶端,可確保相容性,同時讓執行速度飛快。目前 TensorFlow.js 支援:

  • 在裝置的顯示卡上執行 WebGL (GPU):這是在 GPU 加速功能下執行大型模型 (大小超過 3 MB) 的最快方法。
  • 在 CPU 上執行網路組合 (WASM):提升各裝置的 CPU 效能 (例如較舊的手機)。這種級別更適用於小型模型 (大小小於 3 MB),比起使用 WebGL,實際執行 WASM 的 CPU 執行速度較 WebGL,因為將內容上傳至圖形處理器時須耗用大量資源。
  • CPU 執行:若其他環境都無法使用,備用機制就應該沒有問題。這是最慢的圖示,但系統能隨時為您效勞。

注意:如果你知道要在哪個裝置上執行,可以選擇強制執行其中一個後端;如未指定,也可以直接讓 TensorFlow.js 決定。

用戶端超能力

在用戶端電腦的網路瀏覽器中執行 TensorFlow.js,能夠帶來許多值得評估的好處。

隱私權

您可以在用戶端電腦上訓練及分類資料,而不必將資料傳送至第三方網路伺服器。在某些情況下,可能需要遵守 GDPR 等當地法律,或者在處理使用者可能想保留在電腦上的資料,而不傳送至第三方時。

速度

由於您不需要將資料傳送至遠端伺服器,推論 (資料分類) 可以更快。更棒的是,在使用者授予您存取權的情況下,您可以直接存取裝置的感應器,例如相機、麥克風、GPS、加速計等。

觸及範圍廣大

全球使用者只要按一下您傳送給自己的連結,就能在瀏覽器中開啟網頁,並且使用您製作的內容。您不必透過 CUDA 驅動程式進行複雜的伺服器端 Linux 設定,而不只是透過機器學習系統完成。

費用

不需要伺服器,只要付費使用 CDN 代管 HTML、CSS、JS 和模型檔案即可。CDN 的費用比讓全天候運作的伺服器 (可能附有顯示卡) 便宜許多。

伺服器端功能

運用 TensorFlow.js 的 Node.js 實作可啟用下列功能。

完整的 CUDA 支援

在伺服器端,如要加速顯示顯示卡,您必須安裝 NVIDIA CUDA 驅動程式,才能使用 TensorFlow 搭配顯示卡 (這點與使用 WebGL 的瀏覽器不同,不需要安裝)。不過,享有完整的 CUDA 支援,可充分運用顯示卡層級較低的功能,加快訓練和推論速度。效能與 Python TensorFlow 實作不相上下,因為兩者都共用相同的 C++ 後端。

型號大小

針對研究提供的先進模型,您可能需要處理非常大型的模型 (可能多達 GB)。由於每個瀏覽器分頁的記憶體用量限制,這些模型目前無法在網路瀏覽器中執行。如要執行這類大型模型,您可以在自己的伺服器上使用 Node.js,同時將必要的硬體規格設為有效率地執行這類模型。

IOT

Node.js 支援熱門的單板電腦 (例如 Raspberry Pi),因此您也可以在這類裝置上執行 TensorFlow.js 模型。

速度

Node.js 是以 JavaScript 編寫,因此只需進行時間編譯即可。這表示您在使用 Node.js 時,通常可以發現效能提升,因為 Node.js 會在執行階段進行最佳化,特別是您執行的任何預先處理作業。透過這份個案研究,您可以看出 Hugging Face 如何運用 Node.js,將自然語言處理模型的效能提升 2 倍。

現在您已瞭解 TensorFlow.js 的基本功能、可在何處執行,以及一些優點,讓我們開始運用 TensorFlow.js 有效執行各項工作了!

3. 遷移學習

什麼是遷移學習?

遷移學習是指要習得已學到的知識,幫助學習截然不同的事物。

我們是人類一直以來的努力。腦力激盪不已,你可以協助辨識前所未有的新事物。以這棵柳樹為例:

e28070392cd4afb9.png

視您所在的世界而定,您可能從未看過這種類型的樹。

不過,如果希望你跟我說明新的圖片中是否有柳樹,即使它們處處不同,也和我剛才看到的線有點不同,你可能會很快發現它們。

d9073a0d5df27222.png

腦中已經有許多神經元,知道如何識別樹類類似的物體,以及其他適合尋找長直線的神經元。您可以重複使用這些知識,快速分類柳樹。樹葉是一種像樹木,且有許多長直的垂直分支。

同樣地,如果您的機器學習模型已在某個領域上完成訓練 (例如辨識圖片),則可重複使用該模型執行不同但相關的工作。

MobileNet 是一種相當熱門的研究模型,可在 1000 種不同的物件類型上執行圖片辨識,而 MobileNet 這類進階模型也能讓您執行類似操作。無論是狗還是汽車,都能透過名為 ImageNet 的大型資料集完成訓練,該資料集含有數百萬張已加上標籤的圖片。

在這段動畫中,您可以看到這個 MobileNet V1 模型中的大量圖層:

(7d4e1e35c1a89715.gif)

在訓練期間,此模型學會如何從這 1000 個物件中擷取共同的通用特徵,以及許多用來找出這類物件的低階特徵,進而偵測從未見過的新物件。畢竟,項目最終只是線條、紋理和形狀的組合。

以下將介紹傳統的捲積類神經網路 (CNN) 架構 (類似 MobileNet),說明遷移學習如何運用這個經過訓練的網路來學習新事物。下圖顯示 CNN 的一般模型架構。在此例中,我們訓練了辨識從 0 到 9 的手寫數字:

baf4e3d434576106.png

如果您能夠將現有訓練模型 (如左圖所示) 的預先訓練較低層級,與右側顯示的模型結尾附近的分類層分開 (有時稱為模型的分類頭),可以使用較低層級的圖層,根據訓練時的原始資料產生任何特定圖片的輸出特徵。以下是同一個網路已移除分類標頭的網路:

369a8a9041c6917d.png

假設要辨識的新項目也可能使用先前模型學到的這類輸出特徵,那麼這些輸出內容很有可能重複用於新用途。

在上圖中,這個假設模型是以數字訓練而成,因此所學到的數字可能也適用於 a、b 和 c 等字母。

因此,您現在可以新增嘗試預測 a、b 或 c 的新分類標頭,如下所示:

db97e5e60ae73bbd.png

這裡的較低層級圖層會凍結且未經過訓練,只有新的分類標題會自行更新,以從左側預先訓練的切割模型中學習。

這種做法稱為遷移學習,也是 Teachable Machine 在幕後執行的作業。

您也會發現,只需要在網路的最端訓練多層感知體,訓練速度會比從頭開始訓練整個網路要快得多。

但要怎麼做才能完全掌控模型的組成部分呢?詳情請見下一節。

4. TensorFlow Hub - 基礎模型

尋找合適的基礎模型

如需更進階且熱門的研究模型 (例如 MobileNet),請前往 TensorFlow Hub,然後篩選使用 MobileNet v3 架構適用的 TensorFlow.js 模型,如下所示:

c5dc1420c6238c14.png

請注意,部分結果屬於「圖片分類」類型(詳見每個模型資訊卡結果的左上角),其他類型則是「圖片特徵向量」。

這些「圖片功能向量」結果基本上是預先剪輯的 MobileNet 版本,可用來取得圖片特徵向量,而非最終分類。

這類模型通常稱為「基礎模型」您接著可以按照上一節所示的方式進行遷移學習,方法是新增一個分類標題,並使用您自己的資料進行訓練。

接下來,我們要確認特定基礎模型是以 TensorFlow.js 格式發布時使用的基本模型。如果您開啟其中一個特徵向量 MobileNet v3 模型的網頁,可以在 JS 說明文件中查看該頁面是根據使用 tf.loadGraphModel() 的程式碼片段範例所呈現的圖形模型形式。

f97d903d2e46924b.png

另請注意,如果發現模型是以圖層格式 (而非圖表格式) 呈現,可以選擇要在訓練時凍結哪些圖層,以及要取消凍結的部分。為新工作建立模型時,這項功能非常實用,這項作業通常稱為「轉移模式」。不過,目前您將使用本教學課程的預設圖形模型類型,而大部分的 TF Hub 模型會部署為這種類型。如要進一步瞭解如何使用圖層模型,請參閱 0ro to hero TensorFlow.js 課程。

遷移學習的優點

相較於從頭開始訓練整個模型架構,採用遷移學習有什麼好處?

首先,由於您已訓練好用來建構的基礎模型,因此使用遷移學習方法的關鍵優勢是進行訓練。

其次,由於已完成訓練,您可以減少要分類的新事物樣本,數量會大幅減少。

如果時間和資源有限,可用來收集要分類的範例資料,且需要先快速製作原型,才能收集更多訓練資料來讓模型更完善,這個方法就很實用。

由於需要減少資料用量及訓練規模較小的網路,因此遷移學習所需的資源較少。因此特別適合瀏覽器環境,花了數十秒即可在現代化機器上完成模型訓練,無須耗費數小時、數天或數週的時間。

好的!現在您已瞭解遷移學習的基本知識,接著就來打造自己專屬的訓練學習機器版本。立即開始!

5. 設定程式碼

軟硬體需求

  • 新型網路瀏覽器。
  • 具備 HTML、CSS、JavaScript 和 Chrome 開發人員工具 (查看控制台輸出內容) 的基本知識。

開始編寫程式碼

已為 Glitch.comCodepen.io 建立範本範本。只要按一下滑鼠,即可複製任一範本做為本程式碼研究室的基礎狀態。

在 Glitch 上按一下「重混這個」按鈕,即可建立一組分支,並建立一組可以編輯的檔案。

或者,在 Codepen 上,按一下fork"

這個簡易架構提供下列檔案:

  • HTML 網頁 (index.html)
  • 樣式表 (style.css)
  • 要編寫 JavaScript 程式碼 (script.js) 的檔案

為方便起見,我們在 TensorFlow.js 程式庫的 HTML 檔案中新增了匯入作業。這是訂閱按鈕的圖示:

index.html

<!-- Import TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js" type="text/javascript"></script>

替代做法:使用偏好的網頁編輯器或在本機環境中運作

如想下載程式碼並透過本機或其他線上編輯器執行工作,只要在同一目錄中建立上述 3 個檔案,然後複製 Glitch 樣板的程式碼,然後貼到各個檔案中即可。

6. 應用程式 HTML 樣板

如何踏出第一步?

所有原型都需要一些基本的 HTML 鷹架,以呈現您的發現。現在就進行設定吧!您將新增:

  • 頁面標題。
  • 一些說明文字。
  • 狀態段落。
  • 影片準備就緒後,應保留網路攝影機畫面。
  • 使用多個按鈕啟動相機、收集資料或重設體驗。
  • 匯入 TensorFlow.js 和 JS 檔案,稍後需編寫程式碼。

開啟 index.html,並貼上現有的程式碼來設定上述功能:

index.html

<!DOCTYPE html>
<html lang="en">
  <head>
    <title>Transfer Learning - TensorFlow.js</title>
    <meta charset="utf-8">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <!-- Import the webpage's stylesheet -->
    <link rel="stylesheet" href="/style.css">
  </head>  
  <body>
    <h1>Make your own "Teachable Machine" using Transfer Learning with MobileNet v3 in TensorFlow.js using saved graph model from TFHub.</h1>
    
    <p id="status">Awaiting TF.js load</p>
    
    <video id="webcam" autoplay muted></video>
    
    <button id="enableCam">Enable Webcam</button>
    <button class="dataCollector" data-1hot="0" data-name="Class 1">Gather Class 1 Data</button>
    <button class="dataCollector" data-1hot="1" data-name="Class 2">Gather Class 2 Data</button>
    <button id="train">Train &amp; Predict!</button>
    <button id="reset">Reset</button>

    <!-- Import TensorFlow.js library -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.11.0/dist/tf.min.js" type="text/javascript"></script>

    <!-- Import the page's JavaScript to do some stuff -->
    <script type="module" src="/script.js"></script>
  </body>
</html>

解析

接下來,讓我們針對上述的部分 HTML 程式碼進行細分,以列出您新增的一些重要內容。

  • 你為網頁標題新增了 <h1> 標記,以及一個 ID 為「status」的 <p> 標記。當您使用系統的不同部分檢視輸出內容時,請在此列印資訊。
  • 你新增了 ID 為「網路攝影機」的 <video> 元素稍後將顯示網路攝影機串流
  • 您已新增 5 個 <button> 元素。第一個是 ID 為「enableCam」會啟用相機。下一個兩個按鈕是 dataCollector 類別;讓您為想識別的物件收集範例圖片。您稍後撰寫的程式碼將設計成您可以加入任意數量的按鈕,且按鈕會正常運作。

請注意,這些按鈕也有稱為 data-1hot 的特殊使用者定義屬性,並針對第一個類別使用從 0 開始的整數值。這是指數字索引,用來代表特定類別的資料。由於機器學習模型只能使用數字,因此這個索引會以數值表示 (而非字串) 正確為輸出類別編碼。

還有一個 data-name 屬性,其中包含您要用於此類別的人類可讀名稱,您就能為使用者提供更有意義的名稱,而不是 1 熱門編碼的數值索引值。

最後,您可以使用訓練和重設按鈕在開始收集資料後開始訓練程序,或分別重設應用程式。

  • 您也新增了 2 項 <script> 匯入作業。一個用於 TensorFlow.js,另一種則適用於您稍後定義的 Script.js。

7. 新增樣式

元素預設值

為剛新增的 HTML 元素新增樣式,確保元素正確顯示。以下是為位置和大小元素新增正確的樣式。什麼都不重要。當然,您可以稍後加入這些內容以提供更優質的使用者體驗,就像你在訓練用機器的影片一樣。

style.css

body {
  font-family: helvetica, arial, sans-serif;
  margin: 2em;
}

h1 {
  font-style: italic;
  color: #FF6F00;
}


video {
  clear: both;
  display: block;
  margin: 10px;
  background: #000000;
  width: 640px;
  height: 480px;
}

button {
  padding: 10px;
  float: left;
  margin: 5px 3px 5px 10px;
}

.removed {
  display: none;
}

#status {
  font-size:150%;
}

太好了!這樣就大功告成了。如果您現在預覽輸出內容,看起來應像這樣:

81909685d7566dcb.png

8. JavaScript:索引鍵常數和事件監聽器

定義索引鍵常數

首先,請新增一些您需要在應用程式中使用的關鍵常數。首先,請將 script.js 的內容替換為這些常數:

script.js

const STATUS = document.getElementById('status');
const VIDEO = document.getElementById('webcam');
const ENABLE_CAM_BUTTON = document.getElementById('enableCam');
const RESET_BUTTON = document.getElementById('reset');
const TRAIN_BUTTON = document.getElementById('train');
const MOBILE_NET_INPUT_WIDTH = 224;
const MOBILE_NET_INPUT_HEIGHT = 224;
const STOP_DATA_GATHER = -1;
const CLASS_NAMES = [];

以下將詳細說明這些目標:

  • STATUS 只會保留您要寫入狀態更新的段落標記參照。
  • VIDEO 會保留 HTML 影片元素的參照,以轉譯網路攝影機動態饋給。
  • ENABLE_CAM_BUTTONRESET_BUTTONTRAIN_BUTTON 會擷取 HTML 頁面中所有鍵按鈕的 DOM 參照。
  • MOBILE_NET_INPUT_WIDTHMOBILE_NET_INPUT_HEIGHT 可分別定義 MobileNet 模型的預期輸入內容寬度和高度。透過將此字串儲存在接近此檔案頂端的常數中,如果您稍後決定使用其他版本,就可以更輕鬆地更新一次值,而不必在許多不同位置替換值。
  • STOP_DATA_GATHER 設為 - 1。這會儲存狀態值,讓您能夠瞭解使用者是否停止點選按鈕,透過網路攝影機視訊收集資料。為這個數字取一個更有意義的名稱,日後程式碼會更易於閱讀。
  • CLASS_NAMES 會執行查詢,並使用人類可讀的名稱保留可能的類別預測結果。這個陣列稍後會填入資料。

既然您已經參照到主要元素,現在該是將某些事件監聽器與其關聯。

新增重要事件事件監聽器

請先將點擊事件處理常式新增至按鍵按鈕,如下所示:

script.js

ENABLE_CAM_BUTTON.addEventListener('click', enableCam);
TRAIN_BUTTON.addEventListener('click', trainAndPredict);
RESET_BUTTON.addEventListener('click', reset);


function enableCam() {
  // TODO: Fill this out later in the codelab!
}


function trainAndPredict() {
  // TODO: Fill this out later in the codelab!
}


function reset() {
  // TODO: Fill this out later in the codelab!
}

ENABLE_CAM_BUTTON - 獲得點擊時呼叫 enableCam 函式。

TRAIN_BUTTON - 按下時呼叫 trainAndPredict。

RESET_BUTTON - 按下時重設通話。

最後,這個章節會顯示所有擁有「dataCollector」類別的按鈕。使用 document.querySelectorAll()。這會傳回在文件中所找到符合的元素陣列:

script.js

let dataCollectorButtons = document.querySelectorAll('button.dataCollector');
for (let i = 0; i < dataCollectorButtons.length; i++) {
  dataCollectorButtons[i].addEventListener('mousedown', gatherDataForClass);
  dataCollectorButtons[i].addEventListener('mouseup', gatherDataForClass);
  // Populate the human readable names for classes.
  CLASS_NAMES.push(dataCollectorButtons[i].getAttribute('data-name'));
}


function gatherDataForClass() {
  // TODO: Fill this out later in the codelab!
}

程式碼說明:

接著,您可以反覆查看找到的按鈕,並將 2 個事件監聽器與各按鈕建立關聯。一個是「mousedown」和「mouseup」的指令。如此一來,您就可以在按下按鈕的情況下保留記錄樣本,這有助於收集資料。

這兩個事件都會呼叫您稍後定義的 gatherDataForClass 函式。

此時,您也可以將找到的人類可讀類別名稱,從「HTML」按鈕屬性 data-name 推送至 CLASS_NAMES 陣列。

接著,新增一些變數來儲存稍後會用到的重要項目。

script.js

let mobilenet = undefined;
let gatherDataState = STOP_DATA_GATHER;
let videoPlaying = false;
let trainingDataInputs = [];
let trainingDataOutputs = [];
let examplesCount = [];
let predict = false;

現在就來逐一說明

首先,您要使用變數 mobilenet 儲存已載入的 mobilenet 模型。一開始先將這個選項設為未定義。

接下來,有一個名為 gatherDataState 的變數。如果 DataCollector按下按鈕後,這會改為該按鈕的 1 個熱門 ID (如 HTML 中的定義),方便您瞭解目前正在收集的資料類別。一開始,此選項會設為 STOP_DATA_GATHER,這樣您之後撰寫的資料收集迴圈就不會在按下任何按鈕時收集任何資料。

videoPlaying」會追蹤網路攝影機串流是否成功載入及播放,以及是否可供使用。一開始會設為 false,因為在你按下 ENABLE_CAM_BUTTON. 之前,網路攝影機不會開啟

接著,請定義 2 個陣列:trainingDataInputstrainingDataOutputs。當您點選「dataCollector」時,這些欄位會儲存收集到的訓練資料值顯示 MobileNet 基礎模型產生的輸入特徵按鈕,以及分別取樣的輸出內容類別按鈕。

系統接著會定義一個最終陣列 examplesCount,,以便在您新增類別後追蹤每個類別包含的範例數量。

最後,使用名為 predict 的變數控制預測迴圈。這個引數一開始設為 false。系統稍後會將其設為 true,才能進行預測。

所有鍵變數都已定義完畢,讓我們開始載入預先剪輯的 MobileNet v3 基本模型,這個模型提供圖片特徵向量而非分類。

9. 載入 MobileNet 基礎模型

首先,定義名為 loadMobileNetFeatureModel 的新函式,如下所示。這必須是非同步函式,因為載入模型的方式為非同步:

script.js

/**
 * Loads the MobileNet model and warms it up so ready for use.
 **/
async function loadMobileNetFeatureModel() {
  const URL = 
    'https://tfhub.dev/google/tfjs-model/imagenet/mobilenet_v3_small_100_224/feature_vector/5/default/1';
  
  mobilenet = await tf.loadGraphModel(URL, {fromTFHub: true});
  STATUS.innerText = 'MobileNet v3 loaded successfully!';
  
  // Warm up the model by passing zeros through it once.
  tf.tidy(function () {
    let answer = mobilenet.predict(tf.zeros([1, MOBILE_NET_INPUT_HEIGHT, MOBILE_NET_INPUT_WIDTH, 3]));
    console.log(answer.shape);
  });
}

// Call the function immediately to start loading.
loadMobileNetFeatureModel();

在這個程式碼中,您將定義要載入模型的 URL,該模型位於 TFHub 說明文件中。

接著,您可以使用 await tf.loadGraphModel() 載入模型,請記得在從這個 Google 網站載入模型時,將特殊屬性 fromTFHub 設為 true。只有在使用 TF Hub 上託管的模型時,才需要設定這項額外屬性,這是特例情況。

載入完成後,您可以使用訊息設定 STATUS 元素的 innerText,如此一來,您就可以看見元素已正確載入並且可以開始收集資料了。

現在,唯一剩下的就是為模型暖身。如果是這類大型模型,首次使用時,系統可能需要一些時間才能完成設定。因此,建議在模型中略過零,避免將來的時間安排較為重要。

您可以使用在 tf.tidy() 中納入的 tf.zeros(),確保張量正確處理 (批量為 1),且在開始時在常數中定義的正確高度和寬度。最後,您也可以指定色彩頻道,在本例中為 3,因為模型預期的 RGB 圖片。

接著,使用 answer.shape() 記錄產生的張量形狀,幫助您瞭解這個模型產生的圖片特徵大小。

定義這個函式後,您便可立即呼叫該函式,在載入網頁時啟動模型下載作業。

如果您現在查看即時預覽,片刻後會顯示狀態文字從「正在等待 TF.js 載入」才變成「成功載入 MobileNet v3!」如下所示。請先確認這項功能可以正常運作,再繼續操作。

a28b734e190afff.png

您也可以查看控制台輸出內容,瞭解這個模型產生的輸出特徵的列印大小。透過 MobileNet 模型執行零後,您會看到 [1, 1024] 的形狀。第一個項目只是 1 的批量,實際上會傳回 1024 個特徵,方便您分類新物件。

10. 定義新的模型標頭

現在,您要定義模型頭部,基本上是非常最小的多層式感知器。

script.js

let model = tf.sequential();
model.add(tf.layers.dense({inputShape: [1024], units: 128, activation: 'relu'}));
model.add(tf.layers.dense({units: CLASS_NAMES.length, activation: 'softmax'}));

model.summary();

// Compile the model with the defined optimizer and specify a loss function to use.
model.compile({
  // Adam changes the learning rate over time which is useful.
  optimizer: 'adam',
  // Use the correct loss function. If 2 classes of data, must use binaryCrossentropy.
  // Else categoricalCrossentropy is used if more than 2 classes.
  loss: (CLASS_NAMES.length === 2) ? 'binaryCrossentropy': 'categoricalCrossentropy', 
  // As this is a classification problem you can record accuracy in the logs too!
  metrics: ['accuracy']  
});

現在就來逐步瞭解這個程式碼首先定義您要新增模型層的 tf.serial 模型。

接下來,新增稠密層做為這個模型的輸入層。這項輸入形狀是 1024,因為 MobileNet v3 功能的輸出大小為這個大小。您在透過模型傳遞模型後,到上一個步驟中發現了這個現象。這個層有 128 個使用 ReLU 活化函式的神經元。

如果您是剛接觸活化函式和模型層的新手,建議參加本研討會開頭的詳細說明課程,瞭解這些屬性背後的用途。

要新增的下一個層是輸出層。神經元的數量應等於您想預測的類別數量。方法是使用 CLASS_NAMES.length 來尋找您要分類的類別,也就是使用者介面中的資料收集按鈕數量。由於這是分類問題,您可以在這個輸出層使用 softmax 啟用,嘗試建立模型來解決分類問題,而非迴歸。

現在輸出 model.summary(),以在主控台顯示新定義模型的總覽。

最後是編譯模型,以便隨時進行訓練。此處的最佳化工具已設為 adam,如果 CLASS_NAMES.length 等於 2,損失會是 binaryCrossentropy;如果有 3 個以上的類別進行分類,則會使用 categoricalCrossentropy。系統也會要求準確度指標,以便日後監控記錄中進行偵錯。

您應該會在控制台中看到如下的內容:

22eaf32286fea4bb.png

請注意,這個可訓練參數超過 13 萬個,但由於這是簡單的標準神經元層,因此訓練速度相當快。

完成專案後,您可以嘗試變更第一層神經元的數量,看看在取得平衡的同時,其他神經元的數量可以降低。一般來說,機器學習技術需要一定程度的試驗和錯誤,才能找出最佳參數值,協助您在資源用量和速度之間取得最佳平衡。

11. 啟用網路攝影機

現在,您可以排出您先前定義的 enableCam() 函式。新增名為 hasGetUserMedia() 的函式 (如下所示),然後將先前定義 enableCam() 函式的內容替換為以下對應程式碼。

script.js

function hasGetUserMedia() {
  return !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
}

function enableCam() {
  if (hasGetUserMedia()) {
    // getUsermedia parameters.
    const constraints = {
      video: true,
      width: 640, 
      height: 480 
    };

    // Activate the webcam stream.
    navigator.mediaDevices.getUserMedia(constraints).then(function(stream) {
      VIDEO.srcObject = stream;
      VIDEO.addEventListener('loadeddata', function() {
        videoPlaying = true;
        ENABLE_CAM_BUTTON.classList.add('removed');
      });
    });
  } else {
    console.warn('getUserMedia() is not supported by your browser');
  }
}

首先,請建立名為 hasGetUserMedia() 的函式,藉由檢查主要瀏覽器 API 屬性是否存在,檢查瀏覽器是否支援 getUserMedia()

enableCam() 函式中,使用您在上方定義的 hasGetUserMedia() 函式確認是否受到支援。如果沒有,則在控制台顯示警告。

如果可支援,請定義 getUserMedia() 呼叫的一些限制 (例如只想觀看影片串流),且您希望影片 width 的大小為 640 像素,而 height 設定為 480 像素。這是因為影片大於 340 其實沒有太大意義,因為您必須將大小調整為 224 x 224 像素,才能加入 MobileNet 模型。此外,您也可以要求較低解析度,節省一些運算資源。多數相機支援這種解析度。

接下來,請使用上述的 constraints 呼叫 navigator.mediaDevices.getUserMedia(),然後等待傳回 stream。傳回 stream 後,您可以將 VIDEO 元素設為其 srcObject 值,藉此播放 stream

您也應在 VIDEO 元素中新增事件監聽器,得知 stream 何時載入並順利播放。

Steam 載入後,您可以將 videoPlaying 設為 true,並移除 ENABLE_CAM_BUTTON,方法是將其類別設為「removed」,避免再次點選該類別。

現在請執行程式碼,按一下啟用相機按鈕,然後允許存取網路攝影機。如果是你第一次執行這項作業,應該會在網頁上的影片元素中看到你的畫面,如下所示:

b378eb1affa9b883.png

好,現在要新增函式,處理 dataCollector 按鈕點擊的動作。

12. 資料收集按鈕事件處理常式

現在要填寫名為 gatherDataForClass(). 的現有空白函式。這是您在程式碼研究室一開始時指派給 dataCollector 按鈕的事件處理常式函式。

script.js

/**
 * Handle Data Gather for button mouseup/mousedown.
 **/
function gatherDataForClass() {
  let classNumber = parseInt(this.getAttribute('data-1hot'));
  gatherDataState = (gatherDataState === STOP_DATA_GATHER) ? classNumber : STOP_DATA_GATHER;
  dataGatherLoop();
}

首先,使用屬性名稱呼叫 this.getAttribute() (在本例中為 data-1hot 做為參數),藉此檢查目前點選按鈕上的 data-1hot 屬性。由於這是字串,所以您可以使用 parseInt() 將結果轉換為整數,並將這個結果指派給名為 classNumber. 的變數

接著,請據此設定 gatherDataState 變數。如果目前的 gatherDataState 等於 STOP_DATA_GATHER (設為 -1),表示您目前未收集任何資料,而是觸發的 mousedown 事件。將 gatherDataState 設為您剛找到的 classNumber

否則,表示您目前正在收集資料,而觸發的事件為 mouseup 事件,而現在您想要停止收集該類別的資料。只要將設定回 STOP_DATA_GATHER 狀態,即可結束您稍後定義的資料收集迴圈。

最後,啟動對 dataGatherLoop(), 的呼叫,以便實際記錄類別資料。

13. 資料收集

現在請定義 dataGatherLoop() 函數。這個函式負責從網路攝影機影片取樣圖片,並透過 MobileNet 模型傳遞圖片,然後擷取該模型的輸出內容 (1024 年特徵向量)。

然後儲存這些 ID 以及目前按下按鈕的 gatherDataState ID,以便瞭解此資料代表的類別。

以下將詳細說明:

script.js

function dataGatherLoop() {
  if (videoPlaying && gatherDataState !== STOP_DATA_GATHER) {
    let imageFeatures = tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [MOBILE_NET_INPUT_HEIGHT, 
          MOBILE_NET_INPUT_WIDTH], true);
      let normalizedTensorFrame = resizedTensorFrame.div(255);
      return mobilenet.predict(normalizedTensorFrame.expandDims()).squeeze();
    });

    trainingDataInputs.push(imageFeatures);
    trainingDataOutputs.push(gatherDataState);
    
    // Intialize array index element if currently undefined.
    if (examplesCount[gatherDataState] === undefined) {
      examplesCount[gatherDataState] = 0;
    }
    examplesCount[gatherDataState]++;

    STATUS.innerText = '';
    for (let n = 0; n < CLASS_NAMES.length; n++) {
      STATUS.innerText += CLASS_NAMES[n] + ' data count: ' + examplesCount[n] + '. ';
    }
    window.requestAnimationFrame(dataGatherLoop);
  }
}

只有在 videoPlaying 為「是」的情況下 (即網路攝影機已啟用),且 gatherDataState 不等於 STOP_DATA_GATHER,且目前按下收集類別資料的按鈕時,系統才會繼續執行這個函式。

接著,將程式碼納入 tf.tidy() 中,以在後續的程式碼中處置任何已建立的張量。這項 tf.tidy() 程式碼執行的結果會儲存在名為 imageFeatures 的變數中。

你現在可以使用 tf.browser.fromPixels() 擷取網路攝影機的 VIDEO 畫面。含有圖片資料產生的張量會儲存在名為 videoFrameAsTensor 的變數中。

接下來,將 videoFrameAsTensor 變數調整為 MobileNet 模型輸入內容的正確形狀。使用 tf.image.resizeBilinear() 呼叫,將您要重塑為第一個參數的張量,然後建立一個形狀,以根據您先前建立的常數定義定義新的高度和寬度。最後,傳遞第三個參數,避免調整大小時出現對齊問題,將對齊邊角設為 true。此調整大小的結果會儲存在名為 resizedTensorFrame 的變數中。

請注意,這種原始調整大小可延展影像,因為網路攝影機影像的大小為 640 x 480 像素,而模型需要 224 x 224 像素的正方形圖片。

為方便示範,這個範例應能正常運作。不過,完成本程式碼研究室後,建議您嘗試裁剪這張圖片中的正方形,這樣日後建立的任何正式系統都能獲得更優質的結果。

接著,將圖片資料正規化。使用 tf.browser.frompixels() 時,圖片資料一律會在 0 到 255 的範圍內,因此您只需將調整大小的 TensorFrame 除以 255,即可確保所有值都介於 0 到 1 之間,這也是 MobileNet 模型預期做為輸入內容的值。

最後,在程式碼的 tf.tidy() 區段,呼叫 mobilenet.predict() 以透過已載入的模型推送這個正規化張量,接著使用 expandDims() 傳遞 normalizedTensorFrame 的擴充版本,因此該批次為 1 的批次,因為模型預期有批次輸入來處理。

傳回結果後,您可以立即針對傳回結果呼叫 squeeze(),將其壓縮回 1D 張量,接著您會傳回該變數,並指派給從 tf.tidy() 擷取結果的 imageFeatures 變數。

現在您已取得 MobileNet 模型的 imageFeatures,因此可以將其推送至先前定義的 trainingDataInputs 陣列,以記錄資料。

您也可以將目前的 gatherDataState 推送到 trainingDataOutputs 陣列,以記錄這項輸入內容所代表的意義。

請注意,gatherDataState 變數會設為目前類別的數值 ID,也就是您在先前定義的 gatherDataForClass() 函式中點選按鈕時記錄的資料。

在此階段,您也可以依特定類別增加樣本數。為此,請先檢查 examplesCount 陣列中的索引是否已初始化。如果未定義,請設為 0 針對特定類別的數字 ID 初始化計數器,接著即可依目前 gatherDataState 遞增 examplesCount

現在,更新網頁上的 STATUS 元素文字,讓系統在擷取各類別時顯示目前各類別的次數。方法是透過 CLASS_NAMES 陣列迴圈,輸出人類可讀的名稱,以及 examplesCount 中相同索引的資料計數。

最後,請使用 dataGatherLoop 做為參數來呼叫 window.requestAnimationFrame(),以遞迴方式再次呼叫這個函式。這樣會繼續取樣影片影格,直到偵測到按鈕的 mouseup,且 gatherDataState 設為 STOP_DATA_GATHER, 時,資料收集迴圈就會結束。

如果您現在執行程式碼,應該可以點選「啟用相機」按鈕,等待網路攝影機載入,然後按住每個資料收集按鈕,收集每種資料類別的範例。各位可以看到我分別收集手機和我手的資料。

541051644a45131f.gif

您應該會看到狀態文字更新,因為狀態文字儲存了記憶體中的所有張量,如上方螢幕截圖所示。

14. 訓練與預測

下一步是為目前空白的 trainAndPredict() 函式導入程式碼,也就是進行遷移學習的地方。我們來看看程式碼:

script.js

async function trainAndPredict() {
  predict = false;
  tf.util.shuffleCombo(trainingDataInputs, trainingDataOutputs);
  let outputsAsTensor = tf.tensor1d(trainingDataOutputs, 'int32');
  let oneHotOutputs = tf.oneHot(outputsAsTensor, CLASS_NAMES.length);
  let inputsAsTensor = tf.stack(trainingDataInputs);
  
  let results = await model.fit(inputsAsTensor, oneHotOutputs, {shuffle: true, batchSize: 5, epochs: 10, 
      callbacks: {onEpochEnd: logProgress} });
  
  outputsAsTensor.dispose();
  oneHotOutputs.dispose();
  inputsAsTensor.dispose();
  predict = true;
  predictLoop();
}

function logProgress(epoch, logs) {
  console.log('Data for epoch ' + epoch, logs);
}

首先,請將 predict 設為 false,確保停止目前的預測作業。

接下來,請使用 tf.util.shuffleCombo() 重組輸入和輸出陣列,確保順序不會導致訓練作業發生問題。

將輸出陣列 trainingDataOutputs, 轉換為 int32 類型的 tensor1d,以便用於一種熱編碼。內容儲存在名為 outputsAsTensor 的變數中。

搭配這個 outputsAsTensor 變數使用 tf.oneHot() 函式,以及要編碼的類別數量上限 (也就是 CLASS_NAMES.length)。您的一個熱編碼輸出現在會儲存在名為 oneHotOutputs 的新張量中。

請注意,目前 trainingDataInputs 是已記錄張量的陣列。為了使用這些資料進行訓練,您需要將張量陣列轉換為一般 2D 張量。

為此,TensorFlow.js 程式庫中有一個名為 tf.stack() 的絕佳函式:

這會擷取張量陣列並相互堆疊,以產生更高維度張量做為輸出內容。此時會傳回張量 2D,也就是 1 個維度輸入的批次,每 1024 年包含已記錄的特徵,這就是必要的訓練。

接下來,await model.fit(),訓練自訂模型標頭。在這裡,您傳遞 inputsAsTensor 變數和 oneHotOutputs 來代表要分別用於輸入和目標輸出範例的訓練資料。在第 3 個參數的設定物件中,將 shuffle 設為 true,使用 5batchSizeepochs 設為 10,然後將 onEpochEndcallback 指定為即將定義的 logProgress 函式。

最後,您可以丟棄已建立的張量,因為模型現已訓練完成。然後將 predict 設為 true 以允許系統重新進行預測,然後呼叫 predictLoop() 函式即可開始預測即時網路攝影機拍攝的圖片。

您也可以定義 logProcess() 函式來記錄訓練狀態,這會在上述的 model.fit() 中使用,並在每次訓練後將結果輸出至主控台。

就快大功告成了!現在可以新增 predictLoop() 函式來進行預測。

核心預測迴圈

在這個階段中,您將實作主要預測迴圈,從網路攝影機取樣畫面,並持續透過瀏覽器即時預測畫面內容。

請檢查程式碼:

script.js

function predictLoop() {
  if (predict) {
    tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO).div(255);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor,[MOBILE_NET_INPUT_HEIGHT, 
          MOBILE_NET_INPUT_WIDTH], true);

      let imageFeatures = mobilenet.predict(resizedTensorFrame.expandDims());
      let prediction = model.predict(imageFeatures).squeeze();
      let highestIndex = prediction.argMax().arraySync();
      let predictionArray = prediction.arraySync();

      STATUS.innerText = 'Prediction: ' + CLASS_NAMES[highestIndex] + ' with ' + Math.floor(predictionArray[highestIndex] * 100) + '% confidence';
    });

    window.requestAnimationFrame(predictLoop);
  }
}

首先,請確認 predict 為 true,這樣只有在模型經過訓練且可供使用時,才能進行預測。

接著,您可以像在 dataGatherLoop() 函式中一樣,取得目前圖片的圖片功能。基本上,您可以使用 tf.browser.from pixels() 從網路攝影機擷取畫面,並將其正規化為 224 x 224 像素的大小,然後透過 MobileNet 模型傳遞資料,取得產生的圖片功能。

不過,您現在可以使用新訓練的模型標頭,傳送剛透過訓練模型的 predict() 函式找到的結果 imageFeatures,以實際執行預測結果。接著,您可以握壓產生的張量,使其再次變成 1 維度,並指派給名為 prediction 的變數。

透過這個 prediction,您可以使用 argMax() 尋找值最高的索引,然後使用 arraySync() 將這個產生的張量轉換為陣列,以透過 JavaScript 取得基礎資料,找出最高值元素的位置。這個值儲存在名為 highestIndex 的變數中。

您也可以直接對 prediction 張量呼叫 arraySync(),以相同的方式取得實際預測可信度分數。

您現在已經擁有使用 prediction 資料更新 STATUS 文字所需的一切內容。方法是在 CLASS_NAMES 陣列中查詢 highestIndex,然後從 predictionArray 擷取信心值,才能取得類別的人類可讀字串。如要以百分比顯示更清楚易懂,請將結果乘以 100 及 math.floor()

最後,您可以使用 window.requestAnimationFrame() 再次呼叫 predictionLoop(),取得影片串流的即時分類功能。如果您選擇使用新資料訓練新模型,此步驟將持續到 predict 設為 false 為止。

這把你帶到最後一個謎題了。實作「重設」按鈕。

15. 實作重設按鈕

即將完成!謎題的最後一個部分,就是實作重設按鈕以重新開始。目前空白 reset() 函式的程式碼如下。請繼續並依照下列方式更新:

script.js

/**
 * Purge data and start over. Note this does not dispose of the loaded 
 * MobileNet model and MLP head tensors as you will need to reuse 
 * them to train a new model.
 **/
function reset() {
  predict = false;
  examplesCount.length = 0;
  for (let i = 0; i < trainingDataInputs.length; i++) {
    trainingDataInputs[i].dispose();
  }
  trainingDataInputs.length = 0;
  trainingDataOutputs.length = 0;
  STATUS.innerText = 'No data collected';
  
  console.log('Tensors in memory: ' + tf.memory().numTensors);
}

首先,將 predict 設為 false,以停止任何執行中的預測迴圈。接著,將 examplesCount 陣列的內容長度設為 0,即可刪除陣列中的所有內容,這是清除陣列中所有內容的便利方法。

現在瀏覽目前記錄到的所有 trainingDataInputs,並確保其中每個張量的 dispose() 都能再次釋放記憶體,因為 JavaScript 垃圾收集器不會清理 Tensor。

完成後,您現在可以安全地將 trainingDataInputstrainingDataOutputs 陣列的陣列長度設為 0,以便清除這些長度。

最後,請將 STATUS 文字設為適當內容,然後輸出記憶體中保留的張量做為例行檢查。

請注意,由於 MobileNet 模型和您定義的多層感知不會被處理,因此記憶體中仍會有數百個張量。重設後,如要重新訓練,就必須將這些資料與新的訓練資料重複使用。

16. 立即試用

現在來測試您專屬的學習版本吧!

前往即時預覽,啟用網路攝影機,為會議室中的某些物件收集至少 30 個範例 (至少 30 個),接著為類別 2 的其他物件完成相同步驟,按一下「火車」,然後查看控制台記錄,瞭解進度。訓練速度應該相當快:

bf1ac3cc5b15740.gif

訓練完成後,請向相機展示物體,讓系統即時預測並顯示在網頁上方的狀態文字區域。如果遇到問題,請檢查我已完成的工作程式碼,看看您是否遺漏了任何內容。

17. 恭喜

恭喜!你剛剛在瀏覽器中實際使用 TensorFlow.js 完成了第一個遷移學習範例。

試試看、用各種物件進行測試,你可能會發現某些事物比其他物件更難辨識,尤其是與其他事物類似的情況。您可能需要加入更多課程或訓練資料,這樣才能清楚辨別。

重點回顧

在本程式碼研究室中,您已瞭解以下內容:

  1. 什麼是遷移學習?相較於完整訓練模型,遷移學習的優勢。
  2. 如何從 TensorFlow Hub 取得可重複使用的模型。
  3. 如何設定適合遷移學習的網頁應用程式。
  4. 如何載入及使用基礎模型來產生圖片特徵。
  5. 如何訓練新的預測頭像,該頭可以辨識網路攝影機圖像中的自訂物件。
  6. 如何使用產生的模型即時分類資料。

後續步驟

有現成的入門階段後,您可以思考哪些創意構想,讓這個機器學習模型範本,以做為日後開發的應用實例。您或許可以為目前從事的產業革新,讓公司內的員工訓練模型,將日常工作中最重要的事物分類?一切都有無限的可能。

也可以考慮免費進行這整堂課程,瞭解如何把目前在本程式碼研究室中的 2 個模型合併成 1 個模型,藉此提高效率。

此外,如要進一步瞭解原始可訓練機器應用程式背後的理論,請參閱這個教學課程

分享您的成果

此外,你可以輕鬆將今日成果拓展至其他創意用途,建議跳脫傳統思維並持續嘗試入侵。

請記得在社群媒體上標記我們,並加上 #MadeWithTFJS 主題標記,你的專案就有機會登上我們的 TensorFlow 網誌,甚至是日後的活動。我們很期待看到你的作品。

建議結帳網站