1. はじめに
この Codelab では、JavaScript 用の強力で柔軟な ML ライブラリである TensorFlow.js を使用して、サーバーサイドで野球の投球タイプのトレーニングと分類を行う Node.js ウェブサーバーを構築する方法を学びます。ウェブ アプリケーションを構築して、ピッチセンサーのデータからピッチのタイプを予測するモデルをトレーニングし、ウェブ クライアントから予測を呼び出すようにします。この Codelab の完全に機能するバージョンは、tfjs-examples GitHub リポジトリにあります。
学習内容
- Node.js で使用する tensorflow.js npm パッケージをインストールして設定する方法。
- Node.js 環境でトレーニング データとテストデータにアクセスする方法
- Node.js サーバーで TensorFlow.js を使用してモデルをトレーニングする方法。
- クライアント/サーバー アプリケーションで推論用のトレーニング済みモデルをデプロイする方法。
では始めましょう。
2. 要件
この Codelab を完了するには、以下が必要です。
- Chrome の最新バージョンまたは他の最新のブラウザ。
- マシンのローカルで実行されるテキスト エディタとコマンド ターミナル。
- HTML、CSS、JavaScript、Chrome DevTools(または推奨ブラウザ開発ツール)に関する知識。
- ニューラル ネットワークの概要レベルの概念。概要の説明や復習が必要な場合は、3blue1brown によるこちらの動画または Ashi Krishnan による JavaScript のディープ ラーニングに関する動画をご覧ください。
3. Node.js アプリを設定する
Node.js と npm をインストールします。サポートされているプラットフォームと依存関係については、tfjs-node インストール ガイドをご覧ください。
Node.js アプリ用に ./baseball という名前のディレクトリを作成します。リンクされている package.json と webpack.config.js をこのディレクトリにコピーして、npm パッケージの依存関係(@tensorflow/tfjs-node npm パッケージを含む)を構成します。次に、npm install を実行して依存関係をインストールします。
$ cd baseball
$ ls
package.json webpack.config.js
$ npm install
...
$ ls
node_modules package.json package-lock.json webpack.config.js
これで、コードを記述してモデルをトレーニングする準備が整いました。
4. トレーニング データとテストデータを設定する
以下のリンクから、トレーニング データとテストデータを CSV ファイルとして使用します。次のファイルのデータをダウンロードして探索します。
サンプルのトレーニング データを見てみましょう。
vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0
ピッチセンサーのデータを説明する 8 つの入力特徴があります。
- ボール速度(vx0、vy0、vz0)
- ボールの加速度(ax、ay、az)
- ピッチの開始速度
- 投手が左利きかどうか
1 つの出力ラベル:
- 7 つの提案タイプのいずれかを表す pitch_code:
Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball
目標は、指定されたピッチセンサーデータからピッチタイプを予測できるモデルを構築することです。
モデルを作成する前に、トレーニング データとテストデータを準備する必要があります。baseball/ dir に pick_type.js ファイルを作成し、次のコードをコピーします。このコードは、tf.data.csv API を使用してトレーニング データとテストデータを読み込みます。また、最小 / 最大正規化スケールを使用してデータを正規化します(これは常に推奨です)。
const tf = require('@tensorflow/tfjs');
// util function to normalize a value between a given range.
function normalize(value, min, max) {
if (min === undefined || max === undefined) {
return value;
}
return (value - min) / (max - min);
}
// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH = 'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';
// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;
const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;
// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
({xs, ys}) => {
const values = [
normalize(xs.vx0, VX0_MIN, VX0_MAX),
normalize(xs.vy0, VY0_MIN, VY0_MAX),
normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
xs.left_handed_pitcher
];
return {xs: values, ys: ys.pitch_code};
}
const trainingData =
tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
.map(csvTransform)
.shuffle(TRAINING_DATA_LENGTH)
.batch(100);
// Load all training data in one batch to use for evaluation
const trainingValidationData =
tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
.map(csvTransform)
.batch(TRAINING_DATA_LENGTH);
// Load all test data in one batch to use for evaluation
const testValidationData =
tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
.map(csvTransform)
.batch(TEST_DATA_LENGTH);
5. 提案の種類を分類するモデルを作成する
これで、モデルを構築する準備が整いました。tf.layers API を使用して、入力([8] のピッチ センサー値の形状)を ReLU 活性化ユニットからなる 3 つの隠れ全接続層と、それぞれが出力ピッチタイプの 1 つを表す 7 つのユニットからなる 1 つのソフトマックス出力層に接続します。
adam オプティマイザーと sparseCategoricalCrossentropy 損失関数を使用してモデルをトレーニングする。これらの選択肢の詳細については、トレーニング モデルのトレーニング ガイドをご覧ください。
pick_type.js の末尾に次のコードを追加します。
const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));
model.compile({
optimizer: tf.train.adam(),
loss: 'sparseCategoricalCrossentropy',
metrics: ['accuracy']
});
後で作成するメインサーバーのコードからトレーニングをトリガーします。
pick_type.js モジュールを完成させるために、検証用データセットとテスト用データセットを評価し、1 つのサンプルのピッチタイプを予測して、精度の指標を計算する関数を記述しましょう。以下のコードを pick_type.js の末尾に追加します。
// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
let results = {};
await trainingValidationData.forEachAsync(pitchTypeBatch => {
const values = model.predict(pitchTypeBatch.xs).dataSync();
const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
results[pitchFromClassNum(i)] = {
training: calcPitchClassEval(i, classSize, values)
};
}
});
if (useTestData) {
await testValidationData.forEachAsync(pitchTypeBatch => {
const values = model.predict(pitchTypeBatch.xs).dataSync();
const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
results[pitchFromClassNum(i)].validation =
calcPitchClassEval(i, classSize, values);
}
});
}
return results;
}
async function predictSample(sample) {
let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
var maxValue = 0;
var predictedPitch = 7;
for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
if (result[0][i] > maxValue) {
predictedPitch = i;
maxValue = result[0][i];
}
}
return pitchFromClassNum(predictedPitch);
}
// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
// Output has 7 different class values for each pitch, offset based on
// which pitch class (ordered by i)
let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
let total = 0;
for (let i = 0; i < classSize; i++) {
total += values[index];
index += NUM_PITCH_CLASSES;
}
return total / classSize;
}
// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
switch (classNum) {
case 0:
return 'Fastball (2-seam)';
case 1:
return 'Fastball (4-seam)';
case 2:
return 'Fastball (sinker)';
case 3:
return 'Fastball (cutter)';
case 4:
return 'Slider';
case 5:
return 'Changeup';
case 6:
return 'Curveball';
default:
return 'Unknown';
}
}
module.exports = {
evaluate,
model,
pitchFromClassNum,
predictSample,
testValidationData,
trainingData,
TEST_DATA_LENGTH
}
6. サーバーでモデルをトレーニングする
server.js という新しいファイルに、モデルのトレーニングと評価を実行するサーバーコードを記述します。まず、HTTP サーバーを作成し、socket.io API を使用して双方向のソケット接続を開きます。次に、model.fitDataset API を使用してモデルのトレーニングを実行し、先ほど作成した pitch_type.evaluate()
メソッドを使用してモデルの精度を評価します。10 回のイテレーションでトレーニングと評価を行い、指標をコンソールに出力する。
以下のコードを server.js にコピーします。
require('@tensorflow/tfjs-node');
const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');
const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;
// util function to sleep for a given ms
function sleep(ms) {
return new Promise(resolve => setTimeout(resolve, ms));
}
// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
const port = process.env.PORT || PORT;
const server = http.createServer();
const io = socketio(server);
server.listen(port, () => {
console.log(` > Running socket on port: ${port}`);
});
io.on('connection', (socket) => {
socket.on('predictSample', async (sample) => {
io.emit('predictResult', await pitch_type.predictSample(sample));
});
});
let numTrainingIterations = 10;
for (var i = 0; i < numTrainingIterations; i++) {
console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
console.log('accuracyPerClass', await pitch_type.evaluate(true));
await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
}
io.emit('trainingComplete', true);
}
run();
これで、サーバーを実行してテストする準備が整いました。次に示すように、サーバーが反復処理ごとに 1 エポックをトレーニングすることがわかります(model.fitDataset
API を使用して、1 回の呼び出しで複数のエポックをトレーニングすることもできます)。この時点でエラーが発生した場合は、ノードと npm のインストールを確認してください。
$ npm run start-server
...
> Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49
Ctrl+C キーを押して、実行中のサーバーを停止します。次のステップで再度実行します。
7. クライアント ページと表示コードを作成する
サーバーの準備が完了したら、次はクライアント コードを記述します。コードはブラウザで実行されます。サーバーでモデル予測を呼び出して結果を表示するシンプルなページを作成します。クライアント/サーバー通信に socket.io を使用します。
まず、baseball/ フォルダに index.html を作成します。
<!doctype html>
<html>
<head>
<title>Pitch Training Accuracy</title>
</head>
<body>
<h3 id="waiting-msg">Waiting for server...</h3>
<p>
<span style="font-size:16px" id="trainingStatus"></span>
<p>
<div id="predictContainer" style="font-size:16px;display:none">
Sensor data: <span id="predictSample"></span>
<button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
</div>
<script src="dist/bundle.js"></script>
<style>
html,
body {
font-family: Roboto, sans-serif;
color: #5f6368;
}
body {
background-color: rgb(248, 249, 250);
}
</style>
</body>
</html>
次に、baseball/ フォルダに、以下のコードを使用して client.js という新しいファイルを作成します。
import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');
const socket =
io('http://localhost:8001',
{reconnectionDelay: 300, reconnectionDelayMax: 300});
const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball
predictButton.onclick = () => {
predictButton.disabled = true;
socket.emit('predictSample', testSample);
};
// functions to handle socket events
socket.on('connect', () => {
document.getElementById('waiting-msg').style.display = 'none';
document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});
socket.on('trainingComplete', () => {
document.getElementById('trainingStatus').innerHTML = 'Training Complete';
document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
predictContainer.style.display = 'block';
});
socket.on('predictResult', (result) => {
plotPredictResult(result);
});
socket.on('disconnect', () => {
document.getElementById('trainingStatus').innerHTML = '';
predictContainer.style.display = 'none';
document.getElementById('waiting-msg').style.display = 'block';
});
function plotPredictResult(result) {
predictButton.disabled = false;
document.getElementById('predictResult').innerHTML = result;
console.log(result);
}
クライアントが trainingComplete
ソケット メッセージを処理して予測ボタンを表示します。このボタンをクリックすると、クライアントはサンプル センサーデータを含むソケット メッセージを送信します。predictResult
メッセージを受信すると、ページに検索候補が表示されます。
8. アプリを実行する
サーバーとクライアントの両方を実行して、アプリの動作全体を確認します。
[In one terminal, run this first]
$ npm run start-client
[In another terminal, run this next]
$ npm run start-server
ブラウザでクライアント ページ(http://localhost:8080)を開きます。モデルのトレーニングが完了したら、[予測サンプル] ボタンをクリックします。ブラウザに予測結果が表示されます。サンプルのセンサーデータを自由に変更して、テスト用 CSV ファイルのサンプルをいくつか加え、モデルがどの程度正確に予測するかを確認してください。
9. 学習した内容
この Codelab では、TensorFlow.js を使用して簡単な ML ウェブ アプリケーションを実装しました。ここでは、センサーデータから野球の投球タイプを分類するカスタムモデルをトレーニングしました。サーバー上でトレーニングを実行し、クライアントから送信されたデータを使用してトレーニング済みモデルの推論を呼び出す Node.js コードを記述しました。
tensorflow.org/js にアクセスして、コードで TensorFlow.js を使用する方法を示すその他の例やデモをご覧ください。