TensorFlow.js - 転移学習を使用した音声認識

1. はじめに

この Codelab では、音声認識ネットワークを構築し、それを使用して音を立てることでブラウザのスライダーを制御します。この Codelab では、強力かつ柔軟な機械学習ライブラリである TensorFlow.js を使用します。

まず、20 個の音声コマンドを認識できる事前トレーニング済みモデルを読み込んで実行します。次に、マイクを使用して、音を認識してスライダーを左右に動かすシンプルなニューラル ネットワークを構築してトレーニングします。

この Codelab では、音声認識モデルの理論については説明しません。ご興味のある場合は、こちらのチュートリアルをご覧ください。

また、この Codelab で使用されている ML 用語の用語集も作成しました。

学習内容

  • 事前トレーニング済みの音声コマンド認識モデルを読み込む方法
  • マイクを使用してリアルタイム予測を行う方法
  • ブラウザのマイクを使用してカスタム音声認識モデルをトレーニングして使用する方法

では始めましょう。

2. 要件

この Codelab を完了するには、次の準備が必要です。

  1. Chrome の最新バージョンか、その他の最新のブラウザ。
  2. テキスト エディタ。お使いのマシンでローカルに実行するか、CodepenGlitch などを介してウェブ上で実行します。
  3. HTML、CSS、JavaScript、Chrome DevTools(または、使い慣れたブラウザ DevTools)に関する知識。
  4. ニューラル ネットワークのコンセプトに関する深い理解。概要の説明や復習が必要な場合は、3blue1brown の動画や、Ashi Krishnan による JavaScript のディープ ラーニングに関する動画をご覧ください。

3. TensorFlow.js と音声モデルを読み込む

エディタで index.html を開き、次のコンテンツを追加します。

<html>
  <head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/speech-commands"></script>
  </head>
  <body>
    <div id="console"></div>
    <script src="index.js"></script>
  </body>
</html>

最初の <script> タグは TensorFlow.js ライブラリをインポートし、2 番目の <script> は事前トレーニング済みの Speech Commands モデルをインポートします。<div id="console"> タグは、モデルの出力を表示するために使用されます。

4. リアルタイムで予測する

次に、コードエディタで index.js ファイルを開いて作成し、次のコードを含めます。

let recognizer;

function predictWord() {
 // Array of words that the recognizer is trained to recognize.
 const words = recognizer.wordLabels();
 recognizer.listen(({scores}) => {
   // Turn scores into a list of (score,word) pairs.
   scores = Array.from(scores).map((s, i) => ({score: s, word: words[i]}));
   // Find the most probable word.
   scores.sort((s1, s2) => s2.score - s1.score);
   document.querySelector('#console').textContent = scores[0].word;
 }, {probabilityThreshold: 0.75});
}

async function app() {
 recognizer = speechCommands.create('BROWSER_FFT');
 await recognizer.ensureModelLoaded();
 predictWord();
}

app();

5. 予測をテストする

デバイスにマイクが搭載されていることを確認します。また、この機能はスマートフォンでも動作します。ウェブページを実行するには、ブラウザで index.html を開きます。ローカル ファイルから作業している場合は、マイクにアクセスするためにウェブサーバーを起動して http://localhost:port/ を使用する必要があります。

ポート 8000 でシンプルなウェブサーバーを起動するには:

python -m SimpleHTTPServer

モデルのダウンロードには時間がかかることがあります。モデルが読み込まれると、ページの上部に単語が表示されます。このモデルは、0 ~ 9 の数字と、「left」、「right」、「yes」、「no」などのいくつかの追加コマンドを認識するようにトレーニングされています。

いずれかの単語を発音します。単語は正しく認識されますか?probabilityThreshold を調整して、モデルが発火する頻度を制御します。0.75 の場合、モデルは特定の単語が聞こえたと 75% 以上の確信がある場合に発火します。

音声コマンド モデルとその API の詳細については、Github の README.md をご覧ください。

6. データを収集する

スライダーを操作する際に、単語全体ではなく短い音を使用することで、楽しく操作できます。

ここでは、スライダーを左右に移動させる「左」、「右」、「ノイズ」の 3 つのコマンドを認識するモデルをトレーニングします。「ノイズ」を認識する(アクションは不要)ことは、音声検出において非常に重要です。スライダーは、正しい音を発したときのみ反応し、一般的に話したり動き回ったりしているときには反応しないようにする必要があるためです。

  1. まず、データを収集する必要があります。<div id="console"> の前の <body> タグ内に次のコードを追加して、アプリにシンプルな UI を追加します。
<button id="left" onmousedown="collect(0)" onmouseup="collect(null)">Left</button>
<button id="right" onmousedown="collect(1)" onmouseup="collect(null)">Right</button>
<button id="noise" onmousedown="collect(2)" onmouseup="collect(null)">Noise</button>
  1. index.js に以下を追加します。
// One frame is ~23ms of audio.
const NUM_FRAMES = 3;
let examples = [];

function collect(label) {
 if (recognizer.isListening()) {
   return recognizer.stopListening();
 }
 if (label == null) {
   return;
 }
 recognizer.listen(async ({spectrogram: {frameSize, data}}) => {
   let vals = normalize(data.subarray(-frameSize * NUM_FRAMES));
   examples.push({vals, label});
   document.querySelector('#console').textContent =
       `${examples.length} examples collected`;
 }, {
   overlapFactor: 0.999,
   includeSpectrogram: true,
   invokeCallbackOnNoiseAndUnknown: true
 });
}

function normalize(x) {
 const mean = -100;
 const std = 10;
 return x.map(x => (x - mean) / std);
}
  1. predictWord()app() から削除します。
async function app() {
 recognizer = speechCommands.create('BROWSER_FFT');
 await recognizer.ensureModelLoaded();
 // predictWord() no longer called.
}

詳細

このコードは最初はわかりにくい可能性があるため、詳しく説明します。

UI に「Left」、「Right」、「Noise」という 3 つのボタンを追加しました。これらは、モデルに認識させたい 3 つのコマンドに対応しています。これらのボタンを押すと、新たに追加された collect() 関数が呼び出され、モデルのトレーニング例が作成されます。

collect() は、labelrecognizer.listen() の出力に関連付けます。includeSpectrogram が true, のため、recognizer.listen() は 1 秒の音声の生スペクトログラム(周波数データ)を 43 フレームに分割して返します。つまり、各フレームは ~23 ミリ秒の音声です。

recognizer.listen(async ({spectrogram: {frameSize, data}}) => {
...
}, {includeSpectrogram: true});

スライダーを制御するために単語ではなく短い音を使用したいので、最後の 3 フレーム(約 70 ミリ秒)のみを考慮します。

let vals = normalize(data.subarray(-frameSize * NUM_FRAMES));

数値に関する問題を回避するため、データの平均が 0、標準偏差が 1 になるように正規化します。この場合、スペクトログラムの値は通常、-100 前後の大きな負の数で、偏差は 10 です。

const mean = -100;
const std = 10;
return x.map(x => (x - mean) / std);

最後に、各トレーニング例には 2 つのフィールドがあります。

  • label****: 「左」、「右」、「ノイズ」にそれぞれ 0、1、2。
  • vals****: 周波数情報(スペクトログラム)を保持する 696 個の数値

すべてのデータを examples 変数に保存します。

examples.push({vals, label});

7. テストデータの収集

ブラウザで index.html を開くと、3 つのコマンドに対応する 3 つのボタンが表示されます。ローカル ファイルから作業している場合は、ウェブサーバーを起動して http://localhost:port/ を使用してマイクにアクセスする必要があります。

ポート 8000 でシンプルなウェブサーバーを起動するには:

python -m SimpleHTTPServer

各コマンドの例を収集するには、各ボタンを 3 ~ 4 秒間長押ししながら、一貫した音を繰り返し(または継続的に)出します。各ラベルについて約 150 個の例を収集する必要があります。たとえば、「左」は指を鳴らし、「右」は口笛を吹き、「ノイズ」は無音と発言を交互に行います。

例を収集するにつれて、ページに表示されるカウンタが増加します。コンソールで examples 変数に対して console.log() を呼び出して、データを検査することもできます。この時点での目標は、データ収集プロセスをテストすることです。後でアプリ全体をテストするときに、データを再収集します。

8. モデルをトレーニングする

  1. index.html: の本文の [ノイズ] ボタンの直後に [電車] ボタンを追加します。
<br/><br/>
<button id="train" onclick="train()">Train</button>
  1. index.js の既存のコードに以下を追加します。
const INPUT_SHAPE = [NUM_FRAMES, 232, 1];
let model;

async function train() {
 toggleButtons(false);
 const ys = tf.oneHot(examples.map(e => e.label), 3);
 const xsShape = [examples.length, ...INPUT_SHAPE];
 const xs = tf.tensor(flatten(examples.map(e => e.vals)), xsShape);

 await model.fit(xs, ys, {
   batchSize: 16,
   epochs: 10,
   callbacks: {
     onEpochEnd: (epoch, logs) => {
       document.querySelector('#console').textContent =
           `Accuracy: ${(logs.acc * 100).toFixed(1)}% Epoch: ${epoch + 1}`;
     }
   }
 });
 tf.dispose([xs, ys]);
 toggleButtons(true);
}

function buildModel() {
 model = tf.sequential();
 model.add(tf.layers.depthwiseConv2d({
   depthMultiplier: 8,
   kernelSize: [NUM_FRAMES, 3],
   activation: 'relu',
   inputShape: INPUT_SHAPE
 }));
 model.add(tf.layers.maxPooling2d({poolSize: [1, 2], strides: [2, 2]}));
 model.add(tf.layers.flatten());
 model.add(tf.layers.dense({units: 3, activation: 'softmax'}));
 const optimizer = tf.train.adam(0.01);
 model.compile({
   optimizer,
   loss: 'categoricalCrossentropy',
   metrics: ['accuracy']
 });
}

function toggleButtons(enable) {
 document.querySelectorAll('button').forEach(b => b.disabled = !enable);
}

function flatten(tensors) {
 const size = tensors[0].length;
 const result = new Float32Array(tensors.length * size);
 tensors.forEach((arr, i) => result.set(arr, i * size));
 return result;
}
  1. アプリの読み込み時に buildModel() を呼び出します。
async function app() {
 recognizer = speechCommands.create('BROWSER_FFT');
 await recognizer.ensureModelLoaded();
 // Add this line.
 buildModel();
}

この時点でアプリを更新すると、新しい [トレーニング] ボタンが表示されます。トレーニングをテストするには、データを再収集して [トレーニング] をクリックするか、ステップ 10 まで待って予測とともにトレーニングをテストします。

詳細

大まかに言うと、buildModel() はモデル アーキテクチャを定義し、train() は収集されたデータを使用してモデルをトレーニングします。

モデル アーキテクチャ

このモデルには 4 つのレイヤがあります。音声データ(スペクトログラムとして表される)を処理する畳み込みレイヤ、最大プーリング レイヤ、フラット化レイヤ、3 つのアクションにマッピングする密レイヤです。

model = tf.sequential();
 model.add(tf.layers.depthwiseConv2d({
   depthMultiplier: 8,
   kernelSize: [NUM_FRAMES, 3],
   activation: 'relu',
   inputShape: INPUT_SHAPE
 }));
 model.add(tf.layers.maxPooling2d({poolSize: [1, 2], strides: [2, 2]}));
 model.add(tf.layers.flatten());
 model.add(tf.layers.dense({units: 3, activation: 'softmax'}));

モデルの入力形状は [NUM_FRAMES, 232, 1] です。各フレームは 23 ms の音声で、さまざまな周波数に対応する 232 個の数値が含まれています(232 は、人間の声をキャプチャするために必要な周波数バケットの数であるため、選択されました)。この Codelab では、スライダーを操作するために単語全体を発音するのではなく、音を出すため、3 フレームの長さのサンプル(約 70 ミリ秒のサンプル)を使用しています。

モデルをコンパイルして、トレーニングの準備をします。

const optimizer = tf.train.adam(0.01);
 model.compile({
   optimizer,
   loss: 'categoricalCrossentropy',
   metrics: ['accuracy']
 });

ディープ ラーニングでよく使用されるオプティマイザーである Adam オプティマイザーと、分類に使用される標準的な損失関数である categoricalCrossEntropy を損失に使用します。簡単に言うと、予測された確率(クラスごとに 1 つの確率)が、真のクラスで 100% の確率、他のすべてのクラスで 0% の確率になるまでの距離を測定します。また、モニタリングする指標として accuracy も提供します。これにより、トレーニングの各エポック後にモデルが正しく分類した例の割合がわかります。

トレーニング

トレーニングは、バッチサイズ 16(一度に 16 個のサンプルを処理)を使用してデータを 10 回(エポック)実行し、現在の精度を UI に表示します。

await model.fit(xs, ys, {
   batchSize: 16,
   epochs: 10,
   callbacks: {
     onEpochEnd: (epoch, logs) => {
       document.querySelector('#console').textContent =
           `Accuracy: ${(logs.acc * 100).toFixed(1)}% Epoch: ${epoch + 1}`;
     }
   }
 });

9. スライダーをリアルタイムで更新する

モデルをトレーニングできるようになったので、リアルタイムで予測を行い、スライダーを移動するコードを追加しましょう。index.html の [Train] ボタンの直後に次のコードを追加します。

<br/><br/>
<button id="listen" onclick="listen()">Listen</button>
<input type="range" id="output" min="0" max="10" step="0.1">

index.js には次のコードを記述します。

async function moveSlider(labelTensor) {
 const label = (await labelTensor.data())[0];
 document.getElementById('console').textContent = label;
 if (label == 2) {
   return;
 }
 let delta = 0.1;
 const prevValue = +document.getElementById('output').value;
 document.getElementById('output').value =
     prevValue + (label === 0 ? -delta : delta);
}

function listen() {
 if (recognizer.isListening()) {
   recognizer.stopListening();
   toggleButtons(true);
   document.getElementById('listen').textContent = 'Listen';
   return;
 }
 toggleButtons(false);
 document.getElementById('listen').textContent = 'Stop';
 document.getElementById('listen').disabled = false;

 recognizer.listen(async ({spectrogram: {frameSize, data}}) => {
   const vals = normalize(data.subarray(-frameSize * NUM_FRAMES));
   const input = tf.tensor(vals, [1, ...INPUT_SHAPE]);
   const probs = model.predict(input);
   const predLabel = probs.argMax(1);
   await moveSlider(predLabel);
   tf.dispose([input, probs, predLabel]);
 }, {
   overlapFactor: 0.999,
   includeSpectrogram: true,
   invokeCallbackOnNoiseAndUnknown: true
 });
}

詳細

リアルタイム予測

listen() はマイクをリッスンし、リアルタイムの予測を行います。このコードは、未加工のスペクトログラムを正規化し、最後の NUM_FRAMES フレーム以外のすべてのフレームをドロップする collect() メソッドとよく似ています。唯一の違いは、トレーニング済みモデルを呼び出して予測を取得していることです。

const probs = model.predict(input);
const predLabel = probs.argMax(1);
await moveSlider(predLabel);

model.predict(input) の出力は、クラス数の確率分布を表す形状 [1, numClasses] の Tensor です。簡単に言うと、これは可能な出力クラスごとの信頼度のセットであり、合計すると 1 になります。テンソルの外側のディメンションは 1 です。これはバッチのサイズ(単一の例)であるためです。

確率分布を最も可能性の高いクラスを表す単一の整数に変換するには、probs.argMax(1) を呼び出します。これにより、最も確率の高いクラスのインデックスが返されます。最後のディメンション numClassesargMax を計算するため、軸パラメータとして「1」を渡します。

スライダーを更新する

moveSlider() は、ラベルが 0(「左」)の場合はスライダーの値を減らし、ラベルが 1(「右」)の場合は増やし、ラベルが 2(「ノイズ」)の場合は無視します。

テンソルの破棄

GPU メモリをクリーンアップするには、出力 Tensor で tf.dispose() を手動で呼び出すことが重要です。手動の tf.dispose() の代替手段は、関数呼び出しを tf.tidy() でラップすることですが、これは非同期関数では使用できません。

   tf.dispose([input, probs, predLabel]);

10. 最終的なアプリをテストする

ブラウザで index.html を開き、前のセクションと同じように、3 つのコマンドに対応する 3 つのボタンを使用してデータを収集します。データを収集する際は、各ボタンを 3 ~ 4 秒間長押ししてください。

例を収集したら、[トレーニング] ボタンを押します。これにより、モデルのトレーニングが開始され、モデルの精度が 90% を超えることがわかります。モデルのパフォーマンスが向上しない場合は、より多くのデータを収集してみてください。

トレーニングが完了したら、[聞く] ボタンを押して、マイクから予測を行い、スライダーを操作します。

その他のチュートリアルについては、http://js.tensorflow.org/ をご覧ください。