Node.js での TensorFlow.js トレーニングの Codelab

1. はじめに

この Codelab では、JavaScript 用の強力で柔軟な ML ライブラリである TensorFlow.js を使用して、サーバーサイドで野球の投球タイプのトレーニングと分類を行う Node.js ウェブサーバーを構築する方法を学びます。ウェブ アプリケーションを構築して、ピッチセンサーのデータからピッチのタイプを予測するモデルをトレーニングし、ウェブ クライアントから予測を呼び出すようにします。この Codelab の完全に機能するバージョンは、tfjs-examples GitHub リポジトリにあります。

学習内容

  • Node.js で使用する tensorflow.js npm パッケージをインストールして設定する方法。
  • Node.js 環境でトレーニング データとテストデータにアクセスする方法
  • Node.js サーバーで TensorFlow.js を使用してモデルをトレーニングする方法。
  • クライアント/サーバー アプリケーションで推論用のトレーニング済みモデルをデプロイする方法。

では始めましょう。

2. 要件

この Codelab を完了するには、以下が必要です。

  1. Chrome の最新バージョンまたは他の最新のブラウザ。
  2. マシンのローカルで実行されるテキスト エディタとコマンド ターミナル。
  3. HTML、CSS、JavaScript、Chrome DevTools(または推奨ブラウザ開発ツール)に関する知識。
  4. ニューラル ネットワークの概要レベルの概念。概要の説明や復習が必要な場合は、3blue1brown によるこちらの動画または Ashi Krishnan による JavaScript のディープ ラーニングに関する動画をご覧ください。

3. Node.js アプリを設定する

Node.js と npm をインストールします。サポートされているプラットフォームと依存関係については、tfjs-node インストール ガイドをご覧ください。

Node.js アプリ用に ./baseball という名前のディレクトリを作成します。リンクされている package.jsonwebpack.config.js をこのディレクトリにコピーして、npm パッケージの依存関係(@tensorflow/tfjs-node npm パッケージを含む)を構成します。次に、npm install を実行して依存関係をインストールします。

$ cd baseball
$ ls
package.json  webpack.config.js
$ npm install
...
$ ls
node_modules  package.json  package-lock.json  webpack.config.js

これで、コードを記述してモデルをトレーニングする準備が整いました。

4. トレーニング データとテストデータを設定する

以下のリンクから、トレーニング データとテストデータを CSV ファイルとして使用します。次のファイルのデータをダウンロードして探索します。

pitch_type_training_data.csv

pitch_type_test_data.csv

サンプルのトレーニング データを見てみましょう。

vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0

ピッチセンサーのデータを説明する 8 つの入力特徴があります。

  • ボール速度(vx0、vy0、vz0)
  • ボールの加速度(ax、ay、az)
  • ピッチの開始速度
  • 投手が左利きかどうか

1 つの出力ラベル:

  • 7 つの提案タイプのいずれかを表す pitch_code: Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball

目標は、指定されたピッチセンサーデータからピッチタイプを予測できるモデルを構築することです。

モデルを作成する前に、トレーニング データとテストデータを準備する必要があります。baseball/ dir に pick_type.js ファイルを作成し、次のコードをコピーします。このコードは、tf.data.csv API を使用してトレーニング データとテストデータを読み込みます。また、最小 / 最大正規化スケールを使用してデータを正規化します(これは常に推奨です)。

const tf = require('@tensorflow/tfjs');

// util function to normalize a value between a given range.
function normalize(value, min, max) {
  if (min === undefined || max === undefined) {
    return value;
  }
  return (value - min) / (max - min);
}

// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH =    'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';

// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;

const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;

// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
    ({xs, ys}) => {
      const values = [
        normalize(xs.vx0, VX0_MIN, VX0_MAX),
        normalize(xs.vy0, VY0_MIN, VY0_MAX),
        normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
        normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
        normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
        xs.left_handed_pitcher
      ];
      return {xs: values, ys: ys.pitch_code};
    }

const trainingData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .shuffle(TRAINING_DATA_LENGTH)
        .batch(100);

// Load all training data in one batch to use for evaluation
const trainingValidationData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TRAINING_DATA_LENGTH);

// Load all test data in one batch to use for evaluation
const testValidationData =
    tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TEST_DATA_LENGTH);

5. 提案の種類を分類するモデルを作成する

これで、モデルを構築する準備が整いました。tf.layers API を使用して、入力([8] のピッチ センサー値の形状)を ReLU 活性化ユニットからなる 3 つの隠れ全接続層と、それぞれが出力ピッチタイプの 1 つを表す 7 つのユニットからなる 1 つのソフトマックス出力層に接続します。

adam オプティマイザーと sparseCategoricalCrossentropy 損失関数を使用してモデルをトレーニングする。これらの選択肢の詳細については、トレーニング モデルのトレーニング ガイドをご覧ください。

pick_type.js の末尾に次のコードを追加します。

const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));

model.compile({
  optimizer: tf.train.adam(),
  loss: 'sparseCategoricalCrossentropy',
  metrics: ['accuracy']
});

後で作成するメインサーバーのコードからトレーニングをトリガーします。

pick_type.js モジュールを完成させるために、検証用データセットとテスト用データセットを評価し、1 つのサンプルのピッチタイプを予測して、精度の指標を計算する関数を記述しましょう。以下のコードを pick_type.js の末尾に追加します。

// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
  let results = {};
  await trainingValidationData.forEachAsync(pitchTypeBatch => {
    const values = model.predict(pitchTypeBatch.xs).dataSync();
    const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
    for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
      results[pitchFromClassNum(i)] = {
        training: calcPitchClassEval(i, classSize, values)
      };
    }
  });

  if (useTestData) {
    await testValidationData.forEachAsync(pitchTypeBatch => {
      const values = model.predict(pitchTypeBatch.xs).dataSync();
      const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
      for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
        results[pitchFromClassNum(i)].validation =
            calcPitchClassEval(i, classSize, values);
      }
    });
  }
  return results;
}

async function predictSample(sample) {
  let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
  var maxValue = 0;
  var predictedPitch = 7;
  for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
    if (result[0][i] > maxValue) {
      predictedPitch = i;
      maxValue = result[0][i];
    }
  }
  return pitchFromClassNum(predictedPitch);
}

// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
  // Output has 7 different class values for each pitch, offset based on
  // which pitch class (ordered by i)
  let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
  let total = 0;
  for (let i = 0; i < classSize; i++) {
    total += values[index];
    index += NUM_PITCH_CLASSES;
  }
  return total / classSize;
}

// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
  switch (classNum) {
    case 0:
      return 'Fastball (2-seam)';
    case 1:
      return 'Fastball (4-seam)';
    case 2:
      return 'Fastball (sinker)';
    case 3:
      return 'Fastball (cutter)';
    case 4:
      return 'Slider';
    case 5:
      return 'Changeup';
    case 6:
      return 'Curveball';
    default:
      return 'Unknown';
  }
}

module.exports = {
  evaluate,
  model,
  pitchFromClassNum,
  predictSample,
  testValidationData,
  trainingData,
  TEST_DATA_LENGTH
}

6. サーバーでモデルをトレーニングする

server.js という新しいファイルに、モデルのトレーニングと評価を実行するサーバーコードを記述します。まず、HTTP サーバーを作成し、socket.io API を使用して双方向のソケット接続を開きます。次に、model.fitDataset API を使用してモデルのトレーニングを実行し、先ほど作成した pitch_type.evaluate() メソッドを使用してモデルの精度を評価します。10 回のイテレーションでトレーニングと評価を行い、指標をコンソールに出力する。

以下のコードを server.js にコピーします。

require('@tensorflow/tfjs-node');

const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');

const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;

// util function to sleep for a given ms
function sleep(ms) {
  return new Promise(resolve => setTimeout(resolve, ms));
}

// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
  const port = process.env.PORT || PORT;
  const server = http.createServer();
  const io = socketio(server);

  server.listen(port, () => {
    console.log(`  > Running socket on port: ${port}`);
  });

  io.on('connection', (socket) => {
    socket.on('predictSample', async (sample) => {
      io.emit('predictResult', await pitch_type.predictSample(sample));
    });
  });

  let numTrainingIterations = 10;
  for (var i = 0; i < numTrainingIterations; i++) {
    console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
    await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
    console.log('accuracyPerClass', await pitch_type.evaluate(true));
    await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
  }

  io.emit('trainingComplete', true);
}

run();

これで、サーバーを実行してテストする準備が整いました。次に示すように、サーバーが反復処理ごとに 1 エポックをトレーニングすることがわかります(model.fitDataset API を使用して、1 回の呼び出しで複数のエポックをトレーニングすることもできます)。この時点でエラーが発生した場合は、ノードと npm のインストールを確認してください。

$ npm run start-server
...
  > Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49

Ctrl+C キーを押して、実行中のサーバーを停止します。次のステップで再度実行します。

7. クライアント ページと表示コードを作成する

サーバーの準備が完了したら、次はクライアント コードを記述します。コードはブラウザで実行されます。サーバーでモデル予測を呼び出して結果を表示するシンプルなページを作成します。クライアント/サーバー通信に socket.io を使用します。

まず、baseball/ フォルダに index.html を作成します。

<!doctype html>
<html>
  <head>
    <title>Pitch Training Accuracy</title>
  </head>
  <body>
    <h3 id="waiting-msg">Waiting for server...</h3>
    <p>
    <span style="font-size:16px" id="trainingStatus"></span>
    <p>
    <div id="predictContainer" style="font-size:16px;display:none">
      Sensor data: <span id="predictSample"></span>
      <button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
      Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
    </div>
    <script src="dist/bundle.js"></script>
    <style>
      html,
      body {
        font-family: Roboto, sans-serif;
        color: #5f6368;
      }
      body {
        background-color: rgb(248, 249, 250);
      }
    </style>
  </body>
</html>

次に、baseball/ フォルダに、以下のコードを使用して client.js という新しいファイルを作成します。

import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');

const socket =
    io('http://localhost:8001',
       {reconnectionDelay: 300, reconnectionDelayMax: 300});

const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball

predictButton.onclick = () => {
  predictButton.disabled = true;
  socket.emit('predictSample', testSample);
};

// functions to handle socket events
socket.on('connect', () => {
    document.getElementById('waiting-msg').style.display = 'none';
    document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});

socket.on('trainingComplete', () => {
  document.getElementById('trainingStatus').innerHTML = 'Training Complete';
  document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
  predictContainer.style.display = 'block';
});

socket.on('predictResult', (result) => {
  plotPredictResult(result);
});

socket.on('disconnect', () => {
  document.getElementById('trainingStatus').innerHTML = '';
  predictContainer.style.display = 'none';
  document.getElementById('waiting-msg').style.display = 'block';
});

function plotPredictResult(result) {
  predictButton.disabled = false;
  document.getElementById('predictResult').innerHTML = result;
  console.log(result);
}

クライアントが trainingComplete ソケット メッセージを処理して予測ボタンを表示します。このボタンをクリックすると、クライアントはサンプル センサーデータを含むソケット メッセージを送信します。predictResult メッセージを受信すると、ページに検索候補が表示されます。

8. アプリを実行する

サーバーとクライアントの両方を実行して、アプリの動作全体を確認します。

[In one terminal, run this first]
$ npm run start-client

[In another terminal, run this next]
$ npm run start-server

ブラウザでクライアント ページ(http://localhost:8080)を開きます。モデルのトレーニングが完了したら、[予測サンプル] ボタンをクリックします。ブラウザに予測結果が表示されます。サンプルのセンサーデータを自由に変更して、テスト用 CSV ファイルのサンプルをいくつか加え、モデルがどの程度正確に予測するかを確認してください。

9. 学習した内容

この Codelab では、TensorFlow.js を使用して簡単な ML ウェブ アプリケーションを実装しました。ここでは、センサーデータから野球の投球タイプを分類するカスタムモデルをトレーニングしました。サーバー上でトレーニングを実行し、クライアントから送信されたデータを使用してトレーニング済みモデルの推論を呼び出す Node.js コードを記述しました。

tensorflow.org/js にアクセスして、コードで TensorFlow.js を使用する方法を示すその他の例やデモをご覧ください。