Capacitación sobre TensorFlow.js en el codelab de Node.js

1. Introducción

En este codelab, aprenderás a compilar un servidor web de Node.js para entrenar y clasificar tipos de lanzamientos de béisbol en el servidor con TensorFlow.js, una biblioteca de aprendizaje automático potente y flexible para JavaScript. Compilarás una aplicación web para entrenar un modelo que prediga el tipo de lanzamiento a partir de los datos del sensor de lanzamiento y para invocar la predicción desde un cliente web. En el repositorio de GitHub tfjs-examples, se encuentra una versión completamente funcional de este codelab.

Qué aprenderás

  • Cómo instalar y configurar el paquete npm de TensorFlow.js para usarlo con Node.js
  • Cómo acceder a los datos de entrenamiento y de prueba en el entorno de Node.js
  • Cómo entrenar un modelo con TensorFlow.js en un servidor de Node.js
  • Cómo implementar el modelo entrenado para la inferencia en una aplicación cliente/servidor

¡Comencemos!

2. Requisitos

Para completar este codelab, necesitarás lo siguiente:

  1. Una versión reciente de Chrome o de otro navegador actualizado
  2. Un editor de texto y una terminal de comandos que se ejecuten de forma local en tu máquina
  3. Conocimientos sobre HTML, CSS, JavaScript y las Herramientas para desarrolladores de Chrome (o las de tu navegador preferido)
  4. Comprensión conceptual de alto nivel de las redes neuronales Si necesitas una introducción o un repaso, te recomendamos mirar este video de 3blue1brown o este video sobre aprendizaje profundo en JavaScript de Ashi Krishnan.

3. Configura una app de Node.js

Instala Node.js y npm. Para conocer las plataformas y dependencias compatibles, consulta la guía de instalación de tfjs-node.

Crea un directorio llamado ./baseball para nuestra app de Node.js. Copia los archivos vinculados package.json y webpack.config.js en este directorio para configurar las dependencias del paquete npm (incluido el paquete npm @tensorflow/tfjs-node). Luego, ejecuta npm install para instalar las dependencias.

$ cd baseball
$ ls
package.json  webpack.config.js
$ npm install
...
$ ls
node_modules  package.json  package-lock.json  webpack.config.js

Ahora puedes escribir código y entrenar un modelo.

4. Configura los datos de entrenamiento y de prueba

Usarás los datos de entrenamiento y prueba como archivos CSV de los vínculos que se indican a continuación. Descarga y explora los datos de estos archivos:

pitch_type_training_data.csv

pitch_type_test_data.csv

Veamos algunos datos de entrenamiento de muestra:

vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0

Hay ocho atributos de entrada que describen los datos del sensor de inclinación:

  • Velocidad de la pelota (vx0, vy0, vz0)
  • Aceleración de la pelota (ax, ay, az)
  • velocidad inicial del lanzamiento
  • Si el lanzador es zurdo o no

y una etiqueta de salida:

  • pitch_code que indica uno de los siete tipos de lanzamientos: Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball

El objetivo es crear un modelo que pueda predecir el tipo de lanzamiento a partir de los datos del sensor de lanzamiento.

Antes de crear el modelo, debes preparar los datos de entrenamiento y de prueba. Crea el archivo pitch_type.js en el directorio baseball/ y copia el siguiente código en él. Este código carga datos de entrenamiento y prueba con la API de tf.data.csv. También normaliza los datos (lo que siempre se recomienda) con una escala de normalización min-max.

const tf = require('@tensorflow/tfjs');

// util function to normalize a value between a given range.
function normalize(value, min, max) {
  if (min === undefined || max === undefined) {
    return value;
  }
  return (value - min) / (max - min);
}

// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH =    'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';

// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;

const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;

// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
    ({xs, ys}) => {
      const values = [
        normalize(xs.vx0, VX0_MIN, VX0_MAX),
        normalize(xs.vy0, VY0_MIN, VY0_MAX),
        normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
        normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
        normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
        xs.left_handed_pitcher
      ];
      return {xs: values, ys: ys.pitch_code};
    }

const trainingData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .shuffle(TRAINING_DATA_LENGTH)
        .batch(100);

// Load all training data in one batch to use for evaluation
const trainingValidationData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TRAINING_DATA_LENGTH);

// Load all test data in one batch to use for evaluation
const testValidationData =
    tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TEST_DATA_LENGTH);

5. Crea un modelo para clasificar los tipos de lanzamientos

Ya puedes compilar el modelo. Usa la API de tf.layers para conectar las entradas (forma de valores del sensor de tono [8]) a 3 capas ocultas completamente conectadas que constan de unidades de activación ReLU, seguidas de una capa de salida softmax que consta de 7 unidades, cada una de las cuales representa uno de los tipos de tono de salida.

Entrena el modelo con el optimizador adam y la función de pérdida sparseCategoricalCrossentropy. Para obtener más información sobre estas opciones, consulta la guía para entrenar modelos.

Agrega el siguiente código al final de pitch_type.js:

const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));

model.compile({
  optimizer: tf.train.adam(),
  loss: 'sparseCategoricalCrossentropy',
  metrics: ['accuracy']
});

Activa el entrenamiento desde el código del servidor principal que escribirás más adelante.

Para completar el módulo pitch_type.js, escribamos una función para evaluar el conjunto de datos de validación y prueba, predecir un tipo de lanzamiento para una sola muestra y calcular las métricas de precisión. Agrega este código al final de pitch_type.js:

// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
  let results = {};
  await trainingValidationData.forEachAsync(pitchTypeBatch => {
    const values = model.predict(pitchTypeBatch.xs).dataSync();
    const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
    for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
      results[pitchFromClassNum(i)] = {
        training: calcPitchClassEval(i, classSize, values)
      };
    }
  });

  if (useTestData) {
    await testValidationData.forEachAsync(pitchTypeBatch => {
      const values = model.predict(pitchTypeBatch.xs).dataSync();
      const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
      for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
        results[pitchFromClassNum(i)].validation =
            calcPitchClassEval(i, classSize, values);
      }
    });
  }
  return results;
}

async function predictSample(sample) {
  let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
  var maxValue = 0;
  var predictedPitch = 7;
  for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
    if (result[0][i] > maxValue) {
      predictedPitch = i;
      maxValue = result[0][i];
    }
  }
  return pitchFromClassNum(predictedPitch);
}

// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
  // Output has 7 different class values for each pitch, offset based on
  // which pitch class (ordered by i)
  let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
  let total = 0;
  for (let i = 0; i < classSize; i++) {
    total += values[index];
    index += NUM_PITCH_CLASSES;
  }
  return total / classSize;
}

// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
  switch (classNum) {
    case 0:
      return 'Fastball (2-seam)';
    case 1:
      return 'Fastball (4-seam)';
    case 2:
      return 'Fastball (sinker)';
    case 3:
      return 'Fastball (cutter)';
    case 4:
      return 'Slider';
    case 5:
      return 'Changeup';
    case 6:
      return 'Curveball';
    default:
      return 'Unknown';
  }
}

module.exports = {
  evaluate,
  model,
  pitchFromClassNum,
  predictSample,
  testValidationData,
  trainingData,
  TEST_DATA_LENGTH
}

6. Entrena el modelo en el servidor

Escribe el código del servidor para realizar el entrenamiento y la evaluación del modelo en un archivo nuevo llamado server.js. Primero, crea un servidor HTTP y abre una conexión de socket bidireccional con la API de socket.io. Luego, ejecuta el entrenamiento del modelo con la API de model.fitDataset y evalúa la precisión del modelo con el método pitch_type.evaluate() que escribiste antes. Entrena y evalúa durante 10 iteraciones, y muestra las métricas en la consola.

Copia el siguiente código en server.js:

require('@tensorflow/tfjs-node');

const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');

const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;

// util function to sleep for a given ms
function sleep(ms) {
  return new Promise(resolve => setTimeout(resolve, ms));
}

// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
  const port = process.env.PORT || PORT;
  const server = http.createServer();
  const io = socketio(server);

  server.listen(port, () => {
    console.log(`  > Running socket on port: ${port}`);
  });

  io.on('connection', (socket) => {
    socket.on('predictSample', async (sample) => {
      io.emit('predictResult', await pitch_type.predictSample(sample));
    });
  });

  let numTrainingIterations = 10;
  for (var i = 0; i < numTrainingIterations; i++) {
    console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
    await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
    console.log('accuracyPerClass', await pitch_type.evaluate(true));
    await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
  }

  io.emit('trainingComplete', true);
}

run();

En este punto, ya puedes ejecutar y probar el servidor. Deberías ver algo como esto, con el servidor entrenando una época en cada iteración (también podrías usar la API de model.fitDataset para entrenar varias épocas con una sola llamada). Si encuentras algún error en este punto, verifica la instalación de Node y npm.

$ npm run start-server
...
  > Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49

Escribe Ctrl + C para detener el servidor en ejecución. Lo volveremos a ejecutar en el siguiente paso.

7. Crea la página del cliente y muestra el código

Ahora que el servidor está listo, el siguiente paso es escribir el código del cliente que se ejecuta en el navegador. Crea una página simple para invocar la predicción del modelo en el servidor y mostrar el resultado. Esto usa socket.io para la comunicación entre el cliente y el servidor.

Primero, crea index.html en la carpeta baseball/:

<!doctype html>
<html>
  <head>
    <title>Pitch Training Accuracy</title>
  </head>
  <body>
    <h3 id="waiting-msg">Waiting for server...</h3>
    <p>
    <span style="font-size:16px" id="trainingStatus"></span>
    <p>
    <div id="predictContainer" style="font-size:16px;display:none">
      Sensor data: <span id="predictSample"></span>
      <button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
      Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
    </div>
    <script src="dist/bundle.js"></script>
    <style>
      html,
      body {
        font-family: Roboto, sans-serif;
        color: #5f6368;
      }
      body {
        background-color: rgb(248, 249, 250);
      }
    </style>
  </body>
</html>

Luego, crea un archivo nuevo llamado client.js en la carpeta baseball/ con el siguiente código:

import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');

const socket =
    io('http://localhost:8001',
       {reconnectionDelay: 300, reconnectionDelayMax: 300});

const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball

predictButton.onclick = () => {
  predictButton.disabled = true;
  socket.emit('predictSample', testSample);
};

// functions to handle socket events
socket.on('connect', () => {
    document.getElementById('waiting-msg').style.display = 'none';
    document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});

socket.on('trainingComplete', () => {
  document.getElementById('trainingStatus').innerHTML = 'Training Complete';
  document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
  predictContainer.style.display = 'block';
});

socket.on('predictResult', (result) => {
  plotPredictResult(result);
});

socket.on('disconnect', () => {
  document.getElementById('trainingStatus').innerHTML = '';
  predictContainer.style.display = 'none';
  document.getElementById('waiting-msg').style.display = 'block';
});

function plotPredictResult(result) {
  predictButton.disabled = false;
  document.getElementById('predictResult').innerHTML = result;
  console.log(result);
}

El cliente controla el mensaje de socket trainingComplete para mostrar un botón de predicción. Cuando se hace clic en este botón, el cliente envía un mensaje de socket con datos de muestra del sensor. Cuando recibe un mensaje predictResult, muestra la predicción en la página.

8. Ejecuta la app

Ejecuta el servidor y el cliente para ver la app completa en acción:

[In one terminal, run this first]
$ npm run start-client

[In another terminal, run this next]
$ npm run start-server

Abre la página del cliente en tu navegador ( http://localhost:8080). Cuando finalice el entrenamiento del modelo, haz clic en el botón Predict Sample. Deberías ver un resultado de predicción que se muestra en el navegador. Puedes modificar los datos de muestra del sensor con algunos ejemplos del archivo CSV de prueba y ver con qué precisión realiza las predicciones el modelo.

9. Qué aprendiste

En este codelab, implementaste una aplicación web de aprendizaje automático simple con TensorFlow.js. Entrenaste un modelo personalizado para clasificar los tipos de lanzamientos de béisbol a partir de datos de sensores. Escribiste código de Node.js para ejecutar el entrenamiento en el servidor y llamar a la inferencia en el modelo entrenado con los datos enviados desde el cliente.

Visita tensorflow.org/js para obtener más ejemplos y demostraciones con código y ver cómo puedes usar TensorFlow.js en tus aplicaciones.