1. Giới thiệu
Trong Lớp học lập trình này, bạn sẽ tìm hiểu cách tạo một máy chủ web Node.js để huấn luyện và phân loại các loại cú ném bóng chày ở phía máy chủ bằng TensorFlow.js, một thư viện học máy mạnh mẽ và linh hoạt cho JavaScript. Bạn sẽ tạo một ứng dụng web để huấn luyện mô hình dự đoán loại cú ném từ dữ liệu cảm biến cú ném và để gọi dự đoán từ một ứng dụng web. Một phiên bản hoạt động đầy đủ của Lớp học lập trình này có trong kho lưu trữ tfjs-examples trên GitHub.
Kiến thức bạn sẽ học được
- Cách cài đặt và thiết lập gói npm tensorflow.js để sử dụng với Node.js.
- Cách truy cập vào dữ liệu huấn luyện và dữ liệu kiểm thử trong môi trường Node.js.
- Cách huấn luyện một mô hình bằng TensorFlow.js trong máy chủ Node.js.
- Cách triển khai mô hình đã huấn luyện để suy luận trong ứng dụng máy khách/máy chủ.
Vậy thì hãy cùng bắt đầu!
2. Yêu cầu
Để hoàn tất Lớp học lập trình này, bạn cần:
- Phiên bản mới nhất của Chrome hoặc một trình duyệt hiện đại khác.
- Một trình chỉnh sửa văn bản và thiết bị đầu cuối lệnh chạy cục bộ trên máy của bạn.
- Có kiến thức về HTML, CSS, JavaScript và Chrome DevTools (hoặc công cụ cho nhà phát triển của trình duyệt mà bạn muốn dùng).
- Hiểu biết khái niệm cấp cao về mạng nơron. Nếu bạn cần tìm hiểu hoặc ôn lại, hãy cân nhắc xem video này của 3blue1brown hoặc video này về Học sâu bằng JavaScript của Ashi Krishnan.
3. Thiết lập ứng dụng Node.js
Cài đặt Node.js và npm. Để biết các nền tảng và phần phụ thuộc được hỗ trợ, vui lòng xem hướng dẫn cài đặt tfjs-node.
Tạo một thư mục có tên ./baseball cho ứng dụng Node.js của chúng ta. Sao chép package.json và webpack.config.js được liên kết vào thư mục này để định cấu hình các phần phụ thuộc của gói npm (bao gồm cả gói npm @tensorflow/tfjs-node). Sau đó, hãy chạy lệnh npm install để cài đặt các phần phụ thuộc.
$ cd baseball
$ ls
package.json webpack.config.js
$ npm install
...
$ ls
node_modules package.json package-lock.json webpack.config.js
Giờ đây, bạn đã sẵn sàng viết mã và huấn luyện một mô hình!
4. Thiết lập dữ liệu huấn luyện và kiểm thử
Bạn sẽ sử dụng dữ liệu huấn luyện và kiểm thử dưới dạng tệp CSV từ các đường liên kết bên dưới. Tải xuống và khám phá dữ liệu trong các tệp này:
Hãy xem xét một số dữ liệu huấn luyện mẫu:
vx0,vy0,vz0,ax,ay,az,start_speed,left_handed_pitcher,pitch_code
7.69914900671662,-132.225686405648,-6.58357157666866,-22.5082591074995,28.3119270826735,-16.5850095967027,91.1,0,0
6.68052308575228,-134.215511616881,-6.35565979491619,-19.6602769147989,26.7031848314466,-14.3430602022656,92.4,0,0
2.56546504690782,-135.398673977074,-2.91657310799559,-14.7849950586111,27.8083916890792,-21.5737737390901,93.1,0,0
Có 8 đặc điểm đầu vào mô tả dữ liệu cảm biến độ cao:
- vận tốc của quả bóng (vx0, vy0, vz0)
- gia tốc của bóng (ax, ay, az)
- tốc độ ban đầu của cú ném
- cho biết người ném bóng thuận tay trái hay không
và một nhãn đầu ra:
- pitch_code cho biết một trong 7 loại cú ném:
Fastball (2-seam), Fastball (4-seam), Fastball (sinker), Fastball (cutter), Slider, Changeup, Curveball
Mục tiêu là xây dựng một mô hình có thể dự đoán loại cú ném dựa trên dữ liệu cảm biến cú ném.
Trước khi tạo mô hình, bạn cần chuẩn bị dữ liệu huấn luyện và dữ liệu kiểm thử. Tạo tệp pitch_type.js trong thư mục baseball/ rồi sao chép đoạn mã sau vào tệp đó. Đoạn mã này tải dữ liệu huấn luyện và kiểm thử bằng API tf.data.csv. Thư viện này cũng chuẩn hoá dữ liệu (bạn nên làm như vậy) bằng cách sử dụng thang chuẩn hoá tối thiểu – tối đa.
const tf = require('@tensorflow/tfjs');
// util function to normalize a value between a given range.
function normalize(value, min, max) {
if (min === undefined || max === undefined) {
return value;
}
return (value - min) / (max - min);
}
// data can be loaded from URLs or local file paths when running in Node.js.
const TRAIN_DATA_PATH =
'https://storage.googleapis.com/mlb-pitch-data/pitch_type_training_data.csv';
const TEST_DATA_PATH = 'https://storage.googleapis.com/mlb-pitch-data/pitch_type_test_data.csv';
// Constants from training data
const VX0_MIN = -18.885;
const VX0_MAX = 18.065;
const VY0_MIN = -152.463;
const VY0_MAX = -86.374;
const VZ0_MIN = -15.5146078412997;
const VZ0_MAX = 9.974;
const AX_MIN = -48.0287647107959;
const AX_MAX = 30.592;
const AY_MIN = 9.397;
const AY_MAX = 49.18;
const AZ_MIN = -49.339;
const AZ_MAX = 2.95522851438373;
const START_SPEED_MIN = 59;
const START_SPEED_MAX = 104.4;
const NUM_PITCH_CLASSES = 7;
const TRAINING_DATA_LENGTH = 7000;
const TEST_DATA_LENGTH = 700;
// Converts a row from the CSV into features and labels.
// Each feature field is normalized within training data constants
const csvTransform =
({xs, ys}) => {
const values = [
normalize(xs.vx0, VX0_MIN, VX0_MAX),
normalize(xs.vy0, VY0_MIN, VY0_MAX),
normalize(xs.vz0, VZ0_MIN, VZ0_MAX), normalize(xs.ax, AX_MIN, AX_MAX),
normalize(xs.ay, AY_MIN, AY_MAX), normalize(xs.az, AZ_MIN, AZ_MAX),
normalize(xs.start_speed, START_SPEED_MIN, START_SPEED_MAX),
xs.left_handed_pitcher
];
return {xs: values, ys: ys.pitch_code};
}
const trainingData =
tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
.map(csvTransform)
.shuffle(TRAINING_DATA_LENGTH)
.batch(100);
// Load all training data in one batch to use for evaluation
const trainingValidationData =
tf.data.csv(TRAIN_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
.map(csvTransform)
.batch(TRAINING_DATA_LENGTH);
// Load all test data in one batch to use for evaluation
const testValidationData =
tf.data.csv(TEST_DATA_PATH, {columnConfigs: {pitch_code: {isLabel: true}}})
.map(csvTransform)
.batch(TEST_DATA_LENGTH);
5. Tạo mô hình để phân loại các kiểu ném bóng
Giờ đây, bạn đã sẵn sàng xây dựng mô hình. Sử dụng API tf.layers để kết nối các đầu vào (hình dạng của 8 giá trị cảm biến độ cao) với 3 lớp ẩn được kết nối hoàn toàn bao gồm các đơn vị kích hoạt ReLU, theo sau là một lớp đầu ra softmax bao gồm 7 đơn vị, mỗi đơn vị đại diện cho một trong các loại độ cao đầu ra.
Huấn luyện mô hình bằng trình tối ưu hoá adam và hàm tổn thất sparseCategoricalCrossentropy. Để biết thêm thông tin về các lựa chọn này, hãy tham khảo hướng dẫn về mô hình huấn luyện.
Thêm đoạn mã sau vào cuối pitch_type.js:
const model = tf.sequential();
model.add(tf.layers.dense({units: 250, activation: 'relu', inputShape: [8]}));
model.add(tf.layers.dense({units: 175, activation: 'relu'}));
model.add(tf.layers.dense({units: 150, activation: 'relu'}));
model.add(tf.layers.dense({units: NUM_PITCH_CLASSES, activation: 'softmax'}));
model.compile({
optimizer: tf.train.adam(),
loss: 'sparseCategoricalCrossentropy',
metrics: ['accuracy']
});
Kích hoạt quá trình huấn luyện từ mã máy chủ chính mà bạn sẽ viết sau.
Để hoàn tất mô-đun pitch_type.js, hãy viết một hàm để đánh giá tập dữ liệu xác thực và kiểm thử, dự đoán loại cú ném cho một mẫu duy nhất và tính toán các chỉ số về độ chính xác. Thêm mã này vào cuối pitch_type.js:
// Returns pitch class evaluation percentages for training data
// with an option to include test data
async function evaluate(useTestData) {
let results = {};
await trainingValidationData.forEachAsync(pitchTypeBatch => {
const values = model.predict(pitchTypeBatch.xs).dataSync();
const classSize = TRAINING_DATA_LENGTH / NUM_PITCH_CLASSES;
for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
results[pitchFromClassNum(i)] = {
training: calcPitchClassEval(i, classSize, values)
};
}
});
if (useTestData) {
await testValidationData.forEachAsync(pitchTypeBatch => {
const values = model.predict(pitchTypeBatch.xs).dataSync();
const classSize = TEST_DATA_LENGTH / NUM_PITCH_CLASSES;
for (let i = 0; i < NUM_PITCH_CLASSES; i++) {
results[pitchFromClassNum(i)].validation =
calcPitchClassEval(i, classSize, values);
}
});
}
return results;
}
async function predictSample(sample) {
let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
var maxValue = 0;
var predictedPitch = 7;
for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
if (result[0][i] > maxValue) {
predictedPitch = i;
maxValue = result[0][i];
}
}
return pitchFromClassNum(predictedPitch);
}
// Determines accuracy evaluation for a given pitch class by index
function calcPitchClassEval(pitchIndex, classSize, values) {
// Output has 7 different class values for each pitch, offset based on
// which pitch class (ordered by i)
let index = (pitchIndex * classSize * NUM_PITCH_CLASSES) + pitchIndex;
let total = 0;
for (let i = 0; i < classSize; i++) {
total += values[index];
index += NUM_PITCH_CLASSES;
}
return total / classSize;
}
// Returns the string value for Baseball pitch labels
function pitchFromClassNum(classNum) {
switch (classNum) {
case 0:
return 'Fastball (2-seam)';
case 1:
return 'Fastball (4-seam)';
case 2:
return 'Fastball (sinker)';
case 3:
return 'Fastball (cutter)';
case 4:
return 'Slider';
case 5:
return 'Changeup';
case 6:
return 'Curveball';
default:
return 'Unknown';
}
}
module.exports = {
evaluate,
model,
pitchFromClassNum,
predictSample,
testValidationData,
trainingData,
TEST_DATA_LENGTH
}
6. Đào tạo mô hình trên máy chủ
Viết mã máy chủ để thực hiện quá trình huấn luyện mô hình và đánh giá mô hình trong một tệp mới có tên là server.js. Trước tiên, hãy tạo một máy chủ HTTP và mở kết nối socket hai chiều bằng API socket.io. Sau đó, hãy thực hiện quá trình huấn luyện mô hình bằng cách sử dụng API model.fitDataset và đánh giá độ chính xác của mô hình bằng phương thức pitch_type.evaluate() mà bạn đã viết trước đó. Huấn luyện và đánh giá trong 10 lần lặp lại, in các chỉ số vào bảng điều khiển.
Sao chép mã bên dưới vào server.js:
require('@tensorflow/tfjs-node');
const http = require('http');
const socketio = require('socket.io');
const pitch_type = require('./pitch_type');
const TIMEOUT_BETWEEN_EPOCHS_MS = 500;
const PORT = 8001;
// util function to sleep for a given ms
function sleep(ms) {
return new Promise(resolve => setTimeout(resolve, ms));
}
// Main function to start server, perform model training, and emit stats via the socket connection
async function run() {
const port = process.env.PORT || PORT;
const server = http.createServer();
const io = socketio(server);
server.listen(port, () => {
console.log(` > Running socket on port: ${port}`);
});
io.on('connection', (socket) => {
socket.on('predictSample', async (sample) => {
io.emit('predictResult', await pitch_type.predictSample(sample));
});
});
let numTrainingIterations = 10;
for (var i = 0; i < numTrainingIterations; i++) {
console.log(`Training iteration : ${i+1} / ${numTrainingIterations}`);
await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
console.log('accuracyPerClass', await pitch_type.evaluate(true));
await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
}
io.emit('trainingComplete', true);
}
run();
Đến đây, bạn đã sẵn sàng chạy và kiểm thử máy chủ! Bạn sẽ thấy nội dung tương tự như sau, trong đó máy chủ huấn luyện một số lượng lớn dữ liệu trong mỗi lần lặp lại (bạn cũng có thể sử dụng API model.fitDataset để huấn luyện nhiều số lượng lớn dữ liệu bằng một lệnh gọi). Nếu bạn gặp lỗi tại thời điểm này, vui lòng kiểm tra quá trình cài đặt node và npm.
$ npm run start-server
...
> Running socket on port: 8001
Epoch 1 / 1
eta=0.0 ========================================================================================================>
2432ms 34741us/step - acc=0.429 loss=1.49
Nhập Ctrl-C để dừng máy chủ đang chạy. Chúng ta sẽ chạy lại trong bước tiếp theo.
7. Tạo trang khách hàng và mã hiển thị
Bây giờ, khi máy chủ đã sẵn sàng, bước tiếp theo là viết mã ứng dụng khách và mã đó sẽ chạy trong trình duyệt. Tạo một trang đơn giản để gọi hoạt động dự đoán mô hình trên máy chủ và hiển thị kết quả. Thao tác này sử dụng socket.io để giao tiếp giữa máy khách và máy chủ.
Trước tiên, hãy tạo index.html trong thư mục baseball/:
<!doctype html>
<html>
<head>
<title>Pitch Training Accuracy</title>
</head>
<body>
<h3 id="waiting-msg">Waiting for server...</h3>
<p>
<span style="font-size:16px" id="trainingStatus"></span>
<p>
<div id="predictContainer" style="font-size:16px;display:none">
Sensor data: <span id="predictSample"></span>
<button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button><p>
Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
</div>
<script src="dist/bundle.js"></script>
<style>
html,
body {
font-family: Roboto, sans-serif;
color: #5f6368;
}
body {
background-color: rgb(248, 249, 250);
}
</style>
</body>
</html>
Sau đó, hãy tạo một tệp client.js mới trong thư mục baseball/ bằng mã bên dưới:
import io from 'socket.io-client';
const predictContainer = document.getElementById('predictContainer');
const predictButton = document.getElementById('predict-button');
const socket =
io('http://localhost:8001',
{reconnectionDelay: 300, reconnectionDelayMax: 300});
const testSample = [2.668,-114.333,-1.908,4.786,25.707,-45.21,78,0]; // Curveball
predictButton.onclick = () => {
predictButton.disabled = true;
socket.emit('predictSample', testSample);
};
// functions to handle socket events
socket.on('connect', () => {
document.getElementById('waiting-msg').style.display = 'none';
document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
});
socket.on('trainingComplete', () => {
document.getElementById('trainingStatus').innerHTML = 'Training Complete';
document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
predictContainer.style.display = 'block';
});
socket.on('predictResult', (result) => {
plotPredictResult(result);
});
socket.on('disconnect', () => {
document.getElementById('trainingStatus').innerHTML = '';
predictContainer.style.display = 'none';
document.getElementById('waiting-msg').style.display = 'block';
});
function plotPredictResult(result) {
predictButton.disabled = false;
document.getElementById('predictResult').innerHTML = result;
console.log(result);
}
Ứng dụng xử lý thông báo socket trainingComplete để hiển thị nút dự đoán. Khi người dùng nhấp vào nút này, ứng dụng sẽ gửi một thông báo socket kèm theo dữ liệu mẫu của cảm biến. Sau khi nhận được thông báo predictResult, thông báo này sẽ hiển thị thông tin dự đoán trên trang.
8. Chạy ứng dụng
Chạy cả máy chủ và ứng dụng để xem toàn bộ ứng dụng hoạt động:
[In one terminal, run this first]
$ npm run start-client
[In another terminal, run this next]
$ npm run start-server
Mở trang ứng dụng trong trình duyệt ( http://localhost:8080). Khi quá trình huấn luyện mô hình kết thúc, hãy nhấp vào nút Dự đoán mẫu. Bạn sẽ thấy kết quả dự đoán xuất hiện trong trình duyệt. Bạn có thể sửa đổi dữ liệu cảm biến mẫu bằng một số ví dụ trong tệp CSV kiểm thử và xem mô hình dự đoán chính xác đến mức nào.
9. Kiến thức bạn học được
Trong Lớp học lập trình này, bạn đã triển khai một ứng dụng web học máy đơn giản bằng TensorFlow.js. Bạn đã huấn luyện một mô hình tuỳ chỉnh để phân loại các loại cú ném bóng chày dựa trên dữ liệu cảm biến. Bạn đã viết mã Node.js để thực thi quá trình huấn luyện trên máy chủ và gọi suy luận trên mô hình đã huấn luyện bằng dữ liệu được gửi từ máy khách.
Nhớ truy cập vào tensorflow.org/js để xem thêm các ví dụ và bản minh hoạ có mã nguồn nhằm tìm hiểu cách bạn có thể sử dụng TensorFlow.js trong các ứng dụng của mình.