TensorFlow.js: 独自の「Teachable Machine」を作成するTensorFlow.js で転移学習を使用する

1. 始める前に

TensorFlow.js のモデルの使用は、ここ数年で飛躍的に増加しました。現在、多くの JavaScript デベロッパーは、既存の最先端のモデルを利用して、業界固有のカスタムデータで処理できるように再トレーニングすることを検討しています。既存のモデル(ベースモデルと呼ばれることもあります)を、似通っているが別の領域で使用する行為を転移学習と呼びます。

転移学習には、完全に空のモデルから始めるよりも多くの利点があります。事前トレーニング済みモデルからすでに学習した知識を再利用でき、分類する新しいアイテムの例を少なくできます。また、ネットワーク全体ではなく、モデル アーキテクチャの最後の数レイヤを再トレーニングするだけで済むため、トレーニングが大幅に高速化されることがよくあります。このため、転移学習は、実行デバイスによってリソースが異なる可能性がある一方で、センサーに直接アクセスしてデータを簡単に取得できるウェブブラウザ環境に適しています。

この Codelab では、空白のキャンバスからウェブアプリを作成し、Google の人気の「Teachable Machine」確認できますこのウェブサイトでは、ウェブカメラのサンプル画像をいくつか使用して、すべてのユーザーがカスタム オブジェクトを認識できる、正常に機能するウェブアプリを作成できます。この Codelab の機械学習の側面に焦点を当てられるように、ウェブサイトは意図的に最小限になっています。ただし、当初の Teachable Machine ウェブサイトと同様に、既存のウェブ デベロッパー エクスペリエンスを適用して UX を改善できる余地はたくさんあります。

前提条件

この Codelab は、TensorFlow.js の既製のモデルと基本的な API の使用方法にある程度精通しており、TensorFlow.js で転移学習を始めたいウェブ デベロッパーを対象としています。

  • このラボでは、TensorFlow.js、HTML5、CSS、JavaScript に関する基本的な知識があることを前提としています。

Tensorflow.js を初めて使用する場合は、まず無料の「ゼロからヒーロー」コースを受講してください。このコースでは、ML や TensorFlow.js の予備知識がない方を前提としており、知っておくべきポイントをすべて詳しく学ぶことができます。

学習内容

  • TensorFlow.js の概要と、次にウェブアプリで TensorFlow.js を使用すべき理由
  • Teachable Machine のユーザー エクスペリエンスを複製した、シンプルな HTML/CSS /JS ウェブページを作成する方法。
  • TensorFlow.js を使用して事前トレーニング済みのベースモデル(特に MobileNet)を読み込み、転移学習に使用できる画像の特徴を生成する方法。
  • 認識したい複数のクラスのデータをユーザーのウェブカメラから収集する方法。
  • 画像の特徴を取得し、それらを使用して新しいオブジェクトを分類することを学習するマルチレイヤ パーセプトロンを作成および定義する方法。

さあ、ハッキングをしよう...

必要なもの

  • Glitch.com のアカウントをお使いいただくのがおすすめです。または、ご自身で編集して運用できるウェブサービス環境を使用することもできます。

2. TensorFlow.js とは

54e81d02971f53e8.png

TensorFlow.js は、JavaScript が実行可能な場所であればどこでも実行できるオープンソースの ML ライブラリです。Python で記述された元の TensorFlow ライブラリをベースとし、このデベロッパー エクスペリエンスと JavaScript エコシステム用の API セットを再構築することを目的としています。

使用できる場所

JavaScript のポータビリティにより、1 つの言語で記述して、以下のすべてのプラットフォームで簡単に ML を実行できるようになりました。

  • 標準の JavaScript を使用するウェブブラウザのクライアントサイド
  • Node.js を使用したサーバーサイドのデバイスや IoT デバイス(Raspberry Pi など)
  • Electron を使用したデスクトップ アプリ
  • React Native を使用したネイティブ モバイルアプリ

TensorFlow.js は、これらの各環境(CPU や WebGL など、内部で実行できる実際のハードウェアベースの環境)内の複数のバックエンドもサポートしています。「バックエンド」このコンテキストでは、サーバー側の環境を意味するわけではありません。たとえば、実行のバックエンドは WebGL のクライアント側の場合があります)。これは、互換性を確保し、動作の高速化を維持するためです。現在、TensorFlow.js は次のものをサポートしています。

  • デバイスのグラフィック カード(GPU)での WebGL の実行 - GPU アクセラレーションを利用して、より大規模なモデル(サイズが 3 MB 超)を実行する最も高速な方法です。
  • CPU でのウェブ アセンブリ(WASM)の実行 - 古い世代のスマートフォンなど、さまざまなデバイスの CPU パフォーマンスを向上させます。これは小さいモデル(サイズが 3 MB 未満)に適しています。グラフィック プロセッサにコンテンツをアップロードするオーバーヘッドがあるため、WASM では WebGL よりも CPU での実行速度が速くなります。
  • CPU の実行 - 他のどの環境も使用できない場合にフォールバックします。これは 3 つのうち最も遅い方法ですが、常に役立ちます。

注: どのデバイスで実行するかがわかっている場合は、これらのバックエンドのいずれかを強制的に適用するか、指定しない場合は単に TensorFlow.js に処理を任せることもできます。

クライアントサイドの強み

クライアント マシンのウェブブラウザで TensorFlow.js を実行すると、検討に値するいくつかのメリットがあります。

プライバシー

サードパーティのウェブサーバーにデータを送信することなく、クライアント マシン上でデータのトレーニングと分類を行うことができます。GDPR など地域の法律への準拠が求められる場合や、ユーザーが自分のマシンに保存し、第三者に送信してはならないデータを処理する際に必要になることがあります。

速度

データをリモート サーバーに送信する必要がないため、推論(データを分類する作業)を高速化できます。さらに、ユーザーがアクセスを許可すれば、デバイスのセンサー(カメラ、マイク、GPS、加速度計など)に直接アクセスできるようになります。

リーチと規模

世界中の誰もが、送信したリンクをワンクリックでクリックし、ブラウザでウェブページを開き、あなたが作成したものを利用できます。CUDA ドライバを使用した複雑なサーバーサイドの Linux セットアップは不要で、ML システムを使用するだけで十分です。

費用

サーバーがないということは、HTML、CSS、JS、モデルファイルをホストするための CDN さえあれば有料です。CDN の費用は、サーバー(場合によってはグラフィック カードが接続されている)を 24 時間 365 日稼働させるよりもはるかに安価です。

サーバーサイドの機能

TensorFlow.js の Node.js 実装を利用すると、次の機能が有効になります。

CUDA のフルサポート

サーバー側でグラフィック カード アクセラレーションを使用するには、NVIDIA CUDA ドライバをインストールして、TensorFlow がグラフィック カードと連携できるようにする必要があります(WebGL を使用するブラウザとは異なり、インストールは不要です)。ただし、CUDA を完全にサポートすることで、グラフィック カードの下位レベルの機能をフルに活用でき、トレーニングと推論にかかる時間を短縮できます。両方とも同じ C++ バックエンドを共有しているため、パフォーマンスは Python TensorFlow 実装と同等です。

モデルサイズ

研究段階の最先端モデルの場合、ギガバイト規模の非常に大規模なモデルを扱っている可能性があります。ブラウザのタブごとのメモリ使用量の制限により、現在、これらのモデルをウェブブラウザで実行することはできません。このような大規模なモデルを実行するには、このようなモデルを効率的に実行するために必要なハードウェア仕様を備えた独自のサーバーで Node.js を使用できます。

IoT

Node.js は Raspberry Pi などの一般的なシングルボード コンピュータでサポートされているため、このようなデバイスでも TensorFlow.js モデルを実行できます。

速度

Node.js は JavaScript で記述されているため、ジャストインタイム コンパイルが有効です。つまり、Node.js を使用すると、実行時に最適化されるため、パフォーマンスが向上することがよくあります。これは特に前処理を行う場合に顕著です。その好例が、こちらのケーススタディで紹介されています。Hugging Face が Node.js を使用して自然言語処理モデルのパフォーマンスを 2 倍に高めた事例が紹介されています。

TensorFlow.js の基本や実行できる場所とメリットについて理解できたところで、実際に TensorFlow.js を使ってみましょう。

3. 転移学習

転移学習とは正確にはどのようなものですか。

転移学習とは、すでに習得した知識を、別のものの似たような学習を支援することです。

私たち人間は常にそうしています。脳には生涯にわたる経験があり、それを使用して、これまでに見たことのない新しいものを認識することができます。このヤナギの木の例を見てみましょう。

e28070392cd4afb9.png

お住まいの地域によっては、この種の樹木をこれまで見たことがないかもしれません。

下の新しい画像にヤナギの木が写っているかどうか教えてもらえば、先ほどお見せした元の画像とは少し違う角度からでも、すぐに見分けられるでしょう。

d9073a0d5df27222.png

脳には、木のような物体を識別する方法を理解するニューロンや、長い直線を見つけやすいニューロンがすでにたくさんあります。その知識を再利用して、ヤナギの木をすばやく分類できます。ヤナギの木は、長いまっすぐな垂直の枝をたくさん持つ、木のような物体です。

同様に、画像認識などの特定の分野ですでにトレーニング済みの ML モデルがある場合は、そのモデルを再利用して、別の関連タスクを実行できます。

MobileNet のような高度なモデルでも同じことができます。MobileNet は、1, 000 種類のオブジェクトに対して画像認識を実行できる非常に人気のある研究モデルです。犬から自動車まで、数百万ものラベル付き画像を含む ImageNet という巨大なデータセットでトレーニングされました。

このアニメーションでは、MobileNet V1 モデルに膨大な数のレイヤがあることがわかります。

7d4e1e35c1a89715.gif

このモデルはトレーニング中に、これら 1,000 個のオブジェクトすべてにとって重要な特徴を抽出する方法を学習しました。このようなオブジェクトの識別に使用する下位レベルの特徴の多くは、これまでに見たことのない新しいオブジェクトの検出にも活用できます。結局のところ、あらゆるものは線、テクスチャ、形状の組み合わせにすぎません。

従来の畳み込みニューラル ネットワーク(CNN)アーキテクチャ(MobileNet に類似)と、転移学習がこのトレーニングされたネットワークを活用して新しいことを学ぶ方法を見てみましょう。以下の画像は CNN の典型的なモデル アーキテクチャを示しています。この例では、0 ~ 9 の手書きの数字を認識するようにトレーニングされています。

baf4e3d434576106.png

左側に示されているように、既存のトレーニング済みモデルの事前トレーニング済みの下位レイヤを、右側に示されているモデルの終わりに近い分類レイヤ(モデルの分類ヘッドとも呼ばれます)から分離できるとしたら、その下位レイヤを使用して、トレーニングに使用した元のデータに基づいて任意の画像の出力特徴を生成できます。同じネットワークを分類ヘッドを除いたものは、次のようになります。

369a8a9041c6917d.png

認識しようとしている新しいものも、以前のモデルが学習した出力特徴を利用できると仮定すると、それらを新しい目的に再利用できる可能性が高いです。

上の図では、この架空のモデルは数字でトレーニングされているため、数字について学習した内容は a、b、c などの文字にも適用できる可能性があります。

そこで、次に示すように、代わりに a、b、c を予測する新しい分類ヘッドを追加できます。

db97e5e60ae73bbd.png

ここでは、下位レベルのレイヤは凍結されてトレーニングされません。新しい分類ヘッドのみが自身を更新して、左側の事前トレーニング済みのチョップアップ モデルから提供される特徴から学習します。

この処理は転移学習と呼ばれ、Teachable Machine はバックグラウンドで行われます。

また、ネットワークの最後で多層パーセプトロンをトレーニングするだけで済むため、ネットワーク全体をゼロからトレーニングするよりもはるかに速くトレーニングできることがわかります。

では、モデルのサブ部分を手に入れるにはどうすればよいのでしょうか。次のセクションに進みましょう。

4. TensorFlow Hub - ベースモデル

使用に適したベースモデルを見つける

MobileNet などのより高度で人気のあるリサーチモデルの場合は、TensorFlow Hub に移動し、MobileNet v3 アーキテクチャを使用する TensorFlow.js に適したモデルをフィルタして、以下に示すような結果を得ることができます。

c5dc1420c6238c14.png

なお、これらの結果の一部は「画像分類」タイプです。(詳細は各モデルカード結果の左上にある)、その他は「画像特徴ベクトル」タイプです。

これらの画像特徴ベクトルの結果は基本的に、あらかじめ切り取られた MobileNet のバージョンであり、最終的な分類の代わりに画像特徴ベクトルを取得するために使用できます。

このようなモデルは「ベースモデル」と呼ばれることもありますが、その後、新しい分類ヘッドを追加して独自のデータでトレーニングすることで、前のセクションで説明したのと同じ方法で転移学習を実行できます。

次に、対象のベースモデルについて、モデルがリリースされている TensorFlow.js 形式を確認します。これらの特徴ベクトル MobileNet v3 モデルのページを開くと、JS ドキュメントから、tf.loadGraphModel() を使用するドキュメントのサンプル コード スニペットに基づくグラフモデルの形式になっていることがわかります。

f97d903d2e46924b.png

モデルがグラフ形式ではなくレイヤ形式である場合は、トレーニングのために固定するレイヤと固定解除するレイヤを選択できます。これは、新しいタスクのモデルを作成するときに非常に効果的です。これはしばしば「移行モデル」と呼ばれます。ただし、このチュートリアルでは、デフォルトのグラフモデル タイプを使用します。ほとんどの TF Hub モデルは、このタイプでデプロイされます。レイヤモデルの操作について詳しくは、TensorFlow.js のゼロからヒーローへの変換コースをご覧ください。

転移学習のメリット

モデル アーキテクチャ全体をゼロからトレーニングする代わりに、転移学習を使用するメリットは何ですか。

まず、トレーニング済みのベースモデルがすでに構築されているため、転移学習アプローチを使用する際の重要なメリットはトレーニング時間です。

第 2 に、トレーニングがすでに実施されているため、分類しようとしている新しいもののサンプルを大幅に減らすことができます。

これは、分類対象のサンプルデータを収集するための時間とリソースが限られており、堅牢性を高めるために、より多くのトレーニング データを収集する前にプロトタイプを迅速に作成する必要がある場合に非常に便利です。

必要なデータ量が少なく、トレーニング速度が小さいことから、転移学習はリソース消費量が少なくなります。そのため、ブラウザ環境には非常に適しています。最新のマシンではわずか数十秒でモデルのフル トレーニングに数時間、数日、数週間かかっていましたが、

承知いたしました。転移学習の概要がわかったところで、次は独自のバージョンの Teachable Machine を作成しましょう。では始めましょう。

5. コーディングの準備

必要なもの

  • 最新のウェブブラウザ。
  • HTML、CSS、JavaScript、Chrome DevTools(コンソール出力の表示)に関する基本的な知識。

コーディングを始めましょう

Glitch.com または Codepen.io のボイラープレート テンプレートが用意されています。どちらのテンプレートも、この Codelab の基本状態としてワンクリックでクローンを作成できます。

Glitch で [remix this] ボタンをクリックしてフォークし、編集可能な新しいファイルセットを作成します。

または Codepen でfork] を選択します。

この非常にシンプルなスケルトンにより、次のファイルが提供されます。

  • HTML ページ(index.html)
  • スタイルシート(style.css)
  • JavaScript コードを記述するためのファイル(script.js)

HTML ファイルに TensorFlow.js ライブラリのインポートが追加されています。たとえば、次のようになります。

index.html

<!-- Import TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js" type="text/javascript"></script>

別の方法: お好みのウェブエディタを使用するか、ローカルで作業する

コードをダウンロードしてローカルに、または別のオンライン エディタで作業したい場合は、同じディレクトリに上記の 3 つのファイルを作成し、Glitch のボイラープレートのコードをコピーして、それぞれのファイルに貼り付けます。

6. アプリの HTML ボイラープレート

開始方法

すべてのプロトタイプには、結果をレンダリングできる基本的な HTML スキャフォールディングが必要です。今すぐ設定してください。以下を追加します。

  • ページのタイトル。
  • 説明文。
  • ステータスの段落。
  • 準備が整ったらウェブカメラのフィードを表示するための動画。
  • カメラの起動、データの収集、エクスペリエンスのリセットを行うための複数のボタン。
  • 後でコーディングする TensorFlow.js ファイルと JS ファイルのインポート。

index.html を開き、既存のコードを次のコードに上書きして、上記の機能を設定します。

index.html

<!DOCTYPE html>
<html lang="en">
  <head>
    <title>Transfer Learning - TensorFlow.js</title>
    <meta charset="utf-8">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <!-- Import the webpage's stylesheet -->
    <link rel="stylesheet" href="/style.css">
  </head>  
  <body>
    <h1>Make your own "Teachable Machine" using Transfer Learning with MobileNet v3 in TensorFlow.js using saved graph model from TFHub.</h1>
    
    <p id="status">Awaiting TF.js load</p>
    
    <video id="webcam" autoplay muted></video>
    
    <button id="enableCam">Enable Webcam</button>
    <button class="dataCollector" data-1hot="0" data-name="Class 1">Gather Class 1 Data</button>
    <button class="dataCollector" data-1hot="1" data-name="Class 2">Gather Class 2 Data</button>
    <button id="train">Train &amp; Predict!</button>
    <button id="reset">Reset</button>

    <!-- Import TensorFlow.js library -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.11.0/dist/tf.min.js" type="text/javascript"></script>

    <!-- Import the page's JavaScript to do some stuff -->
    <script type="module" src="/script.js"></script>
  </body>
</html>

ビートに乗る

上記の HTML コードの一部を分割して、追加した重要なポイントを強調しましょう。

  • ページタイトルに <h1> タグを追加し、ID が「status」の <p> タグを追加しました。出力の表示にシステムのさまざまな部分を使用するため、ここに情報を出力します。
  • ID が「webcam」の <video> 要素を追加しました。これに後でウェブカメラ ストリームをレンダリングします。
  • 5 つの <button> 要素を追加しました。1 つ目は「enableCam」という ID のカメラが有効になります。次の 2 つのボタンのクラスは「dataCollector」です。を使用すると、認識したいオブジェクトのサンプル画像を収集できます。後で記述するコードでは、これらのボタンを何個でも追加でき、意図したとおりに自動的に機能します。

これらのボタンには「data-1hot」という特別なユーザー定義の属性もあり、最初のクラスは 0 から始まる整数値になっています。これは、特定のクラスのデータを表すために使用する数値インデックスです。ML モデルは数値しか処理できないため、インデックスは文字列ではなく数値表現で出力クラスを正しくエンコードするために使用されます。

data-name 属性には、このクラスに使用する、人が読める形式の名前を含めることもできます。この属性を使用すると、1 ホット エンコーディングでの数値インデックス値の代わりに、より意味のある名前をユーザーに示すことができます。

最後に、[トレーニングとリセット] ボタンで、データが収集されたらトレーニング プロセスを開始したり、アプリをリセットしたりできます。

  • また、<script> のインポートを 2 つ追加しました。1 つは TensorFlow.js 用、もう 1 つは後で定義する script.js 用です。

7. スタイルを追加

要素のデフォルト

先ほど追加した HTML 要素のスタイルを追加して、正しくレンダリングされるようにします。要素の位置とサイズを正しく設定するために追加されるスタイルを以下に示します。特別なものではありません。学習可能な機械に関する動画で見たように、後でこれを追加してさらに UX を改善することもできます。

style.css

body {
  font-family: helvetica, arial, sans-serif;
  margin: 2em;
}

h1 {
  font-style: italic;
  color: #FF6F00;
}


video {
  clear: both;
  display: block;
  margin: 10px;
  background: #000000;
  width: 640px;
  height: 480px;
}

button {
  padding: 10px;
  float: left;
  margin: 5px 3px 5px 10px;
}

.removed {
  display: none;
}

#status {
  font-size:150%;
}

これで必要な処理はこれだけです。この時点で出力をプレビューすると、次のようになります。

81909685d7566dcb.png

8. JavaScript: キー定数とリスナー

キー定数を定義する

まず、アプリ全体で使用するキー定数を追加します。まず、script.js の内容を次の定数に置き換えます。

script.js

const STATUS = document.getElementById('status');
const VIDEO = document.getElementById('webcam');
const ENABLE_CAM_BUTTON = document.getElementById('enableCam');
const RESET_BUTTON = document.getElementById('reset');
const TRAIN_BUTTON = document.getElementById('train');
const MOBILE_NET_INPUT_WIDTH = 224;
const MOBILE_NET_INPUT_HEIGHT = 224;
const STOP_DATA_GATHER = -1;
const CLASS_NAMES = [];

その用途を細かく見ていきましょう。

  • STATUS は、ステータスの更新を記述する段落タグへの参照を保持するだけです。
  • VIDEO には、ウェブカメラのフィードを表示する HTML 動画要素への参照が保持されます。
  • ENABLE_CAM_BUTTONRESET_BUTTONTRAIN_BUTTON は、HTML ページからすべての主要なボタンへの DOM 参照を取得します。
  • MOBILE_NET_INPUT_WIDTHMOBILE_NET_INPUT_HEIGHT は、MobileNet モデルに想定される入力の幅と高さをそれぞれ定義します。このように、ファイルの先頭付近に定数として保存しておくと、後で別のバージョンを使用する場合に、値を一度だけ更新しやすくなります。値をさまざまな場所で置き換える必要がなくなります。
  • STOP_DATA_GATHER は - 1 に設定されています。このイベントには state の値が保存され、ユーザーがウェブカメラのフィードからデータを収集するためのボタンのクリックを停止したことを知ることができます。この番号にわかりやすい名前を付けることで、後でコードが読みやすくなります。
  • CLASS_NAMES はルックアップとして機能し、想定されるクラス予測の名前を人が読める形式で保持します。この配列は後で入力されます。

鍵要素への参照を取得できたので、次はイベント リスナーをそれらに関連付けてみましょう。

キーイベント リスナーを追加する

まず、次に示すように、キーボタンにクリック イベント ハンドラを追加します。

script.js

ENABLE_CAM_BUTTON.addEventListener('click', enableCam);
TRAIN_BUTTON.addEventListener('click', trainAndPredict);
RESET_BUTTON.addEventListener('click', reset);


function enableCam() {
  // TODO: Fill this out later in the codelab!
}


function trainAndPredict() {
  // TODO: Fill this out later in the codelab!
}


function reset() {
  // TODO: Fill this out later in the codelab!
}

ENABLE_CAM_BUTTON - クリックされたときに enableCam 関数を呼び出します。

TRAIN_BUTTON - クリックされると trainAndPredict を呼び出します。

RESET_BUTTON - クリックすると、呼び出しがリセットされます。

最後に、クラスが「dataCollector」のボタンをすべて確認します。document.querySelectorAll() を使用します。これは、次の条件と一致するドキュメントで見つかった要素の配列を返します。

script.js

let dataCollectorButtons = document.querySelectorAll('button.dataCollector');
for (let i = 0; i < dataCollectorButtons.length; i++) {
  dataCollectorButtons[i].addEventListener('mousedown', gatherDataForClass);
  dataCollectorButtons[i].addEventListener('mouseup', gatherDataForClass);
  // Populate the human readable names for classes.
  CLASS_NAMES.push(dataCollectorButtons[i].getAttribute('data-name'));
}


function gatherDataForClass() {
  // TODO: Fill this out later in the codelab!
}

コードの説明:

次に、見つかったボタンを反復処理し、2 つのイベント リスナーをそれぞれに関連付けます。1 つは「mousedown」、もう 1 つは「mouseup」です。ボタンが押されている間はサンプルを記録し続けることができるので、データ収集に便利です。

どちらのイベントも、後で定義する gatherDataForClass 関数を呼び出します。

この時点で、見つかった人が読める形式のクラス名を HTML ボタン属性の data-name から CLASS_NAMES 配列にプッシュすることもできます。

次に、後で使用する重要なものを格納する変数をいくつか追加します。

script.js

let mobilenet = undefined;
let gatherDataState = STOP_DATA_GATHER;
let videoPlaying = false;
let trainingDataInputs = [];
let trainingDataOutputs = [];
let examplesCount = [];
let predict = false;

それぞれ見ていきましょう

まず、読み込まれたモバイルネット モデルを格納する変数 mobilenet を用意します。最初は未定義に設定します。

次は、gatherDataState という変数です。「dataCollector」がボタンが押されると、HTML で定義されているように、そのボタンのホットな ID に変わります。そのため、その時点で収集しているデータのクラスを把握できます。最初は STOP_DATA_GATHER に設定されており、後で作成するデータ収集ループは、ボタンが押されていないときにデータを収集しないようにします。

videoPlaying は、ウェブカメラ ストリームが正常に読み込まれて再生され、使用可能かどうかを追跡します。ENABLE_CAM_BUTTON. を押すまでウェブカメラがオンにならないため、最初は false に設定されています

次に、trainingDataInputstrainingDataOutputs の 2 つの配列を定義します。[dataCollector] をクリックすると、収集されたトレーニング データの値が保存されます。MobileNet ベースモデルによって生成された入力特徴と、サンプリングされた出力クラスのそれぞれのボタンがあります。

最後の配列として examplesCount, を定義して、クラスの追加を開始した後、各クラスに含まれるサンプルの数を追跡します。

最後は、予測ループを制御する predict という変数です。初期設定では false に設定されています。後でこれが true に設定されるまで、予測は行われません。

すべての主要な変数が定義されたので、分類の代わりに画像特徴ベクトルを提供する、あらかじめ切り取られた MobileNet v3 ベースモデルを読み込みましょう。

9. MobileNet ベースモデルを読み込む

まず、以下に示すように loadMobileNetFeatureModel という新しい関数を定義します。モデルの読み込みは非同期であるため、これは非同期関数である必要があります。

script.js

/**
 * Loads the MobileNet model and warms it up so ready for use.
 **/
async function loadMobileNetFeatureModel() {
  const URL = 
    'https://tfhub.dev/google/tfjs-model/imagenet/mobilenet_v3_small_100_224/feature_vector/5/default/1';
  
  mobilenet = await tf.loadGraphModel(URL, {fromTFHub: true});
  STATUS.innerText = 'MobileNet v3 loaded successfully!';
  
  // Warm up the model by passing zeros through it once.
  tf.tidy(function () {
    let answer = mobilenet.predict(tf.zeros([1, MOBILE_NET_INPUT_HEIGHT, MOBILE_NET_INPUT_WIDTH, 3]));
    console.log(answer.shape);
  });
}

// Call the function immediately to start loading.
loadMobileNetFeatureModel();

このコードでは、TFHub ドキュメントから、読み込むモデルが配置されている URL を定義します。

次に、await tf.loadGraphModel() を使用してモデルを読み込みます。この Google ウェブサイトからモデルを読み込むときは、特別なプロパティ fromTFHubtrue に設定します。これは、TF Hub でホストされているモデルのみを使用する特殊なケースで、この追加プロパティを設定する必要があります。

読み込みが完了したら、STATUS 要素の innerText にメッセージを設定して、正しく読み込まれたことを視覚的に確認し、データの収集を開始できます。

あとはモデルをウォームアップするだけです。このような大規模なモデルでは、モデルを初めて使用するときに、すべての設定に時間がかかることがあります。そのため、モデル全体に 0 を渡して、タイミングがより重要になりうる将来の待機を回避できるようにすると効果的です。

テンソルを tf.tidy() でラップした tf.zeros() を使用すると、バッチサイズ 1、開始時に定数で定義した正しい高さと幅でテンソルを正しく破棄できます。最後に、カラーチャンネルも指定します。この例では、モデルが RGB 画像を想定しているため、3 です。

次に、answer.shape() を使用して返されたテンソルの形状をログに記録し、このモデルが生成する画像特徴のサイズを把握できるようにします。

この関数を定義したら、すぐにこれを呼び出して、ページの読み込み時にモデルのダウンロードを開始できます。

ライブ プレビューを表示すると、しばらくするとステータス テキストが [Awaiting TF.js load](TF.js の読み込みを待機中)から変化します。「MobileNet v3 load successfully!」と表示されます。下に示します。続行する前に、これが機能することを確認してください。

a28b734e190afff.png

コンソールの出力で、このモデルで生成される出力特徴の印刷サイズを確認することもできます。MobileNet モデルにゼロを実行すると、[1, 1024] の形状が出力されます。最初のアイテムはバッチサイズ 1 だけです。実際には 1, 024 個の特徴が返され、これを使用して新しいオブジェクトを分類できます。

10. 新しいモデルヘッドの定義

次に、モデルヘッドを定義します。これは、基本的に非常に最小限の多層パーセプトロンです。

script.js

let model = tf.sequential();
model.add(tf.layers.dense({inputShape: [1024], units: 128, activation: 'relu'}));
model.add(tf.layers.dense({units: CLASS_NAMES.length, activation: 'softmax'}));

model.summary();

// Compile the model with the defined optimizer and specify a loss function to use.
model.compile({
  // Adam changes the learning rate over time which is useful.
  optimizer: 'adam',
  // Use the correct loss function. If 2 classes of data, must use binaryCrossentropy.
  // Else categoricalCrossentropy is used if more than 2 classes.
  loss: (CLASS_NAMES.length === 2) ? 'binaryCrossentropy': 'categoricalCrossentropy', 
  // As this is a classification problem you can record accuracy in the logs too!
  metrics: ['accuracy']  
});

このコードを詳しく見ていきましょう。まず、モデルレイヤを追加する tf.sequence モデルを定義します。

次に、このモデルの入力レイヤとして Dense レイヤを追加します。MobileNet v3 機能からの出力がこのサイズのため、入力形状は 1024 です。これは、前のステップでモデルに渡した後のステップで検出しました。この層には、ReLU 活性化関数を使用する 128 個のニューロンがあります。

活性化関数とモデルレイヤを初めて使用する場合は、このワークショップの冒頭で紹介したコースを受講して、これらのプロパティがバックグラウンドで何をするのかを理解してください。

次に追加するレイヤは出力レイヤです。ニューロンの数は、予測するクラスの数と同じである必要があります。そのためには、CLASS_NAMES.length を使用して、分類予定のクラスの数を確認します。これは、ユーザー インターフェースにあるデータ収集ボタンの数と同じです。これは分類問題であるため、この出力レイヤでは softmax 活性化を使用します。これは、分類問題を解決するために回帰ではなく、モデルを作成するときに使用する必要があります。

次に、model.summary() を出力して、新しく定義されたモデルの概要をコンソールに出力します。

最後に、トレーニングの準備ができるようにモデルをコンパイルします。ここでは、オプティマイザーが adam に設定されており、CLASS_NAMES.length2 と等しい場合は損失が binaryCrossentropy になり、分類するクラスが 3 つ以上ある場合は categoricalCrossentropy が使用されます。精度指標もリクエストされ、後でデバッグのためにログでモニタリングできます。

コンソールに次のように表示されます。

22eaf32286fea4bb.png

これには 13 万を超えるトレーニング可能なパラメータがあります。しかし、これは規則正しいニューロンの単純な高密度層なので、かなり速くトレーニングされます。

プロジェクト完了後に行う作業として、最初のレイヤのニューロン数を変更して、妥当なパフォーマンスを維持しつつ、ニューロン数をどれだけ低くできるか試してみましょう。ML では多くの場合、リソース使用量と速度の最適なトレードオフとなる最適なパラメータ値を見つけるために、ある程度の試行錯誤が必要になります。

11. ウェブカメラを有効にする

次に、前に定義した enableCam() 関数を具体化します。以下に示すように、hasGetUserMedia() という名前の新しい関数を追加し、以前に定義した enableCam() 関数の内容を以下の対応するコードに置き換えます。

script.js

function hasGetUserMedia() {
  return !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
}

function enableCam() {
  if (hasGetUserMedia()) {
    // getUsermedia parameters.
    const constraints = {
      video: true,
      width: 640, 
      height: 480 
    };

    // Activate the webcam stream.
    navigator.mediaDevices.getUserMedia(constraints).then(function(stream) {
      VIDEO.srcObject = stream;
      VIDEO.addEventListener('loadeddata', function() {
        videoPlaying = true;
        ENABLE_CAM_BUTTON.classList.add('removed');
      });
    });
  } else {
    console.warn('getUserMedia() is not supported by your browser');
  }
}

まず、hasGetUserMedia() という名前の関数を作成し、主要なブラウザ API プロパティの存在をチェックして、ブラウザが getUserMedia() をサポートしているかどうかを確認します。

enableCam() 関数で、先ほど定義した hasGetUserMedia() 関数を使用して、サポートされているかどうかを確認します。有効になっていない場合は、コンソールに警告を出力します。

サポートしている場合は、getUserMedia() 呼び出しにいくつかの制約(動画ストリームのみにする、動画の width のサイズを 640 ピクセル、height480 ピクセルにするなど)を定義します。その理由は、これより大きい動画は、MobileNet モデルにフィードするために 224 x 224 ピクセルにサイズ変更する必要があるため、あまり意味がありません。また、解像度を下げてリクエストするコンピューティング リソースを節約することもできます。ほとんどのカメラは、このサイズの解像度をサポートしています。

次に、前述の constraints を使用して navigator.mediaDevices.getUserMedia() を呼び出し、stream が返されるまで待ちます。stream が返されたら、VIDEO 要素を srcObject 値として設定することで、stream を再生させることができます。

また、stream が読み込まれて正常に再生されていることを把握するために、VIDEO 要素に eventListener を追加する必要があります。

スチームが読み込まれたら、videoPlaying を true に設定し、ENABLE_CAM_BUTTON を削除して、クラスを「removed」に設定することで再度クリックされないようにできます。

コードを実行し、[カメラを有効にする] ボタンをクリックして、ウェブカメラへのアクセスを許可します。この操作を初めて行う場合は、次のように、ページの動画要素に自身がレンダリングされているはずです。

b378eb1affa9b883.png

次に、dataCollector ボタンのクリックを処理する関数を追加します。

12. データ収集ボタンのイベント ハンドラ

次に、現在空になっている gatherDataForClass(). という関数に入力します。これは、Codelab の冒頭で dataCollector ボタンのイベント ハンドラ関数として割り当てたものです。

script.js

/**
 * Handle Data Gather for button mouseup/mousedown.
 **/
function gatherDataForClass() {
  let classNumber = parseInt(this.getAttribute('data-1hot'));
  gatherDataState = (gatherDataState === STOP_DATA_GATHER) ? classNumber : STOP_DATA_GATHER;
  dataGatherLoop();
}

まず、属性名(この場合は data-1hot)を指定して this.getAttribute() を呼び出し、現在クリックされているボタンの data-1hot 属性を確認します。これは文字列であるため、parseInt() を使用して整数にキャストし、この結果を classNumber. という名前の変数に代入できます。

次に、それに応じて gatherDataState 変数を設定します。現在の gatherDataStateSTOP_DATA_GATHER(-1 に設定)と等しい場合は、現在データを収集しておらず、mousedown イベントが発生したことを意味します。gatherDataState を、先ほど見つけた classNumber に設定します。

それ以外の場合は、現在データを収集していて、呼び出されたイベントが mouseup イベントであったため、そのクラスのデータの収集を停止することを意味します。これを STOP_DATA_GATHER 状態に戻して、この後で定義するデータ収集ループを終了します。

最後に、クラスデータの記録を実際に実行する dataGatherLoop(), の呼び出しを開始します。

13. データ収集

ここで、dataGatherLoop() 関数を定義します。この関数は、ウェブカメラの動画から画像をサンプリングして MobileNet モデルに渡して、そのモデルの出力(1,024 個の特徴ベクトル)をキャプチャします。

次に、現在押されているボタンの gatherDataState ID とともに保存されるため、このデータが表すクラスを把握できます。

詳しく見ていきましょう。

script.js

function dataGatherLoop() {
  if (videoPlaying && gatherDataState !== STOP_DATA_GATHER) {
    let imageFeatures = tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [MOBILE_NET_INPUT_HEIGHT, 
          MOBILE_NET_INPUT_WIDTH], true);
      let normalizedTensorFrame = resizedTensorFrame.div(255);
      return mobilenet.predict(normalizedTensorFrame.expandDims()).squeeze();
    });

    trainingDataInputs.push(imageFeatures);
    trainingDataOutputs.push(gatherDataState);
    
    // Intialize array index element if currently undefined.
    if (examplesCount[gatherDataState] === undefined) {
      examplesCount[gatherDataState] = 0;
    }
    examplesCount[gatherDataState]++;

    STATUS.innerText = '';
    for (let n = 0; n < CLASS_NAMES.length; n++) {
      STATUS.innerText += CLASS_NAMES[n] + ' data count: ' + examplesCount[n] + '. ';
    }
    window.requestAnimationFrame(dataGatherLoop);
  }
}

videoPlaying が true(ウェブカメラがアクティブ)で、gatherDataStateSTOP_DATA_GATHER と等しくなく、クラスデータ収集のボタンが現在押されている場合にのみ、この関数の実行を続行します。

次に、コードを tf.tidy() でラップして、後続のコードで作成したテンソルを破棄します。この tf.tidy() コード実行の結果は、imageFeatures という変数に格納されます。

tf.browser.fromPixels() を使用して、ウェブカメラ VIDEO のフレームをキャプチャできるようになりました。生成された画像データを含むテンソルは、videoFrameAsTensor という変数に格納されます。

次に、MobileNet モデルの入力に適した形状になるように videoFrameAsTensor 変数のサイズを変更します。tf.image.resizeBilinear() 呼び出しを使用します。最初のパラメータとして再形成するテンソルを指定し、次に以前に作成した定数で定義された新しい高さと幅を定義するシェイプを指定します。最後に、3 つ目のパラメータを渡して角の位置を true に設定し、サイズ変更時に配置の問題が発生しないようにします。このサイズ変更の結果は resizedTensorFrame という変数に格納されます。

ウェブカメラの画像のサイズは 640 x 480 ピクセルで、モデルには 224 x 224 ピクセルの正方形の画像が必要なため、このプリミティブのサイズ変更によって画像が引き伸ばされます。

このデモでは、これで十分です。ただし、この Codelab を完了したら、後で作成する本番環境システムでより良い結果を得るために、この画像から正方形を切り抜くことをおすすめします。

次に、画像データを正規化します。tf.browser.frompixels() を使用する場合、画像データは常に 0 ~ 255 の範囲内となります。そのため、サイズ変更された TensorFrame を 255 で除算すると、すべての値が 0 ~ 1 の間に収まるようになります。これは MobileNet モデルが入力として想定する値です。

最後に、コードの tf.tidy() セクションで、mobilenet.predict() を呼び出して、この正規化されたテンソルを、読み込まれたモデルに push します。このテンソルには、expandDims() を使用して normalizedTensorFrame の拡張バージョンを渡します。これにより、モデルは入力のバッチを処理するために 1 のバッチになります。

結果が返されたら、その返された結果に対してすぐに squeeze() を呼び出して 1 次元のテンソルに戻します。その後、このテンソルを返して、tf.tidy() から結果を取得する imageFeatures 変数に代入します。

MobileNet モデルから imageFeatures を取得したので、前に定義した trainingDataInputs 配列にそれらを push して記録できます。

現在の gatherDataStatetrainingDataOutputs 配列に push して、この入力が表す内容を記録することもできます。

gatherDataState 変数は、以前に定義した gatherDataForClass() 関数でボタンがクリックされたときに、データを記録する現在のクラスの数値 ID に設定されます。

この時点で、特定のクラスに対して持つサンプルの数を増やすこともできます。これを行うには、まず examplesCount 配列内のインデックスが以前に初期化されているかどうかを確認します。未定義の場合は、0 に設定して特定のクラスの数値 ID のカウンタを初期化し、現在の gatherDataStateexamplesCount をインクリメントできます。

次に、ウェブページの STATUS 要素のテキストを更新して、キャプチャ時に各クラスの現在のカウントを表示します。これを行うには、CLASS_NAMES 配列をループし、人が読める形式の名前と、examplesCount の同じインデックスにあるデータ数を組み合わせて出力します。

最後に、パラメータとして dataGatherLoop を渡して window.requestAnimationFrame() を呼び出し、この関数を再帰的に再度呼び出します。これにより、ボタンの mouseup が検出され、gatherDataStateSTOP_DATA_GATHER, に設定され、データ収集ループが終了する時点まで、動画からフレームがサンプリングされ続けます。

ここでコードを実行すると、[Enable camera] ボタンをクリックし、ウェブカメラが読み込まれるのを待ちます。その後、各データ収集ボタンをクリックして長押しすると、各クラスのデータの例を収集できるようになります。ここには、私の携帯電話と手のデータがそれぞれ収集されています。

541051644a45131f.gif

上のスクリーン キャプチャに示すように、すべてのテンソルがメモリに保存されるため、ステータス テキストが更新されます。

14. トレーニングと予測

次のステップでは、現在空の trainAndPredict() 関数のコードを実装します。ここで、転移学習が行われます。コードを見てみましょう。

script.js

async function trainAndPredict() {
  predict = false;
  tf.util.shuffleCombo(trainingDataInputs, trainingDataOutputs);
  let outputsAsTensor = tf.tensor1d(trainingDataOutputs, 'int32');
  let oneHotOutputs = tf.oneHot(outputsAsTensor, CLASS_NAMES.length);
  let inputsAsTensor = tf.stack(trainingDataInputs);
  
  let results = await model.fit(inputsAsTensor, oneHotOutputs, {shuffle: true, batchSize: 5, epochs: 10, 
      callbacks: {onEpochEnd: logProgress} });
  
  outputsAsTensor.dispose();
  oneHotOutputs.dispose();
  inputsAsTensor.dispose();
  predict = true;
  predictLoop();
}

function logProgress(epoch, logs) {
  console.log('Data for epoch ' + epoch, logs);
}

まず、predictfalse に設定して、現在の予測の実行を停止します。

次に、tf.util.shuffleCombo() を使用して入力配列と出力配列をシャッフルし、順序によってトレーニングで問題が発生しないようにします。

出力配列 trainingDataOutputs, を int32 型の tensor1d に変換して、ワンホット エンコーディングで使用できるようにします。これは outputsAsTensor という名前の変数に格納されます。

この outputsAsTensor 変数と、エンコードするクラスの最大数(CLASS_NAMES.length のみ)を指定して、tf.oneHot() 関数を使用します。ホット エンコードされた 1 つの出力が、oneHotOutputs という新しいテンソルに格納されます。

現在、trainingDataInputs は記録されたテンソルの配列です。これらをトレーニングに使用するには、テンソルの配列を通常の 2 次元テンソルに変換する必要があります。

そのために、TensorFlow.js ライブラリには tf.stack() という優れた関数が含まれています。

この関数では、テンソルの配列を受け取り、それらをスタックして、高次元のテンソルを出力として生成します。このケースでは 2 次元テンソルが返されます。これは、それぞれ 1,024 長の 1 次元の入力で、記録された特徴が含まれています。これがトレーニングに必要なものです。

次に、await model.fit() でカスタムモデルヘッドをトレーニングします。ここでは、入力例とターゲット出力のそれぞれに使用するトレーニング データを表すために、inputsAsTensor 変数と oneHotOutputs を渡しています。3 番目のパラメータの構成オブジェクトで、shuffletrue に設定し、epochs10 に設定して 5batchSize を使用します。次に、onEpochEndcallback を、この後で定義する logProgress 関数に指定します。

最後に、モデルのトレーニングが完了したら、作成したテンソルを破棄します。次に、predicttrue に戻して予測が再び行われるようにしてから、predictLoop() 関数を呼び出してライブ ウェブカメラの画像の予測を開始できます。

logProcess() 関数を定義してトレーニングの状態をログに記録することもできます。これは上記の model.fit() で使用され、トレーニングの各ラウンドの後に結果をコンソールに出力します。

あと少しです。predictLoop() 関数を追加して予測を行う時間。

コア予測ループ

ここでは、ウェブカメラのフレームをサンプリングし、ブラウザでリアルタイムの結果を使用して各フレーム内のものを連続的に予測するメイン予測ループを実装します。

コードを確認してみましょう。

script.js

function predictLoop() {
  if (predict) {
    tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO).div(255);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor,[MOBILE_NET_INPUT_HEIGHT, 
          MOBILE_NET_INPUT_WIDTH], true);

      let imageFeatures = mobilenet.predict(resizedTensorFrame.expandDims());
      let prediction = model.predict(imageFeatures).squeeze();
      let highestIndex = prediction.argMax().arraySync();
      let predictionArray = prediction.arraySync();

      STATUS.innerText = 'Prediction: ' + CLASS_NAMES[highestIndex] + ' with ' + Math.floor(predictionArray[highestIndex] * 100) + '% confidence';
    });

    window.requestAnimationFrame(predictLoop);
  }
}

まず、モデルがトレーニングされて使用可能になった後でのみ予測が行われるように、predict が true であることを確認します。

次に、dataGatherLoop() 関数の場合と同様に、現在の画像の画像特徴を取得します。基本的には、tf.browser.from pixels() を使用してウェブカメラからフレームを取得して正規化し、サイズを 224 x 224 ピクセルにサイズ変更し、そのデータを MobileNet モデルに渡して、結果として得られる画像の特徴を取得します。

ただし、トレーニング済みモデルの predict() 関数で結果の imageFeatures を渡すことで、新しくトレーニングされたモデルのヘッドを使用して実際に予測を実行できます。次に、結果のテンソルをスクイーズして再び 1 次元にし、prediction という変数に代入します。

この prediction では、argMax() を使用して最大値を持つインデックスを見つけ、arraySync() を使用してこの結果のテンソルを配列に変換し、JavaScript で基になるデータを取得し、最も高い値要素の位置を見つけることができます。この値は highestIndex という変数に格納されます。

同じ方法で実際の予測信頼スコアを取得するには、prediction テンソルで arraySync() を直接呼び出します。

これで、STATUS テキストを prediction データで更新するために必要なすべての準備が整いました。人が読める形式の文字列を取得するには、CLASS_NAMES 配列で highestIndex を検索し、predictionArray から信頼値を取得します。パーセンテージで読みやすくするには、100 を掛けて math.floor() を掛けます。

最後に、準備ができたら window.requestAnimationFrame() を使用して predictionLoop() をもう一度呼び出し、動画ストリームでリアルタイム分類を行うことができます。これは、新しいデータで新しいモデルをトレーニングすることを選択した場合、predictfalse に設定されるまで続きます。

そうすると、パズルの最後のピースが現れます。リセットボタンを実装する。

15. リセットボタンを実装する

あと少しで完了です。パズルの最後のピースは、最初からやり直すリセットボタンを実装することです。現在空の reset() 関数のコードを以下に示します。次のように更新します。

script.js

/**
 * Purge data and start over. Note this does not dispose of the loaded 
 * MobileNet model and MLP head tensors as you will need to reuse 
 * them to train a new model.
 **/
function reset() {
  predict = false;
  examplesCount.length = 0;
  for (let i = 0; i < trainingDataInputs.length; i++) {
    trainingDataInputs[i].dispose();
  }
  trainingDataInputs.length = 0;
  trainingDataOutputs.length = 0;
  STATUS.innerText = 'No data collected';
  
  console.log('Tensors in memory: ' + tf.memory().numTensors);
}

まず、predictfalse に設定して、実行中の予測ループを停止します。次に、長さを 0 に設定して examplesCount 配列の内容をすべて削除します。これにより、配列からすべての内容を簡単に消去できます。

次に、現在記録されているすべての trainingDataInputs を調べ、そこに含まれている各テンソルの dispose() を実行してメモリを再び解放します。テンソルは JavaScript ガベージ コレクタによってクリーンアップされないためです。

それが完了したら、trainingDataInputs 配列と trainingDataOutputs 配列の両方で配列の長さを安全に 0 に設定し、これらも消去します。

最後に、STATUS テキストを適切な内容に設定し、メモリに残っているテンソルをサニティ チェックとして出力します。

なお、MobileNet モデルと定義したマルチレイヤ パーセプトロンの両方が破棄されないため、メモリにはまだ数百のテンソルが残ります。このリセット後に再度トレーニングを行う場合は、新しいトレーニング データでモデルを再利用する必要があります。

16. やってみましょう

ぜひ、自分だけの Teachable Machine をお試しください。

ライブ プレビューに移動してウェブカメラを有効にし、部屋のオブジェクトについてクラス 1 のサンプルを 30 以上収集します。次に、別のオブジェクトについてクラス 2 でも同じ処理を行い、[トレーニング] をクリックしてコンソールログで進行状況を確認します。トレーニングはかなり速くなります。

bf1ac3cc5b15740.gif

トレーニングが完了したら、オブジェクトをカメラに見せ、ライブ予測を取得します。この予測は、ウェブページの上部にあるステータス テキスト領域に出力されます。問題が発生した場合は、完成した正常に機能するコードを確認して、コピーされていないものがないかご確認ください。

17. 完了

これで、ブラウザで TensorFlow.js を使用して、初めての転移学習の例を完了しました。

さまざまな物体で試してみてください。特に他のものと比べて、認識が難しいものがあることに気づくかもしれません。これらのクラスとトレーニング データを区別できるようにするには、さらにクラスまたはトレーニング データを追加する必要があります。

内容のまとめ

この Codelab では、以下のことを学びました。

  1. 転移学習の概要と、完全なモデルをトレーニングすることに対するメリット
  2. 再利用するモデルを TensorFlow Hub から取得する方法。
  3. 転移学習に適したウェブアプリを設定する方法。
  4. ベースモデルを読み込んで使用して画像の特徴を生成する方法。
  5. ウェブカメラの画像からカスタム オブジェクトを認識できる新しい予測ヘッドをトレーニングする方法。
  6. 結果のモデルを使用してデータをリアルタイムで分類する方法。

次のステップ

作業の土台が整ったところで、この機械学習モデルのボイラープレートを現在取り組んでいる実際のユースケースに拡張するために、どのようなクリエイティブなアイデアを思いつくことができるでしょうか。もしかしたら、現在の業界に革命を起こし、会社の従業員が日常業務で重要なことを分類するためのモデルをトレーニングできるようにできたらいいと思いませんか?可能性は無限大です

さらに学習を進めるには、こちらのフルコースを無料で受講することもできます。このコースでは、この Codelab で現在使用している 2 つのモデルを 1 つのモデルに統合して効率を向上させる方法をご紹介しています。

また、オリジナルのティータブル マシン アプリケーションの背後にある理論について詳しくは、こちらのチュートリアルをご覧ください。

成功事例を共有する

今作成したものは、他のクリエイティブなユースケースにも簡単に拡張できます。既成概念にとらわれず、ハッキングを続けることをおすすめします。

ソーシャル メディアで #MadeWithTFJS ハッシュタグを付けて、Google のプロジェクトが TensorFlow ブログ今後のイベントで紹介されるチャンスがありますので、お忘れなく。皆さんが作るものを楽しみにしています。

おすすめのウェブサイト