1. Prima di iniziare
L'utilizzo del modello TensorFlow.js è aumentato in modo esponenziale negli ultimi anni e molti sviluppatori JavaScript stanno ora cercando di convertire i modelli all'avanguardia esistenti in modo che utilizzino dati personalizzati unici per il loro settore. L'atto di prendere un modello esistente (spesso indicato come modello di base) e di utilizzarlo su un dominio simile ma diverso è noto come Transfer Learning.
Transfer Learning ha molti vantaggi rispetto a un modello completamente vuoto. Puoi riutilizzare le conoscenze già apprese da un modello addestrato precedente e hai bisogno di meno esempi del nuovo elemento che vuoi classificare. Inoltre, l'addestramento è spesso significativamente più rapido poiché è necessario riaddestrare solo gli ultimi livelli dell'architettura del modello anziché l'intera rete. Per questo motivo, il transfer learning è molto adatto all'ambiente del browser web in cui le risorse possono variare in base al dispositivo di esecuzione, ma ha anche accesso diretto ai sensori per una facile acquisizione dei dati.
Questo codelab ti mostra come creare un'app web da una tela vuota, ricreando il popolare " Teachable Machine" sito web. Il sito web ti consente di creare un'app web funzionale che qualsiasi utente può utilizzare per riconoscere un oggetto personalizzato con poche immagini di esempio prese dalla webcam. Il sito web è volutamente ridotto al minimo, in modo che tu possa concentrarti sugli aspetti di questo codelab sul machine learning. Tuttavia, come per il sito web originale di Teachable Machine, esistono ampi margini per applicare la tua attuale esperienza di sviluppatore web al fine di migliorare la UX.
Prerequisiti
Questo codelab è scritto per gli sviluppatori web che hanno una certa familiarità con i modelli predefiniti TensorFlow.js e l'utilizzo di base dell'API e che vogliono iniziare a utilizzare Transfer Learning inTensorFlow.js.
- Per questo lab è necessaria una conoscenza di base di TensorFlow.js, HTML5, CSS e JavaScript.
Se è la prima volta che utilizzi Tensorflow.js, valuta la possibilità di seguire questo corso senza costi da zero to hero, che non richiede alcuna esperienza con il machine learning o TensorFlow.js, e ti insegnerà tutto ciò che devi sapere in passaggi più piccoli.
Cosa imparerai a fare
- Che cos'è TensorFlow.js e perché dovresti usarlo nella tua prossima app web.
- Come creare una pagina web HTML/CSS /JS semplificata che replica l'esperienza utente di Teachable Machine.
- Come utilizzare TensorFlow.js per caricare un modello di base preaddestrato, in particolare MobileNet, per generare caratteristiche di immagine da utilizzare nel Transfer Learning.
- Come raccogliere i dati dalla webcam di un utente per più classi di dati che vuoi riconoscere.
- Creare e definire un perceptron multi-strato che prenda le caratteristiche dell'immagine e impari a classificare nuovi oggetti utilizzandole.
Cominciamo con l'hack...
Che cosa ti serve
- È preferibile seguire un account Glitch.com oppure puoi utilizzare un ambiente di pubblicazione web che hai dimestichezza con l'editing e l'esecuzione in autonomia.
2. Che cos'è TensorFlow.js?
TensorFlow.js è una libreria open source di machine learning che può essere eseguita ovunque JavaScript possa essere eseguito. Si basa sulla libreria TensorFlow originale scritta in Python e mira a ricreare questa esperienza di sviluppo e un insieme di API per l'ecosistema JavaScript.
Dove può essere utilizzato?
Data la portabilità di JavaScript, ora puoi scrivere in un solo linguaggio ed eseguire facilmente il machine learning su tutte le seguenti piattaforme:
- Lato client nel browser web con JavaScript Vanilla
- Lato server e persino dispositivi IoT come Raspberry Pi che utilizzano Node.js
- App desktop che utilizzano Electron
- App native per dispositivi mobili che utilizzano React Native
TensorFlow.js supporta anche più backend all'interno di ciascuno di questi ambienti (ad esempio gli ambienti basati su hardware effettivi che può eseguire all'interno, come CPU o WebGL. Un "backend" in questo contesto non significa un ambiente lato server: il backend per l'esecuzione potrebbe essere sul lato client, ad esempio, in WebGL) per garantire la compatibilità e garantire anche la velocità. Attualmente TensorFlow.js supporta:
- Esecuzione WebGL sulla scheda grafica del dispositivo (GPU): è il modo più veloce per eseguire modelli più grandi (superiori a 3 MB) con accelerazione GPU.
- Esecuzione di Web Assembly (WASM) sulla CPU: per migliorare le prestazioni della CPU su tutti i dispositivi, inclusi, ad esempio, telefoni cellulari meno recenti. È più adatto a modelli più piccoli (di dimensioni inferiori a 3 MB) che possono essere eseguiti più velocemente sulla CPU con WASM rispetto a WebGL a causa dell'overhead associato al caricamento dei contenuti su un processore grafico.
- Esecuzione CPU: il fallback non dovrebbe essere disponibile in nessuno degli altri ambienti. È il più lento dei tre, ma è sempre disponibile.
Nota: puoi scegliere di forzare uno di questi backend se sai su quale dispositivo eseguirai l'esecuzione oppure puoi lasciare che sia TensorFlow.js a decidere per te, se non lo specifichi.
Superpoteri lato client
L'esecuzione di TensorFlow.js nel browser web sul computer client può portare a numerosi vantaggi che vale la pena prendere in considerazione.
Privacy
È possibile addestrare e classificare i dati sul computer client senza mai inviarli a un server web di terze parti. In alcuni casi potrebbe essere necessario rispettare le leggi locali, come ad esempio il GDPR, o durante l'elaborazione di dati che l'utente potrebbe voler conservare sul proprio computer e non essere inviati a terze parti.
Velocità
Poiché non devi inviare i dati a un server remoto, l'inferenza (l'atto di classificare i dati) può essere più veloce. Inoltre, avrai accesso diretto ai sensori del dispositivo, come fotocamera, microfono, GPS, accelerometro e altri, se l'utente ti concede l'accesso.
Copertura e scalabilità
Con un solo clic chiunque nel mondo può fare clic su un link che hai inviato, aprire la pagina web nel proprio browser e utilizzare ciò che hai creato. Non è necessaria una configurazione Linux lato server complessa con driver CUDA e molto altro solo per utilizzare il sistema di machine learning.
Costo
Nessun server significa che l'unica cosa che devi pagare è una CDN per ospitare i tuoi file HTML, CSS, JS e modello. Il costo di una CDN è molto più economico che tenere un server (potenzialmente con una scheda grafica collegata) in esecuzione 24/7.
Funzionalità lato server
L'implementazione di Node.js di TensorFlow.js attiva le seguenti funzionalità.
Supporto CUDA completo
Sul lato server, per l'accelerazione della scheda grafica, devi installare i driver NVIDIA CUDA per consentire a TensorFlow di funzionare con la scheda grafica (a differenza del browser che utilizza WebGL - nessuna installazione necessaria). Tuttavia, con il supporto CUDA completo è possibile sfruttare appieno le capacità di livello inferiore della scheda grafica, con tempi di addestramento e inferenza più rapidi. Le prestazioni sono uguali a quelle dell'implementazione di TensorFlow Python, in quanto entrambi condividono lo stesso backend C++.
Dimensioni modello
Nel caso di modelli all'avanguardia provenienti dalla ricerca, potresti lavorare con modelli molto grandi, ad esempio di gigabyte. Questi modelli al momento non possono essere eseguiti nel browser web a causa delle limitazioni di utilizzo della memoria per scheda del browser. Per eseguire questi modelli di dimensioni maggiori puoi utilizzare Node.js sul tuo server con le specifiche hardware necessarie per eseguire questo modello in modo efficiente.
IoT
Node.js è supportato su computer a scheda singola popolari come Raspberry Pi, il che a sua volta significa che puoi eseguire i modelli TensorFlow.js anche su questi dispositivi.
Velocità
Node.js è scritto in JavaScript, il che significa che trae vantaggio dalla compilazione just-in-time. Ciò significa che spesso potresti notare un miglioramento delle prestazioni quando utilizzi Node.js, in quanto verrà ottimizzato in fase di runtime, in particolare per qualsiasi pre-elaborazione. Un ottimo esempio di ciò può essere visto in questo case study, che mostra come Hugging Face ha utilizzato Node.js per ottenere un aumento delle prestazioni di 2 volte per il proprio modello di elaborazione del linguaggio naturale.
Ora che conosci le nozioni di base di TensorFlow.js, dove può essere eseguito e alcuni dei vantaggi, iniziamo a utilizzarlo.
3. Transfer learning
Che cos'è esattamente il Transfer Learning?
Il transfer learning implica l'acquisizione di conoscenze che si sono già apprese per aiutare ad apprendere qualcosa di diverso ma simile.
Noi umani lo facciamo sempre. Hai una vita vissuta nel tuo cervello di esperienze che puoi usare per riconoscere cose nuove che non hai mai visto prima. Prendiamo ad esempio questo salice:
A seconda della località in cui ti trovi nel mondo, è possibile che tu non abbia mai visto questo tipo di albero prima d'ora.
Tuttavia, se vi chiedo di dirmi se ci sono salici nella nuova immagine qui sotto, probabilmente li potrete avvistare abbastanza velocemente, anche se hanno un'angolazione diversa e leggermente diversa da quella originale che vi ho mostrato.
Hai già un gruppo di neuroni nel tuo cervello che sanno come identificare oggetti simili a alberi, e altri neuroni che sono bravi a trovare lunghe linee rette. Puoi riutilizzare queste conoscenze per classificare rapidamente un salice, un oggetto simile a un albero che ha molti lunghi rami verticali dritti.
Analogamente, se disponi di un modello di machine learning già addestrato su un dominio, ad esempio che riconosce le immagini, puoi riutilizzarlo per eseguire un'attività diversa ma correlata.
Puoi fare lo stesso con un modello avanzato come MobileNet, un modello di ricerca molto diffuso in grado di eseguire il riconoscimento di immagini su 1000 diversi tipi di oggetti. Dai cani alle auto, è stato addestrato su un enorme set di dati noto come ImageNet che contiene milioni di immagini etichettate.
In questa animazione puoi vedere l'enorme numero di livelli presenti nel modello MobileNet V1:
Durante l'addestramento, il modello ha imparato a estrarre caratteristiche comuni che contano per tutti i 1000 oggetti, e molte delle caratteristiche di livello inferiore che utilizza per identificare questi oggetti possono essere utili anche per rilevare nuovi oggetti che non ha mai visto prima. Dopotutto, tutto è solo una combinazione di linee, texture e forme.
Diamo un'occhiata a una tradizionale architettura CNN (Convolutional Neural Network) (simile a MobileNet) e vediamo come il Transfer Learning può sfruttare questa rete addestrata per apprendere qualcosa di nuovo. L'immagine seguente mostra la tipica architettura del modello di una CNN che in questo caso è stata addestrata per riconoscere cifre scritte a mano libera da 0 a 9:
Se fossi in grado di separare i livelli di livello inferiore preaddestrati di un modello addestrato esistente come questo, mostrato a sinistra, dai livelli di classificazione vicino alla fine del modello mostrato a destra (a volte indicati come test di classificazione del modello), potresti utilizzare i livelli di livello inferiore per produrre caratteristiche di output per una determinata immagine sulla base dei dati originali su cui è stata addestrata. Ecco la stessa rete con la testa di classificazione rimossa:
Supponendo che la nuova cosa che stai cercando di riconoscere possa anche utilizzare queste caratteristiche di output che il modello precedente ha imparato, c'è una buona probabilità che possano essere riutilizzate per un nuovo scopo.
Nel diagramma qui sopra, questo modello ipotetico è stato addestrato sulle cifre, quindi forse ciò che è stato imparato sulle cifre può essere applicato anche a lettere come a, b e c.
Ora puoi aggiungere una nuova intestazione di classificazione che provi a prevedere a, b o c, come mostrato di seguito:
Qui gli strati di livello inferiore sono congelati e non addestrati, solo la nuova testa di classificazione si aggiornerà per apprendere dalle caratteristiche fornite dal modello suddiviso preaddestrato a sinistra.
Questo processo è noto come transfer learning ed è ciò che fa Teachable Machine dietro le quinte.
Si può anche vedere che, perché è necessario addestrare il perceptron a più strati solo alla fine della rete, l'addestramento avviene molto più velocemente di quanto accadrebbe se si dovesse addestrare l'intera rete da zero.
Ma come puoi mettere in pratica le sottoparti di un modello? Vai alla sezione successiva per scoprirlo.
4. TensorFlow Hub - Modelli di base
Trovare un modello di base adatto da utilizzare
Per modelli di ricerca più avanzati e popolari come MobileNet, puoi accedere all'hub TensorFlow e quindi filtrare i modelli adatti a TensorFlow.js che utilizzano l'architettura MobileNet v3 per trovare risultati come quelli mostrati qui:
Tieni presente che alcuni di questi risultati sono di tipo "classificazione immagini" (dettagli in alto a sinistra dei risultati della scheda di ciascun modello) e altri sono di tipo "vettore di caratteristiche immagine".
Questi risultati relativi ai vettori di caratteristiche delle immagini sono essenzialmente le versioni pre-triturate di MobileNet che puoi utilizzare per ottenere i vettori delle caratteristiche dell'immagine invece della classificazione finale.
I modelli come questo sono spesso chiamati "modelli di base", che puoi quindi utilizzare per eseguire il Transfer Learning come mostrato nella sezione precedente, aggiungendo un nuovo intestazione di classificazione e addestrandolo con i tuoi dati.
Il passaggio successivo da verificare è la presenza di un determinato modello di base di interesse per il formato TensorFlow.js in cui viene rilasciato il modello. Se apri la pagina per uno di questi modelli MobileNet v3 di caratteristiche vettoriali, puoi vedere dalla documentazione di JS che si tratta di un modello grafico basato sullo snippet di codice di esempio nella documentazione che utilizza tf.loadGraphModel()
.
Tieni presente, inoltre, che se trovi un modello nel formato dei livelli anziché nel formato del grafico, puoi scegliere quali livelli bloccare e quali sbloccare per l'addestramento. Questo può essere molto efficace quando si crea un modello per una nuova attività, che è spesso indicato come "modello di trasferimento". Per ora, tuttavia, per questo tutorial utilizzerai il tipo di modello grafico predefinito, in base al quale viene eseguito il deployment della maggior parte dei modelli TF Hub. Per saperne di più sull'utilizzo dei modelli Livelli, consulta il corso TensorFlow.js da zero to hero.
Vantaggi del Transfer Learning
Quali sono i vantaggi di utilizzare il transfer learning invece di addestrare l'intera architettura del modello da zero?
In primo luogo, il tempo di addestramento è un vantaggio fondamentale dell'utilizzo di un approccio di transfer learning, in quanto disponi già di un modello di base addestrato su cui sviluppare.
In secondo luogo, potete riuscire a mostrare molti meno esempi del nuovo elemento che state cercando di classificare a causa dell'addestramento.
Questo è davvero ottimo se hai tempo e risorse limitati per raccogliere dati di esempio di ciò che vuoi classificare e devi creare rapidamente un prototipo prima di raccogliere altri dati di addestramento per renderlo più solido.
Data la necessità di ridurre i dati e la velocità di addestramento di una rete più piccola, il Transfer Learning richiede meno risorse. Questo lo rende molto adatto all'ambiente browser, impiegando solo decine di secondi su una macchina moderna anziché ore, giorni o settimane per l'addestramento completo del modello.
Perfetto. Ora che conosci in sostanza cos'è Transfer Learning, è il momento di creare la tua versione personale di Teachable Machine. Iniziamo.
5. Configura per programmare
Che cosa ti serve
- Un browser web moderno.
- Conoscenza di base di HTML, CSS, JavaScript e Chrome DevTools (visualizzazione dell'output della console).
Iniziamo a programmare
Sono stati creati modelli boilerplate da cui iniziare per Glitch.com o Codepen.io. Puoi semplicemente clonare uno dei modelli come stato base per questo lab di codice con un solo clic.
Su Glitch, fai clic sul pulsante "Esegui il remix" per creare un fork e creare un nuovo set di file che potrai modificare.
In alternativa, su Codepen, fai clic su" fork" nella parte inferiore destra dello schermo.
Questo scheletro molto semplice fornisce i seguenti file:
- Pagina HTML (index.html)
- Foglio di stile (style.css)
- File per scrivere il nostro codice JavaScript (script.js)
Per comodità, è disponibile un'importazione aggiunta nel file HTML per la libreria TensorFlow.js. Ha questo aspetto:
index.html
<!-- Import TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js" type="text/javascript"></script>
Alternativa: utilizza il tuo editor web preferito o lavora localmente
Se vuoi scaricare il codice e lavorare in locale o su un altro editor online, crea semplicemente i 3 file denominati sopra nella stessa directory e copia e incolla in ciascuno di essi il codice dal nostro boilerplate Glitch.
6. Boilerplate HTML dell'app
Da dove inizio?
Tutti i prototipi richiedono alcuni scaffold HTML di base su cui eseguire il rendering dei risultati. Configuralo ora. Stai per aggiungere:
- Un titolo per la pagina.
- Un po' di testo descrittivo.
- Un paragrafo sullo stato.
- Un video per conservare il feed della webcam quando è pronto.
- Diversi pulsanti per avviare la videocamera, raccogliere dati o reimpostare l'esperienza.
- Importazioni per file TensorFlow.js e JS da scrivere in un secondo momento.
Apri index.html
e incolla il codice esistente con il codice seguente per configurare le funzionalità riportate sopra:
index.html
<!DOCTYPE html>
<html lang="en">
<head>
<title>Transfer Learning - TensorFlow.js</title>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1">
<!-- Import the webpage's stylesheet -->
<link rel="stylesheet" href="/style.css">
</head>
<body>
<h1>Make your own "Teachable Machine" using Transfer Learning with MobileNet v3 in TensorFlow.js using saved graph model from TFHub.</h1>
<p id="status">Awaiting TF.js load</p>
<video id="webcam" autoplay muted></video>
<button id="enableCam">Enable Webcam</button>
<button class="dataCollector" data-1hot="0" data-name="Class 1">Gather Class 1 Data</button>
<button class="dataCollector" data-1hot="1" data-name="Class 2">Gather Class 2 Data</button>
<button id="train">Train & Predict!</button>
<button id="reset">Reset</button>
<!-- Import TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.11.0/dist/tf.min.js" type="text/javascript"></script>
<!-- Import the page's JavaScript to do some stuff -->
<script type="module" src="/script.js"></script>
</body>
</html>
Suddividi
Analizziamo parte del codice HTML riportato sopra per evidenziare alcuni elementi chiave che hai aggiunto.
- Hai aggiunto un tag
<h1>
per il titolo della pagina insieme a un tag<p>
con l'ID "stato", che è la modalità di stampa delle informazioni, poiché utilizzi parti diverse del sistema per visualizzare gli output. - Hai aggiunto un elemento
<video>
con l'ID "webcam", su cui eseguirai il rendering dello stream con la webcam in un secondo momento. - Hai aggiunto 5 elementi
<button>
. La prima, con l'ID "enableCam", attiva la fotocamera. I due pulsanti successivi hanno una classe "dataCollector", che ti consente di raccogliere immagini di esempio per gli oggetti che vuoi riconoscere. Il codice che scriverai in seguito sarà progettato in modo da poter aggiungere il numero desiderato di pulsanti, che funzioneranno automaticamente come previsto.
Tieni presente che questi pulsanti hanno anche uno speciale attributo definito dall'utente chiamato data-1hot, con un valore intero che parte da 0 per la prima classe. Si tratta dell'indice numerico che utilizzerai per rappresentare i dati di una determinata classe. L'indice verrà utilizzato per codificare correttamente le classi di output con una rappresentazione numerica anziché una stringa, poiché i modelli ML possono funzionare solo con numeri.
Esiste anche un attributo data-name che contiene il nome leggibile da una persona che vuoi utilizzare per questa classe, che ti consente di fornire all'utente un nome più significativo invece di un valore di indice numerico della codifica a caldo 1.
Infine, hai a disposizione un pulsante di addestramento e un pulsante di reimpostazione per avviare il processo di addestramento una volta raccolti i dati o per reimpostare l'app.
- Hai anche aggiunto 2 importazioni
<script>
. uno per TensorFlow.js e l'altro per script.js che definirai a breve.
7. Aggiungi stile
Valori predefiniti elemento
Aggiungi stili per gli elementi HTML appena aggiunti per assicurarti che vengano visualizzati correttamente. Di seguito sono riportati alcuni stili che vengono aggiunti correttamente agli elementi di posizione e dimensione. Niente di troppo speciale. Potresti sicuramente aggiungerlo in un secondo momento per migliorare l'esperienza utente, come hai visto nel video Teachable Machine.
style.css
body {
font-family: helvetica, arial, sans-serif;
margin: 2em;
}
h1 {
font-style: italic;
color: #FF6F00;
}
video {
clear: both;
display: block;
margin: 10px;
background: #000000;
width: 640px;
height: 480px;
}
button {
padding: 10px;
float: left;
margin: 5px 3px 5px 10px;
}
.removed {
display: none;
}
#status {
font-size:150%;
}
Bene. Non serve altro. Se adesso visualizzi l'anteprima dell'output, dovrebbe avere un aspetto simile al seguente:
8. JavaScript: costanti chiave e listener
Definire le costanti chiave
Innanzitutto, aggiungi alcune costanti chiave che utilizzerai nell'app. Inizia sostituendo i contenuti di script.js
con queste costanti:
script.js
const STATUS = document.getElementById('status');
const VIDEO = document.getElementById('webcam');
const ENABLE_CAM_BUTTON = document.getElementById('enableCam');
const RESET_BUTTON = document.getElementById('reset');
const TRAIN_BUTTON = document.getElementById('train');
const MOBILE_NET_INPUT_WIDTH = 224;
const MOBILE_NET_INPUT_HEIGHT = 224;
const STOP_DATA_GATHER = -1;
const CLASS_NAMES = [];
Analizziamo in dettaglio a cosa servono:
STATUS
contiene semplicemente un riferimento al tag paragrafo in cui scriverai gli aggiornamenti di stato.VIDEO
contiene un riferimento all'elemento video HTML che visualizzerà il feed della webcam.ENABLE_CAM_BUTTON
,RESET_BUTTON
eTRAIN_BUTTON
recuperano i riferimenti DOM a tutti i pulsanti chiave dalla pagina HTML.MOBILE_NET_INPUT_WIDTH
eMOBILE_NET_INPUT_HEIGHT
definiscono rispettivamente la larghezza e l'altezza di input previste del modello MobileNet. Archiviandoli in una costante nella parte superiore del file come questa, se in un secondo momento decidi di utilizzare una versione diversa, risulta più facile aggiornare i valori una volta anziché sostituirli in molti punti diversi.STOP_DATA_GATHER
è impostato su - 1. In questo modo viene memorizzato un valore relativo allo stato che ti consente di sapere quando l'utente ha smesso di fare clic su un pulsante per raccogliere i dati dal feed della webcam. Assegnando a questo numero un nome più significativo, il codice sarà più leggibile in un secondo momento.CLASS_NAMES
funge da ricerca e contiene i nomi leggibili da una persona per le possibili previsioni della classe. Questo array verrà compilato in un secondo momento.
Ora che hai riferimenti agli elementi chiave, è il momento di associare alcuni listener di eventi.
Aggiungere listener di eventi chiave
Per iniziare, aggiungi i gestori di eventi di clic ai pulsanti chiave, come mostrato di seguito:
script.js
ENABLE_CAM_BUTTON.addEventListener('click', enableCam);
TRAIN_BUTTON.addEventListener('click', trainAndPredict);
RESET_BUTTON.addEventListener('click', reset);
function enableCam() {
// TODO: Fill this out later in the codelab!
}
function trainAndPredict() {
// TODO: Fill this out later in the codelab!
}
function reset() {
// TODO: Fill this out later in the codelab!
}
ENABLE_CAM_BUTTON
: richiama la funzione EnableCam quando viene fatto clic su di esso.
TRAIN_BUTTON
: chiama trainAndPredict quando viene fatto clic.
RESET_BUTTON
: le chiamate vengono reimpostate quando viene fatto clic sopra.
Infine, in questa sezione puoi trovare tutti i pulsanti
che hanno una classe "dataCollector" utilizzando document.querySelectorAll()
. Viene restituito un array di elementi trovati nel documento che corrispondono:
script.js
let dataCollectorButtons = document.querySelectorAll('button.dataCollector');
for (let i = 0; i < dataCollectorButtons.length; i++) {
dataCollectorButtons[i].addEventListener('mousedown', gatherDataForClass);
dataCollectorButtons[i].addEventListener('mouseup', gatherDataForClass);
// Populate the human readable names for classes.
CLASS_NAMES.push(dataCollectorButtons[i].getAttribute('data-name'));
}
function gatherDataForClass() {
// TODO: Fill this out later in the codelab!
}
Spiegazione del codice:
Puoi quindi eseguire l'iterazione dei pulsanti trovati e associare a ciascuno due listener di eventi. Una per "mousedown" e una per "mouseup". In questo modo puoi continuare a registrare i campioni finché il pulsante viene premuto, il che è utile per la raccolta dei dati.
Entrambi gli eventi chiamano una funzione gatherDataForClass
che definirai in seguito.
A questo punto, puoi anche inviare i nomi delle classi leggibili da una persona dall'attributo data-name del pulsante HTML all'array CLASS_NAMES
.
Poi, aggiungi alcune variabili per archiviare gli elementi chiave che verranno utilizzati in seguito.
script.js
let mobilenet = undefined;
let gatherDataState = STOP_DATA_GATHER;
let videoPlaying = false;
let trainingDataInputs = [];
let trainingDataOutputs = [];
let examplesCount = [];
let predict = false;
Vediamole nel dettaglio.
Innanzitutto, hai una variabile mobilenet
per archiviare il modello mobilenet caricato. Impostalo inizialmente su non definito.
Poi abbiamo una variabile chiamata gatherDataState
. Se un oggetto "dataCollector" , questo diventa l'1 hot ID di quel pulsante, come definito nel codice HTML, per consentirti di sapere quale classe di dati stai raccogliendo in quel momento. Inizialmente, il valore è impostato su STOP_DATA_GATHER
, in modo che il loop di raccolta dei dati che scrivi in seguito non raccolga dati quando non viene premuto alcun pulsante.
videoPlaying
consente di sapere se lo stream con la webcam viene caricato e in riproduzione correttamente e se è disponibile per l'uso. Inizialmente l'opzione è impostata su false
perché la webcam non è accesa finché non premi il pulsante ENABLE_CAM_BUTTON.
Poi definisci due array, trainingDataInputs
e trainingDataOutputs
. Queste memorizzano i valori dei dati di addestramento raccolti, quando fai clic su "dataCollector" i pulsanti per le caratteristiche di input generate dal modello base MobileNet e la classe di output campionata rispettivamente.
Viene quindi definito un array finale, examplesCount,
, per tenere traccia di quanti esempi sono contenuti per ogni classe una volta che inizi ad aggiungerli.
Infine, abbiamo una variabile denominata predict
che controlla il ciclo di previsione. Il valore iniziale è false
. Non è possibile fare previsioni finché questo valore non viene impostato su true
in un secondo momento.
Ora che tutte le variabili chiave sono state definite, carichiamo il modello base MobileNet v3 pre-triturato, che fornisce i vettori delle caratteristiche delle immagini anziché le classificazioni.
9. Carica il modello di base MobileNet
Innanzitutto, definisci una nuova funzione denominata loadMobileNetFeatureModel
come mostrato di seguito. Deve essere una funzione asincrona, poiché l'atto di caricamento del modello è asincrono:
script.js
/**
* Loads the MobileNet model and warms it up so ready for use.
**/
async function loadMobileNetFeatureModel() {
const URL =
'https://tfhub.dev/google/tfjs-model/imagenet/mobilenet_v3_small_100_224/feature_vector/5/default/1';
mobilenet = await tf.loadGraphModel(URL, {fromTFHub: true});
STATUS.innerText = 'MobileNet v3 loaded successfully!';
// Warm up the model by passing zeros through it once.
tf.tidy(function () {
let answer = mobilenet.predict(tf.zeros([1, MOBILE_NET_INPUT_HEIGHT, MOBILE_NET_INPUT_WIDTH, 3]));
console.log(answer.shape);
});
}
// Call the function immediately to start loading.
loadMobileNetFeatureModel();
In questo codice definisci il URL
in cui si trova il modello da caricare dalla documentazione di TFHub.
Potrai quindi caricare il modello utilizzando await tf.loadGraphModel()
, ricordando di impostare la proprietà speciale fromTFHub
su true
mentre carichi un modello da questo sito web Google. Si tratta di un caso speciale solo per l'utilizzo di modelli ospitati su TF Hub in cui deve essere impostata questa proprietà aggiuntiva.
Una volta completato il caricamento, puoi impostare l'elemento innerText
dell'elemento STATUS
con un messaggio, in modo da poter vedere che è stato caricato correttamente e iniziare a raccogliere i dati.
A questo punto rimane solo il riscaldamento del modello. Con modelli più grandi come questo, la prima volta che utilizzi il modello, la configurazione può richiedere alcuni istanti. Di conseguenza, è utile passare gli zeri nel modello per evitare attese in futuro, in cui il tempismo può essere più importante.
Puoi utilizzare tf.zeros()
aggregato in un tf.tidy()
per assicurarti che i tensori vengano smaltiti correttamente, con una dimensione del batch pari a 1 e l'altezza e la larghezza corrette che hai definito nelle costanti all'inizio. Infine, devi specificare anche i canali di colore, che in questo caso sono tre, in quanto il modello si aspetta immagini RGB.
Poi, registra la forma risultante del tensore restituito utilizzando answer.shape()
per aiutarti a comprendere le dimensioni delle caratteristiche dell'immagine prodotte da questo modello.
Dopo aver definito questa funzione, puoi chiamarla immediatamente per avviare il download del modello al caricamento pagina.
Se visualizzi l'anteprima in tempo reale adesso, dopo qualche istante il testo dello stato cambierà da "In attesa di caricamento di TF.js". per diventare "MobileNet v3 caricato correttamente!". come mostrato di seguito. Prima di continuare, assicurati che funzioni.
Puoi anche controllare l'output della console per vedere le dimensioni stampate delle caratteristiche di output prodotte da questo modello. Dopo aver eseguito gli zeri nel modello MobileNet, verrà visualizzata la forma [1, 1024]
stampata. Il primo elemento ha solo la dimensione del batch 1 e puoi vedere che restituisce effettivamente 1024 caratteristiche che possono essere utilizzate per aiutarti a classificare nuovi oggetti.
10. Definisci l'intestazione del nuovo modello
Ora è il momento di definire l'intestazione del modello, che è essenzialmente un perceptrone a più strati molto minima.
script.js
let model = tf.sequential();
model.add(tf.layers.dense({inputShape: [1024], units: 128, activation: 'relu'}));
model.add(tf.layers.dense({units: CLASS_NAMES.length, activation: 'softmax'}));
model.summary();
// Compile the model with the defined optimizer and specify a loss function to use.
model.compile({
// Adam changes the learning rate over time which is useful.
optimizer: 'adam',
// Use the correct loss function. If 2 classes of data, must use binaryCrossentropy.
// Else categoricalCrossentropy is used if more than 2 classes.
loss: (CLASS_NAMES.length === 2) ? 'binaryCrossentropy': 'categoricalCrossentropy',
// As this is a classification problem you can record accuracy in the logs too!
metrics: ['accuracy']
});
Esaminiamo questo codice. Inizierai definendo un modello tf.sequenziale a cui aggiungere gli strati del modello.
Quindi, aggiungi uno strato denso come strato di input a questo modello. Ha una forma di input 1024
perché gli output delle funzionalità MobileNet v3 sono di queste dimensioni. L'hai scoperto nel passaggio precedente, dopo averle passati attraverso il modello. Questo strato ha 128 neuroni che utilizzano la funzione di attivazione ReLU.
Se non hai mai utilizzato le funzioni di attivazione e i livelli dei modelli, valuta la possibilità di seguire il corso descritto all'inizio di questo workshop per capire cosa fanno queste proprietà dietro le quinte.
Il livello successivo da aggiungere è il livello di output. Il numero di neuroni deve essere uguale al numero di classi che stai cercando di prevedere. A questo scopo, puoi utilizzare CLASS_NAMES.length
per trovare quanti corsi vuoi classificare, che è uguale al numero di pulsanti per la raccolta dei dati presenti nell'interfaccia utente. Poiché si tratta di un problema di classificazione, su questo livello di output viene usata l'attivazione softmax
, che deve essere utilizzata quando si cerca di creare un modello per risolvere problemi di classificazione anziché la regressione.
Ora stampa un model.summary()
per stampare sulla console la panoramica del modello appena definito.
Infine, compila il modello in modo che sia pronto per essere addestrato. In questo caso, lo strumento di ottimizzazione è impostato su adam
e la perdita sarà binaryCrossentropy
se CLASS_NAMES.length
è uguale a 2
oppure utilizzerà categoricalCrossentropy
se sono presenti 3 o più classi da classificare. Vengono richieste anche metriche di accuratezza in modo che possano essere monitorate nei log in un secondo momento a scopo di debug.
Nella console dovresti vedere qualcosa di simile a questo:
Tieni presente che questo ha oltre 130.000 parametri addestrabili. Ma poiché si tratta di un semplice strato denso di neuroni regolari, l'addestramento avviene in tempi piuttosto rapidi.
Come attività da eseguire una volta completato il progetto, potresti provare a modificare il numero di neuroni nel primo strato per vedere quanto puoi ridurlo fino a ottenere prestazioni soddisfacenti. Spesso con il machine learning sono previsti dei tentativi per trovare i valori dei parametri ottimali e offrire il miglior compromesso tra utilizzo delle risorse e velocità.
11. Attiva la webcam
A questo punto, puoi arricchire la funzione enableCam()
che hai definito in precedenza. Aggiungi una nuova funzione denominata hasGetUserMedia()
come mostrato di seguito, quindi sostituisci i contenuti della funzione enableCam()
definita in precedenza con il codice corrispondente riportato di seguito.
script.js
function hasGetUserMedia() {
return !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
}
function enableCam() {
if (hasGetUserMedia()) {
// getUsermedia parameters.
const constraints = {
video: true,
width: 640,
height: 480
};
// Activate the webcam stream.
navigator.mediaDevices.getUserMedia(constraints).then(function(stream) {
VIDEO.srcObject = stream;
VIDEO.addEventListener('loadeddata', function() {
videoPlaying = true;
ENABLE_CAM_BUTTON.classList.add('removed');
});
});
} else {
console.warn('getUserMedia() is not supported by your browser');
}
}
In primo luogo, crea una funzione denominata hasGetUserMedia()
per verificare se il browser supporta getUserMedia()
controllando l'esistenza di proprietà chiave delle API del browser.
Nella funzione enableCam()
, utilizza la funzione hasGetUserMedia()
appena definita sopra per verificare se è supportata. In caso contrario, stampa un avviso sulla console.
Se supporta questa opzione, definisci alcuni vincoli per la chiamata a getUserMedia()
, ad esempio, se vuoi che sia il solo video stream e che width
del video sia di 640
pixel e height
sia di 480
pixel. Perché? Beh, non ha molto senso ottenere un video più grande di questo perché dovrebbe essere ridimensionato a 224 x 224 pixel per essere inserito nel modello MobileNet. Potresti anche risparmiare alcune risorse di calcolo richiedendo una risoluzione inferiore. La maggior parte delle fotocamere supporta una risoluzione di queste dimensioni.
Successivamente, chiama navigator.mediaDevices.getUserMedia()
specificando il constraints
indicato sopra e attendi che venga restituito il stream
. Una volta restituito l'elemento stream
, puoi fare in modo che l'elemento VIDEO
riproduca stream
impostandolo come valore srcObject
.
Devi anche aggiungere un eventListener sull'elemento VIDEO
per sapere quando stream
si è caricato e viene riprodotto correttamente.
Una volta caricato lo stream, puoi impostare videoPlaying
su true e rimuovere ENABLE_CAM_BUTTON
per evitare che venga ripreso a fare clic impostando la relativa classe su "removed
".
Ora esegui il codice, fai clic sul pulsante Attiva fotocamera e consenti l'accesso alla webcam. Se è la prima volta che esegui questa operazione, dovresti vederti visualizzato nell'elemento video della pagina, come mostrato di seguito:
Ok, ora è il momento di aggiungere una funzione per gestire i clic sul pulsante dataCollector
.
12. Gestore di eventi del pulsante Raccolta dati
Ora è il momento di compilare la funzione attualmente vuota denominata gatherDataForClass().
. Questa è quella che hai assegnato come funzione di gestore di eventi per i pulsanti dataCollector
all'inizio del codelab.
script.js
/**
* Handle Data Gather for button mouseup/mousedown.
**/
function gatherDataForClass() {
let classNumber = parseInt(this.getAttribute('data-1hot'));
gatherDataState = (gatherDataState === STOP_DATA_GATHER) ? classNumber : STOP_DATA_GATHER;
dataGatherLoop();
}
Innanzitutto, controlla l'attributo data-1hot
sul pulsante attualmente selezionato chiamando this.getAttribute()
con il nome dell'attributo, in questo caso data-1hot
come parametro. Poiché questa è una stringa, puoi quindi usare parseInt()
per trasmetterla a un numero intero e assegnare questo risultato a una variabile denominata classNumber.
A questo punto, imposta la variabile gatherDataState
di conseguenza. Se il valore attuale di gatherDataState
è uguale a STOP_DATA_GATHER
(che hai impostato su -1), significa che al momento non stai raccogliendo dati ed è stato attivato un evento mousedown
. Imposta gatherDataState
in modo che sia il classNumber
che hai appena trovato.
In caso contrario, significa che stai raccogliendo dati e l'evento che è stato attivato era un evento mouseup
, perciò vuoi interrompere la raccolta dei dati per quella classe. È sufficiente reimpostarlo sullo stato STOP_DATA_GATHER
per terminare il loop di raccolta dei dati che definirai a breve.
Infine, avvia la chiamata al numero dataGatherLoop(),
che esegue effettivamente la registrazione dei dati del corso.
13. Raccolta dei dati
Ora definisci la funzione dataGatherLoop()
. Questa funzione si occupa del campionamento delle immagini dal video della webcam, del loro passaggio attraverso il modello MobileNet e dell'acquisizione degli output di tale modello (vettori di caratteristiche 1024).
Quindi le memorizza insieme all'ID gatherDataState
del pulsante attualmente premuto, in modo che tu possa sapere quale classe rappresenta questi dati.
Esaminiamolo:
script.js
function dataGatherLoop() {
if (videoPlaying && gatherDataState !== STOP_DATA_GATHER) {
let imageFeatures = tf.tidy(function() {
let videoFrameAsTensor = tf.browser.fromPixels(VIDEO);
let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [MOBILE_NET_INPUT_HEIGHT,
MOBILE_NET_INPUT_WIDTH], true);
let normalizedTensorFrame = resizedTensorFrame.div(255);
return mobilenet.predict(normalizedTensorFrame.expandDims()).squeeze();
});
trainingDataInputs.push(imageFeatures);
trainingDataOutputs.push(gatherDataState);
// Intialize array index element if currently undefined.
if (examplesCount[gatherDataState] === undefined) {
examplesCount[gatherDataState] = 0;
}
examplesCount[gatherDataState]++;
STATUS.innerText = '';
for (let n = 0; n < CLASS_NAMES.length; n++) {
STATUS.innerText += CLASS_NAMES[n] + ' data count: ' + examplesCount[n] + '. ';
}
window.requestAnimationFrame(dataGatherLoop);
}
}
Continuerai l'esecuzione di questa funzione solo se videoPlaying
è true, il che significa che la webcam è attiva, e gatherDataState
è diverso da STOP_DATA_GATHER
e al momento è premuto un pulsante per la raccolta dei dati della classe.
Aggrega il codice in un tf.tidy()
per eliminare eventuali tensori creati nel codice seguente. Il risultato di questa esecuzione di codice di tf.tidy()
è archiviato in una variabile denominata imageFeatures
.
Ora puoi scattare un'inquadratura della webcam VIDEO
utilizzando tf.browser.fromPixels()
. Il tensore risultante contenente i dati dell'immagine viene archiviato in una variabile denominata videoFrameAsTensor
.
A questo punto, ridimensiona la variabile videoFrameAsTensor
in modo che abbia la forma corretta per l'input del modello MobileNet. Usa una chiamata tf.image.resizeBilinear()
con il tensore che vuoi rimodellare come primo parametro, quindi una forma che definisce la nuova altezza e larghezza come definite dalle costanti già create in precedenza. Infine, imposta Allinea gli angoli su true passando il terzo parametro per evitare problemi di allineamento durante il ridimensionamento. Il risultato di questo ridimensionamento viene archiviato in una variabile denominata resizedTensorFrame
.
Tieni presente che questo ridimensionamento primitivo estende l'immagine, poiché l'immagine della webcam ha una dimensione di 640 x 480 pixel e il modello ha bisogno di un'immagine quadrata di 224 x 224 pixel.
Ai fini di questa demo dovrebbe funzionare correttamente. Tuttavia, una volta completato questo codelab, ti consigliamo di provare a ritagliare un quadrato da questa immagine per ottenere risultati ancora migliori per qualsiasi sistema di produzione che potresti creare in un secondo momento.
Poi, normalizza i dati dell'immagine. I dati dell'immagine sono sempre compresi tra 0 e 255 quando utilizzi tf.browser.frompixels()
, quindi puoi semplicemente dividere ridimensionatoTensorFrame per 255 per assicurarti che tutti i valori siano compresi tra 0 e 1, che è ciò che il modello MobileNet prevede come input.
Infine, nella sezione tf.tidy()
del codice, esegui il push di questo tensore normalizzato attraverso il modello caricato chiamando mobilenet.predict()
, a cui passi la versione espansa di normalizedTensorFrame
utilizzando expandDims()
, in modo che sia un batch di 1, poiché il modello prevede un batch di input per l'elaborazione.
Una volta restituito il risultato, puoi chiamare immediatamente squeeze()
su quel risultato restituito per ridurlo di nuovo a un tensore monodimensionale, che poi restituisci e assegni alla variabile imageFeatures
che acquisisce il risultato da tf.tidy()
.
Ora che hai imageFeatures
dal modello MobileNet, puoi registrarli trasferendoli all'array trainingDataInputs
che hai definito in precedenza.
Puoi anche registrare ciò che questo input rappresenta eseguendo il push dell'elemento gatherDataState
corrente anche all'array trainingDataOutputs
.
Tieni presente che la variabile gatherDataState
sarebbe stata impostata sull'ID numerico della classe corrente per cui stai registrando i dati quando è stato fatto clic sul pulsante nella funzione gatherDataForClass()
definita in precedenza.
A questo punto puoi anche incrementare il numero di esempi a tua disposizione per una determinata classe. Per farlo, verifica innanzitutto se l'indice all'interno dell'array examplesCount
è stato inizializzato in precedenza o meno. Se non è definito, impostalo su 0 per inizializzare il contatore per l'ID numerico di una determinata classe, quindi puoi incrementare il valore examplesCount
per il gatherDataState
corrente.
Ora aggiorna il testo dell'elemento STATUS
sulla pagina web per mostrare i conteggi attuali per ogni corso man mano che vengono acquisiti. Per farlo, analizza l'array CLASS_NAMES
e stampa il nome leggibile combinato con il conteggio dei dati per lo stesso indice in examplesCount
.
Infine, richiama window.requestAnimationFrame()
con dataGatherLoop
passato come parametro, per richiamare di nuovo questa funzione in modo ricorsivo. I fotogrammi di campionamento del video verranno mantenuti fino al rilevamento del valore mouseup
del pulsante. gatherDataState
è impostato su STOP_DATA_GATHER,
; a quel punto, il loop di raccolta dei dati terminerà.
Se esegui il codice ora, dovresti essere in grado di fare clic sul pulsante Abilita fotocamera, attendere il caricamento della webcam, quindi fare clic e tenere premuto ogni pulsante per la raccolta dei dati per raccogliere esempi per ogni classe di dati. Qui vediamo mentre raccolgo dati, rispettivamente, per il cellulare e la mano.
Dovresti vedere il testo dello stato aggiornato man mano che tutti i tensori vengono memorizzati nella memoria, come mostrato nell'acquisizione schermo qui sopra.
14. Addestra e prevedi
Il passaggio successivo consiste nell'implementare il codice per la funzione trainAndPredict()
attualmente vuota, dove avviene il Transfer Learning. Diamo un'occhiata al codice:
script.js
async function trainAndPredict() {
predict = false;
tf.util.shuffleCombo(trainingDataInputs, trainingDataOutputs);
let outputsAsTensor = tf.tensor1d(trainingDataOutputs, 'int32');
let oneHotOutputs = tf.oneHot(outputsAsTensor, CLASS_NAMES.length);
let inputsAsTensor = tf.stack(trainingDataInputs);
let results = await model.fit(inputsAsTensor, oneHotOutputs, {shuffle: true, batchSize: 5, epochs: 10,
callbacks: {onEpochEnd: logProgress} });
outputsAsTensor.dispose();
oneHotOutputs.dispose();
inputsAsTensor.dispose();
predict = true;
predictLoop();
}
function logProgress(epoch, logs) {
console.log('Data for epoch ' + epoch, logs);
}
Innanzitutto, assicurati di interrompere l'esecuzione di eventuali previsioni correnti impostando predict
su false
.
Poi, esegui lo shuffling degli array di input e output utilizzando tf.util.shuffleCombo()
per assicurarti che l'ordine non causi problemi durante l'addestramento.
Converti l'array di output, trainingDataOutputs,
, in un tensor1d di tipo int32, in modo che sia pronto per essere utilizzato in una codifica one-hot. Questo viene memorizzato in una variabile denominata outputsAsTensor
.
Usa la funzione tf.oneHot()
con questa variabile outputsAsTensor
insieme al numero massimo di classi da codificare, che è solo il CLASS_NAMES.length
. I tuoi output con codifica a caldo vengono ora memorizzati in un nuovo tensore chiamato oneHotOutputs
.
Tieni presente che attualmente trainingDataInputs
è un array di tensori registrati. Per utilizzarli per l'addestramento, dovrai convertire l'array di tensori in un normale tensore 2D.
A questo scopo, nella libreria TensorFlow.js c'è un'ottima funzione chiamata tf.stack()
,
che prende un array di tensori e li impila per produrre un tensore di dimensione maggiore come output. In questo caso viene restituito un tensore 2D, ovvero un batch di input dimensionali 1 di lunghezza ciascuno contenente le caratteristiche registrate, che è ciò di cui hai bisogno per l'addestramento.
Poi, await model.fit()
per addestrare l'intestazione del modello personalizzato. Qui passi la variabile inputsAsTensor
insieme a oneHotOutputs
per rappresentare i dati di addestramento da utilizzare, rispettivamente, come input e output target. Nell'oggetto di configurazione del terzo parametro, imposta shuffle
su true
, utilizza batchSize
di 5
, con epochs
impostato su 10
, quindi specifica un valore callback
per onEpochEnd
nella funzione logProgress
che definirai a breve.
Infine, puoi eliminare i tensori creati quando il modello viene addestrato. Puoi quindi impostare nuovamente predict
su true
per consentire nuovamente l'esecuzione delle previsioni, quindi chiamare la funzione predictLoop()
per iniziare a prevedere le immagini della webcam in diretta.
Puoi anche definire la funzione logProcess()
per registrare lo stato dell'addestramento, che viene utilizzato in model.fit()
sopra e che stampa i risultati nella console dopo ogni ciclo di addestramento.
Ci sei quasi. È il momento di aggiungere la funzione predictLoop()
per effettuare previsioni.
Ciclo di previsione principale
Qui implementi il ciclo di previsione principale, che campiona i frame da una webcam e prevede continuamente cosa c'è in ogni frame con risultati in tempo reale nel browser.
Controlliamo il codice:
script.js
function predictLoop() {
if (predict) {
tf.tidy(function() {
let videoFrameAsTensor = tf.browser.fromPixels(VIDEO).div(255);
let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor,[MOBILE_NET_INPUT_HEIGHT,
MOBILE_NET_INPUT_WIDTH], true);
let imageFeatures = mobilenet.predict(resizedTensorFrame.expandDims());
let prediction = model.predict(imageFeatures).squeeze();
let highestIndex = prediction.argMax().arraySync();
let predictionArray = prediction.arraySync();
STATUS.innerText = 'Prediction: ' + CLASS_NAMES[highestIndex] + ' with ' + Math.floor(predictionArray[highestIndex] * 100) + '% confidence';
});
window.requestAnimationFrame(predictLoop);
}
}
Per prima cosa, verifica che predict
sia vero, in modo che le previsioni vengano effettuate solo dopo che il modello è stato addestrato ed è disponibile per l'uso.
In seguito, puoi ottenere le caratteristiche dell'immagine per l'immagine corrente proprio come hai fatto nella funzione dataGatherLoop()
. In sostanza, si estrae un fotogramma dalla webcam utilizzando tf.browser.from pixels()
, lo normalizzi, lo ridimensiona in modo che abbia una dimensione di 224 x 224 pixel e quindi passi i dati attraverso il modello MobileNet per ottenere le caratteristiche dell'immagine risultanti.
Ora, tuttavia, puoi utilizzare l'intestazione del modello appena addestrato per eseguire effettivamente una previsione passando il valore imageFeatures
risultante appena trovato tramite la funzione predict()
del modello addestrato. Puoi quindi comprimere il tensore risultante per renderlo di nuovo unidimensionale e assegnarlo a una variabile chiamata prediction
.
Con questo prediction
puoi trovare l'indice con il valore più alto utilizzando argMax()
e quindi convertirlo in un array utilizzando arraySync()
per ottenere i dati sottostanti in JavaScript e scoprire la posizione dell'elemento di valore più elevato. Questo valore viene memorizzato nella variabile denominata highestIndex
.
Puoi anche ottenere i punteggi effettivi di confidenza della previsione allo stesso modo chiamando direttamente arraySync()
sul tensore prediction
.
Ora hai tutto il necessario per aggiornare il testo STATUS
con i dati di prediction
. Per ottenere la stringa leggibile per la classe, puoi cercare highestIndex
nell'array CLASS_NAMES
, quindi recuperare il valore di confidenza da predictionArray
. Per rendere il testo più leggibile in termini percentuali, è sufficiente moltiplicare per 100 e math.floor()
il risultato.
Infine, puoi utilizzare window.requestAnimationFrame()
per chiamare predictionLoop()
di nuovo quando è tutto pronto, in modo da ottenere la classificazione in tempo reale sul tuo video stream. Questo procedimento continua fino a quando predict
non viene impostato su false
se scegli di addestrare un nuovo modello con nuovi dati.
Questo ti porta all'ultimo pezzo del puzzle. Implementazione del pulsante di ripristino.
15. Implementare il pulsante di ripristino
Hai quasi finito. L'ultimo pezzo del puzzle consiste nell'implementare un pulsante di ripristino per ricominciare. Di seguito è riportato il codice per la funzione reset()
attualmente vuota. Procedi e aggiornala come segue:
script.js
/**
* Purge data and start over. Note this does not dispose of the loaded
* MobileNet model and MLP head tensors as you will need to reuse
* them to train a new model.
**/
function reset() {
predict = false;
examplesCount.length = 0;
for (let i = 0; i < trainingDataInputs.length; i++) {
trainingDataInputs[i].dispose();
}
trainingDataInputs.length = 0;
trainingDataOutputs.length = 0;
STATUS.innerText = 'No data collected';
console.log('Tensors in memory: ' + tf.memory().numTensors);
}
Innanzitutto, interrompi eventuali loop di previsione in esecuzione impostando predict
su false
. Successivamente, elimina tutti i contenuti nell'array examplesCount
impostandone la lunghezza su 0, in modo da poter cancellare tutti i contenuti da un array.
Ora esamina tutti i trainingDataInputs
registrati attuali e assicurati di dispose()
di ogni tensore contenuto al suo interno per liberare di nuovo la memoria, poiché i tensori non vengono ripuliti dal garbage collector di JavaScript.
Al termine, puoi tranquillamente impostare la lunghezza dell'array su 0 su entrambi gli array trainingDataInputs
e trainingDataOutputs
per cancellarli anche questi.
Infine, imposta il testo STATUS
su qualcosa di ragionevole e stampa i tensori rimasti in memoria per verificare lo stato di integrità.
Tieni presente che ci saranno ancora alcune centinaia di tensori in memoria, dato che sia il modello MobileNet che il perceptrone multistrato che hai definito non sono stati eliminati. Se decidi di ripetere l'addestramento dopo il ripristino, dovrai riutilizzarli con i nuovi dati di addestramento.
16. Proviamo
È il momento di provare la tua versione di Teachable Machine.
Vai all’anteprima live, attiva la webcam, raccogli almeno 30 campioni per la classe 1 di un oggetto nella tua stanza, quindi ripeti lo stesso per un oggetto diverso per la classe 2, fai clic su Addestra e controlla il registro della console per vedere i progressi. Dovrebbe essere addestrato abbastanza rapidamente:
Al termine dell'addestramento, mostra gli oggetti alla videocamera per ottenere le previsioni in tempo reale che verranno stampate nell'area di testo dello stato sulla pagina web in alto. Se hai difficoltà, controlla il codice funzionante completato per vedere se hai perso qualche copia.
17. Complimenti
Complimenti! Hai appena completato il tuo primo esempio di transfer learning utilizzando TensorFlow.js live nel browser.
Provala, testala su una varietà di oggetti e potresti notare che alcune cose sono più difficili da riconoscere di altre, soprattutto se sono simili ad altre. Potresti dover aggiungere altri corsi o dati di addestramento per poterli distinguere.
Riepilogo
In questo codelab hai appreso:
- Che cos'è il Transfer Learning e i suoi vantaggi rispetto all'addestramento di un modello completo.
- Come recuperare modelli da riutilizzare da TensorFlow Hub.
- Come configurare un'app web adatta al Transfer Learning.
- Come caricare e utilizzare un modello di base per generare caratteristiche immagine.
- Come addestrare una nuova testa di previsione in grado di riconoscere gli oggetti personalizzati dalle immagini della webcam.
- Come utilizzare i modelli risultanti per classificare i dati in tempo reale.
Passaggi successivi
Ora che hai una base operativa, quali idee creative puoi trovare per ampliare questo modello boilerplate di machine learning per un caso d'uso reale a cui potresti lavorare? Forse potresti rivoluzionare il settore in cui lavori per aiutare le persone della tua azienda ad addestrare i modelli a classificare gli aspetti importanti nel loro lavoro quotidiano? Le possibilità sono infinite.
Per andare oltre, prendi in considerazione di seguire questo corso completo senza costi, che ti mostra come combinare i due modelli attualmente presenti in questo codelab in un unico modello per ottenere maggiore efficienza.
Inoltre, se vuoi saperne di più sulla teoria alla base dell'applicazione Teachable Machine originale, segui questo tutorial.
Condividi con noi ciò che crei
Puoi facilmente estendere ciò che hai realizzato oggi anche per altri casi d'uso creativi. Ti incoraggiamo a pensare in modo originale e a continuare ad hackerare.
Ricordati di taggarci sui social media utilizzando l'hashtag #MadeWithTFJS per avere la possibilità che il tuo progetto venga inserito nel blog di TensorFlow o addirittura eventi futuri. Ci piacerebbe vedere cosa realizzi.
Siti web da controllare
- Sito web ufficiale TensorFlow.js
- Modelli predefiniti TensorFlow.js
- API TensorFlow.js
- TensorFlow.js Show & Racconta: lasciati ispirare e scopri cosa hanno realizzato gli altri.