TensorFlow.js: créez votre propre Teachable Machine à l'aide de l'apprentissage par transfert avec TensorFlow.js.

1. Avant de commencer

L'utilisation du modèle TensorFlow.js a développé de manière exponentielle au cours des dernières années, et de nombreux développeurs JavaScript cherchent désormais à réinventer les modèles de pointe pour qu'ils utilisent des données personnalisées uniques. de leur secteur. Utiliser un modèle existant (souvent appelé "modèle de base") et l'utiliser sur un domaine similaire, mais différent, s'appelle l'apprentissage par transfert.

L'apprentissage par transfert présente de nombreux avantages par rapport à un modèle complètement vide. Vous pouvez réutiliser les connaissances déjà apprises d'un modèle entraîné précédemment, et vous aurez besoin de moins d'exemples sur le nouvel élément que vous souhaitez classer. En outre, l'entraînement est souvent beaucoup plus rapide, car il suffit de réentraîner les dernières couches de l'architecture du modèle au lieu de réinventer l'ensemble du réseau. C'est pourquoi l'apprentissage par transfert est particulièrement adapté à un environnement de navigateur Web dans lequel les ressources peuvent varier selon l'appareil utilisé. Il offre également un accès direct aux capteurs et facilite l'acquisition de données.

Cet atelier de programmation vous explique comment créer une application Web à partir d'une toile vierge en recréant le célèbre site Web Teachable Machine de Google. Il permet de créer une application Web fonctionnelle permettant à n'importe quel utilisateur de reconnaître un objet personnalisé avec quelques exemples d'images de sa webcam. Le site Web est volontairement réduit afin que vous puissiez vous concentrer sur les aspects du machine learning de cet atelier de programmation. Cependant, comme pour le site Web original de Teachable Machine, vous pouvez appliquer votre expérience aux développeurs Web afin d'améliorer l'expérience utilisateur.

Conditions préalables

Cet atelier de programmation s'adresse aux développeurs Web qui connaissent déjà les modèles prédéfinis de TensorFlow.js et l'utilisation de base des API, et qui souhaitent se familiariser avec l'apprentissage par transfert dans TensorFlow.js.

  • Dans cet atelier, nous partons du principe que vous connaissez les bases de TensorFlow.js, HTML5, CSS et JavaScript.

Si vous débutez avec TensorFlow.js, envisagez de suivre ce cours gratuit du héros d'abord, qui n'a aucune expérience en machine learning ou en TensorFlow.js, et qui vous apprend tout ce qu'il faut savoir en quelques étapes.

Points abordés

  • En quoi consiste TensorFlow.js et pourquoi l'utiliser dans votre prochaine application Web
  • Découvrez comment créer une page Web HTML/CSS /JS simplifiée qui reproduit l'expérience utilisateur de Teachable Machine.
  • Utiliser TensorFlow.js pour charger un modèle de base pré-entraîné, en particulier MobileNet, afin de générer des caractéristiques d'images à utiliser dans l'apprentissage par transfert.
  • Comment collecter des données à partir de la webcam d'un utilisateur pour plusieurs classes de données que vous souhaitez reconnaître
  • Créer et définir un perceptron multicouche qui prend les caractéristiques de l'image et apprend à classer de nouveaux objets à l'aide de ces éléments.

C'est parti !

Prérequis

  • Un compte Glitch.com est préférable si vous le souhaitez, ou vous pouvez utiliser un environnement de diffusion Web qui vous convient pour le montage et la gestion.

2. Qu'est-ce que TensorFlow.js ?

54e81d02971f53e8.png

TensorFlow.js est une bibliothèque de machine learning Open Source qui peut s'exécuter partout où JavaScript est disponible. Il est basé sur la bibliothèque TensorFlow d'origine écrite en Python et vise à recréer cette expérience de développement et cet ensemble d'API pour l'écosystème JavaScript.

Où puis-je l'utiliser ?

Compte tenu de la portabilité de JavaScript, vous pouvez désormais écrire dans un langage et effectuer des opérations de machine learning sur toutes les plates-formes suivantes en toute simplicité:

  • Côté client dans le navigateur Web à l'aide de vanilla JavaScript
  • Côté serveur, et même les appareils IoT comme Raspberry Pi à l'aide de Node.js
  • les applications de bureau qui utilisent Electron ;
  • Les applications mobiles natives utilisant React Native

TensorFlow.js est également compatible avec plusieurs backends dans chacun de ces environnements (les environnements matériels réels qu'il peut exécuter, comme le processeur ou WebGL). Dans ce contexte, un "backend" ne signifie pas un environnement côté serveur. Le backend pour l'exécution peut, par exemple, du côté client dans WebGL, pour assurer la compatibilité et s'exécuter rapidement. Actuellement, TensorFlow.js est compatible avec les points suivants:

  • Exécution WebGL sur la carte graphique de l'appareil : il s'agit du moyen le plus rapide d'exécuter des modèles plus volumineux (plus de 3 Mo) avec accélération GPU.
  • l'exécution WASM (Web Assembly) sur le processeur pour améliorer les performances du processeur sur tous les appareils, y compris les téléphones mobiles d'ancienne génération. Cette approche convient mieux aux modèles plus petits (moins de 3 Mo) qui peuvent s'exécuter plus rapidement sur le processeur avec WASM qu'avec WebGL, en raison de la surcharge d'importation de contenu dans un processeur graphique.
  • CPU utilization (Exécution du processeur) : les autres environnements ne doivent pas être disponibles. C'est le plus lent des trois, mais vous pouvez toujours y accéder.

Remarque:Vous pouvez choisir de forcer l'un de ces backends si vous savez sur quel appareil vous allez vous lancer, ou laisser le système TensorFlow.js décider à votre place si vous le faites. ne la spécifiez pas.

Super-utilisateurs côté client

L'exécution de TensorFlow.js dans le navigateur Web de la machine cliente peut présenter plusieurs avantages à prendre en compte.

Confidentialité

Vous pouvez entraîner et classer des données sur la machine cliente sans jamais envoyer de données à un serveur Web tiers. Dans certains cas, vous pouvez être tenu de respecter les lois locales en vigueur, comme le RGPD, par exemple, ou de traiter les données que l'utilisateur souhaite conserver sur son ordinateur et ne pas envoyer à un tiers.

Vitesse

Comme vous n'avez pas besoin d'envoyer de données à un serveur distant, l'inférence (la classification des données) peut être plus rapide. En outre, vous bénéficiez d'un accès direct aux capteurs de l'appareil, tels que la caméra, le micro, le GPS, l'accéléromètre, etc., si l'utilisateur vous y donne accès.

Couverture et scaling

En un clic, les internautes du monde entier peuvent cliquer sur un lien que vous leur avez envoyé, ouvrir la page Web dans leur navigateur et utiliser votre création. Inutile de configurer Linux côté serveur avec des pilotes CUDA et bien plus encore pour utiliser le système de machine learning.

Coût

Aucun serveur ne signifie que la seule chose à payer est un CDN pour héberger vos fichiers HTML, CSS, JS et de modèle. Le coût d'un CDN est bien moins cher que de laisser un serveur (éventuellement associé à une carte graphique) fonctionner 24h/24, 7j/7.

Fonctionnalités côté serveur

L'implémentation de Node.js de TensorFlow.js active les fonctionnalités suivantes.

Prise en charge complète du CUDA

Du côté du serveur, pour l'accélération de la carte graphique, vous devez installer les pilotes NVIDIA CUDA pour que TensorFlow fonctionne avec la carte graphique (contrairement au navigateur qui utilise WebGL, sans installation). Toutefois, la compatibilité totale avec CUDA vous permet d'exploiter pleinement les capacités de niveau inférieur de la carte graphique, ce qui accélère les entraînements et les inférences. Les performances sont semblables à celles de l'implémentation Python pour Python, car elles partagent le même backend C++.

Taille du modèle

Pour les modèles de pointe de la recherche, vous travaillez peut-être avec des modèles très volumineux, de la taille d'un gigaoctet par exemple. Pour le moment, ces modèles ne peuvent pas être exécutés dans le navigateur Web en raison des limitations de l'utilisation de la mémoire par onglet du navigateur. Pour exécuter ces modèles plus volumineux, vous pouvez utiliser Node.js sur votre propre serveur, avec les caractéristiques matérielles dont vous avez besoin pour exécuter ces modèles efficacement.

OI

Node.js est compatible avec les ordinateurs à carte unique populaires comme Raspberry Pi, ce qui signifie que vous pouvez également exécuter des modèles TensorFlow.js sur ces appareils.

Vitesse

Node.js est écrit en JavaScript, ce qui signifie qu'il ne présente qu'un temps de compilation. Cela signifie que vous pouvez souvent constater des gains de performances lorsque vous utilisez Node.js, car il sera optimisé au moment de l'exécution, en particulier pour tout prétraitement que vous effectuez. Prenons l'exemple de cette étude de cas. Elle montre comment Hugging Face a utilisé Node.js pour multiplier par deux les performances de son modèle de traitement du langage naturel.

Maintenant que vous connaissez les principes de base de TensorFlow.js, où il peut s'exécuter et certains des avantages qu'il offre, nous allons commencer à l'utiliser.

3. Apprentissage par transfert

Qu'est-ce que l'apprentissage par transfert ?

L'apprentissage par transfert implique d'utiliser des connaissances déjà acquises pour apprendre quelque chose de différent, mais similaire.

Nous faisons cela tous les jours. Votre cerveau vital vous offre une expérience permanente. Vous pouvez l'utiliser pour reconnaître de nouvelles choses que vous n'avez jamais vues auparavant. Prenons l'exemple de cet arbre.

e28070392cd4afb9.png

Selon votre localisation, il est possible que vous n'ayez encore jamais vu ce type d'arbre.

Pourtant, si vous me disiez s'il y a des sauvets dans la nouvelle image ci-dessous, vous les remarquerez probablement assez rapidement, même s'ils sont sous un angle différent et légèrement différent de l'original que je vous ai montré.

D9073a0d5df27222.png

Votre cerveau contient déjà un ensemble de neurones qui savent identifier les objets semblables à des arbres et d'autres neurones capables de trouver de longues lignes droites. Vous pouvez réutiliser ces connaissances pour classer rapidement un Saule, un objet en forme d'arbre comportant de nombreuses longues branches verticales.

De même, si vous disposez d'un modèle de machine learning déjà entraîné sur un domaine, par exemple pour reconnaître des images, vous pouvez le réutiliser pour effectuer une autre tâche.

Vous pouvez faire de même avec un modèle avancé comme MobileNet, un modèle de recherche très populaire capable de reconnaître des images sur 1 000 types d'objets différents. Des chiens aux voitures, l'application a été entraînée dans un immense ensemble de données appelé ImageNet, qui contient des millions d'images étiquetées.

Dans cette animation, vous pouvez voir le nombre important de calques qu'il contient dans ce modèle MobileNet V1:

7d4e1e35c1a89715.gif

Pendant son entraînement, le modèle a appris comment extraire des caractéristiques courantes importantes pour ces 1 000 objets. De nombreuses caractéristiques de niveau inférieur utilisées pour identifier ces objets peuvent également être utiles pour détecter de nouveaux objets qu'il n'a jamais vus auparavant. Après tout, tout est en quelque sorte une combinaison de lignes, de textures et de formes.

Nous allons examiner une architecture de réseau de neurones convolutif traditionnel (CNN, MobileNet Neural Network) semblable à MobileNet. Nous verrons comment l'apprentissage par transfert peut exploiter ce réseau entraîné pour apprendre de nouvelles choses. L'image ci-dessous illustre l'architecture de modèle d'un réseau de neurones convolutif : dans ce cas, elle a été entraînée à reconnaître des chiffres manuscrits compris entre 0 et 9 :

Baf4e3d434576106.png

Si vous pouviez séparer les couches de bas de niveau pré-entraînées d'un modèle existant comme sur la gauche, et celles de la fin de modèle, affichées à droite (parfois appelées "tête de classification du modèle"). vous pouvez utiliser les couches de niveau inférieur pour produire des caractéristiques de sortie pour n'importe quelle image en fonction des données d'origine avec lesquelles elle a été entraînée. Voici le même réseau avec la tête de classification supprimée:

369a8a9041c6909d.png

En supposant que la nouvelle chose que vous essayez de reconnaître puisse également exploiter ces caractéristiques de sortie que le modèle précédent a apprises, il est fort probable qu'elles soient réutilisées dans un nouveau but.

Dans le schéma ci-dessus, ce modèle hypothétique a été entraîné sur des chiffres. Peut-être que ce qui a été appris sur les chiffres peut-il également être appliqué aux lettres telles que "b" et "c".

Maintenant, vous pouvez ajouter une tête de classification qui tente de prédire à la place a, b ou c, comme illustré ci-dessous:

Db97e5e60ae73bbd.png

Dans cet exemple, les couches de niveau inférieur sont figées et ne sont pas entraînées. Seule la nouvelle tête de classification se met à jour pour tenir compte des caractéristiques fournies à gauche du modèle pré-entraîné haché.

C'est ce que l'on appelle l'apprentissage par transfert, c'est ce que fait Teachable Machine dans les coulisses.

Vous pouvez également constater qu'en n'ayant besoin d'entraîner le perceptron multicouche qu'à l'extrémité du réseau, l'entraînement est beaucoup plus rapide que si vous deviez entraîner entièrement le réseau.

Mais comment passer aux sous-parties d'un modèle ? Pour le savoir, passez à la section suivante.

4. TensorFlow Hub - Modèles de base

Trouver le modèle de base à utiliser

Pour les modèles de recherche plus avancés et populaires comme MobileNet, vous pouvez accéder à TensorFlow Hub, puis filtrer les modèles adaptés à TensorFlow.js qui utilisent l'architecture MobileNet v3 pour obtenir des résultats comme celles affichées ici:

C5dc1420c6238c14.png

Notez que certains de ces résultats sont de type "classification d'images" (détaillés en haut à gauche de chaque résultat de fiche de modèle), tandis que d'autres sont de type "vecteur de caractéristiques de l'image".

Ces résultats vectoriels correspondent essentiellement aux versions pré-hachées de MobileNet que vous pouvez utiliser pour obtenir les vecteurs de caractéristiques d'image au lieu de la classification finale.

On appelle souvent ces modèles des "modèles de base", que vous pouvez ensuite utiliser pour effectuer l'apprentissage par transfert de la même manière que dans la section précédente, en ajoutant une nouvelle tête de classification et en l'entraînant avec vos propres données.

La prochaine étape consiste à vérifier si un modèle de base intéressant est le format TensorFlow.js qui sera lancé par ce modèle. Si vous ouvrez la page correspondant à l'un de ces modèles MobileNet v3, vous pouvez voir dans la documentation JavaScript que le modèle est au format graphique basé sur l'exemple d'extrait de code de la documentation qui utilise tf.loadGraphModel().

F97d903d2e46924b.png

En outre, notez que si vous trouvez un modèle au format des calques plutôt que le format du graphique, vous pouvez choisir les calques à figer et ceux à figer pour l'entraînement. Cela peut s'avérer très utile pour créer un modèle pour une nouvelle tâche, souvent appelé "modèle de transfert". Pour le moment, vous allez utiliser le type de modèle de graphique par défaut pour ce tutoriel, dans lequel la plupart des modèles TF Hub sont déployés. Pour en savoir plus sur l'utilisation des modèles de calques, consultez le cours De zéro à héros avec TensorFlow.js.

Avantages de l'apprentissage par transfert

Pourquoi utiliser l'apprentissage par transfert au lieu d'entraîner entièrement l'architecture de modèle ?

Tout d'abord, le temps d'entraînement est un avantage clé. Il s'agit d'une méthode d'apprentissage par transfert, car vous disposez déjà d'un modèle de base entraîné.

Ensuite, évitez de présenter beaucoup moins d'exemples de la nouvelle chose que vous essayez de classifier en raison de la formation déjà effectuée.

C'est très utile si vous disposez de suffisamment de temps et de ressources pour rassembler des exemples de données que vous souhaitez classer, et que vous devez créer un prototype rapidement avant de collecter davantage de données d'entraînement pour les renforcer.

L'apprentissage par transfert consomme moins de ressources, car il nécessite moins de données et la vitesse d'entraînement d'un réseau plus petit. Ce service est parfaitement adapté à l'environnement du navigateur. Il suffit de quelques dizaines de secondes sur une machine moderne au lieu de plusieurs heures, jours ou semaines pour l'entraînement complet du modèle.

Voilà ! Maintenant que vous connaissez l'essentiel de l'apprentissage par transfert, il est temps de créer votre propre version de Teachable Machine. C'est parti !

5. Configurer le code

Prérequis

C'est parti

Pour commencer, vous avez créé des modèles récurrents pour Glitch.com ou Codepen.io. Il vous suffit de cloner l'un de ces modèles comme état de base pour cet atelier de programmation, d'un simple clic.

Dans Glitch, cliquez sur le bouton Remix this (Remixer) pour dupliquer le groupe et créer un ensemble de fichiers que vous pouvez modifier.

Vous pouvez également cliquer sur dupliquer en bas à droite de l'écran dans Codepen.

Ce squelette très simple vous fournit les fichiers suivants:

  • Page HTML (index.html)
  • Feuille de style (style.css)
  • Fichier d'écriture de notre code JavaScript (script.js)

Pour vous faciliter la tâche, une importation supplémentaire est disponible dans le fichier HTML de la bibliothèque TensorFlow.js. Elle se présente comme suit :

index.html

<!-- Import TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js" type="text/javascript"></script>

Alternative: Utiliser l'éditeur Web de votre choix ou travailler en local

Si vous souhaitez télécharger le code et travailler en local ou dans un autre éditeur en ligne, créez simplement les trois fichiers nommés ci-dessus dans le même répertoire, puis copiez et collez le code de notre code récurrent Glitch dans chacun d'entre eux.

6. Texte récurrent de l'application

Par où commencer ?

Tous les prototypes nécessitent un échafaudage HTML de base sur lequel vous pouvez afficher vos résultats. Configurez-le maintenant. Vous allez ajouter les éléments suivants:

  • Titre de la page.
  • Texte descriptif.
  • Un paragraphe d'état.
  • Une vidéo qui contient le flux de la webcam lorsqu'elle est prête.
  • Plusieurs boutons permettent de démarrer l'appareil photo, de collecter des données ou de réinitialiser l'expérience.
  • Importations pour TensorFlow.js et les fichiers JS que vous allez coder plus tard

Ouvrez index.html et collez le code existant avec les éléments suivants pour configurer les fonctionnalités ci-dessus:

index.html

<!DOCTYPE html>
<html lang="en">
  <head>
    <title>Transfer Learning - TensorFlow.js</title>
    <meta charset="utf-8">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <!-- Import the webpage's stylesheet -->
    <link rel="stylesheet" href="/style.css">
  </head>
  <body>
    <h1>Make your own "Teachable Machine" using Transfer Learning with MobileNet v3 in TensorFlow.js using saved graph model from TFHub.</h1>

    <p id="status">Awaiting TF.js load</p>

    <video id="webcam" autoplay muted></video>

    <button id="enableCam">Enable Webcam</button>
    <button class="dataCollector" data-1hot="0" data-name="Class 1">Gather Class 1 Data</button>
    <button class="dataCollector" data-1hot="1" data-name="Class 2">Gather Class 2 Data</button>
    <button id="train">Train &amp; Predict!</button>
    <button id="reset">Reset</button>

    <!-- Import TensorFlow.js library -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.11.0/dist/tf.min.js" type="text/javascript"></script>

    <!-- Import the page's JavaScript to do some stuff -->
    <script type="module" src="/script.js"></script>
  </body>
</html>

Analysez vos données

Examinons une partie du code HTML ci-dessus afin de mettre en évidence les éléments clés que vous avez ajoutés.

  • Vous avez ajouté une balise <h1> pour le titre de la page, ainsi qu'une balise <p> avec un ID "status", qui vous permettra d'imprimer des informations à mesure que vous utilisez différentes parties du système pour afficher les résultats.
  • Vous avez ajouté un élément <video> ayant l'ID "webcam" pour afficher le flux de votre webcam.
  • Vous avez ajouté 5 éléments <button>. La première, dont l'ID est "enableCam", active la caméra. Les deux boutons suivants proposent une classe de "dataCollector", qui vous permet de collecter des exemples d'images pour les objets que vous souhaitez reconnaître. Le code que vous rédigerez ultérieurement sera conçu de sorte que vous puissiez ajouter autant de boutons que nécessaire. Ils fonctionneront alors automatiquement.

Notez que ces boutons disposent également d'un attribut spécial défini par l'utilisateur, data-1hot, dont la valeur entière commence par 0 pour la première classe. Il s'agit de l'index numérique que vous utiliserez pour représenter les données d'une classe donnée. Il utilisera l'index pour encoder correctement les classes de sortie avec une représentation numérique au lieu d'une chaîne, car les modèles de ML ne peuvent fonctionner qu'avec des nombres.

Vous pouvez également utiliser un attribut de nom de données qui contient le nom lisible que vous souhaitez utiliser pour cette classe. Il vous permet d'attribuer un nom plus explicite à l'utilisateur, au lieu d'une valeur d'index numérique provenant de l'encodage chaud 1.

Enfin, vous disposez d'un bouton d'entraînement et de réinitialisation pour lancer le processus d'entraînement une fois les données collectées, ou pour réinitialiser l'application, respectivement.

  • Vous avez également ajouté deux importations <script>. un pour TensorFlow.js et un autre pour script.js que vous définirez bientôt.

7. Ajouter style

Éléments par défaut de l'élément

Ajoutez des styles pour les éléments HTML que vous venez d'ajouter afin de vous assurer qu'ils s'affichent correctement. Voici quelques styles qui ont été ajoutés correctement aux éléments de position et de taille. Rien de spécial. Vous pourriez l'ajouter ultérieurement pour améliorer votre expérience utilisateur, comme vous l'avez vu dans la vidéo Teachable Machine.

style.css.

body {
  font-family: helvetica, arial, sans-serif;
  margin: 2em;
}

h1 {
  font-style: italic;
  color: #FF6F00;
}

video {
  clear: both;
  display: block;
  margin: 10px;
  background: #000000;
  width: 640px;
  height: 480px;
}

button {
  padding: 10px;
  float: left;
  margin: 5px 3px 5px 10px;
}

.removed {
  display: none;
}

#status {
  font-size:150%;
}

Parfait ! C'est tout ce dont vous avez besoin. Si vous prévisualisez le résultat maintenant, il devrait ressembler à ceci:

81909685d7566dcb.png

8. JavaScript: constantes clés et écouteurs

Définir des constantes clés

Tout d'abord, ajoutez des constantes clés que vous utiliserez dans l'application. Commencez par remplacer le contenu du fichier script.js par les constantes suivantes:

script

const STATUS = document.getElementById('status');
const VIDEO = document.getElementById('webcam');
const ENABLE_CAM_BUTTON = document.getElementById('enableCam');
const RESET_BUTTON = document.getElementById('reset');
const TRAIN_BUTTON = document.getElementById('train');
const MOBILE_NET_INPUT_WIDTH = 224;
const MOBILE_NET_INPUT_HEIGHT = 224;
const STOP_DATA_GATHER = -1;
const CLASS_NAMES = [];

Examinons ces fonctionnalités plus en détail:

  • STATUS contient simplement une référence à la balise de paragraphe dans laquelle vous souhaitez modifier l'état.
  • VIDEO contient une référence à l'élément vidéo HTML qui permet d'afficher le flux de la webcam.
  • ENABLE_CAM_BUTTON, RESET_BUTTON et TRAIN_BUTTON récupèrent des références DOM à tous les boutons de la page HTML.
  • MOBILE_NET_INPUT_WIDTH et MOBILE_NET_INPUT_HEIGHT définissent respectivement la largeur et la hauteur d'entrée prévues pour le modèle MobileNet. Si vous stockez cette version dans une constante vers la fin du fichier, si vous décidez d'utiliser une autre version plus tard, il sera plus facile de mettre à jour les valeurs une seule fois au lieu de devoir les remplacer dans de nombreux endroits différents.
  • STOP_DATA_GATHER est défini sur -1. Enregistre une valeur d'état pour vous indiquer quand l'utilisateur a cessé de cliquer sur un bouton pour recueillir des données depuis le flux de la webcam. Ainsi, le code sera plus lisible par la suite.
  • CLASS_NAMES joue le rôle d'une recherche et contient les noms lisibles pour les prédictions de classes possibles. Ce tableau sera renseigné plus tard.

Maintenant que vous avez des références à des éléments clés, il est temps d'associer des écouteurs d'événements.

Ajouter des écouteurs d'événements clés

Commencez par ajouter des gestionnaires d'événements de clic aux boutons de clé, comme indiqué ci-dessous:

script

ENABLE_CAM_BUTTON.addEventListener('click', enableCam);
TRAIN_BUTTON.addEventListener('click', trainAndPredict);
RESET_BUTTON.addEventListener('click', reset);

function enableCam() {
  // TODO: Fill this out later in the codelab!
}

function trainAndPredict() {
  // TODO: Fill this out later in the codelab!
}

function reset() {
  // TODO: Fill this out later in the codelab!
}

ENABLE_CAM_BUTTON : appelle la fonction EnableCam en cas de clic.

TRAIN_BUTTON : appelle trainAndPredict en cas de clic.

RESET_BUTTON : les appels sont réinitialisés en cas de clic.

Enfin, cette section vous permet de trouver tous les boutons dont la classe est "dataCollector" avec document.querySelectorAll(). Cela renvoie un tableau d'éléments trouvés dans le document qui correspondent:

script

let dataCollectorButtons = document.querySelectorAll('button.dataCollector');
for (let i = 0; i < dataCollectorButtons.length; i++) {
  dataCollectorButtons[i].addEventListener('mousedown', gatherDataForClass);
  dataCollectorButtons[i].addEventListener('mouseup', gatherDataForClass);
  // Populate the human readable names for classes.
  CLASS_NAMES.push(dataCollectorButtons[i].getAttribute('data-name'));
}

function gatherDataForClass() {
  // TODO: Fill this out later in the codelab!
}

Explication du code:

Vous allez ensuite parcourir les boutons trouvés et associer deux écouteurs d'événements à chacun d'eux. un pour "souris" et un pour "souris". Cela vous permet de continuer à enregistrer des échantillons tant que vous appuyez sur le bouton, ce qui est utile pour la collecte de données.

Les deux événements appellent une fonction gatherDataForClass que vous définirez ultérieurement.

À ce stade, vous pouvez également transférer les noms de classes lisibles depuis l'attribut de nom de données du bouton HTML vers le tableau CLASS_NAMES.

Ensuite, ajoutez des variables pour stocker les éléments clés qui seront utilisés ultérieurement.

script

let mobilenet = undefined;
let gatherDataState = STOP_DATA_GATHER;
let videoPlaying = false;
let trainingDataInputs = [];
let trainingDataOutputs = [];
let examplesCount = [];
let predict = false;

Examinons-les de plus près.

Tout d'abord, vous disposez d'une variable mobilenet pour stocker le modèle mobilenet chargé. Initialement défini sur non défini.

Vous disposez ensuite d'une variable appelée gatherDataState. Lorsque l'utilisateur appuie sur un bouton "dataCollector", il s'agit plutôt d'un ID chaud, tel que défini dans le code HTML, pour que vous puissiez savoir quelle classe de données vous collectez à ce moment-là. Dans un premier temps, il est défini sur STOP_DATA_GATHER afin que la boucle de collecte de données que vous écrivez ultérieurement ne collecte pas de données lorsqu'aucun bouton n'est appuyé.

videoPlaying vérifie si le flux de la webcam est chargé et en cours de lecture, et s'il est disponible. Ce paramètre est initialement défini sur false, car la webcam n'est activée que lorsque vous appuyez sur ENABLE_CAM_BUTTON..

Ensuite, définissez deux tableaux, trainingDataInputs et trainingDataOutputs. Ils stockent les valeurs des données d'entraînement collectées lorsque vous cliquez sur les boutons "dataCollector" pour les caractéristiques d'entrée générées par le modèle de base MobileNet, puis avec la classe de sortie échantillonnée.

Un dernier tableau, examplesCount,, est ensuite défini pour suivre le nombre d'exemples contenus pour chaque classe lorsque vous commencez à les ajouter.

Enfin, une variable appelée predict contrôle votre boucle de prédiction. Il est initialement défini sur false. Aucune prédiction ne peut être effectuée tant que la valeur définie n'est pas true.

Maintenant que toutes les variables clés ont été définies, nous allons charger le modèle de base MobileNet v3 découpé qui fournit des vecteurs de caractéristiques d'images au lieu de classifications.

9. Charger le modèle de base MobileNet

Commencez par définir une nouvelle fonction appelée loadMobileNetFeatureModel, comme indiqué ci-dessous. Il doit s'agir d'une fonction asynchrone, car le chargement d'un modèle est asynchrone:

script

/**
 * Loads the MobileNet model and warms it up so ready for use.
 **/
async function loadMobileNetFeatureModel() {
  const URL =
    'https://tfhub.dev/google/tfjs-model/imagenet/mobilenet_v3_small_100_224/feature_vector/5/default/1';

  mobilenet = await tf.loadGraphModel(URL, {fromTFHub: true});
  STATUS.innerText = 'MobileNet v3 loaded successfully!';

  // Warm up the model by passing zeros through it once.
  tf.tidy(function () {
    let answer = mobilenet.predict(tf.zeros([1, MOBILE_NET_INPUT_HEIGHT, MOBILE_NET_INPUT_WIDTH, 3]));
    console.log(answer.shape);
  });
}

// Call the function immediately to start loading.
loadMobileNetFeatureModel();

Dans ce code, vous définissez le URL où le modèle à charger se trouve dans la documentation TFHub.

Vous pouvez ensuite charger le modèle à l'aide de await tf.loadGraphModel(). N'oubliez pas de définir la propriété spéciale fromTFHub sur true lorsque vous chargez un modèle à partir de ce site Web Google. C'est un cas particulier uniquement pour l'utilisation de modèles hébergés sur TF Hub dans lesquels cette propriété supplémentaire doit être définie.

Une fois le chargement terminé, vous pouvez définir innerText dans l'élément STATUS avec un message pour vous assurer qu'il est chargé correctement et que vous pouvez commencer à collecter des données.

Il ne reste plus qu'à échauffer le modèle. La première fois que vous utilisez un modèle de ce type, sa configuration peut prendre un certain temps. Il est donc utile de transmettre les zéros dans le modèle pour éviter toute attente à l'avenir, là où le moment pourrait être plus important.

Vous pouvez utiliser tf.zeros() encapsulé dans un tf.tidy() pour vous assurer que les Tensors sont éliminés correctement, avec une taille de lot de 1, et que la hauteur et la largeur que vous avez définies dans vos constantes sont correctes au début. Enfin, vous spécifiez également les canaux couleur, qui dans le cas présent sont 3, car le modèle attend des images RVB.

Ensuite, enregistrez la forme obtenue du Tensor renvoyé avec answer.shape() pour vous aider à comprendre la taille des caractéristiques d'image générées par ce modèle.

Après avoir défini cette fonction, vous pouvez l'appeler immédiatement pour lancer le téléchargement du modèle lors du chargement de la page.

Si vous voyez votre aperçu en direct maintenant, le texte d'état passe de "En attente de chargement TF.js" à "MobileNet v3 chargé". comme illustré ci-dessous. Assurez-vous que cela fonctionne avant de continuer.

A28b734e190afff.png

Vous pouvez également consulter la sortie de la console pour connaître la taille imprimée des caractéristiques de sortie générées par ce modèle. Après avoir exécuté des zéros dans le modèle MobileNet, une forme de [1, 1024] s'affiche. Le premier élément est simplement une taille de lot de 1, et vous pouvez voir qu'il renvoie 1 024 caractéristiques qui peuvent ensuite être utilisées pour vous aider à classer de nouveaux objets.

10. Définir l'en-tête du nouveau modèle

Vous allez à présent définir votre tête de modèle, c'est-à-dire un perceptron très multicouche.

script

let model = tf.sequential();
model.add(tf.layers.dense({inputShape: [1024], units: 128, activation: 'relu'}));
model.add(tf.layers.dense({units: CLASS_NAMES.length, activation: 'softmax'}));

model.summary();

// Compile the model with the defined optimizer and specify a loss function to use.
model.compile({
  // Adam changes the learning rate over time which is useful.
  optimizer: 'adam',
  // Use the correct loss function. If 2 classes of data, must use binaryCrossentropy.
  // Else categoricalCrossentropy is used if more than 2 classes.
  loss: (CLASS_NAMES.length === 2) ? 'binaryCrossentropy': 'categoricalCrossentropy',
  // As this is a classification problem you can record accuracy in the logs too!
  metrics: ['accuracy']
});

Examinons ce code. Vous allez commencer par définir un modèle tf.séquentielle auquel vous ajouterez des couches de modèle.

Ajoutez ensuite une couche dense en tant que couche d'entrée à ce modèle. Il présente la forme d'entrée 1024, car les sorties des fonctionnalités MobileNet v3 sont de cette taille. Vous l'avez découvert à l'étape précédente après les avoir transmises via le modèle. Cette couche contient 128 neurons qui utilisent la fonction d'activation ReLU.

Si vous débutez avec les fonctions d'activation et les couches de modèle, suivez le cours détaillé au début de cet atelier pour comprendre ce qu'il se passe en arrière-plan.

La couche suivante à ajouter est la couche de sortie. Le nombre de neurones doit être égal au nombre de classes que vous essayez de prédire. Pour ce faire, utilisez CLASS_NAMES.length pour trouver le nombre de classes que vous prévoyez de classer, soit le nombre de boutons de collecte de données trouvé dans l'interface utilisateur. Comme il s'agit d'un problème de classification, vous devez activer l'softmax sur cette couche de sortie, qui doit être utilisée lors de la création d'un modèle pour résoudre les problèmes de classification plutôt que la régression.

Imprimez maintenant un model.summary() pour imprimer la vue d'ensemble du modèle défini depuis la console.

Enfin, compilez le modèle pour qu'il soit prêt à être entraîné. Ici, l'optimiseur est défini sur adam et la perte sera binaryCrossentropy si CLASS_NAMES.length est égal à 2 ou categoricalCrossentropy si au moins trois classes sont utilisées pour : classer. Des métriques de précision sont également demandées afin qu'elles puissent être surveillées plus tard dans les journaux à des fins de débogage.

Dans la console, vous devriez obtenir un écran semblable à celui-ci:

22eaf32286fea4bb.png

Notez que plus de 130 000 paramètres peuvent être entraînés. Comme il s'agit d'une simple couche dense de neurones réguliers, l'entraînement sera assez rapide.

Une fois le projet terminé, vous pouvez essayer de modifier le nombre de neurones dans la première couche pour voir à quel point vous pouvez le réduire tout en atteignant des performances correctes. Souvent, avec le machine learning, il existe un certain niveau d'essai et d'erreur pour identifier les valeurs de paramètre optimales afin de faire le meilleur compromis entre utilisation des ressources et vitesse.

11. Activer la webcam

Il est maintenant temps d'étoffer la fonction enableCam() que vous avez définie précédemment. Ajoutez une fonction nommée hasGetUserMedia() comme indiqué ci-dessous, puis remplacez le contenu de la fonction enableCam() précédemment définie par le code correspondant ci-dessous.

script

function hasGetUserMedia() {
  return !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
}

function enableCam() {
  if (hasGetUserMedia()) {
    // getUsermedia parameters.
    const constraints = {
      video: true,
      width: 640,
      height: 480
    };

    // Activate the webcam stream.
    navigator.mediaDevices.getUserMedia(constraints).then(function(stream) {
      VIDEO.srcObject = stream;
      VIDEO.addEventListener('loadeddata', function() {
        videoPlaying = true;
        ENABLE_CAM_BUTTON.classList.add('removed');
      });
    });
  } else {
    console.warn('getUserMedia() is not supported by your browser');
  }
}

Commencez par créer une fonction nommée hasGetUserMedia() pour vérifier si le navigateur est compatible avec getUserMedia() en vérifiant l'existence de propriétés clés de l'API du navigateur.

Dans la fonction enableCam(), utilisez la fonction hasGetUserMedia() que vous venez de définir pour vérifier si elle est compatible. Si ce n'est pas le cas, imprimez un avertissement dans la console.

S'il est compatible, définissez des contraintes pour votregetUserMedia() (par exemple, si vous souhaitez diffuser le flux vidéo uniquement, et si vous préférezwidth de la vidéo pour être640 en pixels et la tailleheight CANNOT TRANSLATE480 pixels. Pourquoi ? Il n'y a pas grand-chose à avoir à augmenter la taille d'une vidéo pour qu'elle soit redimensionnée au format 224 x 224 pixels et alimenter le modèle MobileNet. Vous pouvez également économiser des ressources informatiques en demandant une résolution inférieure. La plupart des caméras sont compatibles avec une résolution de cette taille.

Appelez ensuite navigator.mediaDevices.getUserMedia() avec le constraints détaillé ci-dessus, puis attendez que le stream soit renvoyé. Une fois que la valeur stream est renvoyée, vous pouvez obtenir que votre élément VIDEO lance la lecture de stream en la définissant comme valeur srcObject.

Vous devez également ajouter un écouteur d'événements pour l'élément VIDEO afin de savoir quand stream s'est chargé et s'est lu correctement.

Une fois que le hammam se charge, vous pouvez définir videoPlaying sur "true" et supprimer le ENABLE_CAM_BUTTON pour éviter qu'il ne soit cliqué à nouveau en définissant sa classe sur "removed".

Maintenant, exécutez votre code, cliquez sur le bouton d'activation de la caméra et autorisez l'accès à la webcam. Si vous effectuez cette opération pour la première fois, l'élément vidéo de la page doit s'afficher comme suit:

Bb77eb1affa9b883.png

Il est maintenant temps d'ajouter une fonction pour gérer les clics sur le bouton dataCollector.

12. Gestionnaire d'événements du bouton de collecte de données

Vous devez maintenant remplir votre fonction actuellement vide appelée gatherDataForClass().. Il s'agit de la fonction de gestionnaire d'événements que vous avez attribuée pour les boutons dataCollector au début de l'atelier de programmation.

script

/**
 * Handle Data Gather for button mouseup/mousedown.
 **/
function gatherDataForClass() {
  let classNumber = parseInt(this.getAttribute('data-1hot'));
  gatherDataState = (gatherDataState === STOP_DATA_GATHER) ? classNumber : STOP_DATA_GATHER;
  dataGatherLoop();
}

Vérifiez d'abord l'attribut data-1hot sur le bouton sur lequel l'utilisateur clique pour appeler this.getAttribute() en utilisant le nom de l'attribut, dans notre cas data-1hot. Comme il s'agit d'une chaîne, vous pouvez utiliser parseInt() pour la convertir en entier et attribuer ce résultat à une variable nommée classNumber..

Ensuite, définissez la variable gatherDataState en conséquence. Si la valeur gatherDataState actuelle est égale à STOP_DATA_GATHER (valeur que vous avez définie sur -1), cela signifie que vous ne collectez pas de données actuellement et qu'un événement mousedown s'est déclenché. Définissez la gatherDataState comme classNumber.

Cela signifie que vous recueillez actuellement des données et que l'événement déclenché est un événement mouseup. Vous souhaitez maintenant arrêter de collecter des données pour cette classe. Il vous suffit de le redéfinir sur l'état STOP_DATA_GATHER pour mettre fin à la boucle de collecte de données que vous définirez bientôt.

Enfin, démarrez l'appel de dataGatherLoop(), qui effectue l'enregistrement des données de classe.

13. Collecte des données

Définissez maintenant la fonction dataGatherLoop(). Cette fonction sert à échantillonner des images à partir de la vidéo de la webcam, à les transmettre au modèle MobileNet et à capturer les sorties de ce modèle (les vecteurs de caractéristiques 1024).

Il stocke ensuite ces données avec l'ID gatherDataState du bouton sur lequel vous appuyez, afin que vous sachiez à quelle classe ces données correspondent.

Voyons cela de plus près :

script

function dataGatherLoop() {
  if (videoPlaying && gatherDataState !== STOP_DATA_GATHER) {
    let imageFeatures = tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [MOBILE_NET_INPUT_HEIGHT,
          MOBILE_NET_INPUT_WIDTH], true);
      let normalizedTensorFrame = resizedTensorFrame.div(255);
      return mobilenet.predict(normalizedTensorFrame.expandDims()).squeeze();
    });

    trainingDataInputs.push(imageFeatures);
    trainingDataOutputs.push(gatherDataState);

    // Intialize array index element if currently undefined.
    if (examplesCount[gatherDataState] === undefined) {
      examplesCount[gatherDataState] = 0;
    }
    examplesCount[gatherDataState]++;

    STATUS.innerText = '';
    for (let n = 0; n < CLASS_NAMES.length; n++) {
      STATUS.innerText += CLASS_NAMES[n] + ' data count: ' + examplesCount[n] + '. ';
    }
    window.requestAnimationFrame(dataGatherLoop);
  }
}

Vous ne allez poursuivre l'exécution de cette fonction que si la valeur de videoPlaying est "true", ce qui signifie que la webcam est active. De plus, gatherDataState n'est pas égal à STOP_DATA_GATHER et l'utilisateur appuie sur un bouton de collecte de données de classe.

Encapsulez ensuite votre code dans un tf.tidy() pour supprimer les Tensors créés dans le code suivant. Le résultat de cette exécution du code tf.tidy() est stocké dans une variable appelée imageFeatures.

Vous pouvez désormais capturer une image de la webcam VIDEO à l'aide de tf.browser.fromPixels(). Le Tensor obtenu contenant les données d'image est stocké dans une variable appelée videoFrameAsTensor.

Redimensionnez ensuite la variable videoFrameAsTensor de façon à obtenir la forme correcte de l'entrée du modèle MobileNet. Utilisez un appel tf.image.resizeBilinear() avec le Tensor que vous souhaitez reformater en tant que premier paramètre, puis une forme définissant la nouvelle hauteur et largeur comme défini par les constantes que vous avez déjà créées. Enfin, définissez l'alignement des angles sur "true" en transmettant le troisième paramètre afin d'éviter tout problème d'alignement lors du redimensionnement. Le résultat de ce redimensionnement est stocké dans une variable appelée resizedTensorFrame.

Notez que ce redimensionnement primitif étire l'image, car la webcam fait 640 x 480 pixels, et que le modèle a besoin d'une image carrée de 224 x 224 pixels.

Pour les besoins de cette démonstration, cela devrait fonctionner. Cependant, une fois cet atelier de programmation terminé, vous pouvez essayer de recadrer un carré à partir de cette image afin d'améliorer davantage les résultats pour tout système de production que vous créerez ultérieurement.

Ensuite, normalisez les données d'image. Les données d'image sont toujours comprises entre 0 et 255 lorsque vous utilisez tf.browser.frompixels(). Vous pouvez donc simplement diviser redimensionnerdTensorFrame par 255 pour vous assurer que toutes les valeurs sont comprises entre 0 et 1, ce qui correspond au résultat du modèle MobileNet.

Enfin, dans la section tf.tidy() du code, transférez ce Tensor normalisé via le modèle chargé en appelant mobilenet.predict(), auquel vous transmettez la version développée de normalizedTensorFrame à l'aide de expandDims(), de sorte qu'il soit un lot de 1, car le modèle s'attend à recevoir un lot d'entrées pour traitement.

Une fois le résultat obtenu, vous pouvez immédiatement appeler squeeze() sur ce résultat renvoyé pour le reconduire dans un Tensor 1D, que vous pouvez ensuite renvoyer et attribuer à la variable imageFeatures qui capture le résultat de tf.tidy(). }.

Maintenant que vous disposez du imageFeatures du modèle MobileNet, vous pouvez l'enregistrer en le transférant sur le tableau trainingDataInputs que vous avez défini précédemment.

Vous pouvez également enregistrer ce que représente cette entrée en transmettant également le paramètre gatherDataState actuel au tableau trainingDataOutputs.

Notez que la variable gatherDataState aurait été définie sur l'ID numérique de la classe actuelle dans lequel vous enregistrez des données lorsque l'utilisateur a cliqué sur le bouton dans la fonction gatherDataForClass() définie précédemment.

À ce stade, vous pouvez également augmenter le nombre d'exemples pour une classe donnée. Pour ce faire, commencez par vérifier si l'index a été initialisé dans le tableau examplesCount. S'il n'est pas défini, définissez-le sur 0 pour initialiser le compteur pour l'ID numérique d'une classe donnée, puis vous pouvez augmenter sa valeur examplesCount pour le gatherDataState actuel.

À présent, mettez à jour le texte de l'élément STATUS sur la page Web pour afficher le nombre actuel de chaque classe à mesure qu'elles sont capturées. Pour ce faire, effectuez une boucle sur le tableau CLASS_NAMES et imprimez le nom lisible par le nombre de données associé au même index dans examplesCount.

Enfin, appelez window.requestAnimationFrame() avec dataGatherLoop transmis en tant que paramètre pour appeler de nouveau cette fonction de manière récursive. Cette opération continue à échantillonner les images de la vidéo jusqu'à ce que l'élément mouseup du bouton soit détecté. gatherDataState est alors défini sur STOP_DATA_GATHER,, auquel la boucle de collecte de données s'arrête.

Si vous exécutez votre code maintenant, vous devriez pouvoir cliquer sur le bouton d'activation de la caméra, attendre le chargement de la webcam, puis cliquer de manière prolongée sur chacun des boutons de collecte de données pour obtenir des exemples pour chaque classe de données. Je trouve ici des données pour mon téléphone mobile et ma main.

541051644a45131f.gif

Le texte d'état devrait être mis à jour, car il stocke tous les Tensors en mémoire, comme illustré dans la capture d'écran ci-dessus.

14. Entraîner et prédire

L'étape suivante consiste à implémenter le code de votre fonction trainAndPredict() actuellement vide, qui sert à effectuer l'apprentissage par transfert. Examinons le code:

script

async function trainAndPredict() {
  predict = false;
  tf.util.shuffleCombo(trainingDataInputs, trainingDataOutputs);
  let outputsAsTensor = tf.tensor1d(trainingDataOutputs, 'int32');
  let oneHotOutputs = tf.oneHot(outputsAsTensor, CLASS_NAMES.length);
  let inputsAsTensor = tf.stack(trainingDataInputs);

  let results = await model.fit(inputsAsTensor, oneHotOutputs, {shuffle: true, batchSize: 5, epochs: 10,
      callbacks: {onEpochEnd: logProgress} });

  outputsAsTensor.dispose();
  oneHotOutputs.dispose();
  inputsAsTensor.dispose();
  predict = true;
  predictLoop();
}

function logProgress(epoch, logs) {
  console.log('Data for epoch ' + epoch, logs);
}

Tout d'abord, assurez-vous d'arrêter la prédiction en définissant predict sur false.

Ensuite, brassez vos tableaux d'entrée et de sortie à l'aide de tf.util.shuffleCombo() pour vous assurer que l'ordre ne cause aucun problème d'entraînement.

Convertissez votre tableau de sortie, trainingDataOutputs,, en un Tensor de 1re type de type int32 afin qu'il soit prêt à être utilisé dans un encodage à chaud. Il est stocké dans une variable nommée outputsAsTensor.

Utilisez la fonction tf.oneHot() avec cette variable outputsAsTensor ainsi que le nombre maximal de classes à encoder, c'est-à-dire uniquement CLASS_NAMES.length. Vos sorties encodées à chaud sont maintenant stockées dans un nouveau Tensor appelé oneHotOutputs.

Notez qu'actuellement trainingDataInputs est un tableau de Tensors enregistrés. Pour les utiliser pour l'entraînement, vous devez convertir le tableau de Tensors en Tensor 2D standard.

Pour ce faire, il existe une fonction intéressante dans la bibliothèque TensorFlow.js, appelée tf.stack() :

qui utilise un tableau de Tensors et les empilent pour générer un Tensor de dimension supérieure. Dans ce cas, un Tensor 2D est renvoyé, c'est-à-dire un lot d'entrées dimensionnelles de 1 024 degrés contenant les caractéristiques enregistrées, ce qui est nécessaire pour l'entraînement.

Ensuite, await model.fit() pour entraîner l'en-tête du modèle personnalisé. Il vous suffit de transmettre votre variable inputsAsTensor avec le oneHotOutputs pour représenter les données d'entraînement à utiliser, par exemple, pour les entrées et les sorties cibles. Dans l'objet de configuration du troisième paramètre, définissezshuffle jusqu'autrue, utilisezbatchSize sur5, avecepochs Définir sur10, puis spécifiez une valeurcallback sionEpochEnd à lalogProgress que vous définirez bientôt.

Enfin, vous pouvez supprimer les Tensors créés pendant l'entraînement du modèle. Vous pouvez ensuite redéfinir predict sur true pour que les prédictions se produisent à nouveau, puis appeler la fonction predictLoop() pour commencer à prédire les images de webcams en direct.

Vous pouvez également définir la fonction logProcess() pour consigner l'état de l'entraînement, utilisée dans model.fit() ci-dessus, qui affiche les résultats dans la console après chaque cycle d'entraînement.

Vous y êtes presque ! Le moment est venu d'ajouter la fonction predictLoop() pour effectuer des prédictions.

Boucle de prédiction principale

Vous mettez en œuvre la boucle de prédiction principale qui échantillonne les images d'une webcam et prédit en continu ce qui est dans chaque image avec des résultats en temps réel dans le navigateur.

Vérifions le code:

script

function predictLoop() {
  if (predict) {
    tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO).div(255);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor,[MOBILE_NET_INPUT_HEIGHT,
          MOBILE_NET_INPUT_WIDTH], true);

      let imageFeatures = mobilenet.predict(resizedTensorFrame.expandDims());
      let prediction = model.predict(imageFeatures).squeeze();
      let highestIndex = prediction.argMax().arraySync();
      let predictionArray = prediction.arraySync();

      STATUS.innerText = 'Prediction: ' + CLASS_NAMES[highestIndex] + ' with ' + Math.floor(predictionArray[highestIndex] * 100) + '% confidence';
    });

    window.requestAnimationFrame(predictLoop);
  }
}

Commencez par vérifier que predict est vrai, afin que les prédictions ne soient effectuées qu'après l'entraînement d'un modèle et que celui-ci soit disponible pour utilisation.

Vous pouvez ensuite obtenir les caractéristiques de l'image actuelle, comme vous l'avez fait pour la fonction dataGatherLoop(). Vous venez de récupérer une image de la webcam à l'aide de tf.browser.from pixels(), de la normaliser, de la redimensionner au format 224 x 224 pixels, puis de transmettre ces données via le modèle MobileNet pour obtenir les caractéristiques d'image obtenues.

Cependant, vous pouvez désormais utiliser la tête de modèle que vous venez d'entraîner pour effectuer une prédiction en transmettant la imageFeatures obtenue qui vient d'être trouvée via la fonction predict() du modèle entraîné. Vous pouvez ensuite compresser le Tensor obtenu pour le convertir en une dimension et l'attribuer à une variable appelée prediction.

Avec ce prediction, vous pouvez trouver l'index qui a la valeur la plus élevée à l'aide de argMax(), puis convertir ce Tensor obtenu en tableau à l'aide de arraySync() afin d'obtenir les données sous-jacentes en JavaScript pour découvrir la position de l'index l'élément le plus performant. Cette valeur est stockée dans la variable appelée highestIndex.

Vous pouvez également obtenir les scores de confiance des prédictions de la même manière en appelant arraySync() directement sur le Tensor prediction.

Vous disposez maintenant de tout ce dont vous avez besoin pour mettre à jour le texte de STATUS avec les données prediction. Pour obtenir la chaîne lisible par l'utilisateur pour la classe, recherchez simplement highestIndex dans le tableau CLASS_NAMES, puis récupérez la valeur de confiance de predictionArray. Pour le rendre plus lisible sous forme de pourcentage, multipliez-le par 100, puis math.floor().

Enfin, vous pouvez utiliser window.requestAnimationFrame() pour appeler à nouveau predictionLoop() une fois qu'elle est prête et ainsi obtenir une classification en temps réel sur votre flux vidéo. Le processus se poursuit jusqu'à ce que predict soit défini sur false si vous choisissez d'entraîner un nouveau modèle avec de nouvelles données.

Ce qui vous mène au dernier point du puzzle. Implémenter le bouton de réinitialisation

15. Mettre en œuvre le bouton de réinitialisation

Vous avez presque terminé La dernière partie du puzzle consiste à implémenter un bouton de réinitialisation pour recommencer. Vous trouverez ci-dessous le code de votre fonction reset() actuellement vide. Mettez-la à jour comme suit:

script

/**
 * Purge data and start over. Note this does not dispose of the loaded
 * MobileNet model and MLP head tensors as you will need to reuse
 * them to train a new model.
 **/
function reset() {
  predict = false;
  examplesCount.length = 0;
  for (let i = 0; i < trainingDataInputs.length; i++) {
    trainingDataInputs[i].dispose();
  }
  trainingDataInputs.length = 0;
  trainingDataOutputs.length = 0;
  STATUS.innerText = 'No data collected';

  console.log('Tensors in memory: ' + tf.memory().numTensors);
}

Commencez par arrêter les boucles de prédiction en cours d'exécution en définissant predict sur false. Supprimez ensuite tout le contenu du tableau examplesCount en définissant sa longueur sur 0, ce qui est pratique pour effacer tout le contenu d'un tableau.

Parcourez maintenant le trainingDataInputs enregistré actuel et veillez à ce que dispose() de chaque Tensor qu'il contient libère à nouveau de la mémoire, car les Tensors ne sont pas nettoyés par le collecteur de mémoire JavaScript.

Vous pouvez maintenant définir la longueur de 0 de manière sécurisée sur les tableaux trainingDataInputs et trainingDataOutputs pour effacer également ces tableaux.

Enfin, définissez le texte STATUS sur une valeur sensible et imprimez les Tensors laissés en mémoire pour effectuer un contrôle de l'intégrité.

Notez que quelques centaines de Tensors seront encore en mémoire, car le modèle MobileNet et le perceptron multicouche que vous avez défini ne sont pas jetés. Vous devrez les réutiliser avec de nouvelles données d'entraînement si vous décidez de vous entraîner à nouveau après cette réinitialisation.

16. Faisons un essai

Il est temps de tester votre propre version de Teachable Machine.

Accédez à l'aperçu en direct, activez la webcam, collectez au moins 30 échantillons pour la classe 1 pour un objet de votre salle, puis faites de même pour la classe 2 pour un autre objet, cliquez sur "Entraîner" et consultez le journal de la console pour suivre vos progrès. L'entraînement devrait être assez rapide:

bf1ac3cc5b15736.gif

Une fois l'entraînement terminé, affichez les objets devant l'appareil photo pour obtenir des prédictions en direct qui s'afficheront dans la zone de texte de l'état de la page Web vers le haut. Si vous rencontrez des difficultés, vérifiez mon code de travail terminé pour savoir si vous avez raté quelque chose.

17. Félicitations

Félicitations ! Vous venez de terminer votre tout premier exemple d'apprentissage par transfert en utilisant TensorFlow.js dans le navigateur.

Testez-le sur plusieurs objets. Vous remarquerez peut-être que certains éléments sont plus difficiles à reconnaître que d'autres, surtout s'ils sont similaires à un autre. Vous devrez peut-être ajouter d'autres classes ou données d'entraînement pour les différencier.

Résumé

Dans cet atelier de programmation, vous avez appris ce qui suit:

  1. En quoi consiste l'apprentissage par transfert et quels sont ses avantages par rapport à l'entraînement d'un modèle complet.
  2. Obtenez des modèles à réutiliser avec TensorFlow Hub.
  3. Comment configurer une application Web adaptée à l'apprentissage par transfert
  4. Charger et utiliser un modèle de base pour générer des caractéristiques d'image
  5. Entraînement d'une nouvelle tête de prédiction capable de reconnaître des objets personnalisés à partir d'images de webcams
  6. Utilisation des modèles obtenus pour classer les données en temps réel

Et ensuite ?

Maintenant que vous disposez d'une base de travail, quelles idées créatives pouvez-vous mettre en place pour étendre ce modèle de machine learning à un cas d'utilisation concret ? Peut-être pourriez-vous révolutionner le secteur dans lequel vous travaillez actuellement pour aider les collaborateurs de votre entreprise à entraîner des modèles afin qu'ils puissent classer les éléments importants dans leur travail quotidien ? Les possibilités sont innombrables.

Pour aller plus loin, suivez ce cours complet sans frais, qui explique comment combiner les deux modèles actuels dans un seul et même modèle pour plus d'efficacité.

Si vous souhaitez en savoir plus sur la théorie derrière l'application Teachable Machine d'origine, consultez ce tutoriel.

Partagez vos créations

Vous pouvez facilement étendre vos créations pour d'autres cas d'utilisation créatifs, et nous vous encourageons à sortir du lot et à pirater votre chaîne.

N'oubliez pas de nous taguer sur les réseaux sociaux à l'aide du hashtag #MadeWithTFJS, et votre projet sera peut-être mis en avant sur notre blog TensorFlow ou même événements futurs. Nous serions ravis de découvrir vos créations.

Sites Web à consulter