Node.js Codelab의 TensorFlow.js 학습

1. 소개

이 Codelab에서는 강력하면서도 유연한 JavaScript 머신러닝 라이브러리인 TensorFlow.js를 사용하여 서버 측에서 야구 투구 유형을 학습시키고 분류하는 Node.js 웹 서버를 빌드하는 방법을 알아봅니다. 웹 애플리케이션을 빌드하여 피치 센서 데이터에서 피치 유형을 예측하는 모델을 학습시키고 웹 클라이언트에서 예측을 호출합니다. 이 Codelab의 완전히 작동하는 버전은 tfjs-examples GitHub 저장소에 있습니다.

학습할 내용

  • Node.js와 함께 사용할 tensorflow.js npm 패키지를 설치하고 설정하는 방법
  • Node.js 환경에서 학습 및 테스트 데이터에 액세스하는 방법
  • Node.js 서버에서 TensorFlow.js로 모델을 학습시키는 방법
  • 클라이언트/서버 애플리케이션에서 추론을 위해 학습된 모델을 배포하는 방법

그럼 시작해 보겠습니다.

2. 요구사항

이 Codelab을 완료하려면 다음이 필요합니다.

  1. 최신 버전의 Chrome 또는 다른 최신 브라우저
  2. 머신에서 로컬로 실행되는 텍스트 편집기와 명령 터미널
  3. HTML, CSS, JavaScript, Chrome DevTools (또는 선호하는 브라우저의 개발 도구)에 대한 지식
  4. 신경망에 대한 대략적인 개념 이해. 소개 또는 복습이 필요하다면 3blue1brown의 동영상 또는 아시 크리슈난의 자바스크립트 딥 러닝 동영상을 확인하세요.

3. Node.js 앱 설정

Node.js 및 npm을 설치합니다. 지원되는 플랫폼과 종속 항목은 tfjs-node 설치 가이드를 참고하세요.

Node.js 앱을 위한 ./baseball 디렉터리를 만듭니다. 연결된 package.jsonwebpack.config.js를 이 디렉터리에 복사하여 npm 패키지 종속 항목 (@tensorflow/tfjs-node npm 패키지 포함)을 구성합니다. 그런 다음 npm install을 실행하여 종속 항목을 설치합니다.

$ cd baseball
$ ls
package.json  webpack.config.js
$ npm install
...
$ ls
node_modules  package.json  package-lock.json  webpack.config.js

이제 코드를 작성하고 모델을 학습시킬 준비가 되었습니다.

4. 학습 및 테스트 데이터 설정

아래 링크에서 학습 및 테스트 데이터를 CSV 파일로 사용합니다. 다음 파일에서 데이터를 다운로드하고 탐색합니다.

pitch_type_training_data.csv

pitch_type_test_data.csv

샘플 학습 데이터를 살펴보겠습니다.

vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0

피치 센서 데이터를 설명하는 8개의 입력 특성이 있습니다.

  • 공 속도 (vx0, vy0, vz0)
  • 볼 가속도 (ax, ay, az)
  • 피치의 시작 속도
  • 투수가 왼손잡이인지 여부

하나의 출력 라벨:

  • 7가지 피치 유형 중 하나를 나타내는 pitch_code: Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball

목표는 투구 센서 데이터가 주어졌을 때 투구 유형을 예측할 수 있는 모델을 빌드하는 것입니다.

모델을 만들기 전에 학습 데이터와 테스트 데이터를 준비해야 합니다. baseball/ 디렉터리에 pitch_type.js 파일을 만들고 다음 코드를 복사합니다. 이 코드는 tf.data.csv API를 사용하여 학습 데이터와 테스트 데이터를 로드합니다. 또한 최소-최대 정규화 스케일을 사용하여 데이터를 정규화합니다 (항상 권장됨).

const tf = require('@tensorflow/tfjs');

// util function to normalize a value between a given range.
function normalize(value, min, max) {
  if (min === undefined || max === undefined) {
    return value;
  }
  return (value - min) / (max - min);
}

// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH =    'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';

// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;

const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;

// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
    ({xs, ys}) => {
      const values = [
        normalize(xs.vx0, VX0_MIN, VX0_MAX),
        normalize(xs.vy0, VY0_MIN, VY0_MAX),
        normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
        normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
        normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
        xs.left_handed_pitcher
      ];
      return {xs: values, ys: ys.pitch_code};
    }

const trainingData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .shuffle(TRAINING_DATA_LENGTH)
        .batch(100);

// Load all training data in one batch to use for evaluation
const trainingValidationData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TRAINING_DATA_LENGTH);

// Load all test data in one batch to use for evaluation
const testValidationData =
    tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TEST_DATA_LENGTH);

5. 구종을 분류하는 모델 만들기

이제 모델을 빌드할 준비가 되었습니다. tf.layers API를 사용하여 입력 (모양이 [8] 인 피치 센서 값)을 ReLU 활성화 단위로 구성된 3개의 완전 연결형 숨겨진 레이어에 연결한 후 각 출력이 출력 피치 유형 중 하나를 나타내는 7개 단위로 구성된 소프트맥스 출력 레이어 하나에 연결합니다.

adam 옵티마이저와 sparseCategoricalCrossentropy 손실 함수를 사용하여 모델을 학습시킵니다. 이러한 선택사항에 대한 자세한 내용은 모델 학습 가이드를 참고하세요.

pitch_type.js 끝에 다음 코드를 추가합니다.

const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));

model.compile({
  optimizer: tf.train.adam(),
  loss: 'sparseCategoricalCrossentropy',
  metrics: ['accuracy']
});

나중에 작성할 기본 서버 코드에서 학습을 트리거합니다.

pitch_type.js 모듈을 완료하려면 검증 및 테스트 데이터 세트를 평가하고, 단일 샘플의 투구 유형을 예측하고, 정확도 측정항목을 계산하는 함수를 작성하세요. 이 코드를 pitch_type.js 끝에 추가합니다.

// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
  let results = {};
  await trainingValidationData.forEachAsync(pitchTypeBatch => {
    const values = model.predict(pitchTypeBatch.xs).dataSync();
    const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
    for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
      results[pitchFromClassNum(i)] = {
        training: calcPitchClassEval(i, classSize, values)
      };
    }
  });

  if (useTestData) {
    await testValidationData.forEachAsync(pitchTypeBatch => {
      const values = model.predict(pitchTypeBatch.xs).dataSync();
      const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
      for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
        results[pitchFromClassNum(i)].validation =
            calcPitchClassEval(i, classSize, values);
      }
    });
  }
  return results;
}

async function predictSample(sample) {
  let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
  var maxValue = 0;
  var predictedPitch = 7;
  for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
    if (result[0][i] > maxValue) {
      predictedPitch = i;
      maxValue = result[0][i];
    }
  }
  return pitchFromClassNum(predictedPitch);
}

// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
  // Output has 7 different class values for each pitch, offset based on
  // which pitch class (ordered by i)
  let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
  let total = 0;
  for (let i = 0; i < classSize; i++) {
    total += values[index];
    index += NUM_PITCH_CLASSES;
  }
  return total / classSize;
}

// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
  switch (classNum) {
    case 0:
      return 'Fastball (2-seam)';
    case 1:
      return 'Fastball (4-seam)';
    case 2:
      return 'Fastball (sinker)';
    case 3:
      return 'Fastball (cutter)';
    case 4:
      return 'Slider';
    case 5:
      return 'Changeup';
    case 6:
      return 'Curveball';
    default:
      return 'Unknown';
  }
}

module.exports = {
  evaluate,
  model,
  pitchFromClassNum,
  predictSample,
  testValidationData,
  trainingData,
  TEST_DATA_LENGTH
}

6. 서버에서 모델 학습

server.js라는 새 파일에 모델 학습 및 평가를 실행하는 서버 코드를 작성합니다. 먼저 HTTP 서버를 만들고 socket.io API를 사용하여 양방향 소켓 연결을 엽니다. 그런 다음 model.fitDataset API를 사용하여 모델 학습을 실행하고 앞에서 작성한 pitch_type.evaluate() 메서드를 사용하여 모델 정확도를 평가합니다. 10회 반복하여 학습 및 평가하고 측정항목을 콘솔에 출력합니다.

아래 코드를 server.js에 복사합니다.

require('@tensorflow/tfjs-node');

const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');

const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;

// util function to sleep for a given ms
function sleep(ms) {
  return new Promise(resolve => setTimeout(resolve, ms));
}

// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
  const port = process.env.PORT || PORT;
  const server = http.createServer();
  const io = socketio(server);

  server.listen(port, () => {
    console.log(`  > Running socket on port: ${port}`);
  });

  io.on('connection', (socket) => {
    socket.on('predictSample', async (sample) => {
      io.emit('predictResult', await pitch_type.predictSample(sample));
    });
  });

  let numTrainingIterations = 10;
  for (var i = 0; i < numTrainingIterations; i++) {
    console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
    await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
    console.log('accuracyPerClass', await pitch_type.evaluate(true));
    await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
  }

  io.emit('trainingComplete', true);
}

run();

이제 서버를 실행하고 테스트할 준비가 되었습니다. 각 반복에서 서버가 한 에포크를 학습하는 다음과 같은 내용이 표시됩니다. model.fitDataset API를 사용하여 한 번의 호출로 여러 에포크를 학습할 수도 있습니다. 이 시점에서 오류가 발생하면 노드 및 npm 설치를 확인하세요.

$ npm run start-server
...
  > Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49

Ctrl-C를 입력하여 실행 중인 서버를 중지합니다. 다음 단계에서 다시 실행합니다.

7. 클라이언트 페이지를 만들고 코드를 표시합니다.

이제 서버가 준비되었으므로 다음 단계는 브라우저에서 실행되는 클라이언트 코드를 작성하는 것입니다. 서버에서 모델 예측을 호출하고 결과를 표시하는 간단한 페이지를 만듭니다. 클라이언트/서버 통신에 socket.io를 사용합니다.

먼저 baseball/ 폴더에 index.html을 만듭니다.

<!doctype html>
<html>
  <head>
    <title>Pitch Training Accuracy</title>
  </head>
  <body>
    <h3 id="waiting-msg">Waiting for server...</h3>
    <p>
    <span style="font-size:16px" id="trainingStatus"></span>
    <p>
    <div id="predictContainer" style="font-size:16px;display:none">
      Sensor data: <span id="predictSample"></span>
      <button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
      Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
    </div>
    <script src="dist/bundle.js"></script>
    <style>
      html,
      body {
        font-family: Roboto, sans-serif;
        color: #5f6368;
      }
      body {
        background-color: rgb(248, 249, 250);
      }
    </style>
  </body>
</html>

그런 다음 baseball/ 폴더에 아래 코드를 사용하여 client.js라는 새 파일을 만듭니다.

import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');

const socket =
    io('http://localhost:8001',
       {reconnectionDelay: 300, reconnectionDelayMax: 300});

const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball

predictButton.onclick = () => {
  predictButton.disabled = true;
  socket.emit('predictSample', testSample);
};

// functions to handle socket events
socket.on('connect', () => {
    document.getElementById('waiting-msg').style.display = 'none';
    document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});

socket.on('trainingComplete', () => {
  document.getElementById('trainingStatus').innerHTML = 'Training Complete';
  document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
  predictContainer.style.display = 'block';
});

socket.on('predictResult', (result) => {
  plotPredictResult(result);
});

socket.on('disconnect', () => {
  document.getElementById('trainingStatus').innerHTML = '';
  predictContainer.style.display = 'none';
  document.getElementById('waiting-msg').style.display = 'block';
});

function plotPredictResult(result) {
  predictButton.disabled = false;
  document.getElementById('predictResult').innerHTML = result;
  console.log(result);
}

클라이언트는 trainingComplete 소켓 메시지를 처리하여 예측 버튼을 표시합니다. 이 버튼을 클릭하면 클라이언트가 샘플 센서 데이터가 포함된 소켓 메시지를 전송합니다. predictResult 메시지를 수신하면 페이지에 예측을 표시합니다.

8. 앱 실행

서버와 클라이언트를 모두 실행하여 전체 앱이 작동하는지 확인합니다.

[In one terminal, run this first]
$ npm run start-client

[In another terminal, run this next]
$ npm run start-server

브라우저에서 클라이언트 페이지 ( http://localhost:8080)를 엽니다. 모델 학습이 완료되면 샘플 예측 버튼을 클릭합니다. 브라우저에 예측 결과가 표시됩니다. 테스트 CSV 파일의 몇 가지 예시를 사용하여 샘플 센서 데이터를 수정하고 모델이 얼마나 정확하게 예측하는지 확인해 보세요.

9. 학습한 내용

이 Codelab에서는 TensorFlow.js를 사용해 간단한 머신러닝 웹 애플리케이션을 구현했습니다. 센서 데이터에서 야구 투구 유형을 분류하는 맞춤 모델을 학습시켰습니다. 서버에서 학습을 실행하고 클라이언트에서 전송된 데이터를 사용하여 학습된 모델에서 추론을 호출하는 Node.js 코드를 작성했습니다.

더 많은 예시 및 코드가 포함된 데모가 제공되는 tensorflow.org/js를 방문해 애플리케이션에서 TensorFlow.js를 사용하는 방법에 대해 알아보세요.