אימון TensorFlow.js ב-Codelab של Node.js

1. מבוא

ב-Codelab הזה תלמדו איך ליצור שרת אינטרנט של Node.js כדי לאמן ולסווג סוגים של זריקות בייסבול בצד השרת באמצעות TensorFlow.js, ספרייה חזקה וגמישה של למידת מכונה ל-JavaScript. תבנו אפליקציית אינטרנט לאימון מודל לחיזוי סוג הזריקה מנתוני חיישנים של הזריקה, ולהפעלת החיזוי מלקוח אינטרנט. גרסה תקינה לחלוטין של ה-Codelab הזה נמצאת במאגר tfjs-examples ב-GitHub.

מה תלמדו

  • איך להתקין ולהגדיר את חבילת ה-npm של tensorflow.js לשימוש עם Node.js.
  • איך ניגשים לנתוני אימון ולנתוני בדיקה בסביבת Node.js.
  • איך מאמנים מודל באמצעות TensorFlow.js בשרת Node.js.
  • איך פורסים את המודל המאומן להסקת מסקנות באפליקציית לקוח/שרת.

אז בואו נתחיל!

2. דרישות

כדי להשלים את ה-Codelab הזה, תצטרכו:

  1. גרסה עדכנית של Chrome או דפדפן מודרני אחר.
  2. עורך טקסט ומסוף פקודות שפועלים באופן מקומי במחשב.
  3. ידע ב-HTML,‏ CSS,‏ JavaScript ובכלי הפיתוח ל-Chrome (או בכלי הפיתוח של הדפדפן המועדף).
  4. הבנה מושגית ברמה גבוהה של רשתות נוירונים. אם אתם צריכים מבוא או רענון, כדאי לצפות בסרטון הזה של 3blue1brown או בסרטון הזה על למידה עמוקה ב-JavaScript של Ashi Krishnan.

3. הגדרת אפליקציית Node.js

מתקינים את Node.js ואת npm. במדריך ההתקנה של tfjs-node מפורטות הפלטפורמות הנתמכות ויחסי התלות.

יוצרים ספרייה בשם ‎ ./baseball לאפליקציית Node.js. מעתיקים את הקבצים המקושרים package.json ו-webpack.config.js לספרייה הזו כדי להגדיר את יחסי התלות של חבילת npm (כולל חבילת npm‏ ‎@tensorflow/tfjs-node). לאחר מכן מריצים את הפקודה npm install כדי להתקין את התלויות.

$ cd baseball
$ ls
package.json  webpack.config.js
$ npm install
...
$ ls
node_modules  package.json  package-lock.json  webpack.config.js

עכשיו אתם מוכנים לכתוב קוד ולאמן מודל.

4. הגדרת נתוני האימון והבדיקה

תשתמשו בנתוני האימון והבדיקה כקובצי CSV מהקישורים שלמטה. הורידו את הנתונים בקבצים האלה ובדקו אותם:

pitch_type_training_data.csv

pitch_type_test_data.csv

הנה דוגמה לנתוני אימון:

vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0

יש שמונה תכונות קלט – שמתארות נתוני חיישנים של גובה הצליל:

  • מהירות הכדור (vx0, ‏ vy0, ‏ vz0)
  • האצת הכדור (ax, ‏ ay, ‏ az)
  • מהירות התחלתית של המגרש
  • אם הפיצ'ר הוא שמאלי או לא

ותווית פלט אחת:

  • ‫pitch_code שמציין אחד משבעה סוגי זריקות: Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball

המטרה היא לבנות מודל שיכול לחזות את סוג הזריקה בהינתן נתוני חיישן הזריקה.

לפני שיוצרים את המודל, צריך להכין את נתוני האימון והבדיקה. יוצרים את הקובץ pitch_type.js בספרייה baseball/ ומעתיקים אליו את הקוד הבא. הקוד הזה טוען נתוני אימון ובדיקה באמצעות ה-API‏ tf.data.csv. בנוסף, הוא מבצע נרמול של הנתונים (מומלץ תמיד) באמצעות סולם נרמול מינימום-מקסימום.

const tf = require('@tensorflow/tfjs');

// util function to normalize a value between a given range.
function normalize(value, min, max) {
  if (min === undefined || max === undefined) {
    return value;
  }
  return (value - min) / (max - min);
}

// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH =    'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';

// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;

const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;

// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
    ({xs, ys}) => {
      const values = [
        normalize(xs.vx0, VX0_MIN, VX0_MAX),
        normalize(xs.vy0, VY0_MIN, VY0_MAX),
        normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
        normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
        normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
        xs.left_handed_pitcher
      ];
      return {xs: values, ys: ys.pitch_code};
    }

const trainingData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .shuffle(TRAINING_DATA_LENGTH)
        .batch(100);

// Load all training data in one batch to use for evaluation
const trainingValidationData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TRAINING_DATA_LENGTH);

// Load all test data in one batch to use for evaluation
const testValidationData =
    tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TEST_DATA_LENGTH);

5. יצירת מודל לסיווג סוגי זריקות

עכשיו אתם מוכנים לבנות את המודל. משתמשים ב-API‏ tf.layers כדי לחבר את נתוני הקלט (צורה של [8] ערכי חיישן גובה הצליל) ל-3 שכבות מוסתרות שמחוברות באופן מלא ומורכבות מיחידות הפעלה של ReLU, ואחריהן שכבת פלט אחת של softmax שמורכבת מ-7 יחידות, שכל אחת מהן מייצגת אחד מסוגי גובה הצליל של הפלט.

מאמנים את המודל באמצעות אופטימיזציית אדם ופונקציית ההפסד sparseCategoricalCrossentropy. מידע נוסף על האפשרויות האלה זמין במדריך לאימון מודלים.

מוסיפים את הקוד הבא לסוף הקובץ pitch_type.js:

const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));

model.compile({
  optimizer: tf.train.adam(),
  loss: 'sparseCategoricalCrossentropy',
  metrics: ['accuracy']
});

מפעילים את האימון מקוד השרת הראשי שתכתבו בהמשך.

כדי להשלים את המודול pitch_type.js, נכתוב פונקציה להערכת האימות ולבדיקת מערך הנתונים, לחיזוי סוג המגרש עבור דגימה אחת ולחישוב מדדי הדיוק. מוסיפים את הקוד הזה לסוף הקובץ pitch_type.js:

// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
  let results = {};
  await trainingValidationData.forEachAsync(pitchTypeBatch => {
    const values = model.predict(pitchTypeBatch.xs).dataSync();
    const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
    for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
      results[pitchFromClassNum(i)] = {
        training: calcPitchClassEval(i, classSize, values)
      };
    }
  });

  if (useTestData) {
    await testValidationData.forEachAsync(pitchTypeBatch => {
      const values = model.predict(pitchTypeBatch.xs).dataSync();
      const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
      for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
        results[pitchFromClassNum(i)].validation =
            calcPitchClassEval(i, classSize, values);
      }
    });
  }
  return results;
}

async function predictSample(sample) {
  let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
  var maxValue = 0;
  var predictedPitch = 7;
  for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
    if (result[0][i] > maxValue) {
      predictedPitch = i;
      maxValue = result[0][i];
    }
  }
  return pitchFromClassNum(predictedPitch);
}

// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
  // Output has 7 different class values for each pitch, offset based on
  // which pitch class (ordered by i)
  let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
  let total = 0;
  for (let i = 0; i < classSize; i++) {
    total += values[index];
    index += NUM_PITCH_CLASSES;
  }
  return total / classSize;
}

// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
  switch (classNum) {
    case 0:
      return 'Fastball (2-seam)';
    case 1:
      return 'Fastball (4-seam)';
    case 2:
      return 'Fastball (sinker)';
    case 3:
      return 'Fastball (cutter)';
    case 4:
      return 'Slider';
    case 5:
      return 'Changeup';
    case 6:
      return 'Curveball';
    default:
      return 'Unknown';
  }
}

module.exports = {
  evaluate,
  model,
  pitchFromClassNum,
  predictSample,
  testValidationData,
  trainingData,
  TEST_DATA_LENGTH
}

6. אימון המודל בשרת

כותבים את קוד השרת כדי לבצע אימון והערכה של המודל בקובץ חדש בשם server.js. קודם יוצרים שרת HTTP ופותחים חיבור דו-כיווני של שקע באמצעות socket.io API. לאחר מכן מריצים את אימון המודל באמצעות ה-API‏ model.fitDataset, ומעריכים את דיוק המודל באמצעות השיטה pitch_type.evaluate() שכתבתם קודם. מאמנים ומעריכים במשך 10 איטרציות, ומדפיסים את המדדים במסוף.

מעתיקים את הקוד שלמטה אל server.js:

require('@tensorflow/tfjs-node');

const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');

const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;

// util function to sleep for a given ms
function sleep(ms) {
  return new Promise(resolve => setTimeout(resolve, ms));
}

// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
  const port = process.env.PORT || PORT;
  const server = http.createServer();
  const io = socketio(server);

  server.listen(port, () => {
    console.log(`  > Running socket on port: ${port}`);
  });

  io.on('connection', (socket) => {
    socket.on('predictSample', async (sample) => {
      io.emit('predictResult', await pitch_type.predictSample(sample));
    });
  });

  let numTrainingIterations = 10;
  for (var i = 0; i < numTrainingIterations; i++) {
    console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
    await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
    console.log('accuracyPerClass', await pitch_type.evaluate(true));
    await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
  }

  io.emit('trainingComplete', true);
}

run();

בשלב הזה, השרת מוכן להרצה ולבדיקה. אמור להופיע פלט דומה לזה שמוצג כאן, עם אימון השרת באפוקה אחת בכל איטרציה (אפשר גם להשתמש ב-API‏ model.fitDataset כדי לאמן כמה אפוקות בקריאה אחת). אם נתקלתם בשגיאות בשלב הזה, בדקו את הצומת ואת התקנת npm.

$ npm run start-server
...
  > Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49

מקלידים Ctrl-C כדי לעצור את השרת הפועל. נריץ אותו שוב בשלב הבא.

7. יצירת דף לקוח והצגת הקוד

עכשיו, כשהשרת מוכן, השלב הבא הוא לכתוב את קוד הלקוח שפועל בדפדפן. יוצרים דף פשוט כדי להפעיל חיזוי של המודל בשרת ולהציג את התוצאה. התקשורת בין הלקוח לשרת מתבצעת באמצעות socket.io.

קודם כל, יוצרים את הקובץ index.html בתיקייה baseball/:

<!doctype html>
<html>
  <head>
    <title>Pitch Training Accuracy</title>
  </head>
  <body>
    <h3 id="waiting-msg">Waiting for server...</h3>
    <p>
    <span style="font-size:16px" id="trainingStatus"></span>
    <p>
    <div id="predictContainer" style="font-size:16px;display:none">
      Sensor data: <span id="predictSample"></span>
      <button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
      Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
    </div>
    <script src="dist/bundle.js"></script>
    <style>
      html,
      body {
        font-family: Roboto, sans-serif;
        color: #5f6368;
      }
      body {
        background-color: rgb(248, 249, 250);
      }
    </style>
  </body>
</html>

לאחר מכן יוצרים קובץ חדש בשם client.js בתיקייה baseball/ עם הקוד הבא:

import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');

const socket =
    io('http://localhost:8001',
       {reconnectionDelay: 300, reconnectionDelayMax: 300});

const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball

predictButton.onclick = () => {
  predictButton.disabled = true;
  socket.emit('predictSample', testSample);
};

// functions to handle socket events
socket.on('connect', () => {
    document.getElementById('waiting-msg').style.display = 'none';
    document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});

socket.on('trainingComplete', () => {
  document.getElementById('trainingStatus').innerHTML = 'Training Complete';
  document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
  predictContainer.style.display = 'block';
});

socket.on('predictResult', (result) => {
  plotPredictResult(result);
});

socket.on('disconnect', () => {
  document.getElementById('trainingStatus').innerHTML = '';
  predictContainer.style.display = 'none';
  document.getElementById('waiting-msg').style.display = 'block';
});

function plotPredictResult(result) {
  predictButton.disabled = false;
  document.getElementById('predictResult').innerHTML = result;
  console.log(result);
}

הלקוח מטפל בהודעת ה-socket‏ trainingComplete כדי להציג לחצן של תחזית. כשלוחצים על הלחצן הזה, הלקוח שולח הודעת Socket עם נתוני חיישן לדוגמה. כשמקבלים הודעה predictResult, התחזית מוצגת בדף.

8. הפעלת האפליקציה

מריצים גם את השרת וגם את הלקוח כדי לראות את האפליקציה המלאה בפעולה:

[In one terminal, run this first]
$ npm run start-client

[In another terminal, run this next]
$ npm run start-server

פותחים את דף הלקוח בדפדפן ( http://localhost:8080). כשאימון המודל מסתיים, לוחצים על הלחצן חיזוי של דוגמה. תוצאת החיזוי תוצג בדפדפן. אתם יכולים לשנות את נתוני החיישן לדוגמה עם כמה דוגמאות מקובץ ה-CSV של הבדיקה ולראות עד כמה המודל מדויק בתחזיות שלו.

9. מה למדתם

ב-Codelab הזה, הטמעתם אפליקציית אינטרנט פשוטה ללמידת מכונה באמצעות TensorFlow.js. אימנתם מודל בהתאמה אישית לסיווג סוגי זריקות בבייסבול מנתוני חיישנים. כתבת קוד Node.js כדי להפעיל אימון בשרת, וקראת למסקנה במודל המאומן באמצעות נתונים שנשלחו מהלקוח.

כדי לראות איך אפשר להשתמש ב-TensorFlow.js באפליקציות, מומלץ לעיין בדוגמאות נוספות ובסרטוני הדגמה עם קוד בכתובת tensorflow.org/js.