TensorFlow.js 転移学習による画像分類器

この Codelab では、シンプルな「Teachable Machine」を構築する方法を学びます。これは、強力かつ柔軟な機械学習ライブラリである TensorFlow.js を使用して、ブラウザで即座にトレーニングを行うカスタム画像分類器です。まず、ブラウザでの画像分類用に、MobileNet という一般的な事前トレーニング済みモデルを読み込んで実行します。次に、「転移学習」という手法を使用します。この手法では、事前トレーニング済みの MobileNet モデルでトレーニングをブートストラップし、アプリケーションに合わせてトレーニングを行うようにカスタマイズします。

この Codelab では、Teachable Machine アプリケーションの理論については説明しません。ご興味のある場合は、こちらのチュートリアルをご覧ください。

ラボの内容

  • 事前トレーニング済みの MobileNet モデルを読み込んで新しいデータに対する予測を行う方法
  • ウェブカメラを使用して予測を行う方法
  • MobileNet の中間アクティベーションを使用し、ウェブカメラで即座に定義した新しいクラスセットに対して転移学習を行う方法

では、始めましょう。

この Codelab を完了するには、次の準備が必要です。

  1. Chrome の最新バージョンか、その他の最新のブラウザ。
  2. テキスト エディタ。お使いのマシンでローカルに実行するか、CodepenGlitch などを介してウェブ上で実行します。
  3. HTML、CSS、JavaScript、Chrome DevTools(または、使い慣れたブラウザ DevTools)に関する知識。
  4. ニューラル ネットワークのコンセプトに関する深い理解。概要の説明や復習が必要な場合は、3blue1brown の動画や、Ashi Krishnan による JavaScript のディープ ラーニングに関する動画をご覧ください。

エディタで index.html を開き、次のコンテンツを追加します。

<html>
  <head>
    <!-- Load the latest version of TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script>
  </head>
  <body>
    <div id="console"></div>
    <!-- Add an image that we will use to test -->
    <img id="img" crossorigin src="https://i.imgur.com/JlUvsxa.jpg" width="227" height="227"/>
    <!-- Load index.js after the content of the page -->
    <script src="index.js"></script>
  </body>
</html>

次に、コードエディタで index.js ファイルを開いて作成し、次のコードを含めます。

let net;

async function app() {
  console.log('Loading mobilenet..');

  // Load the model.
  net = await mobilenet.load();
  console.log('Successfully loaded model');

  // Make a prediction through the model on our image.
  const imgEl = document.getElementById('img');
  const result = await net.classify(imgEl);
  console.log(result);
}

app();

ウェブページを実行するには、ウェブブラウザで index.html を開きます。Cloud Console を使用している場合は、プレビュー ページを更新します。

犬の写真と、デベロッパー ツールの JavaScript コンソールに MobileNet の上位予測が表示されるはずです。モデルのダウンロードには少し時間がかかる場合があります。しばらくお待ちください。

画像は正しく分類されましたか?

また、この機能はスマートフォンでも動作します。

それでは、この機能をさらにインタラクティブかつリアルタイムにしていきましょう。ウェブカメラ経由の画像に対して予測を実行するようウェブカメラを設定します。

まず、ウェブカメラの動画要素を設定します。index.html ファイルを開き、<body> セクション内に次の行を追加して、犬の画像を読み込むために使用していた <img> タグを削除します。

<video autoplay playsinline muted id="webcam" width="224" height="224"></video>

index.js ファイルを開き、webcamElement をファイルの先頭に追加します。

const webcamElement = document.getElementById('webcam');

これで、前に追加した app() 関数で画像を介した予測を削除し、代わりにウェブカメラ要素を介して予測を行う無限ループを作成できます。

async function app() {
  console.log('Loading mobilenet..');

  // Load the model.
  net = await mobilenet.load();
  console.log('Successfully loaded model');

  // Create an object from Tensorflow.js data API which could capture image
  // from the web camera as Tensor.
  const webcam = await tf.data.webcam(webcamElement);
  while (true) {
    const img = await webcam.capture();
    const result = await net.classify(img);

    document.getElementById('console').innerText = `
      prediction: ${result[0].className}\n
      probability: ${result[0].probability}
    `;
    // Dispose the tensor to release the memory.
    img.dispose();

    // Give some breathing room by waiting for the next animation frame to
    // fire.
    await tf.nextFrame();
  }
}

ウェブページでコンソールを開くと、ウェブカメラで収集された各フレームの確率を示す MobileNet の予測が表示されます。

ImageNet データセットは、通常ウェブカメラに表示される画像とはあまり似ていないため、意味をなさない場合があります。これをテストする方法の 1 つとして、ノートパソコンのカメラの前にスマートフォンの犬の写真をかざす方法があります。

では、これをさらに活用しましょう。ウェブカメラを使用し、その場で 3 つのクラス オブジェクトのカスタム分類器を作成します。これから MobileNet を使用して分類を行いますが、今回は特定のウェブカメラ画像に対するモデルの内部表現(アクティベーション)を取得し、それを分類に使用します。

「K 近傍法(KNN)分類器」と呼ばれるモジュールを使用します。このモジュールでは、ウェブカメラの画像(実際には MobileNet のアクティベーション)をさまざまなカテゴリ(つまり「クラス」)に効果的に分類でき、ユーザーが予測を行うよう求めた場合、予測対象のアクティベーションに最も類似したアクティベーションを持つクラスを選択するだけですみます。

index.html の <head> タグで、インポートの末尾に KNN 分類器のインポートを追加します(MobileNet は引き続き必要なため、そのインポートを削除しないでください)。

...
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script>
...

動画要素の下にある index.html の各ボタンに、3 つのボタンを追加します。これらのボタンを使用して、トレーニング画像をモデルに追加します。

...
<button id="class-a">Add A</button>
<button id="class-b">Add B</button>
<button id="class-c">Add C</button>
...

index.js の先頭で分類器を作成します。

const classifier = knnClassifier.create();

アプリの関数を更新します。

async function app() {
  console.log('Loading mobilenet..');

  // Load the model.
  net = await mobilenet.load();
  console.log('Successfully loaded model');

  // Create an object from Tensorflow.js data API which could capture image
  // from the web camera as Tensor.
  const webcam = await tf.data.webcam(webcamElement);

  // Reads an image from the webcam and associates it with a specific class
  // index.
  const addExample = async classId => {
    // Capture an image from the web camera.
    const img = await webcam.capture();

    // Get the intermediate activation of MobileNet 'conv_preds' and pass that
    // to the KNN classifier.
    const activation = net.infer(img, true);

    // Pass the intermediate activation to the classifier.
    classifier.addExample(activation, classId);

    // Dispose the tensor to release the memory.
    img.dispose();
  };

  // When clicking a button, add an example for that class.
  document.getElementById('class-a').addEventListener('click', () => addExample(0));
  document.getElementById('class-b').addEventListener('click', () => addExample(1));
  document.getElementById('class-c').addEventListener('click', () => addExample(2));

  while (true) {
    if (classifier.getNumClasses() > 0) {
      const img = await webcam.capture();

      // Get the activation from mobilenet from the webcam.
      const activation = net.infer(img, 'conv_preds');
      // Get the most likely class and confidence from the classifier module.
      const result = await classifier.predictClass(activation);

      const classes = ['A', 'B', 'C'];
      document.getElementById('console').innerText = `
        prediction: ${classes[result.label]}\n
        probability: ${result.confidences[result.label]}
      `;

      // Dispose the tensor to release the memory.
      img.dispose();
    }

    await tf.nextFrame();
  }
}

これで、index.html のページを読み込むと、一般的なオブジェクトや顔 / 体のジェスチャーを使用して、3 つのクラスそれぞれの画像をキャプチャできます。「追加」ボタンをクリックするたびに、トレーニング サンプルとして 1 つの画像がそのクラスに追加されます。その間も、モデルはウェブカメラの画像に対する予測を続行し、結果をリアルタイムで表示します。

アクションなしを表す別のクラスを追加してみてください。

この Codelab では、TensorFlow.js を使用して簡単な機械学習ウェブ アプリケーションを実装しました。ウェブカメラの画像を分類するため、事前トレーニング済みの MobileNet モデルを読み込んで使用しました。次に、モデルをカスタマイズして、画像を 3 つのカスタム カテゴリに分類しました。

TensorFlow.js のコードサンプルやデモの詳細については、js.tensorflow.org をご覧ください。アプリケーションでの TensorFlow.js の使用方法をご確認いただけます。