Node.js Codelab'de TensorFlow.js Eğitimi

1. Giriş

Bu Codelab'de, JavaScript için güçlü ve esnek bir makine öğrenimi kitaplığı olan TensorFlow.js'yi kullanarak sunucu tarafında beyzbol atış türlerini eğitmek ve sınıflandırmak için bir Node.js web sunucusu oluşturmayı öğreneceksiniz. Bir modeli, atış sensörü verilerinden atış türünü tahmin edecek şekilde eğitmek ve web istemcisinden tahmin çağırmak için bir web uygulaması oluşturacaksınız. Bu Codelab'in tam olarak çalışan bir sürümü tfjs-examples GitHub deposunda mevcuttur.

Neler öğreneceksiniz?

  • Node.js ile kullanılmak üzere tensorflow.js npm paketini yükleme ve ayarlama.
  • Node.js ortamında eğitim ve test verilerine erişme
  • Node.js sunucusunda TensorFlow.js ile modeli eğitme
  • Eğitilmiş modeli, istemci/sunucu uygulamasında çıkarım için dağıtma

Haydi başlayalım!

2. Şartlar

Bu Codelab'i tamamlamak için ihtiyacınız olanlar:

  1. Chrome'un veya başka bir modern tarayıcının son sürümü
  2. Makinenizde yerel olarak çalışan bir metin düzenleyici ve komut terminali.
  3. HTML, CSS, JavaScript ve Chrome Geliştirici Araçları (veya tercih ettiğiniz tarayıcıların geliştirici araçları) hakkında bilgi sahibi olmak
  4. Nöral ağlar hakkında üst düzeyde kavramsal bilgi sahibi olmalısınız. Giriş veya hatırlatma için 3blue1brown'un bu videosunu ya da Ashi Krishnan'ın JavaScript'te Derin Öğrenme videosunu izleyebilirsiniz.

3. Node.js uygulaması oluşturma

Node.js ve npm'yi yükleyin. Desteklenen platformlar ve bağımlılıklar için lütfen tfjs-node yükleme kılavuzuna bakın.

Node.js uygulamamız için ./baseball adlı bir dizin oluşturun. npm paket bağımlılıklarını (@tensorflow/tfjs-node npm paketi dahil) yapılandırmak için bağlı package.json ve webpack.config.js dosyalarını bu dizine kopyalayın. Ardından, bağımlılıkları yüklemek için npm install komutunu çalıştırın.

$ cd baseball
$ ls
package.json  webpack.config.js
$ npm install
...
$ ls
node_modules  package.json  package-lock.json  webpack.config.js

Artık kod yazmaya ve model eğitmeye hazırsınız.

4. Eğitim ve test verilerini ayarlama

Aşağıdaki bağlantılardan eğitim ve test verilerini CSV dosyaları olarak kullanacaksınız. Aşağıdaki dosyalardaki verileri indirip inceleyin:

pitch_type_training_data.csv

pitch_type_test_data.csv

Şimdi bazı örnek eğitim verilerine bakalım:

vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0

Pitch sensörü verilerini açıklayan sekiz giriş özelliği vardır:

  • topun hızı (vx0, vy0, vz0)
  • top ivmesi (ax, ay, az)
  • topun başlangıç hızı
  • atıcı oyuncunun solak olup olmadığı

ve bir çıkış etiketi:

  • Yedi yayın türünden birini ifade eden pitch_code: Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball

Amaç, atış sensörü verileri verildiğinde atış türünü tahmin edebilen bir model oluşturmaktır.

Modeli oluşturmadan önce eğitim ve test verilerini hazırlamanız gerekir. baseball/ dizininde pitch_type.js dosyasını oluşturun ve aşağıdaki kodu bu dosyaya kopyalayın. Bu kod, tf.data.csv API'sini kullanarak eğitim ve test verilerini yükler. Ayrıca, minimum ve maksimum ölçeklendirme kullanarak verileri normalleştirir (bu her zaman önerilir).

const tf = require('@tensorflow/tfjs');

// util function to normalize a value between a given range.
function normalize(value, min, max) {
  if (min === undefined || max === undefined) {
    return value;
  }
  return (value - min) / (max - min);
}

// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH =    'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';

// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;

const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;

// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
    ({xs, ys}) => {
      const values = [
        normalize(xs.vx0, VX0_MIN, VX0_MAX),
        normalize(xs.vy0, VY0_MIN, VY0_MAX),
        normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
        normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
        normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
        xs.left_handed_pitcher
      ];
      return {xs: values, ys: ys.pitch_code};
    }

const trainingData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .shuffle(TRAINING_DATA_LENGTH)
        .batch(100);

// Load all training data in one batch to use for evaluation
const trainingValidationData =
    tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TRAINING_DATA_LENGTH);

// Load all test data in one batch to use for evaluation
const testValidationData =
    tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
        .map(csvTransform)
        .batch(TEST_DATA_LENGTH);

5. Atış türlerini sınıflandırmak için model oluşturma

Artık modeli oluşturmaya hazırsınız. Girişleri (8 şekilli perde sensörü değerleri) ReLU etkinleştirme birimlerinden oluşan 3 gizli tam bağlantılı katmana ve ardından her biri çıkış perde türlerinden birini temsil eden 7 birimden oluşan bir softmax çıkış katmanına bağlamak için tf.layers API'sini kullanın.

Modeli adam optimizer ve sparseCategoricalCrossentropy kayıp fonksiyonuyla eğitin. Bu seçenekler hakkında daha fazla bilgi için eğitim modelleri kılavuzuna bakın.

Aşağıdaki kodu pitch_type.js dosyasının sonuna ekleyin:

const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));

model.compile({
  optimizer: tf.train.adam(),
  loss: 'sparseCategoricalCrossentropy',
  metrics: ['accuracy']
});

Eğitimi, daha sonra yazacağınız ana sunucu kodundan tetikleyin.

pitch_type.js modülünü tamamlamak için doğrulama ve test veri kümesini değerlendiren, tek bir örnek için atış türünü tahmin eden ve doğruluk metriklerini hesaplayan bir işlev yazalım. Bu kodu pitch_type.js dosyasının sonuna ekleyin:

// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
  let results = {};
  await trainingValidationData.forEachAsync(pitchTypeBatch => {
    const values = model.predict(pitchTypeBatch.xs).dataSync();
    const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
    for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
      results[pitchFromClassNum(i)] = {
        training: calcPitchClassEval(i, classSize, values)
      };
    }
  });

  if (useTestData) {
    await testValidationData.forEachAsync(pitchTypeBatch => {
      const values = model.predict(pitchTypeBatch.xs).dataSync();
      const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
      for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
        results[pitchFromClassNum(i)].validation =
            calcPitchClassEval(i, classSize, values);
      }
    });
  }
  return results;
}

async function predictSample(sample) {
  let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
  var maxValue = 0;
  var predictedPitch = 7;
  for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
    if (result[0][i] > maxValue) {
      predictedPitch = i;
      maxValue = result[0][i];
    }
  }
  return pitchFromClassNum(predictedPitch);
}

// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
  // Output has 7 different class values for each pitch, offset based on
  // which pitch class (ordered by i)
  let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
  let total = 0;
  for (let i = 0; i < classSize; i++) {
    total += values[index];
    index += NUM_PITCH_CLASSES;
  }
  return total / classSize;
}

// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
  switch (classNum) {
    case 0:
      return 'Fastball (2-seam)';
    case 1:
      return 'Fastball (4-seam)';
    case 2:
      return 'Fastball (sinker)';
    case 3:
      return 'Fastball (cutter)';
    case 4:
      return 'Slider';
    case 5:
      return 'Changeup';
    case 6:
      return 'Curveball';
    default:
      return 'Unknown';
  }
}

module.exports = {
  evaluate,
  model,
  pitchFromClassNum,
  predictSample,
  testValidationData,
  trainingData,
  TEST_DATA_LENGTH
}

6. Sunucuda modeli eğitme

Model eğitimi ve değerlendirmesi gerçekleştirmek için server.js adlı yeni bir dosyada sunucu kodunu yazın. Öncelikle bir HTTP sunucusu oluşturun ve socket.io API'sini kullanarak çift yönlü bir yuva bağlantısı açın. Ardından, model.fitDataset API'sini kullanarak model eğitimini gerçekleştirin ve daha önce yazdığınız pitch_type.evaluate() yöntemini kullanarak model doğruluğunu değerlendirin. 10 yineleme boyunca eğitip değerlendirin ve metrikleri konsola yazdırın.

Aşağıdaki kodu server.js dosyasına kopyalayın:

require('@tensorflow/tfjs-node');

const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');

const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;

// util function to sleep for a given ms
function sleep(ms) {
  return new Promise(resolve => setTimeout(resolve, ms));
}

// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
  const port = process.env.PORT || PORT;
  const server = http.createServer();
  const io = socketio(server);

  server.listen(port, () => {
    console.log(`  > Running socket on port: ${port}`);
  });

  io.on('connection', (socket) => {
    socket.on('predictSample', async (sample) => {
      io.emit('predictResult', await pitch_type.predictSample(sample));
    });
  });

  let numTrainingIterations = 10;
  for (var i = 0; i < numTrainingIterations; i++) {
    console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
    await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
    console.log('accuracyPerClass', await pitch_type.evaluate(true));
    await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
  }

  io.emit('trainingComplete', true);
}

run();

Bu noktada sunucuyu çalıştırmaya ve test etmeye hazırsınız. Sunucunun her yinelemede bir dönem eğitim verdiği aşağıdaki gibi bir çıktı görmeniz gerekir (Tek bir çağrıyla birden fazla dönem eğitmek için model.fitDataset API'sini de kullanabilirsiniz). Bu noktada herhangi bir hatayla karşılaşırsanız lütfen düğüm ve npm yüklemenizi kontrol edin.

$ npm run start-server
...
  > Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49

Çalışan sunucuyu durdurmak için Ctrl-C yazın. Bu testi bir sonraki adımda tekrar çalıştıracağız.

7. İstemci sayfası oluşturma ve kodu görüntüleme

Sunucu hazır olduğuna göre, sonraki adım istemci kodunu yazmaktır. Bu kod, tarayıcıda çalışır. Sunucuda model tahmini çağırmak ve sonucu görüntülemek için basit bir sayfa oluşturun. Bu, istemci/sunucu iletişimi için socket.io'yu kullanır.

Öncelikle, baseball/ klasöründe index.html dosyasını oluşturun:

<!doctype html>
<html>
  <head>
    <title>Pitch Training Accuracy</title>
  </head>
  <body>
    <h3 id="waiting-msg">Waiting for server...</h3>
    <p>
    <span style="font-size:16px" id="trainingStatus"></span>
    <p>
    <div id="predictContainer" style="font-size:16px;display:none">
      Sensor data: <span id="predictSample"></span>
      <button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
      Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
    </div>
    <script src="dist/bundle.js"></script>
    <style>
      html,
      body {
        font-family: Roboto, sans-serif;
        color: #5f6368;
      }
      body {
        background-color: rgb(248, 249, 250);
      }
    </style>
  </body>
</html>

Ardından, baseball/ klasöründe aşağıdaki kodla yeni bir client.js dosyası oluşturun:

import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');

const socket =
    io('http://localhost:8001',
       {reconnectionDelay: 300, reconnectionDelayMax: 300});

const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball

predictButton.onclick = () => {
  predictButton.disabled = true;
  socket.emit('predictSample', testSample);
};

// functions to handle socket events
socket.on('connect', () => {
    document.getElementById('waiting-msg').style.display = 'none';
    document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});

socket.on('trainingComplete', () => {
  document.getElementById('trainingStatus').innerHTML = 'Training Complete';
  document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
  predictContainer.style.display = 'block';
});

socket.on('predictResult', (result) => {
  plotPredictResult(result);
});

socket.on('disconnect', () => {
  document.getElementById('trainingStatus').innerHTML = '';
  predictContainer.style.display = 'none';
  document.getElementById('waiting-msg').style.display = 'block';
});

function plotPredictResult(result) {
  predictButton.disabled = false;
  document.getElementById('predictResult').innerHTML = result;
  console.log(result);
}

İstemci, bir tahmin düğmesi görüntülemek için trainingComplete soket mesajını işler. Bu düğme tıklandığında istemci, örnek sensör verileri içeren bir soket mesajı gönderir. predictResult mesajı alındığında tahmini sayfada gösterir.

8. Uygulamayı çalıştırma

Uygulamanın tamamını çalışırken görmek için hem sunucuyu hem de istemciyi çalıştırın:

[In one terminal, run this first]
$ npm run start-client

[In another terminal, run this next]
$ npm run start-server

Tarayıcınızda istemci sayfasını açın ( http://localhost:8080). Model eğitimi tamamlandığında Örnek Tahmin Et düğmesini tıklayın. Tarayıcıda bir tahmin sonucu gösterilir. Örnek sensör verilerini test CSV dosyasındaki bazı örneklerle değiştirebilir ve modelin ne kadar doğru tahmin yaptığını görebilirsiniz.

9. Öğrendikleriniz

Bu Codelab'de, TensorFlow.js kullanarak basit bir makine öğrenimi web uygulaması uyguladınız. Beyzbol atış türlerini sensör verilerinden sınıflandırmak için özel bir model eğittiniz. Sunucuda eğitimi yürütmek ve istemciden gönderilen verileri kullanarak eğitilmiş modelde çıkarım çağırmak için Node.js kodu yazdınız.

TensorFlow.js'yi uygulamalarınızda nasıl kullanabileceğinizi görmek için daha fazla örnek ve kod içeren demolar için tensorflow.org/js adresini ziyaret edin.