TensorFlow.js: 独自の「Teachable Machine」を作成するTensorFlow.js で転移学習を使用する

1. 始める前に

TensorFlow.js モデルの使用は過去数年間で飛躍的に増加しており、多くの JavaScript デベロッパーが、既存の最先端モデルを取得して、業界固有のカスタムデータで動作するように再トレーニングすることを検討しています。既存のモデル(ベースモデルと呼ばれることが多い)を取得し、類似しているが異なるドメインで使用する行為は、転移学習と呼ばれます。

転移学習には、完全に空白のモデルから始めるよりも多くの利点があります。以前にトレーニングしたモデルからすでに学習した知識を再利用できるため、分類する新しいアイテムの例が少なくて済みます。また、ネットワーク全体ではなくモデル アーキテクチャの最後の数層のみを再トレーニングすればよいことが多いため、トレーニングが大幅に高速化されます。このため、転移学習は、実行デバイスによってリソースが異なる可能性があるが、センサーに直接アクセスしてデータを簡単に取得できるウェブブラウザ環境に非常に適しています。

この Codelab では、Google の人気ウェブサイト「Teachable Machine」を再現して、白紙の状態からウェブアプリを構築する方法を説明します。このウェブサイトでは、ユーザーがウェブカメラから取得した数枚のサンプル画像を使用して、カスタム オブジェクトを認識できる機能的なウェブアプリを作成できます。この Codelab の ML の側面に集中できるように、ウェブサイトは意図的に最小限に抑えられています。ただし、元の Teachable Machine ウェブサイトと同様に、既存のウェブ デベロッパーの経験を活かして UX を改善する余地は十分にあります。

前提条件

この Codelab は、TensorFlow.js の既製モデルと基本的な API の使用方法をある程度理解しており、TensorFlow.js で転移学習を始めたいウェブ デベロッパーを対象としています。

  • このラボでは、TensorFlow.js、HTML5、CSS、JavaScript の基本的な知識があることを前提としています。

Tensorflow.js を初めて使用する場合は、まず この無料の入門コースを受講することをおすすめします。このコースでは、機械学習や TensorFlow.js の知識がないことを前提に、必要な知識を段階的に学習できます。

学習内容

  • TensorFlow.js の概要と、次のウェブアプリで TensorFlow.js を使用すべき理由。
  • Teachable Machine のユーザー エクスペリエンスを再現する簡素化された HTML/CSS /JS ウェブページを構築する方法。
  • TensorFlow.js を使用して、事前トレーニング済みのベースモデル(特に MobileNet)を読み込み、転移学習で使用できる画像特徴を生成する方法。
  • 認識する複数のクラスのデータについて、ユーザーのウェブカメラからデータを収集する方法。
  • 画像の特徴を取得し、それを使用して新しいオブジェクトを分類するように学習する多層パーセプトロンを作成して定義する方法。

ハッキングを始めましょう。

必要なもの

  • このチュートリアルを進めるには、Glitch.com アカウントを使用することをおすすめします。また、ご自身で編集して実行できるウェブ サービス環境を使用することもできます。

2. TensorFlow.js とは何ですか?

54e81d02971f53e8.png

TensorFlow.js は、JavaScript が実行できる場所であればどこでも実行できるオープンソースの ML ライブラリです。これは Python で記述された元の TensorFlow ライブラリに基づいており、このデベロッパー エクスペリエンスと API のセットを JavaScript エコシステムで再作成することを目的としています。

どこで使用できますか?

JavaScript の移植性を考慮すると、1 つの言語で記述し、次のすべてのプラットフォームで簡単に ML を実行できます。

  • バニラ JavaScript を使用したウェブブラウザのクライアントサイド
  • Node.js を使用したサーバーサイドや Raspberry Pi などの IoT デバイス
  • Electron を使用するデスクトップ アプリ
  • React Native を使用するネイティブ モバイルアプリ

TensorFlow.js は、これらの各環境内で複数のバックエンドもサポートしています(CPU や WebGL など、実行可能な実際のハードウェア ベースの環境)。このコンテキストの「バックエンド」はサーバーサイド環境を意味するものではありません。たとえば、実行のバックエンドは WebGL のクライアントサイドになる可能性があります。互換性を確保し、高速な実行を維持するためです。現在、TensorFlow.js は以下をサポートしています。

  • デバイスのグラフィック カード(GPU)での WebGL 実行 - GPU アクセラレーションを使用して、より大きなモデル(サイズが 3 MB を超えるもの)を実行する最も高速な方法です。
  • CPU での Web Assembly(WASM)実行 - たとえば、古い世代の携帯電話など、デバイス全体の CPU パフォーマンスを向上させます。これは、グラフィック プロセッサにコンテンツをアップロードするオーバーヘッドにより、WebGL よりも WASM を使用して CPU で高速に実行できる小規模なモデル(サイズが 3 MB 未満)に適しています。
  • CPU 実行 - 他の環境が使用できない場合のフォールバック。3 つの中で最も遅いですが、いつでも利用できます。

注: 実行するデバイスがわかっている場合は、これらのバックエンドのいずれかを強制的に選択できます。指定しない場合は、TensorFlow.js が自動的に選択します。

クライアントサイドのスーパー パワー

クライアント マシン上のウェブブラウザで TensorFlow.js を実行すると、考慮すべきメリットがいくつかあります。

プライバシー

データを第三者のウェブサーバーに送信することなく、クライアント マシンでデータのトレーニングと分類の両方を行うことができます。たとえば、GDPR などの現地の法律を遵守するために、またはユーザーが自分のマシンに保持し、第三者に送信したくないデータを処理するために、この機能が必要になることがあります。

速度

リモート サーバーにデータを送信する必要がないため、推論(データの分類)を高速化できます。さらに、ユーザーがアクセスを許可すれば、カメラ、マイク、GPS、加速度計などのデバイスのセンサーに直接アクセスできます。

リーチとスケーリング

世界中の誰もが、あなたが送ったリンクをクリックしてブラウザでウェブページを開き、あなたが作成したものを利用できます。ML システムを使用するために、CUDA ドライバなどを含む複雑なサーバーサイドの Linux 設定は必要ありません。

費用

サーバーがないため、HTML、CSS、JS、モデルファイルをホストする CDN の料金のみを支払う必要があります。CDN の費用は、サーバー(グラフィック カードが接続されている可能性あり)を 24 時間 365 日稼働させるよりもはるかに安価です。

サーバーサイドの機能

TensorFlow.js の Node.js 実装を活用すると、次の機能が有効になります。

CUDA のフルサポート

サーバー側では、グラフィック カードのアクセラレーションのために、TensorFlow がグラフィック カードで動作するように NVIDIA CUDA ドライバをインストールする必要があります(WebGL を使用するブラウザとは異なり、インストールは不要です)。ただし、CUDA を完全にサポートすることで、グラフィック カードの低レベルの機能を最大限に活用し、トレーニングと推論の時間を短縮できます。どちらも同じ C++ バックエンドを共有しているため、パフォーマンスは Python TensorFlow 実装と同等です。

モデルのサイズ

研究の最先端モデルでは、非常に大規模なモデル(ギガバイト単位のサイズ)を扱うことがあります。これらのモデルは、ブラウザタブあたりのメモリ使用量の制限により、現時点ではウェブブラウザで実行できません。これらの大規模なモデルを実行するには、モデルを効率的に実行するために必要なハードウェア仕様を備えた独自のサーバーで Node.js を使用できます。

IoT

Node.js は Raspberry Pi などの一般的なシングルボード コンピュータでサポートされています。つまり、このようなデバイスで TensorFlow.js モデルを実行することもできます。

速度

Node.js は JavaScript で記述されているため、ジャストインタイム コンパイルのメリットがあります。つまり、Node.js を使用すると、実行時に最適化されるため、パフォーマンスが向上することがよくあります。特に、実行する可能性のある前処理については、パフォーマンスが向上します。この好例は、こちらのケーススタディで確認できます。このケーススタディでは、Hugging Face が Node.js を使用して自然言語処理モデルのパフォーマンスを 2 倍に向上させた方法が紹介されています。

TensorFlow.js の基本、実行場所、メリットについて理解できたので、実際に使ってみましょう。

3. 転移学習

転移学習とは

転移学習では、すでに学習した知識を活用して、別の類似したものを学習します。

人間は常にこれを行っています。脳には、これまで見たことのない新しいものを認識するために使用できる、生涯にわたる経験が蓄積されています。たとえば、柳の木を考えてみましょう。

e28070392cd4afb9.png

お住まいの地域によっては、この種の木を見たことがないかもしれません。

しかし、下の新しい画像にヤナギの木があるかどうかを尋ねると、角度が異なり、最初に示したものと少し異なるにもかかわらず、すぐにそれを見つけることができるでしょう。

d9073a0d5df27222.png

脳には、木のような物体を識別する方法を知っているニューロンと、長い直線を見つけるのが得意なニューロンがすでにたくさんあります。この知識を再利用して、長いまっすぐな縦向きの枝がたくさんある木のようなオブジェクトである柳の木をすばやく分類できます。

同様に、画像認識などのドメインでトレーニング済みの ML モデルがある場合は、別の関連タスクを実行するために再利用できます。

MobileNet などの高度なモデルでも同様のことができます。MobileNet は、1, 000 種類のオブジェクト タイプで画像認識を実行できる非常に一般的な研究モデルです。犬から車まで、数百万枚のラベル付き画像を含む ImageNet と呼ばれる巨大なデータセットでトレーニングされました。

このアニメーションでは、この MobileNet V1 モデルに含まれるレイヤの数が非常に多いことがわかります。

7d4e1e35c1a89715.gif

このモデルは、トレーニング中に、1,000 個のオブジェクトすべてに共通する重要な特徴を抽出する方法を学習しました。このようなオブジェクトの識別に使用する下位レベルの特徴の多くは、これまで見たことのない新しいオブジェクトを検出するのにも役立ちます。結局のところ、すべては線、テクスチャ、図形の組み合わせにすぎません。

従来の畳み込みニューラル ネットワーク(CNN)アーキテクチャ(MobileNet と同様)を見て、転移学習でこのトレーニング済みネットワークを活用して新しいことを学習する方法を確認しましょう。次の図は、0 ~ 9 の手書き数字を認識するようにトレーニングされた CNN の一般的なモデル アーキテクチャを示しています。

baf4e3d434576106.png

左に示すように、既存のトレーニング済みモデルの事前トレーニング済み下位レイヤを、右に示すモデルの終盤にある分類レイヤ(モデルの分類ヘッドと呼ばれることもあります)から分離できれば、下位レイヤを使用して、トレーニングに使用した元のデータに基づいて、任意の画像に対する出力特徴を生成できます。以下は、分類ヘッドが削除された同じネットワークです。

369a8a9041c6917d.png

認識しようとしている新しいものが、以前のモデルが学習した出力特徴を利用できると仮定すると、新しい目的で再利用できる可能性が高くなります。

上の図の架空のモデルは数字でトレーニングされているため、数字について学習した内容を文字(a、b、c など)にも適用できる可能性があります。

次のように、a、b、c のいずれかを予測する新しい分類ヘッドを追加できます。

db97e5e60ae73bbd.png

ここでは、下位レイヤはフリーズされ、トレーニングされません。新しい分類ヘッドのみが更新され、左側の事前にトレーニングされた切り刻まれたモデルから提供された特徴を学習します。

この行為は転移学習と呼ばれ、Teachable Machine がバックグラウンドで行っていることです。

また、ネットワークの最後に多層パーセプトロンをトレーニングするだけで済むため、ネットワーク全体をゼロからトレーニングする場合よりもはるかに高速にトレーニングできます。

では、モデルのサブパーツを取得するにはどうすればよいでしょうか?次のセクションで確認しましょう。

4. TensorFlow Hub - ベースモデル

使用する適切なベースモデルを見つける

MobileNet などのより高度で一般的な研究モデルについては、TensorFlow Hub にアクセスし、MobileNet v3 アーキテクチャを使用する TensorFlow.js に適したモデルをフィルタして、次のような結果を見つけることができます。

c5dc1420c6238c14.png

これらの結果の一部は「画像分類」タイプ(各モデルカードの結果の左上に詳細が表示されます)で、その他は「画像特徴ベクトル」タイプです。

これらの画像特徴ベクトル結果は、基本的に MobileNet の事前分割バージョンであり、最終的な分類ではなく画像特徴ベクトルを取得するために使用できます。

このようなモデルは「ベースモデル」と呼ばれることが多く、新しい分類ヘッドを追加して独自のデータでトレーニングすることで、前のセクションで説明したのと同じ方法で転移学習を実行できます。

次に、目的のベースモデルがどの TensorFlow.js 形式でリリースされているかを確認します。これらの特徴ベクトル MobileNet v3 モデルのいずれかのページを開くと、JS ドキュメントから、tf.loadGraphModel() を使用するドキュメントのコード スニペットの例に基づくグラフモデルの形式であることがわかります。

f97d903d2e46924b.png

グラフ形式ではなくレイヤ形式でモデルが見つかった場合は、トレーニング用にフリーズするレイヤとフリーズ解除するレイヤを選択できます。これは、新しいタスクのモデル(「転移モデル」と呼ばれることが多い)を作成する場合に非常に有効です。ただし、このチュートリアルでは、ほとんどの TF Hub モデルがデプロイされているデフォルトのグラフモデル タイプを使用します。レイヤモデルの操作の詳細については、ゼロからヒーローまでの TensorFlow.js コースをご覧ください。

転移学習のメリット

モデル アーキテクチャ全体をゼロからトレーニングするのではなく、転移学習を使用する利点は何ですか?

まず、転移学習アプローチを使用する主なメリットは、トレーニング時間が短縮されることです。すでにトレーニング済みのベースモデルを基盤として使用できるためです。

第 2 に、すでにトレーニングが行われているため、分類しようとしている新しいものの例をはるかに少なくして表示できます。

分類したいもののサンプルデータを収集する時間やリソースが限られている場合に、より堅牢なモデルにするためのトレーニング データを収集する前に、プロトタイプを迅速に作成する必要がある場合に非常に便利です。

転移学習では、必要なデータが少なく、小規模なネットワークのトレーニング速度が速いため、リソースの消費量が少なくなります。そのため、ブラウザ環境に非常に適しており、最新のマシンではモデルの完全なトレーニングに数時間、数日、数週間ではなく、数十秒しかかかりません。

承知いたしました。転移学習の概要を理解できたところで、独自の Teachable Machine を作成してみましょう。では始めましょう。

5. コーディングの準備をする

必要なもの

  • 最新のウェブブラウザ。
  • HTML、CSS、JavaScript、Chrome DevTools(コンソール出力の表示)に関する基本的な知識。

コーディングを始めましょう

Glitch.com または Codepen.io で使用できるボイラープレート テンプレートが作成されています。この Codelab では、いずれかのテンプレートをベース状態としてワンクリックでクローンできます。

Glitch で、[remix this] ボタンをクリックしてフォークし、編集可能な新しいファイルセットを作成します。

または、Codepen で画面右下の [fork] をクリックします。

この非常にシンプルなスケルトンには、次のファイルが含まれています。

  • HTML ページ(index.html)
  • スタイルシート(style.css)
  • JavaScript コードを記述するファイル(script.js)

便宜上、TensorFlow.js ライブラリの HTML ファイルにインポートが追加されています。たとえば、次のようになります。

index.html

<!-- Import TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js" type="text/javascript"></script>

代替案: お好みのウェブ エディタを使用するか、ローカルで作業する

コードをダウンロードしてローカルで作業する場合や、別のオンライン エディタで作業する場合は、上記の 3 つのファイルを同じディレクトリに作成し、Glitch のボイラープレートから各ファイルにコードをコピーして貼り付けます。

6. アプリの HTML ボイラープレート

開始方法

すべてのプロトタイプには、検出結果をレンダリングできる基本的な HTML スキャフォールディングが必要です。今すぐ設定しましょう。追加する内容は次のとおりです。

  • ページのタイトル。
  • 説明文。
  • ステータス段落。
  • 準備ができたらウェブカメラのフィードを保持する動画。
  • カメラの起動、データの収集、エクスペリエンスのリセットを行うためのボタンがいくつかあります。
  • TensorFlow.js と後でコーディングする JS ファイルのインポート。

index.html を開き、既存のコードを次のコードに貼り付けて、上記の機能を設定します。

index.html

<!DOCTYPE html>
<html lang="en">
  <head>
    <title>Transfer Learning - TensorFlow.js</title>
    <meta charset="utf-8">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <!-- Import the webpage's stylesheet -->
    <link rel="stylesheet" href="/style.css">
  </head>  
  <body>
    <h1>Make your own "Teachable Machine" using Transfer Learning with MobileNet v3 in TensorFlow.js using saved graph model from TFHub.</h1>
    
    <p id="status">Awaiting TF.js load</p>
    
    <video id="webcam" autoplay muted></video>
    
    <button id="enableCam">Enable Webcam</button>
    <button class="dataCollector" data-1hot="0" data-name="Class 1">Gather Class 1 Data</button>
    <button class="dataCollector" data-1hot="1" data-name="Class 2">Gather Class 2 Data</button>
    <button id="train">Train &amp; Predict!</button>
    <button id="reset">Reset</button>

    <!-- Import TensorFlow.js library -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.11.0/dist/tf.min.js" type="text/javascript"></script>

    <!-- Import the page's JavaScript to do some stuff -->
    <script type="module" src="/script.js"></script>
  </body>
</html>

ビートに乗る

上記の HTML コードを分解して、追加した重要な点をいくつか見てみましょう。

  • ページのタイトル用の <h1> タグと、ID が「status」の <p> タグを追加しました。これは、出力の表示にシステムのさまざまな部分を使用するため、情報を出力する場所になります。
  • ID が「webcam」の <video> 要素を追加しました。この要素に、後でウェブカメラのストリームをレンダリングします。
  • 5 個の <button> 要素を追加しました。ID が「enableCam」の最初のボタンは、カメラを有効にします。次の 2 つのボタンには「dataCollector」というクラスが設定されています。このクラスを使用すると、認識するオブジェクトのサンプル画像を集めることができます。後で記述するコードは、これらのボタンをいくつでも追加でき、それらが意図したとおりに自動的に動作するように設計されます。

これらのボタンには、data-1hot という特別なユーザー定義属性もあります。この属性の値は整数で、最初のクラスは 0 から始まります。これは、特定のクラスのデータを表すために使用する数値インデックスです。インデックスは、ML モデルが数値のみを処理できるため、文字列ではなく数値表現で出力クラスを正しくエンコードするために使用されます。

また、このクラスに使用する人間が読める名前を含む data-name 属性もあります。これにより、1 ホット エンコードの数値インデックス値ではなく、より意味のある名前をユーザーに提供できます。

最後に、データが収集されたらトレーニング プロセスを開始するためのトレーニング ボタンと、アプリをリセットするためのリセット ボタンがあります。

  • また、2 つの <script> インポートも追加しました。1 つは TensorFlow.js 用、もう 1 つは後で定義する script.js 用です。

7. スタイルを追加する

要素のデフォルト

追加した HTML 要素のスタイルを追加して、正しくレンダリングされるようにします。位置とサイズの要素に正しく追加されるスタイルの例を次に示します。特に変わったことはありません。後でこのコードを拡張して、Teachable Machine の動画で見たような、さらに優れた UX を実現することもできます。

style.css

body {
  font-family: helvetica, arial, sans-serif;
  margin: 2em;
}

h1 {
  font-style: italic;
  color: #FF6F00;
}


video {
  clear: both;
  display: block;
  margin: 10px;
  background: #000000;
  width: 640px;
  height: 480px;
}

button {
  padding: 10px;
  float: left;
  margin: 5px 3px 5px 10px;
}

.removed {
  display: none;
}

#status {
  font-size:150%;
}

これで、必要な処理はこれだけです。この時点で出力をプレビューすると、次のようになります。

81909685d7566dcb.png

8. JavaScript: キー定数とリスナー

主要な定数を定義する

まず、アプリ全体で使用するキー定数を追加します。まず、script.js の内容を次の定数に置き換えます。

script.js

const STATUS = document.getElementById('status');
const VIDEO = document.getElementById('webcam');
const ENABLE_CAM_BUTTON = document.getElementById('enableCam');
const RESET_BUTTON = document.getElementById('reset');
const TRAIN_BUTTON = document.getElementById('train');
const MOBILE_NET_INPUT_WIDTH = 224;
const MOBILE_NET_INPUT_HEIGHT = 224;
const STOP_DATA_GATHER = -1;
const CLASS_NAMES = [];

それぞれの用途は次のとおりです。

  • STATUS は、ステータスの更新を書き込む段落タグへの参照を保持するだけです。
  • VIDEO は、ウェブカメラのフィードをレンダリングする HTML 動画要素への参照を保持します。
  • ENABLE_CAM_BUTTONRESET_BUTTONTRAIN_BUTTON は、HTML ページからすべてのキーボタンへの DOM 参照を取得します。
  • MOBILE_NET_INPUT_WIDTHMOBILE_NET_INPUT_HEIGHT は、それぞれ MobileNet モデルの想定される入力の幅と高さを定義します。このようにファイルの先頭付近の定数に保存しておくと、後で別のバージョンを使用することになった場合に、多くの場所で値を置き換えるのではなく、1 か所で値を更新するだけで済むため、簡単に対応できます。
  • STOP_DATA_GATHER は -1 に設定されています。これは、ユーザーがボタンのクリックを停止してウェブカメラ フィードからデータを収集したタイミングを把握するために、状態値を保存します。この数値に意味のある名前を付けることで、後でコードを読みやすくすることができます。
  • CLASS_NAMES はルックアップとして機能し、予測可能なクラスの人間が読める名前を保持します。この配列は後で入力されます。

これで、キー要素への参照が取得できたので、イベント リスナーを関連付けます。

キーイベント リスナーを追加する

まず、次のようにキーボタンにクリック イベント ハンドラを追加します。

script.js

ENABLE_CAM_BUTTON.addEventListener('click', enableCam);
TRAIN_BUTTON.addEventListener('click', trainAndPredict);
RESET_BUTTON.addEventListener('click', reset);


function enableCam() {
  // TODO: Fill this out later in the codelab!
}


function trainAndPredict() {
  // TODO: Fill this out later in the codelab!
}


function reset() {
  // TODO: Fill this out later in the codelab!
}

ENABLE_CAM_BUTTON - クリックされると enableCam 関数を呼び出します。

TRAIN_BUTTON - クリック時に trainAndPredict を呼び出します。

RESET_BUTTON - クリック時にリセットを呼び出します。

このセクションでは、document.querySelectorAll() を使用して、クラスが「dataCollector」のすべてのボタンを見つけることができます。これにより、ドキュメントから見つかった要素の配列が返されます。

script.js

let dataCollectorButtons = document.querySelectorAll('button.dataCollector');
for (let i = 0; i < dataCollectorButtons.length; i++) {
  dataCollectorButtons[i].addEventListener('mousedown', gatherDataForClass);
  dataCollectorButtons[i].addEventListener('mouseup', gatherDataForClass);
  // Populate the human readable names for classes.
  CLASS_NAMES.push(dataCollectorButtons[i].getAttribute('data-name'));
}


function gatherDataForClass() {
  // TODO: Fill this out later in the codelab!
}

コードの説明:

次に、見つかったボタンを反復処理し、それぞれに 2 つのイベント リスナーを関連付けます。1 つは mousedown、もう 1 つは mouseup です。これにより、ボタンが押されている間はサンプルの録音を続けることができ、データ収集に役立ちます。

どちらのイベントも、後で定義する gatherDataForClass 関数を呼び出します。

この時点で、HTML ボタン属性 data-name から見つかった人間が読めるクラス名を CLASS_NAMES 配列にプッシュすることもできます。

次に、後で使用する重要なものを保存するための変数を追加します。

script.js

let mobilenet = undefined;
let gatherDataState = STOP_DATA_GATHER;
let videoPlaying = false;
let trainingDataInputs = [];
let trainingDataOutputs = [];
let examplesCount = [];
let predict = false;

詳しく見ていきましょう。

まず、読み込まれた mobilenet モデルを格納する変数 mobilenet があります。最初は undefined に設定します。

次に、gatherDataState という変数があります。「dataCollector」ボタンが押されると、HTML で定義されているように、そのボタンの 1 ホット ID に変更され、その時点で収集しているデータのクラスがわかるようになります。最初は STOP_DATA_GATHER に設定されています。これにより、後で記述するデータ収集ループは、ボタンが押されていないときにデータを収集しません。

videoPlaying は、ウェブカメラのストリームが正常に読み込まれて再生され、使用可能かどうかを追跡します。初期状態では、ENABLE_CAM_BUTTON. を押すまでウェブカメラがオンにならないため、false に設定されています。

次に、2 つの配列 trainingDataInputstrainingDataOutputs を定義します。これらは、MobileNet ベースモデルによって生成された入力特徴と、それぞれサンプリングされた出力クラスの「dataCollector」ボタンをクリックしたときに収集されたトレーニング データの値を保存します。

最後に、examplesCount, という配列を定義して、各クラスに含まれる例の数を追跡します。

最後に、予測ループを制御する predict という変数があります。初期設定では false に設定されています。この値が後で true に設定されるまで、予測は行われません。

すべてのキー変数が定義されたので、分類ではなく画像特徴ベクトルを提供する、事前に切り刻まれた MobileNet v3 ベースモデルを読み込みましょう。

9. MobileNet ベースモデルを読み込む

まず、次のように loadMobileNetFeatureModel という新しい関数を定義します。モデルの読み込みは非同期であるため、これは非同期関数である必要があります。

script.js

/**
 * Loads the MobileNet model and warms it up so ready for use.
 **/
async function loadMobileNetFeatureModel() {
  const URL = 
    'https://tfhub.dev/google/tfjs-model/imagenet/mobilenet_v3_small_100_224/feature_vector/5/default/1';
  
  mobilenet = await tf.loadGraphModel(URL, {fromTFHub: true});
  STATUS.innerText = 'MobileNet v3 loaded successfully!';
  
  // Warm up the model by passing zeros through it once.
  tf.tidy(function () {
    let answer = mobilenet.predict(tf.zeros([1, MOBILE_NET_INPUT_HEIGHT, MOBILE_NET_INPUT_WIDTH, 3]));
    console.log(answer.shape);
  });
}

// Call the function immediately to start loading.
loadMobileNetFeatureModel();

このコードでは、読み込むモデルが配置されている URL を TFHub ドキュメントから定義します。

次に、await tf.loadGraphModel() を使用してモデルを読み込みます。この Google ウェブサイトからモデルを読み込むため、特別なプロパティ fromTFHubtrue に設定することを忘れないでください。これは、この追加プロパティを設定する必要がある TF Hub でホストされているモデルを使用する場合にのみ該当する特別なケースです。

読み込みが完了したら、STATUS 要素の innerText にメッセージを設定して、正しく読み込まれたことを視覚的に確認し、データ収集を開始する準備ができたことを確認できます。

あとはモデルをウォームアップするだけです。このような大規模なモデルでは、モデルを初めて使用するときに、すべての設定に時間がかかることがあります。そのため、モデルにゼロを渡して、タイミングがより重要になる可能性がある将来の待機を回避することをおすすめします。

tf.tidy() でラップされた tf.zeros() を使用すると、テンソルが正しく破棄され、バッチサイズが 1 になり、最初に定数で定義した高さと幅が正しく設定されます。最後に、カラーチャンネルも指定します。この場合は、モデルが RGB 画像を想定しているため、3 になります。

次に、answer.shape() を使用して返されたテンソルの結果の形状を記録し、このモデルが生成する画像特徴のサイズを把握できるようにします。

この関数を定義したら、すぐに呼び出して、ページ読み込み時にモデルのダウンロードを開始できます。

ライブ プレビューを今すぐ表示すると、しばらくしてからステータス テキストが「Awaiting TF.js load」から「MobileNet v3 loaded successfully!」に変わります(下図参照)。続行する前に、これが機能することを確認してください。

a28b734e190afff.png

コンソール出力で、このモデルが生成する出力特徴の印刷サイズを確認することもできます。MobileNet モデルにゼロを実行すると、[1, 1024] の形状が出力されます。最初の項目はバッチサイズ 1 で、実際には 1024 個の特徴が返されます。これらの特徴は、新しいオブジェクトの分類に役立ちます。

10. 新しいモデルヘッドを定義する

モデルヘッドを定義します。これは、基本的に非常に最小限の多層パーセプトロンです。

script.js

let model = tf.sequential();
model.add(tf.layers.dense({inputShape: [1024], units: 128, activation: 'relu'}));
model.add(tf.layers.dense({units: CLASS_NAMES.length, activation: 'softmax'}));

model.summary();

// Compile the model with the defined optimizer and specify a loss function to use.
model.compile({
  // Adam changes the learning rate over time which is useful.
  optimizer: 'adam',
  // Use the correct loss function. If 2 classes of data, must use binaryCrossentropy.
  // Else categoricalCrossentropy is used if more than 2 classes.
  loss: (CLASS_NAMES.length === 2) ? 'binaryCrossentropy': 'categoricalCrossentropy', 
  // As this is a classification problem you can record accuracy in the logs too!
  metrics: ['accuracy']  
});

このコードを詳しく見ていきましょう。まず、モデルレイヤを追加する tf.sequential モデルを定義します。

次に、このモデルに入力レイヤとして密レイヤを追加します。MobileNet v3 の特徴量からの出力はこのサイズであるため、入力形状は 1024 です。これは、前のステップでモデルに 1 を渡した後に検出されました。このレイヤには、ReLU 活性化関数を使用する 128 個のニューロンがあります。

活性化関数とモデルレイヤを初めて使用する場合は、このワークショップの冒頭で説明したコースを受講して、これらのプロパティがバックグラウンドで何を行うかを理解することをおすすめします。

次に、出力レイヤを追加します。ニューロンの数は、予測しようとしているクラスの数と一致する必要があります。これを行うには、CLASS_NAMES.length を使用して、分類する予定のクラスの数を確認します。これは、ユーザー インターフェースで見つかったデータ収集ボタンの数と同じです。これは分類問題であるため、この出力レイヤで softmax 活性化を使用します。これは、回帰ではなく分類問題を解決するモデルを作成しようとする場合に使用する必要があります。

model.summary() を出力して、新しく定義したモデルの概要をコンソールに出力します。

最後に、モデルをコンパイルしてトレーニングの準備を整えます。ここでは、オプティマイザーが adam に設定されています。CLASS_NAMES.length2 と等しい場合、損失は binaryCrossentropy になります。分類するクラスが 3 つ以上ある場合は、categoricalCrossentropy が使用されます。精度指標もリクエストされるため、後でデバッグ目的でログでモニタリングできます。

コンソールに次のように表示されます。

22eaf32286fea4bb.png

このモデルには 13 万を超えるトレーニング可能なパラメータがあります。ただし、これは通常のニューロンの単純な密結合レイヤであるため、トレーニングはかなり高速で行われます。

プロジェクトの完了後に一度試していただきたいのは、最初のレイヤのニューロン数を変更して、パフォーマンスを維持しながらどこまで減らせるかを確認することです。多くの場合、機械学習では、リソース使用量と速度の最適なトレードオフを実現する最適なパラメータ値を見つけるために、ある程度の試行錯誤が必要です。

11. ウェブカメラを有効にする

ここで、前に定義した enableCam() 関数を実装します。次の図に示すように hasGetUserMedia() という名前の新しい関数を追加し、以前に定義した enableCam() 関数の内容を次の対応するコードに置き換えます。

script.js

function hasGetUserMedia() {
  return !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
}

function enableCam() {
  if (hasGetUserMedia()) {
    // getUsermedia parameters.
    const constraints = {
      video: true,
      width: 640, 
      height: 480 
    };

    // Activate the webcam stream.
    navigator.mediaDevices.getUserMedia(constraints).then(function(stream) {
      VIDEO.srcObject = stream;
      VIDEO.addEventListener('loadeddata', function() {
        videoPlaying = true;
        ENABLE_CAM_BUTTON.classList.add('removed');
      });
    });
  } else {
    console.warn('getUserMedia() is not supported by your browser');
  }
}

まず、主要なブラウザ API プロパティの存在を確認して、ブラウザが getUserMedia() をサポートしているかどうかを確認する hasGetUserMedia() という名前の関数を作成します。

enableCam() 関数で、上で定義した hasGetUserMedia() 関数を使用して、サポートされているかどうかを確認します。そうでない場合は、コンソールに警告を出力します。

サポートされている場合は、getUserMedia() 呼び出しの制約を定義します。たとえば、動画ストリームのみが必要で、動画の width のサイズは 640 ピクセル、height480 ピクセルが望ましいなどです。その理由は、MobileNet モデルにフィードするには、224 × 224 ピクセルにサイズ変更する必要があるため、これより大きい動画を取得してもあまり意味がありません。解像度を小さくすることで、コンピューティング リソースを節約することもできます。ほとんどのカメラがこのサイズの解像度をサポートしています。

次に、上記で説明した constraints を使用して navigator.mediaDevices.getUserMedia() を呼び出し、stream が返されるのを待ちます。stream が返されたら、VIDEO 要素を取得し、その srcObject 値として設定することで、stream を再生できます。

また、VIDEO 要素に eventListener を追加して、stream が読み込まれて正常に再生されたタイミングを把握する必要があります。

ストリームが読み込まれたら、videoPlaying を true に設定し、ENABLE_CAM_BUTTON を削除して、クラスを「removed」に設定することで、再度クリックされないようにします。

コードを実行し、[Enable Camera] ボタンをクリックして、ウェブカメラへのアクセスを許可します。初めてこの操作を行う場合は、次のようにページ上の動画要素に自分の姿がレンダリングされます。

b378eb1affa9b883.png

それでは、dataCollector ボタンのクリックを処理する関数を追加しましょう。

12. データ収集ボタンのイベント ハンドラ

次に、現在空の関数 gatherDataForClass(). に入力します。これは、Codelab の開始時に dataCollector ボタンのイベント ハンドラ関数として割り当てたものです。

script.js

/**
 * Handle Data Gather for button mouseup/mousedown.
 **/
function gatherDataForClass() {
  let classNumber = parseInt(this.getAttribute('data-1hot'));
  gatherDataState = (gatherDataState === STOP_DATA_GATHER) ? classNumber : STOP_DATA_GATHER;
  dataGatherLoop();
}

まず、属性の名前(この場合は data-1hot)をパラメータとして this.getAttribute() を呼び出し、現在クリックされているボタンの data-1hot 属性を確認します。これは文字列であるため、parseInt() を使用して整数にキャストし、その結果を classNumber. という名前の変数に割り当てることができます。

次に、gatherDataState 変数を適宜設定します。現在の gatherDataStateSTOP_DATA_GATHER(-1 に設定)と等しい場合、現在データを収集しておらず、mousedown イベントがトリガーされたことを意味します。gatherDataState を、先ほど見つけた classNumber に設定します。

それ以外の場合は、現在データを収集しており、発生したイベントが mouseup イベントであり、そのクラスのデータの収集を停止したいことを意味します。STOP_DATA_GATHER 状態に戻すだけで、まもなく定義するデータ収集ループを終了できます。

最後に、クラスデータの記録を実際に実行する dataGatherLoop(), の呼び出しを開始します。

13. データ収集

ここで、dataGatherLoop() 関数を定義します。この関数は、ウェブカメラ動画から画像をサンプリングし、MobileNet モデルに渡して、そのモデルの出力(1,024 個の特徴ベクトル)を取得します。

次に、現在押されているボタンの gatherDataState ID とともに保存し、このデータがどのクラスを表しているかを確認できるようにします。

詳しく見ていきましょう。

script.js

function dataGatherLoop() {
  if (videoPlaying && gatherDataState !== STOP_DATA_GATHER) {
    let imageFeatures = tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [MOBILE_NET_INPUT_HEIGHT, 
          MOBILE_NET_INPUT_WIDTH], true);
      let normalizedTensorFrame = resizedTensorFrame.div(255);
      return mobilenet.predict(normalizedTensorFrame.expandDims()).squeeze();
    });

    trainingDataInputs.push(imageFeatures);
    trainingDataOutputs.push(gatherDataState);
    
    // Intialize array index element if currently undefined.
    if (examplesCount[gatherDataState] === undefined) {
      examplesCount[gatherDataState] = 0;
    }
    examplesCount[gatherDataState]++;

    STATUS.innerText = '';
    for (let n< = 0; n  CLASS_NAMES.length; n++) {
      STATUS.innerText += CLASS_NAMES[n] + ' data count: ' + examplesCount[n] + '. ';
    }
    window.requestAnimationFrame(dataGatherLoop);
  }
}

videoPlaying が true の場合、つまりウェブカメラがアクティブで、gatherDataStateSTOP_DATA_GATHER と等しくなく、クラスデータ収集用のボタンが現在押されている場合にのみ、この関数の実行を続行します。

次に、コードを tf.tidy() でラップして、後続のコードで作成されたテンソルを破棄します。この tf.tidy() コードの実行結果は、imageFeatures という変数に格納されます。

tf.browser.fromPixels() を使用して、ウェブカメラの VIDEO のフレームを取得できるようになりました。画像データを含む結果のテンソルは、videoFrameAsTensor という変数に格納されます。

次に、MobileNet モデルの入力に適した形状になるように videoFrameAsTensor 変数のサイズを変更します。tf.image.resizeBilinear() 呼び出しを使用します。この呼び出しでは、再形成するテンソルを最初のパラメータとして、新しい高さと幅を定義する形状を、先ほど作成した定数で定義します。最後に、3 番目のパラメータを渡して align corners を true に設定し、サイズ変更時の配置に関する問題を回避します。このサイズ変更の結果は、resizedTensorFrame という変数に格納されます。

このプリミティブは画像を拡大縮小します。ウェブカメラの画像は 640 × 480 ピクセルですが、モデルには 224 × 224 ピクセルの正方形の画像が必要です。

このデモでは、これで問題ありません。ただし、この Codelab を完了したら、この画像から正方形を切り抜いて、後で作成する本番環境システムでより良い結果が得られるように試してみることをおすすめします。

次に、画像データを正規化します。tf.browser.frompixels() を使用する場合、画像データは常に 0 ~ 255 の範囲内にあるため、resizedTensorFrame を 255 で割るだけで、すべての値を 0 ~ 1 の範囲内に収めることができます。これは、MobileNet モデルが入力として想定している範囲です。

最後に、コードの tf.tidy() セクションで、mobilenet.predict() を呼び出して、この正規化されたテンソルを読み込まれたモデルにプッシュします。このとき、expandDims() を使用して normalizedTensorFrame の拡張バージョンを渡して、モデルが処理用の入力バッチを想定しているため、バッチが 1 になるようにします。

結果が返ってきたら、その結果に対して squeeze() を呼び出して 1D テンソルに圧縮し、それを返して tf.tidy() からの結果をキャプチャする imageFeatures 変数に割り当てます。

MobileNet モデルから imageFeatures を取得したので、以前に定義した trainingDataInputs 配列にプッシュして記録できます。

現在の gatherDataStatetrainingDataOutputs 配列にプッシュすることで、この入力が何を表しているかを記録することもできます。

gatherDataState 変数は、以前に定義した gatherDataForClass() 関数でボタンがクリックされたときに、データを記録する現在のクラスの数値 ID に設定されます。

この時点で、特定のクラスの例の数を増やすこともできます。これを行うには、まず examplesCount 配列内のインデックスが初期化されているかどうかを確認します。未定義の場合は、0 に設定して、特定のクラスの数値 ID のカウンタを初期化します。その後、現在の gatherDataStateexamplesCount を増分できます。

ウェブページの STATUS 要素のテキストを更新して、各クラスの現在のカウントをキャプチャされたときに表示します。これを行うには、CLASS_NAMES 配列をループ処理し、examplesCount の同じインデックスにあるデータ数と組み合わせて、人間が読める名前を出力します。

最後に、dataGatherLoop をパラメータとして渡して window.requestAnimationFrame() を呼び出し、この関数を再帰的に呼び出します。ボタンの mouseup が検出され、gatherDataStateSTOP_DATA_GATHER, に設定されるまで、動画からフレームのサンプリングが続行されます。この時点でデータ収集ループが終了します。

ここでコードを実行すると、[Enable Camera] ボタンをクリックしてウェブカメラの読み込みを待機し、各データ収集ボタンをクリックして長押しすることで、各クラスのデータの例を収集できるようになります。ここでは、スマートフォンと手のデータをそれぞれ収集しています。

541051644a45131f.gif

上記のスクリーン キャプチャに示すように、すべてのテンソルがメモリに保存されると、ステータス テキストが更新されます。

14. トレーニングと予測

次のステップでは、現在空の trainAndPredict() 関数のコードを実装します。この関数で転移学習が行われます。コードを見てみましょう。

script.js

async function trainAndPredict() {
  predict = false;
  tf.util.shuffleCombo(trainingDataInputs, trainingDataOutputs);
  let outputsAsTensor = tf.tensor1d(trainingDataOutputs, 'int32');
  let oneHotOutputs = tf.oneHot(outputsAsTensor, CLASS_NAMES.length);
  let inputsAsTensor = tf.stack(trainingDataInputs);
  
  let results = await model.fit(inputsAsTensor, oneHotOutputs, {shuffle: true, batchSize: 5, epochs: 10, 
      callbacks: {onEpochEnd: logProgress} });
  
  outputsAsTensor.dispose();
  oneHotOutputs.dispose();
  inputsAsTensor.dispose();
  predict = true;
  predictLoop();
}

function logProgress(epoch, logs) {
  console.log('Data for epoch ' + epoch, logs);
}

まず、predictfalse に設定して、現在の予測が実行されないようにします。

次に、tf.util.shuffleCombo() を使用して入力配列と出力配列をシャッフルし、順序がトレーニングで問題を引き起こさないようにします。

出力配列 trainingDataOutputs, を int32 型の tensor1d に変換して、ワンホット エンコードで使用できるようにします。これは outputsAsTensor という名前の変数に格納されます。

tf.oneHot() 関数で、この outputsAsTensor 変数とエンコードするクラスの最大数(CLASS_NAMES.length)を使用します。ワンホット エンコードされた出力が、oneHotOutputs という新しいテンソルに保存されます。

現在、trainingDataInputs は記録されたテンソルの配列です。これらをトレーニングに使用するには、テンソルの配列を通常の 2D テンソルに変換する必要があります。

TensorFlow.js ライブラリには、そのための優れた関数 tf.stack() があります。

これは、テンソルの配列を受け取り、それらをスタックして高次元のテンソルを出力として生成します。この場合、テンソル 2D が返されます。これは、記録された特徴を含む長さ 1024 の 1 次元入力のバッチです。これはトレーニングに必要なものです。

次に、await model.fit() を使用してカスタムモデル ヘッドをトレーニングします。ここでは、inputsAsTensor 変数と oneHotOutputs を渡して、それぞれ入力例とターゲット出力に使用するトレーニング データを表します。3 番目のパラメータの構成オブジェクトで、shuffletrue に設定し、5batchSize を使用して、epochs10 に設定し、onEpochEndcallback を、まもなく定義する logProgress 関数に指定します。

最後に、モデルがトレーニングされたので、作成したテンソルを破棄できます。その後、predicttrue に戻して予測を再び実行できるようにし、predictLoop() 関数を呼び出してライブ ウェブカメラ画像の予測を開始します。

logProcess() 関数を定義してトレーニングの状態をロギングすることもできます。この関数は上記の model.fit() で使用され、トレーニングの各ラウンドの後に結果をコンソールに出力します。

あと少しです。predictLoop() 関数を追加して予測を行います。

コア予測ループ

ここでは、ウェブカメラからフレームをサンプリングし、各フレームの内容を継続的に予測して、ブラウザにリアルタイムで結果を表示するメインの予測ループを実装します。

コードを確認してみましょう。

script.js

function predictLoop() {
  if (predict) {
    tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO).div(255);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor,[MOBILE_NET_INPUT_HEIGHT, 
          MOBILE_NET_INPUT_WIDTH], true);

      let imageFeatures = mobilenet.predict(resizedTensorFrame.expandDims());
      let prediction = model.predict(imageFeatures).squeeze();
      let highestIndex = prediction.argMax().arraySync();
      let predictionArray = prediction.arraySync();

      STATUS.innerText = 'Prediction: ' + CLASS_NAMES[highestIndex] + ' with ' + Math.floor(predictionArray[highestIndex] * 100) + '% confidence';
    });

    window.requestAnimationFrame(predictLoop);
  }
}

まず、predict が true であることを確認します。これにより、モデルがトレーニングされ、使用可能になった後にのみ予測が行われます。

次に、dataGatherLoop() 関数で行ったのと同じように、現在の画像の画像特徴を取得できます。基本的には、tf.browser.from pixels() を使用してウェブカメラからフレームを取得し、正規化して 224 x 224 ピクセルのサイズに変更してから、そのデータを MobileNet モデルに渡して、結果の画像特徴を取得します。

ただし、新しくトレーニングしたモデルヘッドを使用して、トレーニング済みモデルの predict() 関数を介して見つかった imageFeatures を渡すことで、実際に予測を実行できるようになりました。次に、結果のテンソルを圧縮して 1 次元に戻し、prediction という変数に割り当てます。

この prediction を使用すると、argMax() を使用して最大値を持つインデックスを見つけ、arraySync() を使用して結果のテンソルを配列に変換し、JavaScript で基になるデータにアクセスして、最大値を持つ要素の位置を見つけることができます。この値は highestIndex という変数に格納されます。

同様に、prediction テンソルで arraySync() を直接呼び出すことで、実際の予測信頼スコアを取得することもできます。

これで、prediction データを使用して STATUS テキストを更新するために必要なものがすべて揃いました。クラスの人間が読める文字列を取得するには、CLASS_NAMES 配列で highestIndex を検索し、predictionArray から信頼値を取得します。パーセンテージとして読みやすくするには、100 を掛けて結果を math.floor() します。

最後に、準備ができたら window.requestAnimationFrame() を使用して predictionLoop() を再度呼び出し、動画ストリームでリアルタイム分類を取得できます。新しいデータで新しいモデルをトレーニングする場合は、predictfalse に設定されるまでこの処理が続行されます。

これで、パズルの最後のピースが完成します。リセットボタンを実装します。

15. リセットボタンを実装する

あと少しで完了です。最後に、最初からやり直すためのリセットボタンを実装します。現在空の reset() 関数のコードは次のとおりです。次のように更新します。

script.js

/**
 * Purge data and start over. Note this does not dispose of the loaded 
 * MobileNet model and MLP head tensors as you will need to reuse 
 * them to train a new model.
 **/
function reset() {
  predict = false;
  examplesCount.length = 0;
  for (let i = 0; i < trainingDataInputs.length; i++) {
    trainingDataInputs[i].dispose();
  }
  trainingDataInputs.length = 0;
  trainingDataOutputs.length = 0;
  STATUS.innerText = 'No data collected';
  
  console.log('Tensors in memory: ' + tf.memory().numTensors);
}

まず、predictfalse に設定して、実行中の予測ループを停止します。次に、examplesCount 配列の長さを 0 に設定して、配列のすべてのコンテンツを削除します。これは、配列からすべてのコンテンツを削除する便利な方法です。

次に、現在記録されているすべての trainingDataInputs を確認し、その中に含まれる各テンソルの dispose() を確認して、メモリを再度解放します。テンソルは JavaScript のガベージ コレクタによってクリーンアップされないためです。

完了したら、trainingDataInputs 配列と trainingDataOutputs 配列の両方で配列の長さを 0 に設定して、それらもクリアできます。

最後に、STATUS テキストを適切な値に設定し、メモリに残っているテンソルをサニティチェックとして出力します。

MobileNet モデルと定義した多層パーセプトロンの両方が破棄されないため、メモリには数百個のテンソルが残ります。このリセット後に再度トレーニングを行う場合は、新しいトレーニング データで再利用する必要があります。

16. 試してみましょう

それでは、独自のバージョンの Teachable Machine を試してみましょう。

ライブ プレビューに移動し、ウェブカメラを有効にして、部屋にあるオブジェクトのクラス 1 のサンプルを 30 個以上収集します。次に、別のオブジェクトのクラス 2 のサンプルを収集し、[トレーニング] をクリックして、コンソールログで進行状況を確認します。トレーニングはかなり高速に行われます。

bf1ac3cc5b15740.gif

トレーニングが完了したら、オブジェクトをカメラに映してライブ予測を取得します。ライブ予測は、ウェブページの上部付近にあるステータス テキスト エリアに表示されます。問題が発生した場合は、完成した動作コードを確認して、コピーし忘れたものがないか確認してください。

17. 完了

おめでとうございます!TensorFlow.js を使用してブラウザで転移学習の最初の例を完了しました。

さまざまなオブジェクトで試してテストしてみると、特に他のものと似ている場合、認識しにくいものがあることに気づくかもしれません。これらを区別するには、クラスやトレーニング データを追加する必要がある場合があります。

内容のまとめ

この Codelab では、以下のことを学びました。

  1. 転移学習の概要と、完全なモデルのトレーニングと比較したメリット。
  2. TensorFlow Hub から再利用可能なモデルを取得する方法。
  3. 転移学習に適したウェブアプリを設定する方法。
  4. ベースモデルを読み込んで使用し、画像の特徴を生成する方法。
  5. ウェブカメラの画像からカスタム オブジェクトを認識できる新しい予測ヘッドをトレーニングする方法。
  6. 結果のモデルを使用してデータをリアルタイムで分類する方法。

次のステップ

これで、実際に使える基礎的な知識を得ることができました。クリエイティブなアイデアを思いついたときは、この機械学習モデルのデプロイの原型を拡張して、取り組んでいる実際のユースケースに役立ててください。たとえば、現在働いている業界に革命を起こし、会社の仲間が日々の業務で重要なものを分類するモデルをトレーニングできるようにする、といったことも考えられます。可能性は無限大です

さらに詳しく知りたい場合は、このコース全体を無料で受講することを検討してください。この Codelab で現在使用している 2 つのモデルを 1 つのモデルに統合して効率を高める方法を学ぶことができます。

元の Teachable Machine アプリケーションの理論について詳しく知りたい場合は、こちらのチュートリアルをご覧ください。

作成したものを共有する

本日作成したものを他のクリエイティブなユースケースに簡単に拡張することもできます。ぜひ、既成概念にとらわれずにハッキングを続けてください。

ソーシャル メディアで #MadeWithTFJS ハッシュタグを使用すると、作成したプロジェクトが TensorFlow ブログ今後のイベントで取り上げられる可能性があります。皆様の作品をぜひお見せください。

参考になるウェブサイト