TensorFlow.js - 転移学習を使用した音声認識

1. はじめに

この Codelab では、音声認識ネットワークを構築し、それを使用して音声を出すことでブラウザのスライダーを制御します。強力かつ柔軟な機械学習ライブラリである TensorFlow.js を使用します。

まず、20 個の音声コマンドを認識できる 事前トレーニング済みのモデル を読み込んで実行します。次に、マイクを使用して、音声を認識してスライダーを左右に動かすシンプルなニューラル ネットワークを構築してトレーニングします。

この Codelab では、音声認識モデルの理論については説明しません 。興味がある場合は、こちらのチュートリアルをご覧ください。

この Codelab で使用する機械学習用語の用語集も作成しました。

学習内容

  • 事前トレーニング済みの音声コマンド認識モデルを読み込む方法
  • マイクを使用してリアルタイムで予測を行う方法
  • ブラウザのマイクを使用してカスタム音声認識モデルをトレーニングして使用する方法

では始めましょう。

2. 要件

この Codelab を完了するには、次の準備が必要です。

  1. 最新バージョンの Chrome またはその他の最新のブラウザ。
  2. テキスト エディタ。お使いのマシンでローカルに実行するか、CodepenGlitch などを介してウェブ上で実行します。
  3. HTML、CSS、JavaScript、および Chrome DevTools(または、使い慣れたブラウザ DevTools)に関する知識。
  4. ニューラル ネットワークのコンセプトに関する深い理解。概要の説明や復習が必要な場合は、3blue1brown の動画や、Ashi Krishnan による JavaScript のディープ ラーニングに関する動画をご覧ください。

3. TensorFlow.js と音声モデルを読み込む

エディタで index.html を開き、次のコンテンツを追加します。

<html>
  <head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/speech-commands"></script>
  </head>
  <body>
    <div id="console"></div>
    <script src="index.js"></script>
  </body>
</html>

最初の <script> タグは TensorFlow.js ライブラリをインポートし、2 番目の <script> は事前トレーニング済みの Speech Commands モデルをインポートします。<div id="console"> タグは、モデルの出力を表示するために使用されます。

4. リアルタイムで予測する

次に、コードエディタでファイル index.js を開くか作成して、次のコードを追加します。

let recognizer;

function predictWord() {
 // Array of words that the recognizer is trained to recognize.
 const words = recognizer.wordLabels();
 recognizer.listen(({scores}) => {
   // Turn scores into a list of (score,word) pairs.
   scores = Array.from(scores).map((s, i) => ({score: s, word: words[i]}));
   // Find the most probable word.
   scores.sort((s1, s2) => s2.score - s1.score);
   document.querySelector('#console').textContent = scores[0].word;
 }, {probabilityThreshold: 0.75});
}

async function app() {
 recognizer = speechCommands.create('BROWSER_FFT');
 await recognizer.ensureModelLoaded();
 predictWord();
}

app();

5. 予測をテストする

デバイスにマイクがあることを確認します。これはモバイルでも動作します。ウェブページを実行するには、ブラウザで index.html を開きます。ローカル ファイルから作業している場合は、マイクにアクセスするためにウェブサーバーを起動して http://localhost:port/ を使用する必要があります。

ポート 8000 でシンプルなウェブサーバーを起動するには:

python -m SimpleHTTPServer

モデルのダウンロードに時間がかかる場合がありますので、しばらくお待ちください。モデルが読み込まれると、ページの上部に単語が表示されます。このモデルは、0 ~ 9 の数字と、「left」、「right」、「yes」、「no」などの追加のコマンドを認識するようにトレーニングされています。

これらの単語のいずれかを話します。単語は正しく認識されますか?probabilityThreshold を調整して、モデルが発火する頻度を制御します。0.75 の場合、モデルは特定の単語を聞いたと 75% 以上の確信がある場合に発火します。

Speech Commands モデルとその API の詳細については、Github の README.md をご覧ください。

6. データを収集する

スライダーを制御するために、単語全体ではなく短い音を使用してみましょう。

スライダーを左右に動かす 3 つの異なるコマンド(「Left」、「Right」、「Noise」)を認識するようにモデルをトレーニングします。「Noise」(操作は不要)を認識することは、音声検出において非常に重要です。スライダーは、適切な音を出したときにのみ反応し、一般的に話したり動き回ったりしているときには反応しないようにする必要があります。

  1. まず、データを収集する必要があります。<body> タグ内に次のコードを追加して、シンプルな UI をアプリに追加します。<div id="console">
<button id="left" onmousedown="collect(0)" onmouseup="collect(null)">Left</button>
<button id="right" onmousedown="collect(1)" onmouseup="collect(null)">Right</button>
<button id="noise" onmousedown="collect(2)" onmouseup="collect(null)">Noise</button>
  1. 次のコードを index.js に追加します。
// One frame is ~23ms of audio.
const NUM_FRAMES = 3;
let examples = [];

function collect(label) {
 if (recognizer.isListening()) {
   return recognizer.stopListening();
 }
 if (label == null) {
   return;
 }
 recognizer.listen(async ({spectrogram: {frameSize, data}}) => {
   let vals = normalize(data.subarray(-frameSize * NUM_FRAMES));
   examples.push({vals, label});
   document.querySelector('#console').textContent =
       `${examples.length} examples collected`;
 }, {
   overlapFactor: 0.999,
   includeSpectrogram: true,
   invokeCallbackOnNoiseAndUnknown: true
 });
}

function normalize(x) {
 const mean = -100;
 const std = 10;
 return x.map(x => (x - mean) / std);
}
  1. app() から predictWord() を削除します。
async function app() {
 recognizer = speechCommands.create('BROWSER_FFT');
 await recognizer.ensureModelLoaded();
 // predictWord() no longer called.
}

詳細

このコードは最初は難解に思えるかもしれませんが、分解して見ていきましょう。

モデルに認識させたい 3 つのコマンドに対応する「Left」、「Right」、「Noise」というラベルの付いた 3 つのボタンを UI に追加しました。これらのボタンを押すと、新しく追加した collect() 関数が呼び出され、モデルのトレーニング サンプルが作成されます。

collect() は、labelrecognizer.listen() の出力に関連付けます。includeSpectrogram が true なので, recognizer.listen() は 1 秒間の音声の生のスペクトログラム(周波数データ)を 43 フレームに分割して返します。つまり、各フレームは ~23 ミリ秒の音声です。

recognizer.listen(async ({spectrogram: {frameSize, data}}) => {
...
}, {includeSpectrogram: true});

スライダーを制御するために単語ではなく短い音を使用するため、最後の 3 フレーム(~70 ミリ秒)のみを考慮します。

let vals = normalize(data.subarray(-frameSize * NUM_FRAMES));

数値の問題を回避するため、データの平均が 0、標準偏差が 1 になるようにデータを正規化します。この場合、スペクトログラムの値は通常、-100 前後の大きな負の数で、偏差は 10 です。

const mean = -100;
const std = 10;
return x.map(x => (x - mean) / std);

最後に、各トレーニング サンプルには次の 2 つのフィールドがあります。

  • label****: 「Left」、「Right」、「Noise」の場合はそれぞれ 0、1、2。
  • vals****: 周波数情報(スペクトログラム)を保持する 696 個の数値。

すべてのデータは examples 変数に保存されます。

examples.push({vals, label});

7. データの収集をテストする

ブラウザで index.html を開くと、3 つのコマンドに対応する 3 つのボタンが表示されます。ローカル ファイルから作業している場合は、マイクにアクセスするためにウェブサーバーを起動して http://localhost:port/ を使用する必要があります。

ポート 8000 でシンプルなウェブサーバーを起動するには:

python -m SimpleHTTPServer

各コマンドのサンプルを収集するには、各ボタンを 3 ~ 4 秒間長押し しながら、一定の音を繰り返し(または継続的に)出します。ラベルごとに ~150 個のサンプルを収集する必要があります。たとえば、「Left」の場合は指を鳴らし、「Right」の場合は口笛を吹き、「Noise」の場合は無音と会話を交互に行います。

サンプルを収集すると、ページに表示されるカウンタが増加します。コンソールで examples 変数に対して console.log() を呼び出して、データを検査することもできます。この時点での目標は、データの収集プロセスをテストすることです。後で、アプリ全体のテスト時にデータを再収集します。

8. モデルをトレーニングする

  1. index.html の本文で、[Noise] ボタンの直後に [Train] ボタンを追加します。
<br/><br/>
<button id="train" onclick="train()">Train</button>
  1. index.js の既存のコードに次のコードを追加します。
const INPUT_SHAPE = [NUM_FRAMES, 232, 1];
let model;

async function train() {
 toggleButtons(false);
 const ys = tf.oneHot(examples.map(e => e.label), 3);
 const xsShape = [examples.length, ...INPUT_SHAPE];
 const xs = tf.tensor(flatten(examples.map(e => e.vals)), xsShape);

 await model.fit(xs, ys, {
   batchSize: 16,
   epochs: 10,
   callbacks: {
     onEpochEnd: (epoch, logs) => {
       document.querySelector('#console').textContent =
           `Accuracy: ${(logs.acc * 100).toFixed(1)}% Epoch: ${epoch + 1}`;
     }
   }
 });
 tf.dispose([xs, ys]);
 toggleButtons(true);
}

function buildModel() {
 model = tf.sequential();
 model.add(tf.layers.depthwiseConv2d({
   depthMultiplier: 8,
   kernelSize: [NUM_FRAMES, 3],
   activation: 'relu',
   inputShape: INPUT_SHAPE
 }));
 model.add(tf.layers.maxPooling2d({poolSize: [1, 2], strides: [2, 2]}));
 model.add(tf.layers.flatten());
 model.add(tf.layers.dense({units: 3, activation: 'softmax'}));
 const optimizer = tf.train.adam(0.01);
 model.compile({
   optimizer,
   loss: 'categoricalCrossentropy',
   metrics: ['accuracy']
 });
}

function toggleButtons(enable) {
 document.querySelectorAll('button').forEach(b => b.disabled = !enable);
}

function flatten(tensors) {
 const size = tensors[0].length;
 const result = new Float32Array(tensors.length * size);
 tensors.forEach((arr, i) => result.set(arr, i * size));
 return result;
}
  1. アプリが読み込まれたときに buildModel() を呼び出します。
async function app() {
 recognizer = speechCommands.create('BROWSER_FFT');
 await recognizer.ensureModelLoaded();
 // Add this line.
 buildModel();
}

この時点でアプリを更新すると、新しい [Train] ボタンが表示されます。データを再収集して [Train] をクリックしてトレーニングをテストすることも、ステップ 10 まで待って予測とともにトレーニングをテストすることもできます。

詳細

大まかに言うと、buildModel() はモデル アーキテクチャを定義し、train() は収集したデータを使用してモデルをトレーニングします。

モデル アーキテクチャ

このモデルには 4 つのレイヤがあります。音声データ(スペクトログラムとして表される)を処理する畳み込みレイヤ、最大プーリング レイヤ、フラット化レイヤ、3 つのアクションにマッピングする密結合レイヤです。

model = tf.sequential();
 model.add(tf.layers.depthwiseConv2d({
   depthMultiplier: 8,
   kernelSize: [NUM_FRAMES, 3],
   activation: 'relu',
   inputShape: INPUT_SHAPE
 }));
 model.add(tf.layers.maxPooling2d({poolSize: [1, 2], strides: [2, 2]}));
 model.add(tf.layers.flatten());
 model.add(tf.layers.dense({units: 3, activation: 'softmax'}));

モデルの入力形状は [NUM_FRAMES, 232, 1] です。各フレームは 23 ミリ秒の音声で、異なる周波数に対応する 232 個の数値が含まれています(232 は、人間の声をキャプチャするために必要な周波数バケットの量であるため選択されました)。この Codelab では、スライダーを制御するために単語全体ではなく音を出すため、3 フレームの長さのサンプル(~70 ミリ秒のサンプル)を使用します。

トレーニングの準備として、モデルをコンパイルします。

const optimizer = tf.train.adam(0.01);
 model.compile({
   optimizer,
   loss: 'categoricalCrossentropy',
   metrics: ['accuracy']
 });

ディープ ラーニングで一般的に使用されるオプティマイザーである Adam オプティマイザー を使用し、損失には分類に使用される標準の損失関数である categoricalCrossEntropy を使用します。つまり、予測確率(クラスごとに 1 つの確率)が、真のクラスで 100% の確率、他のすべてのクラスで 0% の確率からどれだけ離れているかを測定します。また、モニタリングする指標として accuracy も提供します。これにより、トレーニングのエポックごとにモデルが正しく取得したサンプルの割合がわかります。

トレーニング

トレーニングは、バッチサイズ 16(一度に 16 個のサンプルを処理)を使用して、データを 10 回(エポック)処理し、現在の精度を UI に表示します。

await model.fit(xs, ys, {
   batchSize: 16,
   epochs: 10,
   callbacks: {
     onEpochEnd: (epoch, logs) => {
       document.querySelector('#console').textContent =
           `Accuracy: ${(logs.acc * 100).toFixed(1)}% Epoch: ${epoch + 1}`;
     }
   }
 });

9. スライダーをリアルタイムで更新する

モデルをトレーニングできるようになったので、リアルタイムで予測を行ってスライダーを動かすコードを追加しましょう。index.html の [Train] ボタンの直後に次のコードを追加します。

<br/><br/>
<button id="listen" onclick="listen()">Listen</button>
<input type="range" id="output" min="0" max="10" step="0.1">

index.js に次のコードを追加します。

async function moveSlider(labelTensor) {
 const label = (await labelTensor.data())[0];
 document.getElementById('console').textContent = label;
 if (label == 2) {
   return;
 }
 let delta = 0.1;
 const prevValue = +document.getElementById('output').value;
 document.getElementById('output').value =
     prevValue + (label === 0 ? -delta : delta);
}

function listen() {
 if (recognizer.isListening()) {
   recognizer.stopListening();
   toggleButtons(true);
   document.getElementById('listen').textContent = 'Listen';
   return;
 }
 toggleButtons(false);
 document.getElementById('listen').textContent = 'Stop';
 document.getElementById('listen').disabled = false;

 recognizer.listen(async ({spectrogram: {frameSize, data}}) => {
   const vals = normalize(data.subarray(-frameSize * NUM_FRAMES));
   const input = tf.tensor(vals, [1, ...INPUT_SHAPE]);
   const probs = model.predict(input);
   const predLabel = probs.argMax(1);
   await moveSlider(predLabel);
   tf.dispose([input, probs, predLabel]);
 }, {
   overlapFactor: 0.999,
   includeSpectrogram: true,
   invokeCallbackOnNoiseAndUnknown: true
 });
}

詳細

リアルタイム予測

listen() はマイクをリッスンし、リアルタイムで予測を行います。このコードは collect() メソッドとよく似ています。生のスペクトログラムを正規化し、最後の NUM_FRAMES フレーム以外のすべてを削除します。唯一の違いは、トレーニング済みのモデルを呼び出して予測を取得することです。

const probs = model.predict(input);
const predLabel = probs.argMax(1);
await moveSlider(predLabel);

model.predict(input) の出力は、クラス数に対する確率分布を表す形状 [1, numClasses] のテンソルです。簡単に言うと、これは可能な出力クラスごとに信頼度のセットであり、合計は 1 になります。Tensor の外部ディメンションは 1 です。これはバッチサイズ(1 つのサンプル)です。

確率分布を、最も可能性の高いクラスを表す単一の整数に変換するには、probs.argMax(1) を呼び出します。これにより、確率が最も高いクラスのインデックスが返されます。軸パラメータとして「1」を渡します。これは、最後のディメンション numClassesargMax を計算するためです。

スライダーを更新する

moveSlider() は、ラベルが 0(「Left」)の場合はスライダーの値を減らし、ラベルが 1(「Right」)の場合は増やし、ラベルが 2(「Noise」)の場合は無視します。

Tensor を破棄する

GPU メモリをクリーンアップするには、出力 Tensor で tf.dispose() を手動で呼び出すことが重要です。手動の tf.dispose() の代わりに、関数呼び出しを tf.tidy() でラップすることもできますが、これは非同期関数では使用できません。

   tf.dispose([input, probs, predLabel]);

10. 最終的なアプリをテストする

ブラウザで index.html を開き、前のセクションと同じように、3 つのコマンドに対応する 3 つのボタンを使用してデータを収集します。データを収集するときは、各ボタンを 3 ~ 4 秒間長押し してください。

サンプルを収集したら、[Train] ボタンを押します。これによりモデルのトレーニングが開始され、モデルの精度が 90% を超えるはずです。モデルのパフォーマンスが向上しない場合は、より多くのデータを収集してみてください。

トレーニングが完了したら、[Listen] ボタンを押してマイクから予測を行い、スライダーを制御します。

その他のチュートリアルについては、http://js.tensorflow.org/ をご覧ください。