การฝึก TensorFlow.js ใน Node.js Codelab

1. บทนำ

ใน Codelab นี้ คุณจะได้เรียนรู้วิธีสร้างเว็บเซิร์ฟเวอร์ Node.js เพื่อฝึกและจำแนกประเภทสนามเบสบอลฝั่งเซิร์ฟเวอร์โดยใช้ TensorFlow.js ซึ่งเป็นไลบรารีแมชชีนเลิร์นนิงที่มีประสิทธิภาพและยืดหยุ่นสำหรับ JavaScript คุณจะสร้างเว็บแอปพลิเคชันเพื่อฝึกโมเดลให้คาดการณ์ประเภทของระดับเสียงจากเซ็นเซอร์ระดับเสียง และเรียกใช้การคาดการณ์จากเว็บไคลเอ็นต์ Codelab เวอร์ชันที่ทำงานอย่างเต็มรูปแบบมีอยู่ในที่เก็บ tfjs-examples ของ GitHub

สิ่งที่คุณจะได้เรียนรู้

  • วิธีติดตั้งและตั้งค่าแพ็กเกจ tensorflow.js npm เพื่อใช้กับ Node.js
  • วิธีเข้าถึงข้อมูลการฝึกและการทดสอบในสภาพแวดล้อม Node.js
  • วิธีฝึกโมเดลด้วย TensorFlow.js ในเซิร์ฟเวอร์ Node.js
  • วิธีทำให้โมเดลที่ผ่านการฝึกสำหรับการอนุมานใช้งานได้ในแอปพลิเคชันไคลเอ็นต์/เซิร์ฟเวอร์

มาเริ่มกันเลย

2. ข้อกำหนด

ในการทำให้ Codelab นี้เสร็จสมบูรณ์ คุณจะต้องมีสิ่งต่อไปนี้

  1. Chrome เวอร์ชันล่าสุด หรือเบราว์เซอร์ที่ทันสมัยอื่นๆ
  2. เครื่องมือแก้ไขข้อความและเทอร์มินัลคำสั่งที่ทำงานในเครื่องของคุณ
  3. ความรู้เกี่ยวกับ HTML, CSS, JavaScript และเครื่องมือสำหรับนักพัฒนาเว็บใน Chrome (หรือเครื่องมือสำหรับนักพัฒนาเว็บในเบราว์เซอร์ที่คุณต้องการ)
  4. ความเข้าใจในแนวคิดระดับสูงเกี่ยวกับโครงข่ายประสาท หากต้องการข้อมูลเบื้องต้นหรือทบทวนความรู้ โปรดดูวิดีโอนี้โดย 3blue1brown หรือวิดีโอเกี่ยวกับการเรียนรู้เชิงลึกใน JavaScript โดย Ashi Krishnan

3. ตั้งค่าแอป Node.js

ติดตั้ง Node.js และ npm สำหรับแพลตฟอร์มและทรัพยากร Dependency ที่รองรับ โปรดดูคู่มือการติดตั้งโหนด tfjs

สร้างไดเรกทอรีชื่อ ./baseball สำหรับแอป Node.js คัดลอก package.json และ webpack.config.js ที่ลิงก์ไว้ลงในไดเรกทอรีนี้เพื่อกำหนดค่าทรัพยากร Dependency ของแพ็กเกจ npm (รวมถึงแพ็กเกจ @tensorflow/tfjs-node npm) จากนั้นเรียกใช้การติดตั้ง npm เพื่อติดตั้งทรัพยากร Dependency

$ cd baseball
$ ls
package.json  webpack.config.js
$ npm install
...
$ ls
node_modules  package.json  package-lock.json  webpack.config.js

ตอนนี้คุณพร้อมที่จะเขียนโค้ดและฝึกโมเดลแล้ว

4. ตั้งค่าข้อมูลการฝึกและการทดสอบ

คุณจะใช้ข้อมูลการฝึกอบรมและการทดสอบในรูปแบบไฟล์ CSV จากลิงก์ด้านล่าง ดาวน์โหลดและสำรวจข้อมูลในไฟล์เหล่านี้

pitch_type_training_data.csv

pitch_type_test_data.csv

มาดูตัวอย่างข้อมูลการฝึกกัน

vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0

มีฟีเจอร์อินพุต 8 แบบ ได้แก่ การอธิบายข้อมูลเซ็นเซอร์ระดับเสียง

  • ความเร็วลูกบอล (vx0, vy0, vz0)
  • ความเร่งของลูกบอล (ax, Ay, Az)
  • ความเร็วเริ่มต้นของระดับเสียงสูงต่ำ
  • พิตเชอร์ถนัดมือซ้ายหรือไม่

และป้ายกำกับเอาต์พุต 1 รายการ:

  • Pitch_code ที่บ่งบอกประเภทการเสนอเพลง 1 ใน 7 ประเภท: Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball

เป้าหมายคือการสร้างโมเดลที่สามารถคาดการณ์ประเภทระดับความสูงต่ำจากข้อมูลเซ็นเซอร์ระดับเสียง

ก่อนที่จะสร้างโมเดล คุณต้องเตรียมข้อมูลการฝึกและการทดสอบ สร้างไฟล์ Pitch_type.js ในเซิร์ฟเวอร์เบสบอล/ ไดเรกทอรี และคัดลอกโค้ดต่อไปนี้ลงในไฟล์ โค้ดนี้จะโหลดข้อมูลการฝึกและทดสอบโดยใช้ tf.data.csv API นอกจากนี้ ยังปรับข้อมูลให้เป็นมาตรฐาน (ซึ่งแนะนำให้ใช้เสมอ) โดยใช้สเกลการปรับมาตรฐานต่ำสุด

const tf = require('@tensorflow/tfjs');

// util function to normalize a value between a given range.
function normalize(value, min, max) {
  if (min === undefined || max === undefined) {
    return value;
  }
  return (value - min) / (max - min);
}

// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH =    'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';

// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;

const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;

// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
    ({xs, ys}) => {
      const values = [
        normalize(xs.vx0, VX0_MIN, VX0_MAX),
        normalize(xs.vy0, VY0_MIN, VY0_MAX),
        normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
        normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
        normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
        xs.left_handed_pitcher
      ];
      return {xs: values, ys: ys.pitch_code};
    }

const trainingData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .shuffle(TRAINING_DATA_LENGTH)
        .batch(100);

// Load all training data in one batch to use for evaluation
const trainingValidationData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TRAINING_DATA_LENGTH);

// Load all test data in one batch to use for evaluation
const testValidationData =
    tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TEST_DATA_LENGTH);

5. สร้างโมเดลเพื่อจำแนกประเภทการเสนอเพลง

ตอนนี้คุณก็พร้อมที่จะสร้างโมเดลแล้ว ใช้ API tf.layers เพื่อเชื่อมต่ออินพุต (รูปร่างของค่าเซ็นเซอร์ระดับเสียง [8]) กับเลเยอร์ที่ซ่อนไว้โดยสมบูรณ์ 3 เลเยอร์ซึ่งประกอบด้วยหน่วยเปิดใช้งาน ReLU ตามด้วยเลเยอร์เอาต์พุต softmax 1 เลเยอร์ที่ประกอบด้วย 7 หน่วย โดยแต่ละเลเยอร์จะแสดงประเภทเอาต์พุตอย่างใดอย่างหนึ่ง

ฝึกโมเดลด้วยเครื่องมือเพิ่มประสิทธิภาพ adam และฟังก์ชันการสูญเสีย sparseCategorical Crossentropy โปรดดูข้อมูลเพิ่มเติมเกี่ยวกับตัวเลือกเหล่านี้ในคู่มือโมเดลการฝึก

เพิ่มโค้ดต่อไปนี้ที่ส่วนท้ายของ Pitch_type.js

const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));

model.compile({
  optimizer: tf.train.adam(),
  loss: 'sparseCategoricalCrossentropy',
  metrics: ['accuracy']
});

เรียกใช้การฝึกจากโค้ดเซิร์ฟเวอร์หลักที่คุณจะเขียนในภายหลัง

มาเขียนโมดูล Pitch_type.js กัน เรามาเขียนฟังก์ชันเพื่อประเมินการตรวจสอบความถูกต้องและทดสอบชุดข้อมูล คาดการณ์ประเภทการเสนอขายสำหรับตัวอย่างเดี่ยว และเมตริกความถูกต้องในการคำนวณ เพิ่มโค้ดนี้ต่อท้าย Pitch_type.js:

// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
  let results = {};
  await trainingValidationData.forEachAsync(pitchTypeBatch => {
    const values = model.predict(pitchTypeBatch.xs).dataSync();
    const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
    for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
      results[pitchFromClassNum(i)] = {
        training: calcPitchClassEval(i, classSize, values)
      };
    }
  });

  if (useTestData) {
    await testValidationData.forEachAsync(pitchTypeBatch => {
      const values = model.predict(pitchTypeBatch.xs).dataSync();
      const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
      for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
        results[pitchFromClassNum(i)].validation =
            calcPitchClassEval(i, classSize, values);
      }
    });
  }
  return results;
}

async function predictSample(sample) {
  let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
  var maxValue = 0;
  var predictedPitch = 7;
  for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
    if (result[0][i] > maxValue) {
      predictedPitch = i;
      maxValue = result[0][i];
    }
  }
  return pitchFromClassNum(predictedPitch);
}

// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
  // Output has 7 different class values for each pitch, offset based on
  // which pitch class (ordered by i)
  let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
  let total = 0;
  for (let i = 0; i < classSize; i++) {
    total += values[index];
    index += NUM_PITCH_CLASSES;
  }
  return total / classSize;
}

// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
  switch (classNum) {
    case 0:
      return 'Fastball (2-seam)';
    case 1:
      return 'Fastball (4-seam)';
    case 2:
      return 'Fastball (sinker)';
    case 3:
      return 'Fastball (cutter)';
    case 4:
      return 'Slider';
    case 5:
      return 'Changeup';
    case 6:
      return 'Curveball';
    default:
      return 'Unknown';
  }
}

module.exports = {
  evaluate,
  model,
  pitchFromClassNum,
  predictSample,
  testValidationData,
  trainingData,
  TEST_DATA_LENGTH
}

6. ฝึกโมเดลในเซิร์ฟเวอร์

เขียนโค้ดเซิร์ฟเวอร์เพื่อดำเนินการฝึกและประเมินโมเดลในไฟล์ใหม่ที่ชื่อ "server.js" ขั้นแรก ให้สร้างเซิร์ฟเวอร์ HTTP และเปิดการเชื่อมต่อซ็อกเก็ตแบบ 2 ทิศทางโดยใช้ API socket.io จากนั้นจึงดำเนินการฝึกโมเดลโดยใช้ model.fitDataset API และประเมินความถูกต้องของโมเดลโดยใช้เมธอด pitch_type.evaluate() ที่เขียนไว้ก่อนหน้า ฝึกฝนและประเมินผลและทำซ้ำ 10 ครั้ง ตลอดจนพิมพ์เมตริกไปยังคอนโซล

คัดลอกโค้ดด้านล่างไปยัง Server.js

require('@tensorflow/tfjs-node');

const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');

const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;

// util function to sleep for a given ms
function sleep(ms) {
  return new Promise(resolve => setTimeout(resolve, ms));
}

// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
  const port = process.env.PORT || PORT;
  const server = http.createServer();
  const io = socketio(server);

  server.listen(port, () => {
    console.log(`  > Running socket on port: ${port}`);
  });

  io.on('connection', (socket) => {
    socket.on('predictSample', async (sample) => {
      io.emit('predictResult', await pitch_type.predictSample(sample));
    });
  });

  let numTrainingIterations = 10;
  for (var i = 0; i < numTrainingIterations; i++) {
    console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
    await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
    console.log('accuracyPerClass', await pitch_type.evaluate(true));
    await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
  }

  io.emit('trainingComplete', true);
}

run();

เมื่อถึงจุดนี้ คุณก็พร้อมที่จะเรียกใช้และทดสอบเซิร์ฟเวอร์แล้ว คุณควรจะเห็นบางอย่างแบบนี้ เมื่อเซิร์ฟเวอร์ฝึก 1 Epoch ในการทำซ้ำแต่ละครั้ง (คุณยังสามารถใช้ model.fitDataset API เพื่อฝึกหลาย Epoch ด้วยการเรียกครั้งเดียว) หากคุณพบข้อผิดพลาดใดๆ ในขั้นตอนนี้ โปรดตรวจสอบโหนดและการติดตั้ง NPM

$ npm run start-server
...
  > Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49

กด Ctrl-C เพื่อหยุดเซิร์ฟเวอร์ที่ทำงานอยู่ เราจะเรียกใช้อีกครั้งในขั้นตอนถัดไป

7. สร้างหน้าลูกค้าและรหัสแสดง

เมื่อเซิร์ฟเวอร์พร้อมแล้ว ขั้นตอนถัดไปคือการเขียนโค้ดไคลเอ็นต์และทำงานในเบราว์เซอร์ สร้างหน้าเว็บง่ายๆ เพื่อเรียกใช้การคาดการณ์โมเดลบนเซิร์ฟเวอร์และแสดงผลลัพธ์ การดำเนินการนี้ใช้ socket.io สำหรับการสื่อสารระหว่างไคลเอ็นต์/เซิร์ฟเวอร์

ก่อนอื่น ให้สร้าง index.html ในโฟลเดอร์ เบสบอล/ ดังนี้

<!doctype html>
<html>
  <head>
    <title>Pitch Training Accuracy</title>
  </head>
  <body>
    <h3 id="waiting-msg">Waiting for server...</h3>
    <p>
    <span style="font-size:16px" id="trainingStatus"></span>
    <p>
    <div id="predictContainer" style="font-size:16px;display:none">
      Sensor data: <span id="predictSample"></span>
      <button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
      Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
    </div>
    <script src="dist/bundle.js"></script>
    <style>
      html,
      body {
        font-family: Roboto, sans-serif;
        color: #5f6368;
      }
      body {
        background-color: rgb(248, 249, 250);
      }
    </style>
  </body>
</html>

จากนั้นสร้างไฟล์ client.js ใหม่ในโฟลเดอร์ เบสบอล/ ด้วยโค้ดต่อไปนี้

import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');

const socket =
    io('http://localhost:8001',
       {reconnectionDelay: 300, reconnectionDelayMax: 300});

const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball

predictButton.onclick = () => {
  predictButton.disabled = true;
  socket.emit('predictSample', testSample);
};

// functions to handle socket events
socket.on('connect', () => {
    document.getElementById('waiting-msg').style.display = 'none';
    document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});

socket.on('trainingComplete', () => {
  document.getElementById('trainingStatus').innerHTML = 'Training Complete';
  document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
  predictContainer.style.display = 'block';
});

socket.on('predictResult', (result) => {
  plotPredictResult(result);
});

socket.on('disconnect', () => {
  document.getElementById('trainingStatus').innerHTML = '';
  predictContainer.style.display = 'none';
  document.getElementById('waiting-msg').style.display = 'block';
});

function plotPredictResult(result) {
  predictButton.disabled = false;
  document.getElementById('predictResult').innerHTML = result;
  console.log(result);
}

ไคลเอ็นต์จะจัดการข้อความซ็อกเก็ต trainingComplete เพื่อแสดงปุ่มการคาดการณ์ เมื่อคลิกปุ่มนี้ ไคลเอ็นต์จะส่งข้อความซ็อกเก็ตที่มีข้อมูลเซ็นเซอร์ตัวอย่าง เมื่อได้รับข้อความ predictResult ระบบจะแสดงการคาดคะเนในหน้า

8. เรียกใช้แอป

เรียกใช้ทั้งเซิร์ฟเวอร์และไคลเอ็นต์เพื่อดูการทำงานของแอปตัวเต็ม

[In one terminal, run this first]
$ npm run start-client

[In another terminal, run this next]
$ npm run start-server

เปิดหน้าไคลเอ็นต์ในเบราว์เซอร์ ( http://localhost:8080) เมื่อการฝึกโมเดลเสร็จสิ้น ให้คลิกปุ่มคาดการณ์ตัวอย่าง คุณควรเห็นผลการคาดการณ์แสดงในเบราว์เซอร์ คุณสามารถแก้ไขข้อมูลเซ็นเซอร์ตัวอย่างได้ด้วยตัวอย่างบางส่วนจากไฟล์ CSV ทดสอบ และดูว่าโมเดลคาดการณ์ได้แม่นยำเพียงใด

9. สิ่งที่คุณได้เรียนรู้

ใน Codelab นี้ คุณได้ติดตั้งเว็บแอปพลิเคชันสำหรับแมชชีนเลิร์นนิงแบบง่ายโดยใช้ TensorFlow.js คุณได้ฝึกโมเดลที่กำหนดเองเพื่อจำแนกประเภทพิตช์เบสบอลจากข้อมูลเซ็นเซอร์ คุณเขียนโค้ด Node.js เพื่อดำเนินการฝึกในเซิร์ฟเวอร์ และอนุมานการเรียกใช้บนโมเดลที่ได้รับการฝึกโดยใช้ข้อมูลที่ส่งจากไคลเอ็นต์

อย่าลืมไปที่ tensorflow.org/js เพื่อดูตัวอย่างเพิ่มเติมและการสาธิตที่มีโค้ดเพื่อดูวิธีการใช้ TensorFlow.js ในแอปพลิเคชัน