Capacitación sobre TensorFlow.js en el codelab de Node.js

1. Introducción

En este codelab, aprenderás a compilar un servidor web de Node.js para entrenar y clasificar tipos de lanzamiento de béisbol en el servidor con TensorFlow.js, una biblioteca de aprendizaje automático potente y flexible para JavaScript. Compilarás una aplicación web para entrenar un modelo que prediga el tipo de tono a partir de datos de sensores de tono y, también, invocarás la predicción de un cliente web. Hay una versión completamente funcional de este codelab en el repositorio tfjs-examples de GitHub.

Qué aprenderás

  • Cómo instalar y configurar el paquete npm de tensorflow.js para usarlo con Node.js
  • Cómo acceder a los datos de entrenamiento y prueba en el entorno de Node.js
  • Cómo entrenar un modelo con TensorFlow.js en un servidor Node.js
  • Cómo implementar el modelo entrenado para la inferencia en una aplicación cliente-servidor

¡Empecemos!

2. Requisitos

Para completar este codelab, necesitarás lo siguiente:

  1. Una versión reciente de Chrome o de otro navegador actualizado
  2. Un editor de texto y una terminal de comandos que se ejecuten de forma local en tu máquina
  3. Conocimiento de HTML, CSS, JavaScript y las Herramientas para desarrolladores de Chrome (o las de tu navegador preferido)
  4. Una comprensión conceptual de alto nivel de las redes neuronales. Si necesitas una introducción o un repaso, mira este video de 3Blue1Brown o este video sobre aprendizaje profundo en JavaScript de Ashi Krishnan.

3. Configura una app de Node.js

Instala Node.js y npm. Para conocer las plataformas y dependencias compatibles, consulta la guía de instalación de tfjs-node.

Crea un directorio llamado ./baseball para nuestra app de Node.js. Copia los archivos vinculados package.json y webpack.config.js en este directorio para configurar las dependencias del paquete npm (incluido el paquete npm @tensorflow/tfjs-node). Luego, ejecuta npm install para instalar las dependencias.

$ cd baseball
$ ls
package.json  webpack.config.js
$ npm install
...
$ ls
node_modules  package.json  package-lock.json  webpack.config.js

Ya está todo listo para que escribas código y entrenes un modelo.

4. Configura los datos de entrenamiento y prueba

Usarás los datos de entrenamiento y prueba como archivos CSV de los vínculos a continuación. Descarga y explora los datos de estos archivos:

pitch_type_training_data.csv

pitch_type_test_data.csv

Veamos algunos datos de entrenamiento de muestra:

vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0

Existen ocho funciones de entrada que describen los datos del sensor de tono:

  • velocidad de la pelota (vx0, vy0, vz0)
  • aceleración de la bola (ax, ay, az)
  • velocidad inicial del tono
  • independientemente de si el lanzador es zurdo o no

y una etiqueta de salida:

  • Pitch_code, que representa uno de los siete tipos de sugerencias disponibles: Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball

El objetivo es crear un modelo que pueda predecir el tipo de tono a partir de los datos del sensor de tono.

Antes de crear el modelo, debes preparar los datos de entrenamiento y prueba. Crea el archivo pitch_type.js en el archivo béisbol/ dir y copia en él el siguiente código. Este código carga datos de entrenamiento y prueba con la API de tf.data.csv. También normaliza los datos (lo que siempre se recomienda) con una escala de normalización de mínimos y máximos.

const tf = require('@tensorflow/tfjs');

// util function to normalize a value between a given range.
function normalize(value, min, max) {
  if (min === undefined || max === undefined) {
    return value;
  }
  return (value - min) / (max - min);
}

// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH =    'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';

// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;

const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;

// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
    ({xs, ys}) => {
      const values = [
        normalize(xs.vx0, VX0_MIN, VX0_MAX),
        normalize(xs.vy0, VY0_MIN, VY0_MAX),
        normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
        normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
        normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
        xs.left_handed_pitcher
      ];
      return {xs: values, ys: ys.pitch_code};
    }

const trainingData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .shuffle(TRAINING_DATA_LENGTH)
        .batch(100);

// Load all training data in one batch to use for evaluation
const trainingValidationData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TRAINING_DATA_LENGTH);

// Load all test data in one batch to use for evaluation
const testValidationData =
    tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TEST_DATA_LENGTH);

5. Crear un modelo para clasificar tipos de presentaciones

Ya está todo listo para crear el modelo. Usa la API de tf.layers para conectar las entradas (forma de [8] valores del sensor de tono) a 3 capas ocultas completamente conectadas que constan de unidades de activación ReLU, seguidas de una capa de salida softmax que consta de 7 unidades, cada una de las cuales representa uno de los tipos de tono de salida.

Entrena el modelo con el optimizador adam y la función de pérdida sparseCategoricalCrossentropy. Para obtener más información sobre estas opciones, consulta la guía de entrenamiento de modelos.

Agrega el siguiente código al final de pitch_type.js:

const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));

model.compile({
  optimizer: tf.train.adam(),
  loss: 'sparseCategoricalCrossentropy',
  metrics: ['accuracy']
});

Activa el entrenamiento desde el código del servidor principal que escribirás más tarde.

Para completar el módulo pitch_type.js, escribamos una función que evalúe el conjunto de datos de validación y prueba, prediga el tipo de presentación para una sola muestra y calcule las métricas de precisión. Agrega este código al final de pitch_type.js:

// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
  let results = {};
  await trainingValidationData.forEachAsync(pitchTypeBatch => {
    const values = model.predict(pitchTypeBatch.xs).dataSync();
    const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
    for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
      results[pitchFromClassNum(i)] = {
        training: calcPitchClassEval(i, classSize, values)
      };
    }
  });

  if (useTestData) {
    await testValidationData.forEachAsync(pitchTypeBatch => {
      const values = model.predict(pitchTypeBatch.xs).dataSync();
      const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
      for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
        results[pitchFromClassNum(i)].validation =
            calcPitchClassEval(i, classSize, values);
      }
    });
  }
  return results;
}

async function predictSample(sample) {
  let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
  var maxValue = 0;
  var predictedPitch = 7;
  for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
    if (result[0][i] > maxValue) {
      predictedPitch = i;
      maxValue = result[0][i];
    }
  }
  return pitchFromClassNum(predictedPitch);
}

// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
  // Output has 7 different class values for each pitch, offset based on
  // which pitch class (ordered by i)
  let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
  let total = 0;
  for (let i = 0; i < classSize; i++) {
    total += values[index];
    index += NUM_PITCH_CLASSES;
  }
  return total / classSize;
}

// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
  switch (classNum) {
    case 0:
      return 'Fastball (2-seam)';
    case 1:
      return 'Fastball (4-seam)';
    case 2:
      return 'Fastball (sinker)';
    case 3:
      return 'Fastball (cutter)';
    case 4:
      return 'Slider';
    case 5:
      return 'Changeup';
    case 6:
      return 'Curveball';
    default:
      return 'Unknown';
  }
}

module.exports = {
  evaluate,
  model,
  pitchFromClassNum,
  predictSample,
  testValidationData,
  trainingData,
  TEST_DATA_LENGTH
}

6. Entrena el modelo en el servidor

Escribe el código del servidor para realizar el entrenamiento y la evaluación de modelos en un archivo nuevo llamado server.js. Primero, crea un servidor HTTP y abre una conexión de socket bidireccional con la API de socket.io. Luego, ejecuta el entrenamiento de modelos con la API model.fitDataset y evalúa la exactitud del modelo con el método pitch_type.evaluate() que escribiste antes. Entrena y evalúa 10 iteraciones e imprimir métricas en la consola.

Copia el siguiente código en server.js:

require('@tensorflow/tfjs-node');

const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');

const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;

// util function to sleep for a given ms
function sleep(ms) {
  return new Promise(resolve => setTimeout(resolve, ms));
}

// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
  const port = process.env.PORT || PORT;
  const server = http.createServer();
  const io = socketio(server);

  server.listen(port, () => {
    console.log(`  > Running socket on port: ${port}`);
  });

  io.on('connection', (socket) => {
    socket.on('predictSample', async (sample) => {
      io.emit('predictResult', await pitch_type.predictSample(sample));
    });
  });

  let numTrainingIterations = 10;
  for (var i = 0; i < numTrainingIterations; i++) {
    console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
    await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
    console.log('accuracyPerClass', await pitch_type.evaluate(true));
    await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
  }

  io.emit('trainingComplete', true);
}

run();

En este punto, ya tienes todo listo para ejecutar y probar el servidor. Deberías ver algo como esto, con el entrenamiento del servidor un ciclo de entrenamiento en cada iteración (también puedes usar la API de model.fitDataset para entrenar varios ciclos de entrenamiento con una llamada). Si encuentras algún error en este punto, revisa la instalación de tu nodo y npm.

$ npm run start-server
...
  > Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49

Presiona Ctrl-C para detener el servidor en ejecución. Volveremos a ejecutarlo en el siguiente paso.

7. Crea una página de cliente y muestra el código

Ahora que el servidor está listo, el próximo paso es escribir el código del cliente que se ejecutará en el navegador. Crea una página simple para invocar la predicción de modelos en el servidor y mostrar el resultado. Utiliza socket.io para la comunicación entre el cliente y el servidor.

Primero, crea index.html en la carpeta béisbol/:

<!doctype html>
<html>
  <head>
    <title>Pitch Training Accuracy</title>
  </head>
  <body>
    <h3 id="waiting-msg">Waiting for server...</h3>
    <p>
    <span style="font-size:16px" id="trainingStatus"></span>
    <p>
    <div id="predictContainer" style="font-size:16px;display:none">
      Sensor data: <span id="predictSample"></span>
      <button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
      Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
    </div>
    <script src="dist/bundle.js"></script>
    <style>
      html,
      body {
        font-family: Roboto, sans-serif;
        color: #5f6368;
      }
      body {
        background-color: rgb(248, 249, 250);
      }
    </style>
  </body>
</html>

Luego, crea un nuevo archivo client.js en la carpeta béisbol/ con el siguiente código:

import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');

const socket =
    io('http://localhost:8001',
       {reconnectionDelay: 300, reconnectionDelayMax: 300});

const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball

predictButton.onclick = () => {
  predictButton.disabled = true;
  socket.emit('predictSample', testSample);
};

// functions to handle socket events
socket.on('connect', () => {
    document.getElementById('waiting-msg').style.display = 'none';
    document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});

socket.on('trainingComplete', () => {
  document.getElementById('trainingStatus').innerHTML = 'Training Complete';
  document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
  predictContainer.style.display = 'block';
});

socket.on('predictResult', (result) => {
  plotPredictResult(result);
});

socket.on('disconnect', () => {
  document.getElementById('trainingStatus').innerHTML = '';
  predictContainer.style.display = 'none';
  document.getElementById('waiting-msg').style.display = 'block';
});

function plotPredictResult(result) {
  predictButton.disabled = false;
  document.getElementById('predictResult').innerHTML = result;
  console.log(result);
}

El cliente controla el mensaje del socket trainingComplete para mostrar un botón de predicción. Cuando se hace clic en este botón, el cliente envía un mensaje de socket con datos del sensor de muestra. Cuando se recibe un mensaje predictResult, se muestra la predicción en la página.

8. Ejecuta la app

Ejecuta el servidor y el cliente para ver la app completa en acción:

[In one terminal, run this first]
$ npm run start-client

[In another terminal, run this next]
$ npm run start-server

Abre la página del cliente en tu navegador ( http://localhost:8080). Cuando finalice el entrenamiento del modelo, haz clic en el botón Predecir muestra. Deberías ver un resultado de predicción en el navegador. Puedes modificar los datos del sensor de muestra con algunos ejemplos del archivo CSV de prueba y ver con qué precisión el modelo predice.

9. Qué aprendiste

En este codelab, implementaste una aplicación web de aprendizaje automático simple con TensorFlow.js. Entrenaste un modelo personalizado para clasificar tipos de lanzamiento de béisbol a partir de datos de sensores. Escribiste código de Node.js para ejecutar el entrenamiento en el servidor y llamar a la inferencia en el modelo entrenado con los datos enviados desde el cliente.

Asegúrate de visitar tensorflow.org/js para obtener más ejemplos y demostraciones con código para ver cómo puedes usar TensorFlow.js en tus aplicaciones.