1. Introduction
Dans cet atelier de programmation, vous allez apprendre à créer un serveur Web Node.js pour entraîner et classifier des types de lancers de baseball côté serveur à l'aide de TensorFlow.js, une bibliothèque de machine learning puissante et flexible pour JavaScript. Vous allez créer une application Web pour entraîner un modèle à prédire le type de lancer à partir des données du capteur de lancer et à appeler la prédiction à partir d'un client Web. Une version entièrement fonctionnelle de cet atelier de programmation est disponible dans le dépôt GitHub tfjs-examples.
Points abordés
- Découvrez comment installer et configurer le package npm tensorflow.js pour l'utiliser avec Node.js.
- Comment accéder aux données d'entraînement et de test dans l'environnement Node.js.
- Comment entraîner un modèle avec TensorFlow.js dans un serveur Node.js.
- Découvrez comment déployer le modèle entraîné pour l'inférence dans une application client/serveur.
C'est parti !
2. Conditions requises
Pour suivre cet atelier de programmation, vous aurez besoin :
- Une version récente de Chrome ou d'un autre navigateur moderne.
- Un éditeur de texte et un terminal de commande s'exécutant localement sur votre ordinateur.
- Bonne connaissance des langages HTML, CSS et JavaScript ainsi que des Outils pour les développeurs Chrome (ou des outils de développement de votre navigateur préféré)
- Bonne compréhension du concept de réseau de neurones Si vous avez besoin d'une présentation ou d'un rappel, vous pouvez regarder cette vidéo de 3blue1brown ou cette vidéo sur le deep learning en JavaScript d'Ashi Krishnan.
3. Configurer une application Node.js
Installez Node.js et npm. Pour connaître les plates-formes et les dépendances compatibles, veuillez consulter le guide d'installation de tfjs-node.
Créez un répertoire appelé ./baseball pour notre application Node.js. Copiez les fichiers package.json et webpack.config.js liés dans ce répertoire pour configurer les dépendances du package npm (y compris le package npm @tensorflow/tfjs-node). Exécutez ensuite npm install pour installer les dépendances.
$ cd baseball
$ ls
package.json webpack.config.js
$ npm install
...
$ ls
node_modules package.json package-lock.json webpack.config.js
Vous êtes maintenant prêt à écrire du code et à entraîner un modèle.
4. Configurer les données d'entraînement et de test
Vous utiliserez les données d'entraînement et de test sous forme de fichiers CSV à partir des liens ci-dessous. Téléchargez et explorez les données de ces fichiers :
Examinons quelques exemples de données d'entraînement :
vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0
Huit caractéristiques d'entrée décrivent les données du capteur de tangage :
- vitesse de la balle (vx0, vy0, vz0)
- Accélération de la balle (ax, ay, az)
- vitesse initiale du lancer
- si le lanceur est gaucher ou non.
et un libellé de sortie :
- pitch_code qui désigne l'un des sept types de lancers :
Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball
L'objectif est de créer un modèle capable de prédire le type de lancer à partir des données du capteur de lancer.
Avant de créer le modèle, vous devez préparer les données d'entraînement et de test. Créez le fichier pitch_type.js dans le répertoire baseball/ et copiez-y le code suivant. Ce code charge les données d'entraînement et de test à l'aide de l'API tf.data.csv. Il normalise également les données (ce qui est toujours recommandé) à l'aide d'une échelle de normalisation min-max.
const tf = require('@tensorflow/tfjs');
// util function to normalize a value between a given range.
function normalize(value, min, max) {
if (min === undefined || max === undefined) {
return value;
}
return (value - min) / (max - min);
}
// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH = 'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';
// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;
const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;
// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
({xs, ys}) => {
const values = [
normalize(xs.vx0, VX0_MIN, VX0_MAX),
normalize(xs.vy0, VY0_MIN, VY0_MAX),
normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
xs.left_handed_pitcher
];
return {xs: values, ys: ys.pitch_code};
}
const trainingData =
tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
.map(csvTransform)
.shuffle(TRAINING_DATA_LENGTH)
.batch(100);
// Load all training data in one batch to use for evaluation
const trainingValidationData =
tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
.map(csvTransform)
.batch(TRAINING_DATA_LENGTH);
// Load all test data in one batch to use for evaluation
const testValidationData =
tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
.map(csvTransform)
.batch(TEST_DATA_LENGTH);
5. Créer un modèle pour classer les types de lancers
Vous êtes maintenant prêt à créer le modèle. Utilisez l'API tf.layers pour connecter les entrées (forme des valeurs du capteur de hauteur [8]) à trois couches cachées entièrement connectées composées d'unités d'activation ReLU, suivies d'une couche de sortie softmax composée de sept unités, chacune représentant l'un des types de hauteur de sortie.
Entraînez le modèle avec l'optimiseur Adam et la fonction de perte sparseCategoricalCrossentropy. Pour en savoir plus sur ces choix, consultez le guide d'entraînement des modèles.
Ajoutez le code suivant à la fin de pitch_type.js :
const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));
model.compile({
optimizer: tf.train.adam(),
loss: 'sparseCategoricalCrossentropy',
metrics: ['accuracy']
});
Déclenchez l'entraînement à partir du code du serveur principal que vous écrirez plus tard.
Pour terminer le module pitch_type.js, écrivons une fonction permettant d'évaluer l'ensemble de données de validation et de test, de prédire un type de lancer pour un seul échantillon et de calculer les métriques de précision. Ajoutez ce code à la fin de pitch_type.js :
// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
let results = {};
await trainingValidationData.forEachAsync(pitchTypeBatch => {
const values = model.predict(pitchTypeBatch.xs).dataSync();
const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
results[pitchFromClassNum(i)] = {
training: calcPitchClassEval(i, classSize, values)
};
}
});
if (useTestData) {
await testValidationData.forEachAsync(pitchTypeBatch => {
const values = model.predict(pitchTypeBatch.xs).dataSync();
const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
results[pitchFromClassNum(i)].validation =
calcPitchClassEval(i, classSize, values);
}
});
}
return results;
}
async function predictSample(sample) {
let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
var maxValue = 0;
var predictedPitch = 7;
for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
if (result[0][i] > maxValue) {
predictedPitch = i;
maxValue = result[0][i];
}
}
return pitchFromClassNum(predictedPitch);
}
// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
// Output has 7 different class values for each pitch, offset based on
// which pitch class (ordered by i)
let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
let total = 0;
for (let i = 0; i < classSize; i++) {
total += values[index];
index += NUM_PITCH_CLASSES;
}
return total / classSize;
}
// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
switch (classNum) {
case 0:
return 'Fastball (2-seam)';
case 1:
return 'Fastball (4-seam)';
case 2:
return 'Fastball (sinker)';
case 3:
return 'Fastball (cutter)';
case 4:
return 'Slider';
case 5:
return 'Changeup';
case 6:
return 'Curveball';
default:
return 'Unknown';
}
}
module.exports = {
evaluate,
model,
pitchFromClassNum,
predictSample,
testValidationData,
trainingData,
TEST_DATA_LENGTH
}
6. Entraîner le modèle sur le serveur
Écrivez le code du serveur pour effectuer l'entraînement et l'évaluation du modèle dans un nouveau fichier appelé server.js. Commencez par créer un serveur HTTP et ouvrez une connexion de socket bidirectionnelle à l'aide de l'API socket.io. Exécutez ensuite l'entraînement du modèle à l'aide de l'API model.fitDataset et évaluez la précision du modèle à l'aide de la méthode pitch_type.evaluate() que vous avez écrite précédemment. Entraînez et évaluez le modèle pendant 10 itérations, en imprimant les métriques dans la console.
Copiez le code ci-dessous dans server.js :
require('@tensorflow/tfjs-node');
const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');
const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;
// util function to sleep for a given ms
function sleep(ms) {
return new Promise(resolve => setTimeout(resolve, ms));
}
// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
const port = process.env.PORT || PORT;
const server = http.createServer();
const io = socketio(server);
server.listen(port, () => {
console.log(` > Running socket on port: ${port}`);
});
io.on('connection', (socket) => {
socket.on('predictSample', async (sample) => {
io.emit('predictResult', await pitch_type.predictSample(sample));
});
});
let numTrainingIterations = 10;
for (var i = 0; i < numTrainingIterations; i++) {
console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
console.log('accuracyPerClass', await pitch_type.evaluate(true));
await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
}
io.emit('trainingComplete', true);
}
run();
À ce stade, vous êtes prêt à exécuter et à tester le serveur. Vous devriez obtenir un résultat semblable à celui-ci, avec le serveur entraînant une époque à chaque itération (vous pouvez également utiliser l'API model.fitDataset pour entraîner plusieurs époques en un seul appel). Si vous rencontrez des erreurs à ce stade, veuillez vérifier l'installation de Node.js et de npm.
$ npm run start-server
...
> Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49
Appuyez sur Ctrl+C pour arrêter le serveur en cours d'exécution. Nous l'exécuterons à nouveau à l'étape suivante.
7. Créer une page client et afficher le code
Maintenant que le serveur est prêt, l'étape suivante consiste à écrire le code client qui s'exécute dans le navigateur. Créez une page simple pour appeler la prédiction du modèle sur le serveur et afficher le résultat. Il utilise socket.io pour la communication client/serveur.
Tout d'abord, créez index.html dans le dossier baseball/ :
<!doctype html>
<html>
<head>
<title>Pitch Training Accuracy</title>
</head>
<body>
<h3 id="waiting-msg">Waiting for server...</h3>
<p>
<span style="font-size:16px" id="trainingStatus"></span>
<p>
<div id="predictContainer" style="font-size:16px;display:none">
Sensor data: <span id="predictSample"></span>
<button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
</div>
<script src="dist/bundle.js"></script>
<style>
html,
body {
font-family: Roboto, sans-serif;
color: #5f6368;
}
body {
background-color: rgb(248, 249, 250);
}
</style>
</body>
</html>
Créez ensuite un fichier client.js dans le dossier baseball/ avec le code ci-dessous :
import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');
const socket =
io('http://localhost:8001',
{reconnectionDelay: 300, reconnectionDelayMax: 300});
const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball
predictButton.onclick = () => {
predictButton.disabled = true;
socket.emit('predictSample', testSample);
};
// functions to handle socket events
socket.on('connect', () => {
document.getElementById('waiting-msg').style.display = 'none';
document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});
socket.on('trainingComplete', () => {
document.getElementById('trainingStatus').innerHTML = 'Training Complete';
document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
predictContainer.style.display = 'block';
});
socket.on('predictResult', (result) => {
plotPredictResult(result);
});
socket.on('disconnect', () => {
document.getElementById('trainingStatus').innerHTML = '';
predictContainer.style.display = 'none';
document.getElementById('waiting-msg').style.display = 'block';
});
function plotPredictResult(result) {
predictButton.disabled = false;
document.getElementById('predictResult').innerHTML = result;
console.log(result);
}
Le client gère le message de socket trainingComplete pour afficher un bouton de prédiction. Lorsque l'utilisateur clique sur ce bouton, le client envoie un message de socket avec des exemples de données de capteur. Lorsqu'il reçoit un message predictResult, il affiche la prédiction sur la page.
8. Exécuter l'application
Exécutez le serveur et le client pour voir l'application complète en action :
[In one terminal, run this first]
$ npm run start-client
[In another terminal, run this next]
$ npm run start-server
Ouvrez la page client dans votre navigateur ( http://localhost:8080). Une fois l'entraînement du modèle terminé, cliquez sur le bouton Prédire un échantillon. Un résultat de prédiction devrait s'afficher dans le navigateur. N'hésitez pas à modifier les exemples de données de capteur avec des exemples tirés du fichier CSV de test pour voir la précision des prédictions du modèle.
9. Ce que vous avez appris
Dans cet atelier de programmation, vous avez implémenté une application Web simple de machine learning à l'aide de TensorFlow.js. Vous avez entraîné un modèle personnalisé pour classer les types de lancers de baseball à partir de données de capteurs. Vous avez écrit du code Node.js pour exécuter l'entraînement sur le serveur et appeler l'inférence sur le modèle entraîné à l'aide des données envoyées par le client.
Rendez-vous sur tensorflow.org/js pour voir d'autres exemples et démonstrations de codage, et apprendre à utiliser TensorFlow.js dans vos applications.