Treinamento do TensorFlow.js no codelab Node.js

1. Introdução

Neste codelab, você vai aprender a criar um servidor da Web Node.js para treinar e classificar tipos de arremessos de beisebol do lado do servidor usando o TensorFlow.js, uma biblioteca de machine learning avançada e flexível para JavaScript. Você vai criar um aplicativo da Web para treinar um modelo que prevê o tipo de arremesso com base nos dados do sensor de arremesso e invocar a previsão de um cliente da Web. Uma versão totalmente funcional deste codelab está presente no repositório do GitHub tfjs-examples (link em inglês).

O que você vai aprender

  • Como instalar e configurar o pacote npm do tensorflow.js para uso com o Node.js.
  • Como acessar dados de treinamento e teste no ambiente Node.js.
  • Como treinar um modelo com o TensorFlow.js em um servidor Node.js.
  • Como implantar o modelo treinado para inferência em um aplicativo cliente/servidor.

Então, vamos começar!

2. Requisitos

Para concluir este codelab, você vai precisar de:

  1. Uma versão moderna do Chrome ou de outro navegador mais recente
  2. Um editor de texto e um terminal de comando executados localmente na sua máquina.
  3. Conhecimentos sobre HTML, CSS, JavaScript e Chrome DevTools (ou as DevTools do seu navegador preferido)
  4. Conhecimento conceitual de alto nível sobre redes neurais. Se você precisar de uma introdução ou revisão, assista a este vídeo da 3blue1brown ou este vídeo sobre aprendizado profundo em JavaScript de Ashi Krishnan (links em inglês)

3. Configurar um app Node.js

Instale o Node.js e o npm. Para saber mais sobre plataformas e dependências compatíveis, consulte o guia de instalação do tfjs-node.

Crie um diretório chamado ./baseball para nosso app Node.js. Copie os arquivos package.json e webpack.config.js vinculados para esse diretório e configure as dependências do pacote npm, incluindo o pacote @tensorflow/tfjs-node. Em seguida, execute npm install para instalar as dependências.

$ cd baseball
$ ls
package.json  webpack.config.js
$ npm install
...
$ ls
node_modules  package.json  package-lock.json  webpack.config.js

Agora você já pode escrever um código e treinar um modelo.

4. Configurar os dados de treinamento e teste

Você vai usar os dados de treinamento e teste como arquivos CSV dos links abaixo. Faça o download e analise os dados nestes arquivos:

pitch_type_training_data.csv

pitch_type_test_data.csv

Confira alguns exemplos de dados de treinamento:

vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0

Há oito atributos de entrada que descrevem os dados do sensor de inclinação:

  • velocidade da bola (vx0, vy0, vz0)
  • aceleração da bola (ax, ay, az)
  • velocidade inicial da apresentação
  • se o arremessador é canhoto ou não

e um rótulo de saída:

  • pitch_code que significa um dos sete tipos de tom: Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball

O objetivo é criar um modelo capaz de prever o tipo de arremesso com base nos dados do sensor.

Antes de criar o modelo, é necessário preparar os dados de treinamento e teste. Crie o arquivo pitch_type.js no diretório baseball/ e copie o código a seguir nele. Esse código carrega dados de treinamento e teste usando a API tf.data.csv. Ele também normaliza os dados (o que é sempre recomendado) usando uma escala de normalização mín-máx.

const tf = require('@tensorflow/tfjs');

// util function to normalize a value between a given range.
function normalize(value, min, max) {
  if (min === undefined || max === undefined) {
    return value;
  }
  return (value - min) / (max - min);
}

// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH =    'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';

// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;

const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;

// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
    ({xs, ys}) => {
      const values = [
        normalize(xs.vx0, VX0_MIN, VX0_MAX),
        normalize(xs.vy0, VY0_MIN, VY0_MAX),
        normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
        normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
        normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
        xs.left_handed_pitcher
      ];
      return {xs: values, ys: ys.pitch_code};
    }

const trainingData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .shuffle(TRAINING_DATA_LENGTH)
        .batch(100);

// Load all training data in one batch to use for evaluation
const trainingValidationData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TRAINING_DATA_LENGTH);

// Load all test data in one batch to use for evaluation
const testValidationData =
    tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TEST_DATA_LENGTH);

5. Criar um modelo para classificar tipos de arremesso

Agora você está pronto para criar o modelo. Use a API tf.layers para conectar as entradas (formato de valores do sensor de afinação [8]) a três camadas ocultas totalmente conectadas que consistem em unidades de ativação ReLU, seguidas por uma camada de saída softmax com sete unidades, cada uma representando um dos tipos de afinação de saída.

Treine o modelo com o otimizador Adam e a função de perda sparseCategoricalCrossentropy. Para mais informações sobre essas opções, consulte o guia de treinamento de modelos.

Adicione o seguinte código ao final de pitch_type.js:

const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));

model.compile({
  optimizer: tf.train.adam(),
  loss: 'sparseCategoricalCrossentropy',
  metrics: ['accuracy']
});

Acione o treinamento com o código do servidor principal que você vai escrever mais tarde.

Para concluir o módulo pitch_type.js, vamos escrever uma função para avaliar o conjunto de dados de validação e teste, prever um tipo de arremesso para uma única amostra e calcular métricas de acurácia. Adicione este código ao final de pitch_type.js:

// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
  let results = {};
  await trainingValidationData.forEachAsync(pitchTypeBatch => {
    const values = model.predict(pitchTypeBatch.xs).dataSync();
    const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
    for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
      results[pitchFromClassNum(i)] = {
        training: calcPitchClassEval(i, classSize, values)
      };
    }
  });

  if (useTestData) {
    await testValidationData.forEachAsync(pitchTypeBatch => {
      const values = model.predict(pitchTypeBatch.xs).dataSync();
      const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
      for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
        results[pitchFromClassNum(i)].validation =
            calcPitchClassEval(i, classSize, values);
      }
    });
  }
  return results;
}

async function predictSample(sample) {
  let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
  var maxValue = 0;
  var predictedPitch = 7;
  for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
    if (result[0][i] > maxValue) {
      predictedPitch = i;
      maxValue = result[0][i];
    }
  }
  return pitchFromClassNum(predictedPitch);
}

// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
  // Output has 7 different class values for each pitch, offset based on
  // which pitch class (ordered by i)
  let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
  let total = 0;
  for (let i = 0; i < classSize; i++) {
    total += values[index];
    index += NUM_PITCH_CLASSES;
  }
  return total / classSize;
}

// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
  switch (classNum) {
    case 0:
      return 'Fastball (2-seam)';
    case 1:
      return 'Fastball (4-seam)';
    case 2:
      return 'Fastball (sinker)';
    case 3:
      return 'Fastball (cutter)';
    case 4:
      return 'Slider';
    case 5:
      return 'Changeup';
    case 6:
      return 'Curveball';
    default:
      return 'Unknown';
  }
}

module.exports = {
  evaluate,
  model,
  pitchFromClassNum,
  predictSample,
  testValidationData,
  trainingData,
  TEST_DATA_LENGTH
}

6. Treinar o modelo no servidor

Escreva o código do servidor para realizar o treinamento de modelo e a avaliação em um novo arquivo chamado server.js. Primeiro, crie um servidor HTTP e abra uma conexão de soquete bidirecional usando a API socket.io. Em seguida, execute o treinamento de modelo usando a API model.fitDataset e avalie a acurácia do modelo usando o método pitch_type.evaluate() que você escreveu anteriormente. Treine e avalie por 10 iterações, imprimindo métricas no console.

Copie o código abaixo para server.js:

require('@tensorflow/tfjs-node');

const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');

const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;

// util function to sleep for a given ms
function sleep(ms) {
  return new Promise(resolve => setTimeout(resolve, ms));
}

// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
  const port = process.env.PORT || PORT;
  const server = http.createServer();
  const io = socketio(server);

  server.listen(port, () => {
    console.log(`  > Running socket on port: ${port}`);
  });

  io.on('connection', (socket) => {
    socket.on('predictSample', async (sample) => {
      io.emit('predictResult', await pitch_type.predictSample(sample));
    });
  });

  let numTrainingIterations = 10;
  for (var i = 0; i < numTrainingIterations; i++) {
    console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
    await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
    console.log('accuracyPerClass', await pitch_type.evaluate(true));
    await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
  }

  io.emit('trainingComplete', true);
}

run();

Neste ponto, você já pode executar e testar o servidor. Você vai ver algo assim, com o servidor treinando uma época em cada iteração. Também é possível usar a API model.fitDataset para treinar várias épocas com uma chamada. Se você encontrar erros neste ponto, verifique a instalação do nó e do npm.

$ npm run start-server
...
  > Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49

Digite Ctrl-C para interromper o servidor em execução. Vamos executar de novo na próxima etapa.

7. Criar página do cliente e mostrar código

Agora que o servidor está pronto, a próxima etapa é escrever o código do cliente que será executado no navegador. Crie uma página simples para invocar a previsão do modelo no servidor e mostrar o resultado. Isso usa socket.io para comunicação cliente/servidor.

Primeiro, crie index.html na pasta baseball/:

<!doctype html>
<html>
  <head>
    <title>Pitch Training Accuracy</title>
  </head>
  <body>
    <h3 id="waiting-msg">Waiting for server...</h3>
    <p>
    <span style="font-size:16px" id="trainingStatus"></span>
    <p>
    <div id="predictContainer" style="font-size:16px;display:none">
      Sensor data: <span id="predictSample"></span>
      <button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
      Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
    </div>
    <script src="dist/bundle.js"></script>
    <style>
      html,
      body {
        font-family: Roboto, sans-serif;
        color: #5f6368;
      }
      body {
        background-color: rgb(248, 249, 250);
      }
    </style>
  </body>
</html>

Em seguida, crie um arquivo client.js na pasta baseball/ com o código abaixo:

import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');

const socket =
    io('http://localhost:8001',
       {reconnectionDelay: 300, reconnectionDelayMax: 300});

const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball

predictButton.onclick = () => {
  predictButton.disabled = true;
  socket.emit('predictSample', testSample);
};

// functions to handle socket events
socket.on('connect', () => {
    document.getElementById('waiting-msg').style.display = 'none';
    document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});

socket.on('trainingComplete', () => {
  document.getElementById('trainingStatus').innerHTML = 'Training Complete';
  document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
  predictContainer.style.display = 'block';
});

socket.on('predictResult', (result) => {
  plotPredictResult(result);
});

socket.on('disconnect', () => {
  document.getElementById('trainingStatus').innerHTML = '';
  predictContainer.style.display = 'none';
  document.getElementById('waiting-msg').style.display = 'block';
});

function plotPredictResult(result) {
  predictButton.disabled = false;
  document.getElementById('predictResult').innerHTML = result;
  console.log(result);
}

O cliente processa a mensagem de soquete trainingComplete para mostrar um botão de previsão. Quando esse botão é clicado, o cliente envia uma mensagem de soquete com dados do sensor de amostra. Ao receber uma mensagem predictResult, ele mostra a previsão na página.

8. Executar o app

Execute o servidor e o cliente para ver o app completo em ação:

[In one terminal, run this first]
$ npm run start-client

[In another terminal, run this next]
$ npm run start-server

Abra a página do cliente no navegador ( http://localhost:8080). Quando o treinamento de modelo terminar, clique no botão Prever amostra. Um resultado de previsão vai aparecer no navegador. Modifique os dados do sensor de amostra com alguns exemplos do arquivo CSV de teste e veja a precisão da previsão do modelo.

9. O que você aprendeu

Neste codelab, você implementou um app da Web de machine learning simples usando o TensorFlow.js. Você treinou um modelo personalizado para classificar tipos de arremessos de beisebol com base em dados do sensor. Você escreveu código Node.js para executar o treinamento no servidor e chamar a inferência no modelo treinado usando dados enviados do cliente.

Acesse tensorflow.org/js para ver mais exemplos e demonstrações com código para saber como você pode usar o TensorFlow.js nos seus aplicativos.