TensorFlow.js-Training in Node.js-Codelab

1. Einführung

In diesem Codelab erfahren Sie, wie Sie einen Node.js-Webserver erstellen, um Baseball-Pitch-Typen serverseitig mit TensorFlow.js ist eine leistungsstarke und flexible Bibliothek für maschinelles Lernen in JavaScript. Sie erstellen eine Webanwendung, mit der ein Modell trainiert wird, um den Typ des Wurfs anhand von Daten des Wurfsensors vorherzusagen und Vorhersagen von einem Webclient aus aufzurufen. Eine vollständig funktionierende Version dieses Codelabs ist im tfjs-examples-GitHub-Repository verfügbar.

Lerninhalte

  • So installieren und richten Sie das tensorflow.js-npm-Paket für die Verwendung mit Node.js ein.
  • So greifen Sie in der Node.js-Umgebung auf Trainings- und Testdaten zu.
  • So trainieren Sie ein Modell mit TensorFlow.js auf einem Node.js-Server.
  • So stellen Sie das trainierte Modell für die Inferenz in einer Client-/Server-Anwendung bereit.

Legen wir los!

2. Voraussetzungen

Für dieses Codelab benötigen Sie Folgendes:

  1. Eine aktuelle Version von Chrome oder einem anderen modernen Browser.
  2. Ein Texteditor und ein Befehlsterminal, die lokal auf Ihrem Computer ausgeführt werden.
  3. Kenntnisse von HTML, CSS, JavaScript und den Chrome-Entwicklertools (oder den Entwicklertools Ihres bevorzugten Browsers).
  4. Ein grundlegendes konzeptionelles Verständnis von neuronalen Netzwerken. Wenn Sie eine Einführung oder Auffrischung benötigen, können Sie sich dieses Video von 3blue1brown oder dieses Video zu Deep Learning in JavaScript von Ashi Krishnan ansehen.

3. Node.js-Anwendung einrichten

Installieren Sie Node.js und npm. Informationen zu unterstützten Plattformen und Abhängigkeiten finden Sie in der Installationsanleitung für tfjs-node.

Erstellen Sie ein Verzeichnis namens „./baseball“ für unsere Node.js-App. Kopieren Sie die verlinkten Dateien package.json und webpack.config.js in dieses Verzeichnis, um die npm-Paketabhängigkeiten zu konfigurieren, einschließlich des npm-Pakets @tensorflow/tfjs-node. Führen Sie dann „npm install“ aus, um die Abhängigkeiten zu installieren.

$ cd baseball
$ ls
package.json  webpack.config.js
$ npm install
...
$ ls
node_modules  package.json  package-lock.json  webpack.config.js

Jetzt können Sie Code schreiben und ein Modell trainieren.

4. Trainings- und Testdaten einrichten

Sie verwenden die Trainings- und Testdaten als CSV-Dateien über die folgenden Links. Laden Sie die Daten in diesen Dateien herunter und sehen Sie sie sich an:

pitch_type_training_data.csv

pitch_type_test_data.csv

Sehen wir uns einige Beispieldaten für das Training an:

vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0

Es gibt acht Eingabefeatures, die die Daten des Neigungssensors beschreiben:

  • Ballgeschwindigkeit (vx0, vy0, vz0)
  • Beschleunigung des Balls (ax, ay, az)
  • Startgeschwindigkeit des Tonhöhenwechsels
  • ob der Pitcher Linkshänder ist

und ein Ausgabelabel:

  • pitch_code, der einen von sieben Pitch-Typen angibt: Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball

Ziel ist es, ein Modell zu erstellen, das den Pitch-Typ anhand von Daten des Pitch-Sensors vorhersagen kann.

Bevor Sie das Modell erstellen, müssen Sie die Trainings- und Testdaten vorbereiten. Erstellen Sie die Datei „pitch_type.js“ im Verzeichnis „baseball/“ und kopieren Sie den folgenden Code hinein. Mit diesem Code werden Trainings- und Testdaten über die tf.data.csv API geladen. Außerdem werden die Daten mithilfe einer Min-Max-Normalisierungsskala normalisiert, was immer empfehlenswert ist.

const tf = require('@tensorflow/tfjs');

// util function to normalize a value between a given range.
function normalize(value, min, max) {
  if (min === undefined || max === undefined) {
    return value;
  }
  return (value - min) / (max - min);
}

// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH =    'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';

// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;

const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;

// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
    ({xs, ys}) => {
      const values = [
        normalize(xs.vx0, VX0_MIN, VX0_MAX),
        normalize(xs.vy0, VY0_MIN, VY0_MAX),
        normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
        normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
        normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
        xs.left_handed_pitcher
      ];
      return {xs: values, ys: ys.pitch_code};
    }

const trainingData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .shuffle(TRAINING_DATA_LENGTH)
        .batch(100);

// Load all training data in one batch to use for evaluation
const trainingValidationData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TRAINING_DATA_LENGTH);

// Load all test data in one batch to use for evaluation
const testValidationData =
    tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TEST_DATA_LENGTH);

5. Modell zum Klassifizieren von Pitch-Typen erstellen

Jetzt können Sie das Modell erstellen. Verwenden Sie die tf.layers-API, um die Eingaben (Form von [8] Werten des Neigungssensors) mit drei verborgenen, vollständig verbundenen Ebenen zu verbinden, die aus ReLU-Aktivierungseinheiten bestehen, gefolgt von einer Softmax-Ausgabeebene mit sieben Einheiten, die jeweils einen der Ausgabetypen für die Neigung darstellen.

Trainieren Sie das Modell mit dem Adam-Optimizer und der Verlustfunktion „sparseCategoricalCrossentropy“. Weitere Informationen zu diesen Optionen finden Sie im Leitfaden zu Trainingsmodellen.

Fügen Sie am Ende von „pitch_type.js“ den folgenden Code ein:

const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));

model.compile({
  optimizer: tf.train.adam(),
  loss: 'sparseCategoricalCrossentropy',
  metrics: ['accuracy']
});

Lösen Sie das Training über den Hauptservercode aus, den Sie später schreiben.

Um das Modul „pitch_type.js“ zu vervollständigen, schreiben wir eine Funktion, mit der das Validierungs- und Test-Dataset ausgewertet, ein Pitch-Typ für eine einzelne Stichprobe vorhergesagt und Genauigkeitsmesswerte berechnet werden. Hängen Sie diesen Code an das Ende von „pitch_type.js“ an:

// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
  let results = {};
  await trainingValidationData.forEachAsync(pitchTypeBatch => {
    const values = model.predict(pitchTypeBatch.xs).dataSync();
    const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
    for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
      results[pitchFromClassNum(i)] = {
        training: calcPitchClassEval(i, classSize, values)
      };
    }
  });

  if (useTestData) {
    await testValidationData.forEachAsync(pitchTypeBatch => {
      const values = model.predict(pitchTypeBatch.xs).dataSync();
      const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
      for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
        results[pitchFromClassNum(i)].validation =
            calcPitchClassEval(i, classSize, values);
      }
    });
  }
  return results;
}

async function predictSample(sample) {
  let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
  var maxValue = 0;
  var predictedPitch = 7;
  for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
    if (result[0][i] > maxValue) {
      predictedPitch = i;
      maxValue = result[0][i];
    }
  }
  return pitchFromClassNum(predictedPitch);
}

// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
  // Output has 7 different class values for each pitch, offset based on
  // which pitch class (ordered by i)
  let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
  let total = 0;
  for (let i = 0; i < classSize; i++) {
    total += values[index];
    index += NUM_PITCH_CLASSES;
  }
  return total / classSize;
}

// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
  switch (classNum) {
    case 0:
      return 'Fastball (2-seam)';
    case 1:
      return 'Fastball (4-seam)';
    case 2:
      return 'Fastball (sinker)';
    case 3:
      return 'Fastball (cutter)';
    case 4:
      return 'Slider';
    case 5:
      return 'Changeup';
    case 6:
      return 'Curveball';
    default:
      return 'Unknown';
  }
}

module.exports = {
  evaluate,
  model,
  pitchFromClassNum,
  predictSample,
  testValidationData,
  trainingData,
  TEST_DATA_LENGTH
}

6. Modell auf dem Server trainieren

Schreiben Sie den Servercode für das Modelltraining und die ‑bewertung in eine neue Datei namens „server.js“. Erstellen Sie zuerst einen HTTP-Server und öffnen Sie eine bidirektionale Socket-Verbindung mit der Socket.IO API. Führen Sie dann das Modelltraining mit der model.fitDataset API aus und bewerten Sie die Modellgenauigkeit mit der pitch_type.evaluate()-Methode, die Sie zuvor geschrieben haben. Trainieren und bewerten Sie das Modell 10 Mal und geben Sie die Messwerte in der Konsole aus.

Kopieren Sie den folgenden Code in die Datei „server.js“:

require('@tensorflow/tfjs-node');

const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');

const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;

// util function to sleep for a given ms
function sleep(ms) {
  return new Promise(resolve => setTimeout(resolve, ms));
}

// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
  const port = process.env.PORT || PORT;
  const server = http.createServer();
  const io = socketio(server);

  server.listen(port, () => {
    console.log(`  > Running socket on port: ${port}`);
  });

  io.on('connection', (socket) => {
    socket.on('predictSample', async (sample) => {
      io.emit('predictResult', await pitch_type.predictSample(sample));
    });
  });

  let numTrainingIterations = 10;
  for (var i = 0; i < numTrainingIterations; i++) {
    console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
    await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
    console.log('accuracyPerClass', await pitch_type.evaluate(true));
    await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
  }

  io.emit('trainingComplete', true);
}

run();

Jetzt können Sie den Server ausführen und testen. Sie sollten etwas Ähnliches sehen, wobei der Server in jeder Iteration eine Epoche trainiert. Sie können auch die model.fitDataset API verwenden, um mit einem Aufruf mehrere Epochen zu trainieren. Wenn an dieser Stelle Fehler auftreten, prüfen Sie Ihre Node- und npm-Installation.

$ npm run start-server
...
  > Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49

Geben Sie Strg + C ein, um den laufenden Server zu beenden. Wir führen sie im nächsten Schritt noch einmal aus.

7. Clientseite erstellen und Code anzeigen

Nachdem der Server bereit ist, müssen Sie als Nächstes den Clientcode schreiben, der im Browser ausgeführt wird. Erstellen Sie eine einfache Seite, um die Modellvorhersage auf dem Server aufzurufen und das Ergebnis anzuzeigen. Dabei wird socket.io für die Client-/Server-Kommunikation verwendet.

Erstellen Sie zuerst die Datei „index.html“ im Ordner „baseball“:

<!doctype html>
<html>
  <head>
    <title>Pitch Training Accuracy</title>
  </head>
  <body>
    <h3 id="waiting-msg">Waiting for server...</h3>
    <p>
    <span style="font-size:16px" id="trainingStatus"></span>
    <p>
    <div id="predictContainer" style="font-size:16px;display:none">
      Sensor data: <span id="predictSample"></span>
      <button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
      Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
    </div>
    <script src="dist/bundle.js"></script>
    <style>
      html,
      body {
        font-family: Roboto, sans-serif;
        color: #5f6368;
      }
      body {
        background-color: rgb(248, 249, 250);
      }
    </style>
  </body>
</html>

Erstellen Sie dann im Ordner „baseball/“ eine neue Datei namens „client.js“ mit dem folgenden Code:

import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');

const socket =
    io('http://localhost:8001',
       {reconnectionDelay: 300, reconnectionDelayMax: 300});

const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball

predictButton.onclick = () => {
  predictButton.disabled = true;
  socket.emit('predictSample', testSample);
};

// functions to handle socket events
socket.on('connect', () => {
    document.getElementById('waiting-msg').style.display = 'none';
    document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});

socket.on('trainingComplete', () => {
  document.getElementById('trainingStatus').innerHTML = 'Training Complete';
  document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
  predictContainer.style.display = 'block';
});

socket.on('predictResult', (result) => {
  plotPredictResult(result);
});

socket.on('disconnect', () => {
  document.getElementById('trainingStatus').innerHTML = '';
  predictContainer.style.display = 'none';
  document.getElementById('waiting-msg').style.display = 'block';
});

function plotPredictResult(result) {
  predictButton.disabled = false;
  document.getElementById('predictResult').innerHTML = result;
  console.log(result);
}

Der Client verarbeitet die trainingComplete-Socket-Nachricht, um eine Vorhersageschaltfläche anzuzeigen. Wenn auf diese Schaltfläche geklickt wird, sendet der Client eine Socket-Nachricht mit Beispiel-Sensordaten. Nachdem eine predictResult-Nachricht empfangen wurde, wird die Vorhersage auf der Seite angezeigt.

8. Anwendung ausführen

Führen Sie sowohl den Server als auch den Client aus, um die vollständige App in Aktion zu sehen:

[In one terminal, run this first]
$ npm run start-client

[In another terminal, run this next]
$ npm run start-server

Öffnen Sie die Clientseite in Ihrem Browser ( http://localhost:8080). Wenn das Training des Modells abgeschlossen ist, klicken Sie auf die Schaltfläche Predict Sample (Beispiel vorhersagen). Im Browser sollte ein Vorhersageergebnis angezeigt werden. Sie können die Beispieldaten der Sensoren mit einigen Beispielen aus der CSV-Testdatei ändern und sehen, wie genau das Modell Vorhersagen trifft.

9. Das haben Sie gelernt

In diesem Codelab haben Sie eine einfache Webanwendung für maschinelles Lernen mit TensorFlow.js implementiert. Sie haben ein benutzerdefiniertes Modell zum Klassifizieren von Baseball-Pitch-Typen anhand von Sensordaten trainiert. Sie haben Node.js-Code geschrieben, um das Training auf dem Server auszuführen und die Inferenz für das trainierte Modell mit Daten aufzurufen, die vom Client gesendet werden.

Auf tensorflow.org/js finden Sie weitere Beispiele und Demos mit Code, die zeigen, wie Sie TensorFlow.js in Ihren Anwendungen verwenden können.