Node.js での TensorFlow.js トレーニングの Codelab

1. はじめに

この Codelab では、強力かつ柔軟な JavaScript 用機械学習ライブラリである TensorFlow.js を使用して、サーバーサイドで野球の投球の種類をトレーニングして分類する Node.js ウェブサーバーを構築する方法を学びます。ウェブ アプリケーションを構築して、投球センサーのデータから投球の種類を予測するモデルをトレーニングし、ウェブ クライアントから予測を呼び出します。この Codelab の完全な動作バージョンは、tfjs-examples GitHub リポジトリにあります。

学習内容

  • Node.js で使用する tensorflow.js npm パッケージをインストールして設定する方法。
  • Node.js 環境でトレーニング データとテストデータにアクセスする方法。
  • Node.js サーバーで TensorFlow.js を使用してモデルをトレーニングする方法。
  • クライアント/サーバー アプリケーションで推論用にトレーニング済みモデルをデプロイする方法。

では始めましょう。

2. 要件

この Codelab を完了するには、次のものが必要です。

  1. Chrome の最新バージョンか、その他の最新のブラウザ。
  2. お使いのマシンでローカルに実行されるテキスト エディタとコマンド ターミナル。
  3. HTML、CSS、JavaScript、Chrome DevTools(または、使い慣れたブラウザ DevTools)に関する知識。
  4. ニューラル ネットワークのコンセプトに関する深い理解。概要の説明や復習が必要な場合は、3blue1brown の動画や、Ashi Krishnan による JavaScript のディープ ラーニングに関する動画をご覧ください。

3. Node.js アプリを設定する

Node.js と npm をインストールします。サポートされているプラットフォームと依存関係については、tfjs-node インストール ガイドをご覧ください。

Node.js アプリ用に ./baseball というディレクトリを作成します。リンクされた package.jsonwebpack.config.js をこのディレクトリにコピーして、npm パッケージの依存関係(@tensorflow/tfjs-node npm パッケージを含む)を構成します。次に、npm install を実行して依存関係をインストールします。

$ cd baseball
$ ls
package.json  webpack.config.js
$ npm install
...
$ ls
node_modules  package.json  package-lock.json  webpack.config.js

これで、コードを記述してモデルをトレーニングする準備が整いました。

4. トレーニング データとテストデータを設定する

トレーニング データとテストデータは、次のリンクから CSV ファイルとして使用します。これらのファイルでデータをダウンロードして探索します。

pitch_type_training_data.csv

pitch_type_test_data.csv

トレーニング データの例を見てみましょう。

vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0

ピッチ センサーデータを記述する 8 つの入力特徴があります。

  • ボールの速度(vx0、vy0、vz0)
  • ボールの加速度(ax、ay、az)
  • ピッチの開始速度
  • ピッチャーが左利きかどうか

出力ラベルは 1 つです。

  • 7 種類のピッチタイプのいずれかを示す pitch_code: Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball

目標は、投球センサーのデータから投球の種類を予測できるモデルを構築することです。

モデルを作成する前に、トレーニング データとテストデータを準備する必要があります。baseball/ ディレクトリに pitch_type.js ファイルを作成し、次のコードをコピーします。このコードは、tf.data.csv API を使用してトレーニング データとテストデータを読み込みます。また、Min-Max 正規化スケールを使用してデータを正規化します(これは常に推奨されます)。

const tf = require('@tensorflow/tfjs');

// util function to normalize a value between a given range.
function normalize(value, min, max) {
  if (min === undefined || max === undefined) {
    return value;
  }
  return (value - min) / (max - min);
}

// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH =    'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';

// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;

const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;

// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
    ({xs, ys}) => {
      const values = [
        normalize(xs.vx0, VX0_MIN, VX0_MAX),
        normalize(xs.vy0, VY0_MIN, VY0_MAX),
        normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
        normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
        normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
        xs.left_handed_pitcher
      ];
      return {xs: values, ys: ys.pitch_code};
    }

const trainingData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .shuffle(TRAINING_DATA_LENGTH)
        .batch(100);

// Load all training data in one batch to use for evaluation
const trainingValidationData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TRAINING_DATA_LENGTH);

// Load all test data in one batch to use for evaluation
const testValidationData =
    tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TEST_DATA_LENGTH);

5. 球種を分類するモデルを作成する

これで、モデルをビルドする準備が整いました。tf.layers API を使用して、入力([8] ピッチ センサー値の形状)を ReLU アクティベーション ユニットで構成される 3 つの隠し全結合レイヤに接続し、その後に 7 つのユニットで構成される softmax 出力レイヤを接続します。各ユニットは出力ピッチタイプの 1 つを表します。

adam オプティマイザーと sparseCategoricalCrossentropy 損失関数を使用してモデルをトレーニングします。これらの選択肢の詳細については、モデルのトレーニング ガイドをご覧ください。

pitch_type.js の末尾に次のコードを追加します。

const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));

model.compile({
  optimizer: tf.train.adam(),
  loss: 'sparseCategoricalCrossentropy',
  metrics: ['accuracy']
});

後で作成するメイン サーバーコードからトレーニングをトリガーします。

pitch_type.js モジュールを完成させるために、検証データセットとテスト データセットを評価し、単一のサンプルに対して球種を予測し、精度指標を計算する関数を作成しましょう。このコードを pitch_type.js の末尾に追加します。

// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
  let results = {};
  await trainingValidationData.forEachAsync(pitchTypeBatch => {
    const values = model.predict(pitchTypeBatch.xs).dataSync();
    const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
    for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
      results[pitchFromClassNum(i)] = {
        training: calcPitchClassEval(i, classSize, values)
      };
    }
  });

  if (useTestData) {
    await testValidationData.forEachAsync(pitchTypeBatch => {
      const values = model.predict(pitchTypeBatch.xs).dataSync();
      const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
      for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
        results[pitchFromClassNum(i)].validation =
            calcPitchClassEval(i, classSize, values);
      }
    });
  }
  return results;
}

async function predictSample(sample) {
  let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
  var maxValue = 0;
  var predictedPitch = 7;
  for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
    if (result[0][i] > maxValue) {
      predictedPitch = i;
      maxValue = result[0][i];
    }
  }
  return pitchFromClassNum(predictedPitch);
}

// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
  // Output has 7 different class values for each pitch, offset based on
  // which pitch class (ordered by i)
  let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
  let total = 0;
  for (let i = 0; i < classSize; i++) {
    total += values[index];
    index += NUM_PITCH_CLASSES;
  }
  return total / classSize;
}

// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
  switch (classNum) {
    case 0:
      return 'Fastball (2-seam)';
    case 1:
      return 'Fastball (4-seam)';
    case 2:
      return 'Fastball (sinker)';
    case 3:
      return 'Fastball (cutter)';
    case 4:
      return 'Slider';
    case 5:
      return 'Changeup';
    case 6:
      return 'Curveball';
    default:
      return 'Unknown';
  }
}

module.exports = {
  evaluate,
  model,
  pitchFromClassNum,
  predictSample,
  testValidationData,
  trainingData,
  TEST_DATA_LENGTH
}

6. サーバーでモデルをトレーニングする

server.js という新しいファイルに、モデルのトレーニングと評価を行うサーバーコードを記述します。まず、HTTP サーバーを作成し、socket.io API を使用して双方向ソケット接続を開きます。次に、model.fitDataset API を使用してモデルのトレーニングを実行し、前に作成した pitch_type.evaluate() メソッドを使用してモデルの精度を評価します。10 回の反復でトレーニングと評価を行い、指標をコンソールに出力します。

次のコードを server.js にコピーします。

require('@tensorflow/tfjs-node');

const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');

const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;

// util function to sleep for a given ms
function sleep(ms) {
  return new Promise(resolve => setTimeout(resolve, ms));
}

// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
  const port = process.env.PORT || PORT;
  const server = http.createServer();
  const io = socketio(server);

  server.listen(port, () => {
    console.log(`  > Running socket on port: ${port}`);
  });

  io.on('connection', (socket) => {
    socket.on('predictSample', async (sample) => {
      io.emit('predictResult', await pitch_type.predictSample(sample));
    });
  });

  let numTrainingIterations = 10;
  for (var i = 0; i < numTrainingIterations; i++) {
    console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
    await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
    console.log('accuracyPerClass', await pitch_type.evaluate(true));
    await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
  }

  io.emit('trainingComplete', true);
}

run();

これで、サーバーを実行してテストする準備が整いました。次のような出力が表示されます。各イテレーションでサーバーが 1 エポックをトレーニングします(model.fitDataset API を使用して、1 回の呼び出しで複数のエポックをトレーニングすることもできます)。この時点でエラーが発生した場合は、Node と npm のインストールを確認してください。

$ npm run start-server
...
  > Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49

Ctrl+C キーを押して、実行中のサーバーを停止します。次のステップで再度実行します。

7. クライアント ページを作成してコードを表示する

サーバーの準備ができたので、次はブラウザで実行されるクライアント コードを記述します。サーバーでモデル予測を呼び出して結果を表示する簡単なページを作成します。クライアント/サーバー間の通信には socket.io が使用されます。

まず、baseball/ フォルダに index.html を作成します。

<!doctype html>
<html>
  <head>
    <title>Pitch Training Accuracy</title>
  </head>
  <body>
    <h3 id="waiting-msg">Waiting for server...</h3>
    <p>
    <span style="font-size:16px" id="trainingStatus"></span>
    <p>
    <div id="predictContainer" style="font-size:16px;display:none">
      Sensor data: <span id="predictSample"></span>
      <button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
      Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
    </div>
    <script src="dist/bundle.js"></script>
    <style>
      html,
      body {
        font-family: Roboto, sans-serif;
        color: #5f6368;
      }
      body {
        background-color: rgb(248, 249, 250);
      }
    </style>
  </body>
</html>

次に、baseball/ フォルダに次のコードを含む client.js という新しいファイルを作成します。

import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');

const socket =
    io('http://localhost:8001',
       {reconnectionDelay: 300, reconnectionDelayMax: 300});

const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball

predictButton.onclick = () => {
  predictButton.disabled = true;
  socket.emit('predictSample', testSample);
};

// functions to handle socket events
socket.on('connect', () => {
    document.getElementById('waiting-msg').style.display = 'none';
    document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});

socket.on('trainingComplete', () => {
  document.getElementById('trainingStatus').innerHTML = 'Training Complete';
  document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
  predictContainer.style.display = 'block';
});

socket.on('predictResult', (result) => {
  plotPredictResult(result);
});

socket.on('disconnect', () => {
  document.getElementById('trainingStatus').innerHTML = '';
  predictContainer.style.display = 'none';
  document.getElementById('waiting-msg').style.display = 'block';
});

function plotPredictResult(result) {
  predictButton.disabled = false;
  document.getElementById('predictResult').innerHTML = result;
  console.log(result);
}

クライアントは trainingComplete ソケット メッセージを処理して、予測ボタンを表示します。このボタンをクリックすると、クライアントはサンプル センサーデータを含むソケット メッセージを送信します。predictResult メッセージを受信すると、ページに予測が表示されます。

8. アプリを実行する

サーバーとクライアントの両方を実行して、アプリの動作を確認します。

[In one terminal, run this first]
$ npm run start-client

[In another terminal, run this next]
$ npm run start-server

ブラウザでクライアント ページ(http://localhost:8080)を開きます。モデルのトレーニングが終了したら、[サンプルを予測] ボタンをクリックします。ブラウザに予測結果が表示されます。テスト CSV ファイルの例を使用してサンプル センサーデータを変更し、モデルの予測精度を確認してください。

9. 学習した内容

この Codelab では、TensorFlow.js を使用して簡単な機械学習ウェブ アプリケーションを実装しました。センサーデータから野球の投球の種類を分類するカスタムモデルをトレーニングしました。Node.js コードを記述して、サーバーでトレーニングを実行し、クライアントから送信されたデータを使用してトレーニング済みモデルで推論を呼び出しました。

TensorFlow.js のコードサンプルやデモの詳細については、tensorflow.org/js をご覧ください。アプリケーションでの TensorFlow.js の使用方法をご確認いただけます。