TensorFlow.js: créer votre propre machine Teachable Machine à l'aide de l'apprentissage par transfert avec TensorFlow.js

1. Avant de commencer

L'utilisation des modèles TensorFlow.js a augmenté de manière exponentielle au cours des dernières années. De nombreux développeurs JavaScript cherchent désormais à prendre des modèles de pointe existants et à les réentraîner pour qu'ils fonctionnent avec des données personnalisées propres à leur secteur. L'apprentissage par transfert consiste à prendre un modèle existant (souvent appelé modèle de base) et à l'utiliser dans un domaine similaire, mais différent.

L'apprentissage par transfert présente de nombreux avantages par rapport à la création d'un modèle entièrement vierge. Vous pouvez réutiliser les connaissances déjà acquises à partir d'un modèle entraîné précédemment et vous avez besoin de moins d'exemples du nouvel élément que vous souhaitez classer. De plus, l'entraînement est souvent beaucoup plus rapide, car il ne faut réentraîner que les dernières couches de l'architecture du modèle au lieu de l'ensemble du réseau. C'est pourquoi l'apprentissage par transfert est particulièrement adapté à l'environnement du navigateur Web, où les ressources peuvent varier en fonction de l'appareil d'exécution, mais où l'accès direct aux capteurs permet d'acquérir facilement des données.

Cet atelier de programmation vous explique comment créer une application Web à partir d'une page vierge, en recréant le site Web populaire Teachable Machine de Google. Le site Web vous permet de créer une application Web fonctionnelle que n'importe quel utilisateur peut utiliser pour reconnaître un objet personnalisé avec seulement quelques exemples d'images provenant de sa webcam. Le site Web est volontairement minimaliste pour que vous puissiez vous concentrer sur les aspects de machine learning de cet atelier de programmation. Toutefois, comme pour le site Web Teachable Machine d'origine, vous pouvez mettre à profit votre expérience de développeur Web pour améliorer l'UX.

Prérequis

Cet atelier de programmation s'adresse aux développeurs Web qui connaissent déjà un peu les modèles prédéfinis TensorFlow.js et l'utilisation de base de l'API, et qui souhaitent se lancer dans l'apprentissage par transfert dans TensorFlow.js.

  • Cet atelier suppose que vous connaissez les bases de TensorFlow.js, HTML5, CSS et JavaScript.

Si vous débutez avec TensorFlow.js, suivez d'abord ce cours sans frais pour débutants. Il ne nécessite aucune connaissance préalable du machine learning ni de TensorFlow.js, et vous apprend tout ce que vous devez savoir en plusieurs étapes.

Points abordés

  • Découvrez ce qu'est TensorFlow.js et pourquoi vous devriez l'utiliser dans votre prochaine application Web.
  • Découvrez comment créer une page Web HTML/CSS /JS simplifiée qui reproduit l'expérience utilisateur de Teachable Machine.
  • Utiliser TensorFlow.js pour charger un modèle de base pré-entraîné, en particulier MobileNet, afin de générer des caractéristiques d'image pouvant être utilisées dans l'apprentissage par transfert.
  • Comment collecter des données à partir de la webcam d'un utilisateur pour plusieurs classes de données que vous souhaitez reconnaître.
  • Découvrez comment créer et définir un perceptron multicouche qui prend les caractéristiques de l'image et apprend à classer de nouveaux objets à l'aide de celles-ci.

C'est parti pour le piratage !

Prérequis

  • Pour suivre ce tutoriel, nous vous recommandons d'utiliser un compte Glitch.com. Vous pouvez également utiliser un environnement de service Web que vous pouvez modifier et exécuter vous-même.

2. Qu'est-ce que TensorFlow.js ?

54e81d02971f53e8.png

TensorFlow.js est une bibliothèque de machine learning Open Source qui peut s'exécuter partout où JavaScript le peut. Elle est basée sur la bibliothèque TensorFlow d'origine écrite en Python et vise à recréer cette expérience de développement et cet ensemble d'API pour l'écosystème JavaScript.

Où peut-on l'utiliser ?

Grâce à la portabilité de JavaScript, vous pouvez désormais écrire dans une seule langue et effectuer du machine learning sur toutes les plates-formes suivantes en toute simplicité :

  • Côté client dans le navigateur Web à l'aide de JavaScript vanilla
  • Côté serveur et même sur des appareils IoT comme Raspberry Pi à l'aide de Node.js
  • Applications de bureau utilisant Electron
  • Applications mobiles natives utilisant React Native

TensorFlow.js est également compatible avec plusieurs backends dans chacun de ces environnements (les environnements matériels réels dans lesquels il peut s'exécuter, tels que le processeur ou WebGL, par exemple). Dans ce contexte, le terme "backend" ne désigne pas un environnement côté serveur (le backend d'exécution peut être côté client dans WebGL, par exemple) pour assurer la compatibilité et maintenir la rapidité d'exécution. TensorFlow.js est actuellement compatible avec :

  • Exécution WebGL sur la carte graphique (GPU) de l'appareil : il s'agit du moyen le plus rapide d'exécuter des modèles plus volumineux (de plus de 3 Mo) avec l'accélération GPU.
  • Exécution WebAssembly (WASM) sur le processeur : pour améliorer les performances du processeur sur les appareils, y compris les téléphones mobiles de l'ancienne génération, par exemple. Cette approche est mieux adaptée aux modèles plus petits (moins de 3 Mo), qui peuvent en fait s'exécuter plus rapidement sur le processeur avec WASM qu'avec WebGL en raison de la surcharge liée à l'importation de contenu dans un processeur graphique.
  • Exécution du CPU : solution de repli si aucun autre environnement n'est disponible. Il s'agit de la plus lente des trois, mais elle est toujours disponible.

Remarque : Vous pouvez choisir d'imposer l'un de ces backends si vous savez sur quel appareil vous allez exécuter le code, ou vous pouvez simplement laisser TensorFlow.js décider pour vous si vous ne le spécifiez pas.

Super-pouvoirs côté client

L'exécution de TensorFlow.js dans le navigateur Web sur la machine cliente peut présenter plusieurs avantages à prendre en compte.

Confidentialité

Vous pouvez entraîner et classer des données sur la machine cliente sans jamais envoyer de données à un serveur Web tiers. Il peut arriver que cela soit nécessaire pour se conformer aux lois locales, comme le RGPD par exemple, ou lors du traitement de données que l'utilisateur peut souhaiter conserver sur sa machine et ne pas envoyer à un tiers.

Speed

Comme vous n'avez pas à envoyer de données à un serveur distant, l'inférence (l'acte de classification des données) peut être plus rapide. Mieux encore, vous avez un accès direct aux capteurs de l'appareil (caméra, micro, GPS, accéléromètre, etc.) si l'utilisateur vous y autorise.

Couverture et évolutivité

En un clic, n'importe qui dans le monde peut cliquer sur un lien que vous lui envoyez, ouvrir la page Web dans son navigateur et utiliser ce que vous avez créé. Vous n'avez pas besoin d'une configuration Linux complexe côté serveur avec des pilotes CUDA et bien plus encore pour utiliser le système de machine learning.

Coût

Sans serveur, vous n'avez besoin que d'un CDN pour héberger vos fichiers HTML, CSS, JS et de modèle. Le coût d'un CDN est beaucoup moins élevé que celui d'un serveur fonctionnant 24h/24 et 7j/7 (potentiellement avec une carte graphique).

Fonctionnalités côté serveur

L'utilisation de l'implémentation Node.js de TensorFlow.js permet les fonctionnalités suivantes.

Prise en charge complète de CUDA

Côté serveur, pour l'accélération de la carte graphique, vous devez installer les pilotes NVIDIA CUDA pour permettre à TensorFlow de fonctionner avec la carte graphique (contrairement au navigateur qui utilise WebGL et ne nécessite aucune installation). Toutefois, avec la prise en charge complète de CUDA, vous pouvez exploiter pleinement les capacités de bas niveau de la carte graphique, ce qui permet d'accélérer les temps d'entraînement et d'inférence. Les performances sont équivalentes à celles de l'implémentation Python de TensorFlow, car elles partagent le même backend C++.

Taille du modèle

Pour les modèles de pointe issus de la recherche, vous pouvez travailler avec des modèles très volumineux, peut-être de plusieurs gigaoctets. Ces modèles ne peuvent pas être exécutés dans le navigateur Web pour le moment en raison des limites d'utilisation de la mémoire par onglet de navigateur. Pour exécuter ces modèles plus volumineux, vous pouvez utiliser Node.js sur votre propre serveur avec les spécifications matérielles requises pour exécuter efficacement un tel modèle.

IOT

Node.js est compatible avec les ordinateurs monocartes populaires tels que Raspberry Pi, ce qui signifie que vous pouvez également exécuter des modèles TensorFlow.js sur ces appareils.

Speed

Node.js est écrit en JavaScript, ce qui signifie qu'il bénéficie de la compilation juste-à-temps. Cela signifie que vous pouvez souvent constater des améliorations de performances lorsque vous utilisez Node.js, car il sera optimisé au moment de l'exécution, en particulier pour tout prétraitement que vous pourriez effectuer. Un excellent exemple est présenté dans cette étude de cas, qui montre comment Hugging Face a utilisé Node.js pour doubler les performances de son modèle de traitement du langage naturel.

Maintenant que vous connaissez les bases de TensorFlow.js, où il peut s'exécuter et certains de ses avantages, commençons à faire des choses utiles avec !

3. Apprentissage par transfert

Qu'est-ce que l'apprentissage par transfert ?

L'apprentissage par transfert consiste à utiliser des connaissances déjà acquises pour en apprendre d'autres, différentes mais similaires.

Nous, les humains, le faisons tout le temps. Votre cerveau contient une vie entière d'expériences que vous pouvez utiliser pour vous aider à reconnaître de nouvelles choses que vous n'avez jamais vues auparavant. Prenons cet exemple de saule pleureur :

e28070392cd4afb9.png

Selon l'endroit où vous vous trouvez dans le monde, il est possible que vous n'ayez jamais vu ce type d'arbre auparavant.

Pourtant, si je vous demande de me dire s'il y a des saules dans la nouvelle image ci-dessous, vous pourrez probablement les repérer assez rapidement, même s'ils sont sous un angle différent et légèrement différents de ceux de l'image d'origine que je vous ai montrée.

d9073a0d5df27222.png

Vous avez déjà un tas de neurones dans votre cerveau qui savent identifier les objets en forme d'arbre, et d'autres qui sont doués pour trouver les longues lignes droites. Vous pouvez réutiliser ces connaissances pour classer rapidement un saule pleureur, qui est un objet en forme d'arbre avec de nombreuses branches verticales longues et droites.

De même, si vous disposez d'un modèle de machine learning déjà entraîné sur un domaine, comme la reconnaissance d'images, vous pouvez le réutiliser pour effectuer une tâche différente, mais connexe.

Vous pouvez faire de même avec un modèle avancé comme MobileNet, qui est un modèle de recherche très populaire capable de reconnaître des images de 1 000 types d'objets différents. Il a été entraîné sur un vaste ensemble de données appelé ImageNet, qui contient des millions d'images étiquetées (de chiens, de voitures, etc.).

Dans cette animation, vous pouvez voir le grand nombre de couches qu'il contient dans ce modèle MobileNet V1 :

7d4e1e35c1a89715.gif

Au cours de son entraînement, ce modèle a appris à extraire les caractéristiques communes importantes pour l'ensemble de ces 1 000 objets. De nombreuses caractéristiques de niveau inférieur qu'il utilise pour identifier ces objets peuvent également être utiles pour détecter de nouveaux objets qu'il n'a jamais vus auparavant. Après tout, tout n'est qu'une combinaison de lignes, de textures et de formes.

Examinons une architecture de réseau de neurones convolutifs (CNN, Convolutional Neural Network) traditionnelle (semblable à MobileNet) et voyons comment l'apprentissage par transfert peut tirer parti de ce réseau entraîné pour apprendre quelque chose de nouveau. L'image ci-dessous montre l'architecture de modèle typique d'un CNN qui, dans ce cas, a été entraîné à reconnaître les chiffres manuscrits de 0 à 9 :

baf4e3d434576106.png

Si vous pouviez séparer les couches de niveau inférieur pré-entraînées d'un modèle entraîné existant, comme indiqué à gauche, des couches de classification proches de la fin du modèle, comme indiqué à droite (parfois appelées "tête de classification" du modèle), vous pourriez utiliser les couches de niveau inférieur pour générer des caractéristiques de sortie pour n'importe quelle image en fonction des données d'origine sur lesquelles elle a été entraînée. Voici le même réseau sans l'en-tête de classification :

369a8a9041c6917d.png

Si la nouvelle chose que vous essayez de reconnaître peut également utiliser les caractéristiques de sortie que le modèle a apprises, il y a de fortes chances qu'elles puissent être réutilisées pour un nouvel objectif.

Dans le diagramme ci-dessus, ce modèle hypothétique a été entraîné sur des chiffres. Il est donc possible que ce qui a été appris sur les chiffres puisse également être appliqué à des lettres comme a, b et c.

Vous pouvez donc ajouter un nouvel en-tête de classification qui tente de prédire a, b ou c, comme indiqué ci-dessous :

db97e5e60ae73bbd.png

Ici, les couches de niveau inférieur sont figées et ne sont pas entraînées. Seule la nouvelle tête de classification se mettra à jour pour apprendre à partir des caractéristiques fournies par le modèle pré-entraîné découpé à gauche.

Cette action est appelée "apprentissage par transfert" et c'est ce que fait Teachable Machine en arrière-plan.

Vous pouvez également constater qu'en n'ayant à entraîner le perceptron multicouche qu'à la toute fin du réseau, l'entraînement est beaucoup plus rapide que si vous deviez entraîner l'ensemble du réseau à partir de zéro.

Mais comment obtenir des sous-parties d'un modèle ? Pour le savoir, passez à la section suivante.

4. TensorFlow Hub : modèles de base

Trouver un modèle de base adapté à utiliser

Pour les modèles de recherche plus avancés et populaires comme MobileNet, vous pouvez accéder à TensorFlow Hub, puis filtrer les modèles adaptés à TensorFlow.js qui utilisent l'architecture MobileNet v3 pour trouver des résultats comme ceux affichés ici :

c5dc1420c6238c14.png

Notez que certains de ces résultats sont de type "classification d'images" (détaillés en haut à gauche de chaque fiche de modèle), tandis que d'autres sont de type "vecteur de caractéristiques d'image".

Ces résultats de vecteur de caractéristiques d'image sont essentiellement les versions pré-découpées de MobileNet que vous pouvez utiliser pour obtenir les vecteurs de caractéristiques d'image au lieu de la classification finale.

Les modèles de ce type sont souvent appelés "modèles de base". Vous pouvez ensuite les utiliser pour effectuer un apprentissage par transfert de la même manière que dans la section précédente, en ajoutant une nouvelle tête de classification et en l'entraînant avec vos propres données.

Ensuite, vérifiez le format TensorFlow.js dans lequel le modèle de base qui vous intéresse est publié. Si vous ouvrez la page de l'un de ces modèles MobileNet v3 de vecteur de caractéristiques, vous pouvez voir dans la documentation JS qu'il s'agit d'un modèle de graphique basé sur l'exemple d'extrait de code de la documentation qui utilise tf.loadGraphModel().

f97d903d2e46924b.png

Il convient également de noter que si vous trouvez un modèle au format de calques au lieu du format de graphique, vous pouvez choisir les calques à figer et ceux à dégeler pour l'entraînement. Cela peut être très utile lors de la création d'un modèle pour une nouvelle tâche, souvent appelée "modèle de transfert". Pour l'instant, vous allez utiliser le type de modèle de graphique par défaut pour ce tutoriel, qui est celui utilisé pour la plupart des modèles TF Hub. Pour en savoir plus sur l'utilisation des modèles de calques, consultez le cours Zero to Hero TensorFlow.js.

Avantages de l'apprentissage par transfert

Quels sont les avantages de l'apprentissage par transfert par rapport à l'entraînement de l'ensemble de l'architecture du modèle à partir de zéro ?

Tout d'abord, le temps d'entraînement est un avantage clé de l'approche d'apprentissage par transfert, car vous disposez déjà d'un modèle de base entraîné sur lequel vous appuyer.

Deuxièmement, vous pouvez vous contenter de montrer beaucoup moins d'exemples de la nouvelle chose que vous essayez de classer, grâce à l'entraînement qui a déjà eu lieu.

C'est très utile si vous disposez de peu de temps et de ressources pour collecter des exemples de données de l'élément que vous souhaitez classer, et que vous devez créer un prototype rapidement avant de collecter davantage de données d'entraînement pour le rendre plus robuste.

Étant donné que l'apprentissage par transfert nécessite moins de données et que l'entraînement d'un réseau plus petit est plus rapide, il est moins gourmand en ressources. Il est donc très adapté à l'environnement du navigateur, ne prenant que quelques dizaines de secondes sur une machine moderne au lieu de plusieurs heures, jours ou semaines pour l'entraînement complet du modèle.

Très bien ! Maintenant que vous connaissez l'essence de l'apprentissage par transfert, il est temps de créer votre propre version de Teachable Machine. C'est parti !

5. Se préparer à coder

Prérequis

Commençons à coder

Des modèles récurrents ont été créés pour Glitch.com ou Codepen.io. Vous pouvez simplement cloner l'un des modèles comme état de base pour cet atelier de programmation, en un seul clic.

Sur Glitch, cliquez sur le bouton Remix this (Remixer) pour créer un fork et un nouvel ensemble de fichiers que vous pouvez modifier.

Vous pouvez également cliquer sur fork en bas à droite de l'écran sur CodePen.

Ce squelette très simple vous fournit les fichiers suivants :

  • Page HTML (index.html)
  • Feuille de style (style.css)
  • Fichier pour écrire notre code JavaScript (script.js)

Pour plus de commodité, une importation a été ajoutée au fichier HTML pour la bibliothèque TensorFlow.js. Elle se présente comme suit :

index.html

<!-- Import TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js" type="text/javascript"></script>

Autre option : utiliser l'éditeur Web de votre choix ou travailler en local

Si vous souhaitez télécharger le code et travailler en local ou sur un autre éditeur en ligne, il vous suffit de créer les trois fichiers mentionnés ci-dessus dans le même répertoire, puis de copier et coller le code de notre boilerplate Glitch dans chacun d'eux.

6. Boilerplate HTML de l'application

Par où commencer ?

Tous les prototypes nécessitent une structure HTML de base sur laquelle vous pouvez afficher vos résultats. Configurez-le maintenant. Vous allez ajouter :

  • Titre de la page.
  • Texte descriptif.
  • Paragraphe d'état.
  • Une vidéo pour contenir le flux de la webcam une fois qu'il est prêt.
  • Plusieurs boutons pour démarrer la caméra, collecter des données ou réinitialiser l'expérience.
  • Importations pour TensorFlow.js et les fichiers JS que vous coderez plus tard.

Ouvrez index.html et collez le code suivant sur le code existant pour configurer les fonctionnalités ci-dessus :

index.html

<!DOCTYPE html>
<html lang="en">
  <head>
    <title>Transfer Learning - TensorFlow.js</title>
    <meta charset="utf-8">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <!-- Import the webpage's stylesheet -->
    <link rel="stylesheet" href="/style.css">
  </head>  
  <body>
    <h1>Make your own "Teachable Machine" using Transfer Learning with MobileNet v3 in TensorFlow.js using saved graph model from TFHub.</h1>
    
    <p id="status">Awaiting TF.js load</p>
    
    <video id="webcam" autoplay muted></video>
    
    <button id="enableCam">Enable Webcam</button>
    <button class="dataCollector" data-1hot="0" data-name="Class 1">Gather Class 1 Data</button>
    <button class="dataCollector" data-1hot="1" data-name="Class 2">Gather Class 2 Data</button>
    <button id="train">Train &amp; Predict!</button>
    <button id="reset">Reset</button>

    <!-- Import TensorFlow.js library -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.11.0/dist/tf.min.js" type="text/javascript"></script>

    <!-- Import the page's JavaScript to do some stuff -->
    <script type="module" src="/script.js"></script>
  </body>
</html>

Ça va breaker !

Décomposons une partie du code HTML ci-dessus pour mettre en évidence certains éléments clés que vous avez ajoutés.

  • Vous avez ajouté une balise <h1> pour le titre de la page, ainsi qu'une balise <p> avec l'ID "status", qui est l'endroit où vous imprimerez les informations lorsque vous utiliserez différentes parties du système pour afficher les résultats.
  • Vous avez ajouté un élément <video> avec l'ID "webcam", dans lequel vous afficherez le flux de votre webcam ultérieurement.
  • Vous avez ajouté cinq éléments <button>. La première, dont l'ID est "enableCam", active la caméra. Les deux boutons suivants ont une classe "dataCollector", qui vous permet de collecter des exemples d'images pour les objets que vous souhaitez reconnaître. Le code que vous écrirez plus tard sera conçu de manière à ce que vous puissiez ajouter n'importe quel nombre de ces boutons et qu'ils fonctionnent automatiquement comme prévu.

Notez que ces boutons disposent également d'un attribut spécial défini par l'utilisateur appelé "data-1hot", avec une valeur entière commençant par 0 pour la première classe. Il s'agit de l'index numérique que vous utiliserez pour représenter les données d'une classe donnée. L'index sera utilisé pour encoder correctement les classes de sortie avec une représentation numérique au lieu d'une chaîne, car les modèles de ML ne peuvent fonctionner qu'avec des nombres.

Il existe également un attribut data-name qui contient le nom lisible que vous souhaitez utiliser pour cette classe. Cela vous permet de fournir à l'utilisateur un nom plus significatif au lieu d'une valeur d'index numérique issue de l'encodage one-hot.

Enfin, vous disposez d'un bouton d'entraînement et d'un bouton de réinitialisation pour lancer le processus d'entraînement une fois les données collectées ou pour réinitialiser l'application, respectivement.

  • Vous avez également ajouté deux importations <script>. L'un pour TensorFlow.js et l'autre pour script.js, que vous définirez sous peu.

7. Ajouter style

Valeurs par défaut des éléments

Ajoutez des styles pour les éléments HTML que vous venez d'ajouter afin de vous assurer qu'ils s'affichent correctement. Voici quelques styles ajoutés pour positionner et dimensionner correctement les éléments. Rien de bien spécial. Vous pourrez certainement ajouter des éléments plus tard pour améliorer encore l'UX, comme vous l'avez vu dans la vidéo sur Teachable Machine.

style.css

body {
  font-family: helvetica, arial, sans-serif;
  margin: 2em;
}

h1 {
  font-style: italic;
  color: #FF6F00;
}


video {
  clear: both;
  display: block;
  margin: 10px;
  background: #000000;
  width: 640px;
  height: 480px;
}

button {
  padding: 10px;
  float: left;
  margin: 5px 3px 5px 10px;
}

.removed {
  display: none;
}

#status {
  font-size:150%;
}

Parfait ! C'est tout ce dont vous avez besoin. Si vous prévisualisez le résultat maintenant, il devrait ressembler à ceci :

81909685d7566dcb.png

8. JavaScript : constantes et écouteurs de touches

Définir les constantes clés

Commencez par ajouter des constantes clés que vous utiliserez dans toute l'application. Pour ce faire, remplacez le contenu de script.js par ces constantes :

script.js

const STATUS = document.getElementById('status');
const VIDEO = document.getElementById('webcam');
const ENABLE_CAM_BUTTON = document.getElementById('enableCam');
const RESET_BUTTON = document.getElementById('reset');
const TRAIN_BUTTON = document.getElementById('train');
const MOBILE_NET_INPUT_WIDTH = 224;
const MOBILE_NET_INPUT_HEIGHT = 224;
const STOP_DATA_GATHER = -1;
const CLASS_NAMES = [];

Voici à quoi servent ces éléments :

  • STATUS contient simplement une référence à la balise de paragraphe dans laquelle vous allez écrire les informations sur l'état.
  • VIDEO contient une référence à l'élément vidéo HTML qui affichera le flux de la webcam.
  • ENABLE_CAM_BUTTON, RESET_BUTTON et TRAIN_BUTTON récupèrent les références DOM de tous les boutons clés de la page HTML.
  • MOBILE_NET_INPUT_WIDTH et MOBILE_NET_INPUT_HEIGHT définissent respectivement la largeur et la hauteur d'entrée attendues du modèle MobileNet. En stockant cette valeur dans une constante en haut du fichier, vous pourrez la modifier plus facilement si vous décidez d'utiliser une autre version ultérieurement.
  • STOP_DATA_GATHER est défini sur -1. Cela permet de stocker une valeur d'état pour savoir quand l'utilisateur a cessé de cliquer sur un bouton afin de collecter des données à partir du flux de la webcam. En donnant à ce nombre un nom plus explicite, le code sera plus lisible par la suite.
  • CLASS_NAMES sert de référence et contient les noms lisibles des prédictions de classe possibles. Ce tableau sera rempli ultérieurement.

Maintenant que vous avez des références aux éléments clés, il est temps d'associer des écouteurs d'événements à ces éléments.

Ajouter des écouteurs d'événements clés

Commencez par ajouter des gestionnaires d'événements de clic aux boutons clés, comme indiqué ci-dessous :

script.js

ENABLE_CAM_BUTTON.addEventListener('click', enableCam);
TRAIN_BUTTON.addEventListener('click', trainAndPredict);
RESET_BUTTON.addEventListener('click', reset);


function enableCam() {
  // TODO: Fill this out later in the codelab!
}


function trainAndPredict() {
  // TODO: Fill this out later in the codelab!
}


function reset() {
  // TODO: Fill this out later in the codelab!
}

ENABLE_CAM_BUTTON appelle la fonction enableCam lorsqu'il est cliqué.

TRAIN_BUTTON : appelle trainAndPredict en cas de clic.

RESET_BUTTON : les appels sont réinitialisés lorsque l'utilisateur clique dessus.

Enfin, dans cette section, vous pouvez trouver tous les boutons dont la classe est "dataCollector" à l'aide de document.querySelectorAll(). Cela renvoie un tableau d'éléments trouvés dans le document qui correspondent à :

script.js

let dataCollectorButtons = document.querySelectorAll('button.dataCollector');
for (let i = 0; i < dataCollectorButtons.length; i++) {
  dataCollectorButtons[i].addEventListener('mousedown', gatherDataForClass);
  dataCollectorButtons[i].addEventListener('mouseup', gatherDataForClass);
  // Populate the human readable names for classes.
  CLASS_NAMES.push(dataCollectorButtons[i].getAttribute('data-name'));
}


function gatherDataForClass() {
  // TODO: Fill this out later in the codelab!
}

Explication du code :

Vous parcourez ensuite les boutons trouvés et associez deux écouteurs d'événements à chacun d'eux. Une pour "mousedown" et une pour "mouseup". Cela vous permet de continuer à enregistrer des échantillons tant que le bouton est enfoncé, ce qui est utile pour la collecte de données.

Les deux événements appellent une fonction gatherDataForClass que vous définirez ultérieurement.

À ce stade, vous pouvez également transférer les noms de classe lisibles trouvés à partir de l'attribut data-name du bouton HTML vers le tableau CLASS_NAMES.

Ensuite, ajoutez des variables pour stocker les éléments clés qui seront utilisés ultérieurement.

script.js

let mobilenet = undefined;
let gatherDataState = STOP_DATA_GATHER;
let videoPlaying = false;
let trainingDataInputs = [];
let trainingDataOutputs = [];
let examplesCount = [];
let predict = false;

Examinons-les.

Tout d'abord, vous disposez d'une variable mobilenet pour stocker le modèle Mobilenet chargé. Définissez initialement cette valeur sur "undefined".

Ensuite, vous avez une variable appelée gatherDataState. Si un bouton "dataCollector" est enfoncé, l'ID devient l'ID "one-hot" de ce bouton, tel que défini dans le code HTML. Vous savez ainsi quelle classe de données vous collectez à ce moment-là. Au départ, cette valeur est définie sur STOP_DATA_GATHER afin que la boucle de collecte de données que vous écrirez plus tard ne collecte aucune donnée lorsqu'aucun bouton n'est enfoncé.

videoPlaying permet de savoir si le flux de la webcam est correctement chargé et en cours de lecture, et s'il est disponible. Au départ, cette valeur est définie sur false, car la webcam n'est pas activée tant que vous n'appuyez pas sur ENABLE_CAM_BUTTON..

Ensuite, définissez deux tableaux, trainingDataInputs et trainingDataOutputs. Ils stockent les valeurs des données d'entraînement collectées lorsque vous cliquez sur les boutons "dataCollector" pour les caractéristiques d'entrée générées par le modèle de base MobileNet et la classe de sortie échantillonnée respectivement.

Un dernier tableau, examplesCount,, est ensuite défini pour suivre le nombre d'exemples contenus dans chaque classe une fois que vous commencez à les ajouter.

Enfin, vous disposez d'une variable appelée predict qui contrôle votre boucle de prédiction. Cette valeur est initialement définie sur false. Aucune prédiction ne peut avoir lieu tant que cette valeur n'est pas définie sur true ultérieurement.

Maintenant que toutes les variables clés ont été définies, chargeons le modèle de base MobileNet v3 prédécoupé qui fournit des vecteurs de caractéristiques d'image au lieu de classifications.

9. Charger le modèle de base MobileNet

Tout d'abord, définissez une nouvelle fonction appelée loadMobileNetFeatureModel, comme indiqué ci-dessous. Il doit s'agir d'une fonction asynchrone, car le chargement d'un modèle est asynchrone :

script.js

/**
 * Loads the MobileNet model and warms it up so ready for use.
 **/
async function loadMobileNetFeatureModel() {
  const URL = 
    'https://tfhub.dev/google/tfjs-model/imagenet/mobilenet_v3_small_100_224/feature_vector/5/default/1';
  
  mobilenet = await tf.loadGraphModel(URL, {fromTFHub: true});
  STATUS.innerText = 'MobileNet v3 loaded successfully!';
  
  // Warm up the model by passing zeros through it once.
  tf.tidy(function () {
    let answer = mobilenet.predict(tf.zeros([1, MOBILE_NET_INPUT_HEIGHT, MOBILE_NET_INPUT_WIDTH, 3]));
    console.log(answer.shape);
  });
}

// Call the function immediately to start loading.
loadMobileNetFeatureModel();

Dans ce code, vous définissez URL où se trouve le modèle à charger à partir de la documentation TFHub.

Vous pouvez ensuite charger le modèle à l'aide de await tf.loadGraphModel(), en veillant à définir la propriété spéciale fromTFHub sur true, car vous chargez un modèle à partir de ce site Web Google. Il s'agit d'un cas particulier uniquement pour l'utilisation de modèles hébergés sur TF Hub, où cette propriété supplémentaire doit être définie.

Une fois le chargement terminé, vous pouvez définir le innerText de l'élément STATUS avec un message pour vérifier visuellement qu'il a été chargé correctement et que vous êtes prêt à commencer à collecter des données.

Il ne reste plus qu'à réchauffer le modèle. Avec les modèles plus volumineux comme celui-ci, la première fois que vous l'utilisez, il peut falloir un certain temps pour tout configurer. Il est donc utile de transmettre des zéros au modèle pour éviter toute attente à l'avenir, lorsque le timing peut être plus critique.

Vous pouvez utiliser tf.zeros() enveloppé dans un tf.tidy() pour vous assurer que les Tensors sont correctement éliminés, avec une taille de lot de 1 et la hauteur et la largeur correctes que vous avez définies dans vos constantes au début. Enfin, vous spécifiez également les canaux de couleur, qui sont au nombre de trois dans ce cas, car le modèle attend des images RVB.

Ensuite, enregistrez la forme du Tensor renvoyé à l'aide de answer.shape() pour vous aider à comprendre la taille des caractéristiques d'image produites par ce modèle.

Après avoir défini cette fonction, vous pouvez l'appeler immédiatement pour lancer le téléchargement du modèle lors du chargement de la page.

Si vous affichez votre aperçu en direct maintenant, vous verrez au bout de quelques instants le texte d'état passer de "Awaiting TF.js load" (En attente du chargement de TF.js) à "MobileNet v3 loaded successfully!" (MobileNet v3 chargé avec succès !) comme indiqué ci-dessous. Assurez-vous que cela fonctionne avant de continuer.

a28b734e190afff.png

Vous pouvez également consulter la sortie de la console pour connaître la taille d'impression des entités de sortie générées par ce modèle. Après avoir exécuté des zéros dans le modèle MobileNet, vous verrez une forme [1, 1024] s'afficher. Le premier élément correspond simplement à la taille du lot (1). Vous pouvez constater qu'il renvoie en fait 1 024 caractéristiques qui peuvent ensuite vous aider à classer de nouveaux objets.

10. Définir la nouvelle tête du modèle

Il est maintenant temps de définir la tête de votre modèle, qui est essentiellement un perceptron multicouche très minimal.

script.js

let model = tf.sequential();
model.add(tf.layers.dense({inputShape: [1024], units: 128, activation: 'relu'}));
model.add(tf.layers.dense({units: CLASS_NAMES.length, activation: 'softmax'}));

model.summary();

// Compile the model with the defined optimizer and specify a loss function to use.
model.compile({
  // Adam changes the learning rate over time which is useful.
  optimizer: 'adam',
  // Use the correct loss function. If 2 classes of data, must use binaryCrossentropy.
  // Else categoricalCrossentropy is used if more than 2 classes.
  loss: (CLASS_NAMES.length === 2) ? 'binaryCrossentropy': 'categoricalCrossentropy', 
  // As this is a classification problem you can record accuracy in the logs too!
  metrics: ['accuracy']  
});

Examinons ce code. Vous commencez par définir un modèle tf.sequential auquel vous ajouterez des couches de modèle.

Ensuite, ajoutez une couche dense en tant que couche d'entrée à ce modèle. La forme d'entrée est 1024, car les sorties des caractéristiques MobileNet v3 sont de cette taille. Vous l'avez découvert à l'étape précédente après avoir transmis des uns au modèle. Cette couche comporte 128 neurones qui utilisent la fonction d'activation ReLU.

Si vous n'avez jamais utilisé de fonctions d'activation ni de couches de modèle, nous vous conseillons de suivre le cours présenté au début de cet atelier pour comprendre ce que font ces propriétés en coulisses.

La prochaine couche à ajouter est la couche de sortie. Le nombre de neurones doit être égal au nombre de classes que vous essayez de prédire. Pour ce faire, vous pouvez utiliser CLASS_NAMES.length pour déterminer le nombre de classes que vous prévoyez de classer, ce qui correspond au nombre de boutons de collecte de données présents dans l'interface utilisateur. Comme il s'agit d'un problème de classification, vous utilisez l'activation softmax sur cette couche de sortie, qui doit être utilisée lorsque vous essayez de créer un modèle pour résoudre des problèmes de classification au lieu de problèmes de régression.

Imprimez maintenant un model.summary() pour imprimer l'aperçu du modèle nouvellement défini dans la console.

Enfin, compilez le modèle pour qu'il soit prêt à être entraîné. Ici, l'optimiseur est défini sur adam. La perte sera binaryCrossentropy si CLASS_NAMES.length est égal à 2, ou categoricalCrossentropy s'il y a trois classes ou plus à classer. Des métriques de précision sont également demandées afin de pouvoir les surveiller ultérieurement dans les journaux à des fins de débogage.

Dans la console, vous devriez obtenir un résultat semblable à celui-ci :

22eaf32286fea4bb.png

Notez que ce modèle comporte plus de 130 000 paramètres entraînables. Toutefois, comme il s'agit d'une simple couche dense de neurones classiques, l'entraînement sera assez rapide.

Une fois le projet terminé, vous pouvez essayer de modifier le nombre de neurones de la première couche pour voir jusqu'à quel point vous pouvez le réduire tout en conservant des performances correctes. Le machine learning implique souvent un certain nombre d'essais et d'erreurs pour trouver les valeurs de paramètres optimales qui vous offrent le meilleur compromis entre utilisation des ressources et vitesse.

11. Activer la webcam

Il est maintenant temps de développer la fonction enableCam() que vous avez définie précédemment. Ajoutez une fonction nommée hasGetUserMedia() comme indiqué ci-dessous, puis remplacez le contenu de la fonction enableCam() précédemment définie par le code correspondant ci-dessous.

script.js

function hasGetUserMedia() {
  return !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
}

function enableCam() {
  if (hasGetUserMedia()) {
    // getUsermedia parameters.
    const constraints = {
      video: true,
      width: 640, 
      height: 480 
    };

    // Activate the webcam stream.
    navigator.mediaDevices.getUserMedia(constraints).then(function(stream) {
      VIDEO.srcObject = stream;
      VIDEO.addEventListener('loadeddata', function() {
        videoPlaying = true;
        ENABLE_CAM_BUTTON.classList.add('removed');
      });
    });
  } else {
    console.warn('getUserMedia() is not supported by your browser');
  }
}

Commencez par créer une fonction nommée hasGetUserMedia() pour vérifier si le navigateur est compatible avec getUserMedia() en vérifiant l'existence des propriétés clés des API du navigateur.

Dans la fonction enableCam(), utilisez la fonction hasGetUserMedia() que vous venez de définir ci-dessus pour vérifier si elle est compatible. Si ce n'est pas le cas, affichez un avertissement dans la console.

Si votre navigateur le permet, définissez des contraintes pour votre appel getUserMedia(), par exemple en indiquant que vous souhaitez uniquement le flux vidéo, et que vous préférez que la width de la vidéo soit de 640 pixels et la height de 480 pixels. Pourquoi ? Il n'y a pas grand intérêt à obtenir une vidéo plus grande, car elle devrait être redimensionnée à 224 x 224 pixels pour être intégrée au modèle MobileNet. Vous pouvez également économiser des ressources de calcul en demandant une résolution plus petite. La plupart des caméras prennent en charge cette résolution.

Appelez ensuite navigator.mediaDevices.getUserMedia() avec le constraints détaillé ci-dessus, puis attendez que le stream soit renvoyé. Une fois le stream renvoyé, vous pouvez obtenir votre élément VIDEO pour lire le stream en le définissant comme valeur srcObject.

Vous devez également ajouter un eventListener sur l'élément VIDEO pour savoir quand l'élément stream a été chargé et est en cours de lecture.

Une fois le flux chargé, vous pouvez définir videoPlaying sur "true" et supprimer ENABLE_CAM_BUTTON pour l'empêcher d'être cliqué à nouveau en définissant sa classe sur "removed".

Exécutez maintenant votre code, cliquez sur le bouton "Enable camera" (Activer la caméra), puis autorisez l'accès à la webcam. Si vous effectuez cette opération pour la première fois, vous devriez vous voir dans l'élément vidéo de la page, comme indiqué ci-dessous :

b378eb1affa9b883.png

OK, il est maintenant temps d'ajouter une fonction pour gérer les clics sur le bouton dataCollector.

12. Gestionnaire d'événements du bouton de collecte de données

Il est maintenant temps de remplir votre fonction actuellement vide appelée gatherDataForClass().. C'est ce que vous avez attribué comme fonction de gestionnaire d'événements pour les boutons dataCollector au début de l'atelier de programmation.

script.js

/**
 * Handle Data Gather for button mouseup/mousedown.
 **/
function gatherDataForClass() {
  let classNumber = parseInt(this.getAttribute('data-1hot'));
  gatherDataState = (gatherDataState === STOP_DATA_GATHER) ? classNumber : STOP_DATA_GATHER;
  dataGatherLoop();
}

Tout d'abord, vérifiez l'attribut data-1hot du bouton sur lequel l'utilisateur a cliqué en appelant this.getAttribute() avec le nom de l'attribut (data-1hot dans ce cas) comme paramètre. Comme il s'agit d'une chaîne, vous pouvez ensuite utiliser parseInt() pour la convertir en entier et attribuer ce résultat à une variable nommée classNumber..

Définissez ensuite la variable gatherDataState en conséquence. Si la valeur actuelle de gatherDataState est égale à STOP_DATA_GATHER (que vous avez définie sur -1), cela signifie que vous ne collectez actuellement aucune donnée et qu'un événement mousedown a été déclenché. Définissez gatherDataState sur le classNumber que vous venez de trouver.

Sinon, cela signifie que vous collectez actuellement des données et que l'événement déclenché était un événement mouseup. Vous souhaitez maintenant arrêter de collecter des données pour cette classe. Il vous suffira de le remettre à l'état STOP_DATA_GATHER pour mettre fin à la boucle de collecte de données que vous définirez sous peu.

Enfin, lancez l'appel à dataGatherLoop(),, qui enregistre réellement les données de classe.

13. Collecte des données

Définissez maintenant la fonction dataGatherLoop(). Cette fonction est chargée d'échantillonner les images de la vidéo de la webcam, de les transmettre au modèle MobileNet et de capturer les sorties de ce modèle (les 1 024 vecteurs de caractéristiques).

Il les stocke ensuite avec l'ID gatherDataState du bouton actuellement enfoncé afin que vous sachiez à quelle classe ces données appartiennent.

Voyons cela de plus près :

script.js

function dataGatherLoop() {
  if (videoPlaying && gatherDataState !== STOP_DATA_GATHER) {
    let imageFeatures = tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [MOBILE_NET_INPUT_HEIGHT, 
          MOBILE_NET_INPUT_WIDTH], true);
      let normalizedTensorFrame = resizedTensorFrame.div(255);
      return mobilenet.predict(normalizedTensorFrame.expandDims()).squeeze();
    });

    trainingDataInputs.push(imageFeatures);
    trainingDataOutputs.push(gatherDataState);
    
    // Intialize array index element if currently undefined.
    if (examplesCount[gatherDataState] === undefined) {
      examplesCount[gatherDataState] = 0;
    }
    examplesCount[gatherDataState]++;

    STATUS.innerText = '';
    for (let n< = 0; n  CLASS_NAMES.length; n++) {
      STATUS.innerText += CLASS_NAMES[n] + ' data count: ' + examplesCount[n] + '. ';
    }
    window.requestAnimationFrame(dataGatherLoop);
  }
}

Vous ne poursuivrez l'exécution de cette fonction que si videoPlaying est défini sur "true", ce qui signifie que la webcam est active, et si gatherDataState n'est pas égal à STOP_DATA_GATHER et qu'un bouton de collecte de données de classe est actuellement enfoncé.

Ensuite, enveloppez votre code dans un tf.tidy() pour éliminer tous les Tensors créés dans le code qui suit. Le résultat de l'exécution de ce code tf.tidy() est stocké dans une variable appelée imageFeatures.

Vous pouvez désormais capturer une image de la webcam VIDEO à l'aide de tf.browser.fromPixels(). Le Tensor résultant contenant les données d'image est stocké dans une variable appelée videoFrameAsTensor.

Ensuite, redimensionnez la variable videoFrameAsTensor pour qu'elle ait la forme appropriée pour l'entrée du modèle MobileNet. Utilisez un appel tf.image.resizeBilinear() avec le Tensor que vous souhaitez remodeler comme premier paramètre, puis une forme qui définit la nouvelle hauteur et la nouvelle largeur, telles que définies par les constantes que vous avez déjà créées. Enfin, définissez "align corners" (aligner les coins) sur "true" en transmettant le troisième paramètre pour éviter tout problème d'alignement lors du redimensionnement. Le résultat de ce redimensionnement est stocké dans une variable appelée resizedTensorFrame.

Notez que ce redimensionnement primitif étire l'image, car l'image de votre webcam est de 640 x 480 pixels, et le modèle a besoin d'une image carrée de 224 x 224 pixels.

Pour les besoins de cette démonstration, cela devrait fonctionner correctement. Toutefois, une fois cet atelier de programmation terminé, vous pouvez essayer de recadrer une image carrée à partir de cette image pour obtenir de meilleurs résultats pour tout système de production que vous pourriez créer ultérieurement.

Ensuite, normalisez les données d'image. Les données d'image sont toujours comprises entre 0 et 255 lorsque vous utilisez tf.browser.frompixels(). Vous pouvez donc simplement diviser resizedTensorFrame par 255 pour vous assurer que toutes les valeurs sont comprises entre 0 et 1, ce qui correspond aux entrées attendues par le modèle MobileNet.

Enfin, dans la section tf.tidy() du code, transmettez ce Tensor normalisé via le modèle chargé en appelant mobilenet.predict(), auquel vous transmettez la version développée de normalizedTensorFrame à l'aide de expandDims() afin qu'il s'agisse d'un lot de 1, car le modèle s'attend à un lot d'entrées pour le traitement.

Une fois le résultat renvoyé, vous pouvez immédiatement appeler squeeze() sur ce résultat pour le réduire à un Tensor 1D, que vous renvoyez ensuite et affectez à la variable imageFeatures qui capture le résultat de tf.tidy().

Maintenant que vous avez les imageFeatures du modèle MobileNet, vous pouvez les enregistrer en les insérant dans le tableau trainingDataInputs que vous avez défini précédemment.

Vous pouvez également enregistrer ce que représente cette entrée en transférant le gatherDataState actuel vers le tableau trainingDataOutputs.

Notez que la variable gatherDataState aurait été définie sur l'ID numérique de la classe actuelle pour laquelle vous enregistrez des données lorsque l'utilisateur a cliqué sur le bouton dans la fonction gatherDataForClass() définie précédemment.

À ce stade, vous pouvez également augmenter le nombre d'exemples dont vous disposez pour une classe donnée. Pour ce faire, vérifiez d'abord si l'index du tableau examplesCount a déjà été initialisé. Si elle n'est pas définie, définissez-la sur 0 pour initialiser le compteur pour l'ID numérique d'une classe donnée. Vous pouvez ensuite incrémenter examplesCount pour le gatherDataState actuel.

Mettez à jour le texte de l'élément STATUS sur la page Web pour afficher le nombre actuel de chaque classe à mesure qu'elles sont capturées. Pour ce faire, parcourez le tableau CLASS_NAMES et affichez le nom lisible par l'utilisateur combiné au nombre de données au même index dans examplesCount.

Enfin, appelez window.requestAnimationFrame() avec dataGatherLoop transmis en tant que paramètre, pour appeler à nouveau cette fonction de manière récursive. L'échantillonnage des images de la vidéo se poursuit jusqu'à ce que le mouseup du bouton soit détecté et que gatherDataState soit défini sur STOP_DATA_GATHER,, auquel cas la boucle de collecte de données se termine.

Si vous exécutez votre code maintenant, vous devriez pouvoir cliquer sur le bouton "Enable camera" (Activer la caméra), attendre que la webcam se charge, puis cliquer de manière prolongée sur chacun des boutons de collecte de données pour collecter des exemples pour chaque classe de données. Vous me voyez ici collecter des données pour mon téléphone mobile et ma main, respectivement.

541051644a45131f.gif

Le texte d'état devrait être mis à jour à mesure que tous les Tensors sont stockés en mémoire, comme indiqué dans la capture d'écran ci-dessus.

14. Entraîner et prédire

L'étape suivante consiste à implémenter le code de votre fonction trainAndPredict() actuellement vide, qui est l'endroit où se déroule l'apprentissage par transfert. Examinons le code :

script.js

async function trainAndPredict() {
  predict = false;
  tf.util.shuffleCombo(trainingDataInputs, trainingDataOutputs);
  let outputsAsTensor = tf.tensor1d(trainingDataOutputs, 'int32');
  let oneHotOutputs = tf.oneHot(outputsAsTensor, CLASS_NAMES.length);
  let inputsAsTensor = tf.stack(trainingDataInputs);
  
  let results = await model.fit(inputsAsTensor, oneHotOutputs, {shuffle: true, batchSize: 5, epochs: 10, 
      callbacks: {onEpochEnd: logProgress} });
  
  outputsAsTensor.dispose();
  oneHotOutputs.dispose();
  inputsAsTensor.dispose();
  predict = true;
  predictLoop();
}

function logProgress(epoch, logs) {
  console.log('Data for epoch ' + epoch, logs);
}

Commencez par vous assurer d'arrêter toute prédiction en cours en définissant predict sur false.

Ensuite, mélangez vos tableaux d'entrée et de sortie à l'aide de tf.util.shuffleCombo() pour vous assurer que l'ordre ne pose pas de problème lors de l'entraînement.

Convertissez votre tableau de sortie, trainingDataOutputs,, en Tensor1d de type int32 afin qu'il soit prêt à être utilisé dans un encodage one-hot. Cette valeur est stockée dans une variable nommée outputsAsTensor.

Utilisez la fonction tf.oneHot() avec cette variable outputsAsTensor, ainsi que le nombre maximal de classes à encoder, qui est simplement CLASS_NAMES.length. Vos sorties encodées en one-hot sont désormais stockées dans un nouveau Tensor appelé oneHotOutputs.

Notez que trainingDataInputs est actuellement un tableau de Tensors enregistrés. Pour les utiliser pour l'entraînement, vous devrez convertir le tableau de Tensors en un Tensor 2D normal.

Pour ce faire, la bibliothèque TensorFlow.js propose une excellente fonction appelée tf.stack().

qui prend un tableau de Tensors et les empile pour produire un Tensor de dimension supérieure en sortie. Dans ce cas, un Tensor 2D est renvoyé. Il s'agit d'un lot d'entrées à une dimension, chacune d'une longueur de 1 024, contenant les caractéristiques enregistrées, ce dont vous avez besoin pour l'entraînement.

Ensuite, await model.fit() pour entraîner l'en-tête du modèle personnalisé. Ici, vous transmettez votre variable inputsAsTensor avec oneHotOutputs pour représenter les données d'entraînement à utiliser pour les entrées et sorties cibles, respectivement. Dans l'objet de configuration du troisième paramètre, définissez shuffle sur true, utilisez batchSize de 5, avec epochs défini sur 10, puis spécifiez un callback pour onEpochEnd dans la fonction logProgress que vous définirez sous peu.

Enfin, vous pouvez supprimer les Tensors créés, car le modèle est désormais entraîné. Vous pouvez ensuite redéfinir predict sur true pour autoriser à nouveau les prédictions, puis appeler la fonction predictLoop() pour commencer à prédire les images de la webcam en direct.

Vous pouvez également définir la fonction logProcess() pour enregistrer l'état de l'entraînement, qui est utilisé dans model.fit() ci-dessus et qui affiche les résultats dans la console après chaque cycle d'entraînement.

Vous y êtes presque ! Il est temps d'ajouter la fonction predictLoop() pour faire des prédictions.

Boucle de prédiction principale

Ici, vous implémentez la boucle de prédiction principale qui échantillonne les images d'une webcam et prédit en continu ce qui se trouve dans chaque image avec des résultats en temps réel dans le navigateur.

Examinons le code :

script.js

function predictLoop() {
  if (predict) {
    tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO).div(255);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor,[MOBILE_NET_INPUT_HEIGHT, 
          MOBILE_NET_INPUT_WIDTH], true);

      let imageFeatures = mobilenet.predict(resizedTensorFrame.expandDims());
      let prediction = model.predict(imageFeatures).squeeze();
      let highestIndex = prediction.argMax().arraySync();
      let predictionArray = prediction.arraySync();

      STATUS.innerText = 'Prediction: ' + CLASS_NAMES[highestIndex] + ' with ' + Math.floor(predictionArray[highestIndex] * 100) + '% confidence';
    });

    window.requestAnimationFrame(predictLoop);
  }
}

Commencez par vérifier que predict est défini sur "true", afin que les prédictions ne soient effectuées qu'une fois qu'un modèle a été entraîné et est disponible.

Vous pouvez ensuite obtenir les caractéristiques de l'image actuelle, comme vous l'avez fait dans la fonction dataGatherLoop(). En substance, vous récupérez un frame de la webcam à l'aide de tf.browser.from pixels(), vous le normalisez, vous le redimensionnez à 224 x 224 pixels, puis vous transmettez ces données au modèle MobileNet pour obtenir les caractéristiques de l'image.

Toutefois, vous pouvez désormais utiliser la tête de modèle que vous venez d'entraîner pour effectuer une prédiction en transmettant le imageFeatures résultant que vous venez de trouver à la fonction predict() du modèle entraîné. Vous pouvez ensuite compresser le Tensor obtenu pour le rendre à nouveau unidimensionnel et l'attribuer à une variable appelée prediction.

Avec ce prediction, vous pouvez trouver l'index qui a la valeur la plus élevée à l'aide de argMax(), puis convertir ce Tensor résultant en tableau à l'aide de arraySync() pour accéder aux données sous-jacentes en JavaScript et découvrir la position de l'élément ayant la valeur la plus élevée. Cette valeur est stockée dans la variable highestIndex.

Vous pouvez également obtenir les scores de confiance de prédiction réels de la même manière en appelant arraySync() directement sur le Tensor prediction.

Vous disposez maintenant de tout ce dont vous avez besoin pour mettre à jour le texte STATUS avec les données prediction. Pour obtenir la chaîne lisible par l'homme pour la classe, il vous suffit de rechercher highestIndex dans le tableau CLASS_NAMES, puis de récupérer la valeur de confiance à partir de predictionArray. Pour le rendre plus lisible en pourcentage, il vous suffit de le multiplier par 100 et de math.floor() le résultat.

Enfin, vous pouvez utiliser window.requestAnimationFrame() pour appeler predictionLoop() à nouveau une fois que vous êtes prêt, afin d'obtenir une classification en temps réel de votre flux vidéo. Ce processus se poursuit jusqu'à ce que predict soit défini sur false si vous choisissez d'entraîner un nouveau modèle avec de nouvelles données.

Ce qui vous amène à la dernière pièce du puzzle. Implémenter le bouton de réinitialisation.

15. Implémenter le bouton de réinitialisation

C'est presque fini ! La dernière pièce du puzzle consiste à implémenter un bouton de réinitialisation pour recommencer. Le code de votre fonction reset() actuellement vide est indiqué ci-dessous. Modifiez-le comme suit :

script.js

/**
 * Purge data and start over. Note this does not dispose of the loaded 
 * MobileNet model and MLP head tensors as you will need to reuse 
 * them to train a new model.
 **/
function reset() {
  predict = false;
  examplesCount.length = 0;
  for (let i = 0; i < trainingDataInputs.length; i++) {
    trainingDataInputs[i].dispose();
  }
  trainingDataInputs.length = 0;
  trainingDataOutputs.length = 0;
  STATUS.innerText = 'No data collected';
  
  console.log('Tensors in memory: ' + tf.memory().numTensors);
}

Commencez par arrêter toutes les boucles de prédiction en cours d'exécution en définissant predict sur false. Ensuite, supprimez tout le contenu du tableau examplesCount en définissant sa longueur sur 0. C'est un moyen pratique d'effacer tout le contenu d'un tableau.

Parcourez ensuite tous les trainingDataInputs enregistrés actuels et assurez-vous de dispose() chaque Tensor qu'il contient pour libérer à nouveau de la mémoire, car les Tensors ne sont pas nettoyés par le récupérateur de mémoire JavaScript.

Une fois cette opération effectuée, vous pouvez définir la longueur du tableau sur 0 pour les tableaux trainingDataInputs et trainingDataOutputs afin de les effacer également.

Enfin, définissez le texte STATUS sur une valeur pertinente et affichez les Tensors restants en mémoire pour vérifier que tout est correct.

Notez que quelques centaines de Tensors seront toujours en mémoire, car le modèle MobileNet et le perceptron multicouche que vous avez définis ne sont pas supprimés. Vous devrez les réutiliser avec de nouvelles données d'entraînement si vous décidez de vous entraîner à nouveau après cette réinitialisation.

16. Essayons-la

Il est temps de tester votre propre version de Teachable Machine.

Accédez à l'aperçu en direct, activez la webcam, collectez au moins 30 échantillons pour la classe 1 pour un objet de votre pièce, puis faites de même pour la classe 2 pour un autre objet. Cliquez sur "Train" (Entraîner), puis consultez le journal de la console pour voir la progression. L'entraînement devrait être assez rapide :

bf1ac3cc5b15740.gif

Une fois l'entraînement terminé, montrez les objets à la caméra pour obtenir des prédictions en direct qui seront imprimées dans la zone de texte d'état en haut de la page Web. Si vous rencontrez des difficultés, vérifiez le code fonctionnel que j'ai terminé pour voir si vous avez oublié de copier quelque chose.

17. Félicitations

Félicitations ! Vous venez de terminer votre tout premier exemple d'apprentissage par transfert à l'aide de TensorFlow.js directement dans le navigateur.

Essayez-le, testez-le sur différents objets. Vous remarquerez peut-être que certains sont plus difficiles à reconnaître que d'autres, surtout s'ils ressemblent à autre chose. Vous devrez peut-être ajouter des classes ou des données d'entraînement pour pouvoir les distinguer.

Résumé

Dans cet atelier de programmation, vous avez appris :

  1. Comprendre en quoi consiste l'apprentissage par transfert et ses avantages par rapport à l'entraînement d'un modèle complet
  2. Obtenir des modèles réutilisables à partir de TensorFlow Hub
  3. Configurer une application Web adaptée à l'apprentissage par transfert.
  4. Charger et utiliser un modèle de base pour générer des caractéristiques d'image
  5. Comment entraîner une nouvelle tête de prédiction capable de reconnaître des objets personnalisés à partir d'images de webcam.
  6. Comment utiliser les modèles obtenus pour classer les données en temps réel.

Et ensuite ?

Maintenant que vous disposez d'une base de travail, quelles idées créatives pouvez-vous mettre en place pour étendre ce déploiement de modèle de machine learning pour un cas d'utilisation réel sur lequel vous travaillez peut-être ? Vous pourriez peut-être révolutionner le secteur dans lequel vous travaillez actuellement pour aider les employés de votre entreprise à entraîner des modèles permettant de classer les éléments importants dans leur travail quotidien. Les possibilités sont infinies.

Pour aller plus loin, suivez ce cours complet et sans frais. Vous y apprendrez à combiner les deux modèles que vous avez actuellement dans cet atelier de programmation en un seul modèle pour plus d'efficacité.

Si vous souhaitez en savoir plus sur la théorie qui sous-tend l'application Teachable Machine d'origine, consultez ce tutoriel.

Partagez vos créations

Vous pouvez facilement étendre ce que vous avez créé aujourd'hui à d'autres cas d'utilisation créatifs. Nous vous encourageons à sortir des sentiers battus et à continuer à innover.

N'oubliez pas de nous taguer sur les réseaux sociaux avec le hashtag #MadeWithTFJS. Votre projet sera peut-être mis en avant sur notre blog TensorFlow ou lors de prochains événements. Nous serions ravis de découvrir vos créations.

Sites Web à consulter