TensorFlow.js: créer votre propre machine Teachable Machine à l'aide de l'apprentissage par transfert avec TensorFlow.js

1. Avant de commencer

Ces dernières années, l'utilisation du modèle TensorFlow.js a augmenté de manière exponentielle. De nombreux développeurs JavaScript cherchent désormais à réentraîner des modèles de pointe existants afin qu'ils utilisent des données personnalisées propres à leur secteur. L'apprentissage par transfert consiste à prendre un modèle existant (souvent appelé modèle de base) et à l'utiliser sur un domaine similaire, mais différent.

L'apprentissage par transfert présente de nombreux avantages par rapport à partir d'un modèle entièrement vide. Vous pouvez réutiliser des connaissances déjà acquises à partir d'un modèle entraîné précédent et vous aurez besoin de moins d'exemples du nouvel élément à classer. De plus, l'entraînement est souvent beaucoup plus rapide, car il suffit de réentraîner les dernières couches de l'architecture du modèle plutôt que l'ensemble du réseau. Pour cette raison, l'apprentissage par transfert est particulièrement adapté à l'environnement de navigateur Web dans lequel les ressources peuvent varier en fonction de l'appareil exécuté. Il offre également un accès direct aux capteurs pour faciliter l'acquisition de données.

Cet atelier de programmation vous explique comment créer une application Web à partir d'un canevas vierge, en recréant l'application populaire " Teachable Machine" sur votre site Web. Le site Web vous permet de créer une application Web fonctionnelle que n'importe quel utilisateur peut utiliser pour reconnaître un objet personnalisé à l'aide de quelques exemples d'images de sa webcam. Le site Web est volontairement réduit au minimum pour que vous puissiez vous concentrer sur les aspects du machine learning de cet atelier de programmation. Toutefois, comme pour le site Web d'origine de Teachable Machine, vous avez de nombreuses possibilités pour appliquer votre expérience de développeur Web existante afin d'améliorer l'expérience utilisateur.

Prérequis

Cet atelier de programmation s'adresse aux développeurs Web qui connaissent un peu les modèles prédéfinis TensorFlow.js et l'utilisation de base des API, et qui souhaitent commencer à utiliser l'apprentissage par transfert dans TensorFlow.js.

  • Pour cet atelier, vous devez posséder des connaissances de base sur TensorFlow.js, HTML5, CSS et JavaScript.

Si vous débutez avec TensorFlow.js, pensez à suivre d'abord ce cours sans frais de base, qui n'implique aucune connaissance en machine learning ni à TensorFlow.js, et vous apprend tout ce que vous devez savoir en quelques étapes.

Points abordés

  • Présentation de TensorFlow.js et des raisons pour lesquelles vous devriez l'utiliser dans votre prochaine application Web
  • Découvrez comment créer une page Web HTML/CSS /JS simplifiée qui réplique l'expérience utilisateur de Teachable Machine.
  • Comment utiliser TensorFlow.js pour charger un modèle de base pré-entraîné, en particulier MobileNet, afin de générer des caractéristiques d'image utilisables dans l'apprentissage par transfert.
  • Comment collecter des données à partir de la webcam d'un utilisateur pour plusieurs classes de données que vous souhaitez reconnaître.
  • Comment créer et définir un perceptron multicouche qui exploite les caractéristiques de l'image pour apprendre à classer de nouveaux objets à l'aide de ces caractéristiques.

Allons pirater...

Prérequis

  • Pour suivre, il est préférable d'utiliser un compte Glitch.com ou d'utiliser un environnement de diffusion Web que vous maîtrisez parfaitement.

2. Qu'est-ce que TensorFlow.js ?

54e81d02971f53e8.png

TensorFlow.js est une bibliothèque de machine learning Open Source qui peut s'exécuter partout où JavaScript peut être exécuté. Il est basé sur la bibliothèque TensorFlow d'origine écrite en Python et vise à recréer cette expérience de développement et cet ensemble d'API pour l'écosystème JavaScript.

Où l'utiliser ?

Étant donné la portabilité de JavaScript, vous pouvez désormais écrire dans un seul langage et exécuter facilement des modèles de machine learning sur l'ensemble des plates-formes suivantes:

  • Côté client dans le navigateur Web en utilisant vanilla JavaScript
  • Côté serveur et même appareils IoT comme Raspberry Pi avec Node.js
  • Applications de bureau avec Electron
  • Applications mobiles natives utilisant React Native

TensorFlow.js est également compatible avec plusieurs backends dans chacun de ces environnements (les environnements matériels dans lesquels il peut s'exécuter, tels que le processeur ou WebGL, par exemple). Un "backend" dans ce contexte ne signifie pas qu'il s'agit d'un environnement côté serveur (le backend d'exécution peut par exemple être côté client dans WebGL) pour assurer la compatibilité et garantir un fonctionnement rapide. Actuellement, TensorFlow.js est compatible avec:

  • Exécution WebGL sur la carte graphique de l'appareil (GPU) : il s'agit du moyen le plus rapide d'exécuter des modèles plus volumineux (plus de 3 Mo) avec l'accélération du GPU.
  • Exécution de Web Assembly (WASM) sur le processeur : permet d'améliorer les performances du processeur sur tous les appareils, y compris les téléphones mobiles d'ancienne génération, par exemple. Cela est mieux adapté aux modèles plus petits (moins de 3 Mo) qui peuvent s'exécuter plus rapidement sur processeur avec WASM qu'avec WebGL en raison de la surcharge de l'importation de contenu sur un processeur graphique.
  • Exécution du processeur : le remplacement si aucun des autres environnements n'est disponible. C'est la plus lente des trois, mais elle est toujours là pour vous.

Remarque:Vous pouvez choisir de forcer l'un de ces backends si vous savez sur quel appareil vous allez exécuter l'exécution, ou laisser TensorFlow.js décider à votre place si vous ne le spécifiez pas.

Super-pouvoirs côté client

L'exécution de TensorFlow.js dans un navigateur Web sur la machine cliente peut offrir plusieurs avantages qu'il convient de prendre en compte.

Confidentialité

Vous pouvez à la fois entraîner et classer les données sur la machine cliente sans jamais envoyer de données à un serveur Web tiers. Dans certains cas, il peut être nécessaire de se conformer à des lois locales, telles que le RGPD, ou de traiter des données que l'utilisateur souhaite conserver sur son ordinateur et non envoyées à un tiers.

Vitesse

Comme vous n'avez pas besoin d'envoyer de données à un serveur distant, l'inférence (autrement dit, la classification des données) peut être plus rapide. Qui plus est, vous pouvez accéder directement aux capteurs de l'appareil (appareil photo, micro, GPS, accéléromètre, etc.) si l'utilisateur vous y autorise.

Portée et évolutivité

En un clic, n'importe quel internaute peut cliquer sur un lien que vous lui envoyez, ouvrir la page Web dans son navigateur et utiliser votre création. Pas besoin d'une configuration Linux complexe côté serveur avec des pilotes CUDA et bien plus encore pour utiliser le système de machine learning.

Coût

L'absence de serveur signifie que vous ne payez qu'un CDN pour héberger vos fichiers HTML, CSS et JS, ainsi que vos fichiers de modèle. Le coût d'un CDN est bien moins cher que d'avoir un serveur (éventuellement connecté à une carte graphique) fonctionnant 24h/24, 7j/7.

Fonctionnalités côté serveur

L'implémentation Node.js de TensorFlow.js permet d'activer les fonctionnalités suivantes.

Compatibilité CUDA complète

Côté serveur, pour accélérer l'accélération de la carte graphique, vous devez installer les pilotes NVIDIA CUDA pour permettre à TensorFlow de fonctionner avec la carte graphique (contrairement au navigateur qui utilise WebGL ; aucune installation n'est requise). Cependant, grâce à la prise en charge complète de CUDA, vous pouvez exploiter pleinement les capacités de niveau inférieur de la carte graphique, ce qui accélère les temps d'entraînement et d'inférence. Les performances sont équivalentes à celles de l'implémentation Python de TensorFlow, car elles partagent le même backend C++.

Taille du modèle

Pour les modèles de pointe issus de la recherche, vous travaillez peut-être avec des modèles très volumineux, pouvant atteindre plusieurs gigaoctets. Ces modèles ne peuvent actuellement pas être exécutés dans le navigateur Web en raison des limites d'utilisation de la mémoire par onglet du navigateur. Pour exécuter ces modèles plus volumineux, vous pouvez utiliser Node.js sur votre propre serveur avec les caractéristiques matérielles requises pour exécuter efficacement un tel modèle.

Internet des objets

Node.js est compatible avec les ordinateurs à carte unique courants tels que Raspberry Pi, ce qui signifie que vous pouvez également exécuter des modèles TensorFlow.js sur ces appareils.

Vitesse

Node.js est écrit en JavaScript, ce qui signifie qu'il bénéficie d'une compilation juste à temps. Cela signifie que vous pouvez souvent constater une amélioration des performances lorsque vous utilisez Node.js, car il sera optimisé au moment de l'exécution, en particulier pour les éventuels prétraitements que vous effectuez. Cette étude de cas en est un bon exemple. Elle montre comment Hugging Face a utilisé Node.js pour multiplier par deux les performances de son modèle de traitement du langage naturel.

Vous connaissez désormais les principes de base de TensorFlow.js, ses emplacements d'exécution et certains de ses avantages. Commençons à l'utiliser !

3. Apprentissage par transfert

Qu'est-ce que l'apprentissage par transfert ?

L'apprentissage par transfert consiste à assimiler des connaissances déjà acquises pour apprendre une chose différente, mais similaire.

C'est ce que nous faisons tout le temps, les humains. Tout au long de votre vie, votre cerveau contient des expériences qui vous aideront à reconnaître des choses que vous n'avez jamais vues auparavant. Prenons l'exemple de ce saule:

e28070392cd4afb9.png

Selon l'endroit où vous vous trouvez dans le monde, il est possible que vous n'ayez jamais vu ce type d'arbre auparavant.

Pourtant, si je vous demande de me dire s'il y a des saules dans la nouvelle image ci-dessous, vous pourrez probablement les repérer assez rapidement, même s'ils sont sous un angle différent, et légèrement différent de l'original que je vous ai montré.

d9073a0d5df27222.png

Votre cerveau contient déjà un grand nombre de neurones qui savent identifier des objets semblables à des arbres, et d'autres neurones capables de détecter de longues lignes droites. Vous pouvez réutiliser ces connaissances pour classer rapidement un saule, c'est-à-dire un objet ressemblant à un arbre comportant de nombreuses longues branches verticales droites.

De même, si vous disposez d'un modèle de machine learning déjà entraîné sur un domaine, comme la reconnaissance d'images, vous pouvez le réutiliser pour effectuer une tâche différente, mais associée.

Vous pouvez faire de même avec un modèle avancé comme MobileNet, un modèle de recherche très populaire capable d'effectuer une reconnaissance d'image sur 1 000 types d'objets différents. Des chiens aux voitures, il a été entraîné sur un vaste ensemble de données appelé ImageNet qui contient des millions d'images étiquetées.

Dans cette animation, vous pouvez voir le nombre important de couches qu'elle contient dans ce modèle MobileNet V1:

7d4e1e35c1a89715.gif

Au cours de son entraînement, ce modèle a appris à extraire des caractéristiques communes importantes pour tous ces 1 000 objets. La plupart des caractéristiques de niveau inférieur qu'il utilise pour identifier ces objets peuvent également être utiles pour détecter de nouveaux objets qu'il n'a jamais vus auparavant. Après tout, tout n’est en fin de compte qu’une simple combinaison de lignes, de textures et de formes.

Examinons une architecture de réseau de neurones convolutif (CNN) traditionnelle (semblable à MobileNet) et voyons comment l'apprentissage par transfert peut exploiter ce réseau entraîné pour apprendre de nouvelles choses. L'image ci-dessous illustre l'architecture de modèle typique d'un réseau de neurones convolutif qui, dans ce cas précis, a été entraîné à reconnaître des chiffres manuscrits de 0 à 9:

baf4e3d434576106.png

Si vous pouviez séparer les couches de niveau inférieur pré-entraînées d'un modèle entraîné existant, comme illustré à gauche, des couches de classification vers la fin du modèle présenté à droite (parfois appelées "têtes de classification du modèle"), vous pourriez utiliser les couches de niveau inférieur afin de produire des caractéristiques de sortie pour une image donnée, en fonction des données d'origine sur lesquelles elle a été entraînée. Voici le même réseau, sans la tête de classification:

369a8a9041c6917d.png

En supposant que la nouvelle chose que vous essayez de reconnaître puisse également utiliser les caractéristiques de sortie que le modèle précédent a apprises, il y a de fortes chances qu'elles puissent être réutilisées à d'autres fins.

Dans le diagramme ci-dessus, ce modèle hypothétique a été entraîné avec des chiffres. Par conséquent, ce que vous avez appris sur les chiffres peut également s'appliquer à des lettres telles que a, b et c.

Vous pouvez maintenant ajouter une nouvelle tête de classification qui tente de prédire a, b ou c, comme indiqué ci-dessous:

db97e5e60ae73bbd.png

Ici, les couches de niveau inférieur sont figées et ne sont pas entraînées. Seule la nouvelle tête de classification se met à jour pour apprendre à partir des caractéristiques fournies par le modèle pré-entraîné pré-entraîné à gauche.

C'est ce que l'on appelle l'apprentissage par transfert. Teachable Machine effectue en coulisses.

Vous pouvez également constater qu'en ayant uniquement entraîné le perceptron multicouche à la toute fin du réseau, l'entraînement est beaucoup plus rapide que si vous deviez entraîner l'ensemble du réseau à partir de zéro.

Mais comment mettre la main sur les sous-parties d'un modèle ? Consultez la section suivante pour le savoir.

4. TensorFlow Hub : modèles de base

Trouver un modèle de base approprié

Pour consulter des modèles de recherche plus avancés et populaires tels que MobileNet, accédez au hub TensorFlow, puis filtrez les modèles adaptés à TensorFlow.js qui utilisent l'architecture MobileNet v3 pour obtenir des résultats semblables à ceux présentés ci-dessous:

c5dc1420c6238c14.png

Notez que certains de ces résultats sont du type "classification d'images". (voir le détail en haut à gauche de chaque résultat de la fiche de modèle), tandis que d'autres sont de type "vecteur de caractéristiques image".

Ces résultats de vecteur de caractéristiques d'image correspondent essentiellement aux versions pré-découpées de MobileNet que vous pouvez utiliser pour obtenir les vecteurs de caractéristiques d'image au lieu de la classification finale.

Ces modèles sont souvent appelés "modèles de base", que vous pouvez ensuite utiliser pour effectuer un apprentissage par transfert de la même manière que dans la section précédente, en ajoutant une nouvelle tête de classification et en l'entraînant avec vos propres données.

L'étape suivante consiste à vérifier le format TensorFlow.js sous lequel le modèle est publié pour un modèle de base donné. Si vous ouvrez la page de l'un de ces modèles MobileNet v3 à vecteur de caractéristiques, vous pouvez voir dans la documentation JavaScript qu'il se présente sous la forme d'un modèle graphique basé sur l'exemple d'extrait de code de la documentation qui utilise tf.loadGraphModel().

f97d903d2e46924b.png

Notez également que si vous trouvez un modèle au format "layers" plutôt qu'au format graphique, vous pouvez choisir les couches à figer et celles à libérer pour l'entraînement. Cette solution peut s'avérer très efficace lors de la création d'un modèle pour une nouvelle tâche, souvent appelé "modèle de transfert". Toutefois, pour l'instant, vous utiliserez pour ce tutoriel le type de modèle de graphe par défaut, sous lequel la plupart des modèles TensorFlow Hub sont déployés. Pour en savoir plus sur l'utilisation des modèles de couches, consultez le cours TensorFlow.js, de zéro à héros.

Avantages de l'apprentissage par transfert

Quel est l'intérêt d'utiliser l'apprentissage par transfert plutôt que d'entraîner entièrement l'ensemble de l'architecture du modèle ?

Tout d'abord, le temps d'entraînement est un avantage clé de l'utilisation d'une approche d'apprentissage par transfert, car vous disposez déjà d'un modèle de base entraîné sur lequel vous appuyer.

Deuxièmement, vous pouvez montrer beaucoup moins d'exemples du nouvel élément que vous essayez de classer en raison de l'entraînement qui a déjà eu lieu.

C'est très utile si vous disposez de peu de temps et de ressources pour collecter des exemples de données sur l'élément à classer et que vous devez créer rapidement un prototype avant de collecter plus de données d'entraînement pour le rendre plus robuste.

Étant donné le besoin de moins de données et la vitesse d'entraînement d'un réseau plus petit, l'apprentissage par transfert nécessite moins de ressources. Il est donc particulièrement adapté à l'environnement de navigation, car l'entraînement complet du modèle ne prend que quelques dizaines de secondes sur un ordinateur moderne au lieu de plusieurs heures, jours ou semaines.

Très bien ! Maintenant que vous connaissez l'essence de l'apprentissage par transfert, il est temps de créer votre propre version de Teachable Machine. C'est parti !

5. Préparez-vous au code

Prérequis

  • Un navigateur Web moderne.
  • Connaissances de base sur les langages HTML, CSS et JavaScript, ainsi que dans les outils pour les développeurs Chrome (affichage du résultat de la console)

Commençons à coder

Des modèles récurrents à partir desquels commencer ont été créés pour Glitch.com ou Codepen.io. Vous pouvez simplement cloner l'un ou l'autre modèle comme état de base pour cet atelier de programmation, en un seul clic.

Dans Glitch, cliquez sur le bouton remix this pour le dupliquer et créer un nouvel ensemble de fichiers que vous pourrez modifier.

Dans Codepen, vous pouvez également cliquer sur fork" en bas à droite de l'écran.

Ce squelette très simple vous fournit les fichiers suivants:

  • Page HTML (index.html)
  • Feuille de style (style.css)
  • Fichier pour écrire notre code JavaScript (script.js)

Pour plus de commodité, nous avons ajouté une importation dans le fichier HTML pour la bibliothèque TensorFlow.js. Elle se présente comme suit :

index.html

<!-- Import TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js" type="text/javascript"></script>

Autre possibilité: Utilisez l'éditeur Web de votre choix ou travaillez en local.

Si vous souhaitez télécharger le code et travailler en local ou sur un autre éditeur en ligne, il vous suffit de créer les trois fichiers nommés ci-dessus dans le même répertoire, puis de copier et coller le code de notre code récurrent Glitch dans chacun d'eux.

6. Code HTML standard de l'application

Par où commencer ?

Tous les prototypes nécessitent un échafaudage HTML de base sur lequel vous pouvez afficher vos résultats. Configurez-le maintenant. Vous allez ajouter:

  • Titre de la page.
  • Quelques descriptions.
  • Paragraphe d'état.
  • Une vidéo pour contenir le flux de la webcam lorsqu'il est prêt.
  • Plusieurs boutons permettent de démarrer l'appareil photo, de collecter des données ou de réinitialiser l'expérience.
  • Importations pour les fichiers TensorFlow.js et JS que vous coderez ultérieurement.

Ouvrez index.html et collez le code existant comme suit pour configurer les fonctionnalités ci-dessus:

index.html

<!DOCTYPE html>
<html lang="en">
  <head>
    <title>Transfer Learning - TensorFlow.js</title>
    <meta charset="utf-8">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <!-- Import the webpage's stylesheet -->
    <link rel="stylesheet" href="/style.css">
  </head>  
  <body>
    <h1>Make your own "Teachable Machine" using Transfer Learning with MobileNet v3 in TensorFlow.js using saved graph model from TFHub.</h1>
    
    <p id="status">Awaiting TF.js load</p>
    
    <video id="webcam" autoplay muted></video>
    
    <button id="enableCam">Enable Webcam</button>
    <button class="dataCollector" data-1hot="0" data-name="Class 1">Gather Class 1 Data</button>
    <button class="dataCollector" data-1hot="1" data-name="Class 2">Gather Class 2 Data</button>
    <button id="train">Train &amp; Predict!</button>
    <button id="reset">Reset</button>

    <!-- Import TensorFlow.js library -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.11.0/dist/tf.min.js" type="text/javascript"></script>

    <!-- Import the page's JavaScript to do some stuff -->
    <script type="module" src="/script.js"></script>
  </body>
</html>

Ça va breaker !

Analysons une partie du code HTML ci-dessus pour mettre en évidence les éléments clés que vous avez ajoutés.

  • Vous avez ajouté une balise <h1> pour le titre de la page, ainsi qu'une balise <p> avec l'ID "status", où vous allez imprimer les informations, car vous utiliserez différentes parties du système pour afficher les sorties.
  • Vous avez ajouté un élément <video> avec l'ID "webcam", sur lequel vous restituerez le flux de votre webcam par la suite.
  • Vous avez ajouté cinq éléments <button>. La première, dont l'ID est "enableCam", active l'appareil photo. Les deux boutons suivants ont une classe "dataCollector", qui vous permet de recueillir des exemples d'images pour les objets que vous souhaitez reconnaître. Le code que vous écrirez plus tard sera conçu pour que vous puissiez ajouter autant de boutons que vous le souhaitez. Ils fonctionneront automatiquement comme prévu.

Notez que ces boutons comportent également un attribut spécial défini par l'utilisateur appelé data-1hot, avec une valeur entière commençant à 0 pour la première classe. Il s'agit de l'index numérique que vous utiliserez pour représenter les données d'une classe spécifique. L'index permet d'encoder correctement les classes de sortie avec une représentation numérique au lieu d'une chaîne, car les modèles de ML ne peuvent fonctionner qu'avec des nombres.

Il existe également un attribut data-name contenant le nom lisible que vous souhaitez utiliser pour cette classe, ce qui vous permet de fournir un nom plus explicite à l'utilisateur au lieu d'une valeur d'index numérique issu de l'encodage 1 hot.

Enfin, un bouton d'entraînement et de réinitialisation permet de lancer le processus d'entraînement une fois les données collectées ou de réinitialiser l'application.

  • Vous avez également ajouté deux importations <script>. l'un pour TensorFlow.js et l'autre pour script.js, que vous définirez bientôt.

7. Ajouter style

Valeurs par défaut de l'élément

Ajoutez des styles aux éléments HTML que vous venez d'ajouter afin de vous assurer qu'ils s'affichent correctement. Voici quelques styles qui sont ajoutés correctement aux éléments de position et de taille. Rien de spécial. Vous pourriez certainement en ajouter plus tard pour améliorer l'expérience utilisateur, comme vous l'avez vu dans la vidéo Teachable Machine.

style.css

body {
  font-family: helvetica, arial, sans-serif;
  margin: 2em;
}

h1 {
  font-style: italic;
  color: #FF6F00;
}


video {
  clear: both;
  display: block;
  margin: 10px;
  background: #000000;
  width: 640px;
  height: 480px;
}

button {
  padding: 10px;
  float: left;
  margin: 5px 3px 5px 10px;
}

.removed {
  display: none;
}

#status {
  font-size:150%;
}

Parfait ! C'est tout ce dont vous avez besoin. Si vous prévisualisez maintenant le résultat, vous devriez obtenir un résultat semblable à celui-ci:

81909685d7566dcb.png

8. JavaScript: constantes clés et écouteurs

Définir des constantes clés

Tout d'abord, ajoutez quelques constantes clés que vous utiliserez dans l'application. Commencez par remplacer le contenu de script.js par les constantes suivantes:

script.js

const STATUS = document.getElementById('status');
const VIDEO = document.getElementById('webcam');
const ENABLE_CAM_BUTTON = document.getElementById('enableCam');
const RESET_BUTTON = document.getElementById('reset');
const TRAIN_BUTTON = document.getElementById('train');
const MOBILE_NET_INPUT_WIDTH = 224;
const MOBILE_NET_INPUT_HEIGHT = 224;
const STOP_DATA_GATHER = -1;
const CLASS_NAMES = [];

Voyons à quoi elles servent:

  • STATUS contient simplement une référence à la balise de paragraphe dans laquelle vous allez écrire les mises à jour de l'état.
  • VIDEO contient une référence à l'élément vidéo HTML qui affichera le flux de la webcam.
  • ENABLE_CAM_BUTTON, RESET_BUTTON et TRAIN_BUTTON récupèrent les références DOM à tous les boutons clés de la page HTML.
  • MOBILE_NET_INPUT_WIDTH et MOBILE_NET_INPUT_HEIGHT définissent respectivement la largeur et la hauteur d'entrée attendues du modèle MobileNet. En la stockant dans une constante près du haut du fichier, comme ceci, si vous décidez d'utiliser une autre version par la suite, il est plus facile de mettre à jour les valeurs une seule fois au lieu de la remplacer à de nombreux endroits différents.
  • La valeur de STOP_DATA_GATHER est définie sur - 1. Cette opération stocke une valeur d'état afin que vous sachiez quand l'utilisateur a arrêté de cliquer sur un bouton pour collecter des données à partir du flux de la webcam. Attribuer un nom plus explicite à ce nombre permet de rendre le code plus lisible par la suite.
  • CLASS_NAMES agit comme une recherche et contient les noms lisibles pour les prédictions de classe possibles. Ce tableau sera renseigné ultérieurement.

Maintenant que vous avez des références aux éléments clés, vous pouvez leur associer des écouteurs d'événements.

Ajouter des écouteurs d'événements clés

Commencez par ajouter des gestionnaires d'événements de clic aux boutons clés, comme indiqué ci-dessous:

script.js

ENABLE_CAM_BUTTON.addEventListener('click', enableCam);
TRAIN_BUTTON.addEventListener('click', trainAndPredict);
RESET_BUTTON.addEventListener('click', reset);


function enableCam() {
  // TODO: Fill this out later in the codelab!
}


function trainAndPredict() {
  // TODO: Fill this out later in the codelab!
}


function reset() {
  // TODO: Fill this out later in the codelab!
}

ENABLE_CAM_BUTTON : appelle la fonction enableCam lorsque l'utilisateur clique dessus.

TRAIN_BUTTON : appelle "trainAndPredict" lorsque l'utilisateur clique dessus.

RESET_BUTTON : les appels sont réinitialisés lorsque l'utilisateur clique dessus.

Enfin, dans cette section, vous trouverez tous les boutons ayant une classe "dataCollector" avec document.querySelectorAll(). Cela renvoie un tableau d'éléments trouvés à partir du document qui correspondent à:

script.js

let dataCollectorButtons = document.querySelectorAll('button.dataCollector');
for (let i = 0; i < dataCollectorButtons.length; i++) {
  dataCollectorButtons[i].addEventListener('mousedown', gatherDataForClass);
  dataCollectorButtons[i].addEventListener('mouseup', gatherDataForClass);
  // Populate the human readable names for classes.
  CLASS_NAMES.push(dataCollectorButtons[i].getAttribute('data-name'));
}


function gatherDataForClass() {
  // TODO: Fill this out later in the codelab!
}

Explication du code:

Vous itérez ensuite les boutons trouvés et vous associez deux écouteurs d'événements à chacun d'eux. une pour "mousedown" et une pour "mouseup". Cela vous permet de continuer à enregistrer des échantillons tant que vous appuyez sur le bouton, ce qui est utile pour la collecte de données.

Les deux événements appellent une fonction gatherDataForClass que vous définirez ultérieurement.

À ce stade, vous pouvez également transférer les noms de classe lisibles par l'humain de l'attribut de bouton HTML "data-name" vers le tableau CLASS_NAMES.

Ajoutez ensuite des variables pour stocker les éléments clés qui seront utilisés ultérieurement.

script.js

let mobilenet = undefined;
let gatherDataState = STOP_DATA_GATHER;
let videoPlaying = false;
let trainingDataInputs = [];
let trainingDataOutputs = [];
let examplesCount = [];
let predict = false;

Examinons-les de plus près.

Tout d'abord, vous avez une variable mobilenet pour stocker le modèle mobilenet chargé. Initialement défini sur "non défini".

Vous avez ensuite une variable appelée gatherDataState. Si une valeur "dataCollector" lorsque l'utilisateur appuie sur le bouton, il est remplacé par l'identifiant unique de ce bouton, tel que défini dans le code HTML. Vous savez ainsi quelle classe de données vous collectez à ce moment-là. Au départ, il est défini sur STOP_DATA_GATHER. Ainsi, la boucle de collecte de données que vous écrivez par la suite ne collectera aucune donnée si vous n'appuyez sur aucun bouton.

videoPlaying détermine si le flux de la webcam est chargé, lu et disponible. Au départ, il est défini sur false, car la webcam n'est pas allumée tant que vous n'avez pas appuyé sur le bouton ENABLE_CAM_BUTTON..

Ensuite, définissez deux tableaux : trainingDataInputs et trainingDataOutputs. Celles-ci stockent les valeurs des données d'entraînement collectées pour les caractéristiques d'entrée générées par le modèle de base MobileNet et la classe de sortie échantillonnée respectivement.

Un dernier tableau, examplesCount,, est ensuite défini pour suivre le nombre d'exemples contenus dans chaque classe une fois que vous commencez à les ajouter.

Enfin, vous disposez d'une variable appelée predict qui contrôle votre boucle de prédiction. Dans un premier temps, il est défini sur false. Aucune prédiction ne peut avoir lieu tant que ce paramètre n'est pas défini sur true.

Maintenant que toutes les variables clés ont été définies, chargeons le modèle de base MobileNet v3 pré-haché qui fournit des vecteurs de caractéristiques d'image au lieu de classifications.

9. Charger le modèle de base MobileNet

Tout d'abord, définissez une nouvelle fonction appelée loadMobileNetFeatureModel, comme indiqué ci-dessous. Il doit s'agir d'une fonction asynchrone, car le chargement d'un modèle est asynchrone:

script.js

/**
 * Loads the MobileNet model and warms it up so ready for use.
 **/
async function loadMobileNetFeatureModel() {
  const URL = 
    'https://tfhub.dev/google/tfjs-model/imagenet/mobilenet_v3_small_100_224/feature_vector/5/default/1';
  
  mobilenet = await tf.loadGraphModel(URL, {fromTFHub: true});
  STATUS.innerText = 'MobileNet v3 loaded successfully!';
  
  // Warm up the model by passing zeros through it once.
  tf.tidy(function () {
    let answer = mobilenet.predict(tf.zeros([1, MOBILE_NET_INPUT_HEIGHT, MOBILE_NET_INPUT_WIDTH, 3]));
    console.log(answer.shape);
  });
}

// Call the function immediately to start loading.
loadMobileNetFeatureModel();

Dans ce code, vous définissez l'emplacement URL où se trouve le modèle à charger à partir de la documentation de TFHub.

Vous pouvez ensuite charger le modèle à l'aide de await tf.loadGraphModel(), en n'oubliez pas de définir la propriété spéciale fromTFHub sur true lorsque vous chargez un modèle à partir de ce site Web Google. Il s'agit d'un cas particulier uniquement pour l'utilisation de modèles hébergés sur TensorFlow Hub, où cette propriété supplémentaire doit être définie.

Une fois le chargement terminé, vous pouvez définir le innerText de l'élément STATUS avec un message afin que vous puissiez voir qu'il s'est chargé correctement et que vous êtes prêt à collecter des données.

Il ne vous reste plus qu'à échauffer le modèle. Avec des modèles plus volumineux comme celui-ci, la configuration initiale peut prendre un moment. Il est donc utile de transmettre des zéros dans le modèle pour éviter d'attendre à l'avenir où la temporalité risque d'être plus importante.

Vous pouvez utiliser tf.zeros() encapsulé dans une tf.tidy() pour vous assurer que les Tensors sont supprimés correctement, avec une taille de lot de 1, ainsi que la hauteur et la largeur correctes que vous avez définies au début dans vos constantes. Enfin, vous spécifiez également les canaux de couleur, qui dans ce cas correspond à 3, car le modèle attend des images RVB.

Consignez ensuite la forme obtenue du Tensor renvoyé à l'aide de answer.shape() pour vous aider à comprendre la taille des caractéristiques d'image produites par ce modèle.

Après avoir défini cette fonction, vous pouvez l'appeler immédiatement pour lancer le téléchargement du modèle lors du chargement de la page.

Si vous consultez votre aperçu en direct, vous verrez au bout de quelques instants que le texte d'état "En attente du chargement de TF.js" est remplacé par "En attente du chargement de TF.js". devient le message "MobileNet v3loading successfully!". comme indiqué ci-dessous. Assurez-vous que cela fonctionne avant de continuer.

a28b734e190afff.png

Vous pouvez également consulter la sortie de la console pour voir la taille imprimée des caractéristiques de sortie produites par ce modèle. Après avoir exécuté des zéros via le modèle MobileNet, une forme de [1, 1024] s'affiche. Le premier élément correspond simplement à la taille de lot de 1, et vous pouvez voir qu'il renvoie en fait 1 024 caractéristiques qui peuvent ensuite être utilisées pour vous aider à classer de nouveaux objets.

10. Définir la nouvelle tête du modèle

Il est temps de définir la tête du modèle, qui est essentiellement un perceptron multicouche minimal.

script.js

let model = tf.sequential();
model.add(tf.layers.dense({inputShape: [1024], units: 128, activation: 'relu'}));
model.add(tf.layers.dense({units: CLASS_NAMES.length, activation: 'softmax'}));

model.summary();

// Compile the model with the defined optimizer and specify a loss function to use.
model.compile({
  // Adam changes the learning rate over time which is useful.
  optimizer: 'adam',
  // Use the correct loss function. If 2 classes of data, must use binaryCrossentropy.
  // Else categoricalCrossentropy is used if more than 2 classes.
  loss: (CLASS_NAMES.length === 2) ? 'binaryCrossentropy': 'categoricalCrossentropy', 
  // As this is a classification problem you can record accuracy in the logs too!
  metrics: ['accuracy']  
});

Examinons ce code. Vous commencerez par définir un modèle tf.Sequence auquel vous ajouterez des couches de modèle.

Ensuite, ajoutez une couche dense en tant que couche d'entrée à ce modèle. Sa forme d'entrée est 1024, car les sorties des fonctionnalités de MobileNet v3 sont de cette taille. Vous l'avez découvert à l'étape précédente, après en avoir transmis des via le modèle. Cette couche contient 128 neurones qui utilisent la fonction d'activation ReLU.

Si vous ne connaissez pas encore les fonctions d'activation et les couches de modèle, suivez le cours détaillé au début de cet atelier pour découvrir le fonctionnement de ces propriétés.

La couche suivante à ajouter est la couche de sortie. Le nombre de neurones doit être égal au nombre de classes que vous essayez de prédire. Pour ce faire, vous pouvez utiliser CLASS_NAMES.length afin de connaître le nombre de classes que vous prévoyez de classer, soit le nombre de boutons de collecte de données disponibles dans l'interface utilisateur. Comme il s'agit d'un problème de classification, vous devez utiliser l'activation softmax sur cette couche de sortie, qui doit être utilisée lorsque vous essayez de créer un modèle pour résoudre les problèmes de classification plutôt que de régression.

Imprimez maintenant un model.summary() pour afficher la présentation du modèle nouvellement défini dans la console.

Enfin, compilez le modèle pour le préparer à l'entraînement. Ici, l'optimiseur est défini sur adam, et la perte sera soit binaryCrossentropy si CLASS_NAMES.length est égal à 2, soit categoricalCrossentropy s'il existe au moins trois classes à classer. Les métriques de précision sont également demandées afin de pouvoir être surveillées ultérieurement dans les journaux à des fins de débogage.

Dans la console, un message semblable à celui-ci doit s'afficher:

22eaf32286fea4bb.png

Notez qu'elle contient plus de 130 000 paramètres pouvant être entraînés. Mais comme il s'agit d'une simple couche dense de neurones réguliers, l'entraînement est assez rapide.

Une fois le projet terminé, vous pouvez essayer de modifier le nombre de neurones dans la première couche pour voir si vous pouvez le réduire tout en obtenant des performances convenables. Le machine learning implique souvent un certain degré d'essais et d'erreurs pour trouver les valeurs de paramètres optimales et ainsi obtenir le meilleur compromis entre utilisation des ressources et vitesse.

11. Activer la webcam

Il est maintenant temps d'étoffer la fonction enableCam() que vous avez définie précédemment. Ajoutez une fonction nommée hasGetUserMedia() comme indiqué ci-dessous, puis remplacez le contenu de la fonction enableCam() précédemment définie par le code correspondant ci-dessous.

script.js

function hasGetUserMedia() {
  return !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
}

function enableCam() {
  if (hasGetUserMedia()) {
    // getUsermedia parameters.
    const constraints = {
      video: true,
      width: 640, 
      height: 480 
    };

    // Activate the webcam stream.
    navigator.mediaDevices.getUserMedia(constraints).then(function(stream) {
      VIDEO.srcObject = stream;
      VIDEO.addEventListener('loadeddata', function() {
        videoPlaying = true;
        ENABLE_CAM_BUTTON.classList.add('removed');
      });
    });
  } else {
    console.warn('getUserMedia() is not supported by your browser');
  }
}

Commencez par créer une fonction nommée hasGetUserMedia() pour vérifier si le navigateur est compatible avec getUserMedia() en vérifiant l'existence des propriétés clés des API du navigateur.

Dans la fonction enableCam(), utilisez la fonction hasGetUserMedia() que vous venez de définir ci-dessus pour vérifier si elle est compatible. Si ce n'est pas le cas, affichez un avertissement dans la console.

Si c'est le cas, définissez certaines contraintes pour votre appel getUserMedia(). Par exemple, vous souhaitez que seul le flux vidéo soit diffusé, et que la taille du width de la vidéo soit de 640 pixels et que la taille de la propriété height soit de 480 pixels. Pourquoi ? Il n'est pas utile d'obtenir une vidéo plus grande, car elle devrait être redimensionnée au format 224 x 224 pixels pour être transmise au modèle MobileNet. Vous pouvez également économiser des ressources de calcul en demandant une résolution plus faible. La plupart des caméras acceptent une résolution de cette taille.

Appelez ensuite navigator.mediaDevices.getUserMedia() avec le constraints détaillé ci-dessus, puis attendez que stream soit renvoyé. Une fois que stream est renvoyé, vous pouvez faire en sorte que votre élément VIDEO lise stream en le définissant comme sa valeur srcObject.

Vous devez également ajouter un écouteur d'événements sur l'élément VIDEO pour savoir quand stream est chargé et lu correctement.

Une fois la diffusion en direct chargée, vous pouvez définir videoPlaying sur "true" et supprimer l'élément ENABLE_CAM_BUTTON pour éviter qu'il ne reçoive un autre clic en définissant sa classe sur "removed".

Exécutez votre code, cliquez sur le bouton "Activer l'appareil photo" et autorisez l'accès à la webcam. Si vous effectuez cette opération pour la première fois, vous devriez voir le rendu de l'élément vidéo sur la page, comme illustré ci-dessous:

b378eb1affa9b883.png

Vous allez maintenant ajouter une fonction pour gérer les clics sur le bouton dataCollector.

12. Gestionnaire d'événements du bouton "Collecte des données"

Il est maintenant temps de remplir votre fonction vide appelée gatherDataForClass().. Il s'agit de celle que vous avez définie comme fonction de gestionnaire d'événements pour les boutons dataCollector au début de l'atelier de programmation.

script.js

/**
 * Handle Data Gather for button mouseup/mousedown.
 **/
function gatherDataForClass() {
  let classNumber = parseInt(this.getAttribute('data-1hot'));
  gatherDataState = (gatherDataState === STOP_DATA_GATHER) ? classNumber : STOP_DATA_GATHER;
  dataGatherLoop();
}

Tout d'abord, vérifiez l'attribut data-1hot sur le bouton actuellement cliqué en appelant this.getAttribute() avec le nom de l'attribut, dans ce cas data-1hot comme paramètre. Comme il s'agit d'une chaîne, vous pouvez ensuite utiliser parseInt() pour la caster en entier et attribuer ce résultat à une variable nommée classNumber..

Définissez ensuite la variable gatherDataState en conséquence. Si la gatherDataState actuelle est égale à STOP_DATA_GATHER (que vous avez définie sur -1), cela signifie que vous ne collectez actuellement aucune donnée et qu'un événement mousedown s'est déclenché. Définissez gatherDataState comme classNumber que vous venez de trouver.

Sinon, cela signifie que vous collectez actuellement des données et que l'événement déclenché est un événement mouseup, et que vous souhaitez maintenant arrêter de collecter des données pour cette classe. Il vous suffit de le redéfinir sur l'état STOP_DATA_GATHER pour mettre fin à la boucle de collecte de données que vous allez définir dans un instant.

Enfin, lancez l'appel à dataGatherLoop(),, qui enregistre les données de la classe.

13. Collecte des données

Définissez maintenant la fonction dataGatherLoop(). Cette fonction permet d'échantillonner les images de la vidéo de la webcam, de les transmettre via le modèle MobileNet et de capturer les sorties de ce modèle (les 1 024 vecteurs de caractéristiques).

Elle les stocke ensuite avec l'ID gatherDataState du bouton sur lequel vous appuyez pour que vous sachiez à quelle classe ces données représentent ces données.

Voyons cela de plus près :

script.js

function dataGatherLoop() {
  if (videoPlaying && gatherDataState !== STOP_DATA_GATHER) {
    let imageFeatures = tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [MOBILE_NET_INPUT_HEIGHT, 
          MOBILE_NET_INPUT_WIDTH], true);
      let normalizedTensorFrame = resizedTensorFrame.div(255);
      return mobilenet.predict(normalizedTensorFrame.expandDims()).squeeze();
    });

    trainingDataInputs.push(imageFeatures);
    trainingDataOutputs.push(gatherDataState);
    
    // Intialize array index element if currently undefined.
    if (examplesCount[gatherDataState] === undefined) {
      examplesCount[gatherDataState] = 0;
    }
    examplesCount[gatherDataState]++;

    STATUS.innerText = '';
    for (let n = 0; n < CLASS_NAMES.length; n++) {
      STATUS.innerText += CLASS_NAMES[n] + ' data count: ' + examplesCount[n] + '. ';
    }
    window.requestAnimationFrame(dataGatherLoop);
  }
}

Vous ne poursuivrez l'exécution de cette fonction que si videoPlaying a la valeur "true", ce qui signifie que la webcam est active, que gatherDataState n'est pas égal à STOP_DATA_GATHER et que vous appuyez actuellement sur un bouton de collecte des données de classe.

Ensuite, encapsulez votre code dans un tf.tidy() pour supprimer tous les Tensors créés dans le code qui suit. Le résultat de cette exécution de code tf.tidy() est stocké dans une variable appelée imageFeatures.

Vous pouvez maintenant saisir une image de la webcam VIDEO à l'aide de tf.browser.fromPixels(). Le Tensor obtenu contenant les données d'image est stocké dans une variable appelée videoFrameAsTensor.

Redimensionnez ensuite la variable videoFrameAsTensor pour qu'elle corresponde à la forme d'entrée du modèle MobileNet. Utilisez un appel tf.image.resizeBilinear() avec le Tensor que vous souhaitez remodeler comme premier paramètre, puis une forme qui définit la nouvelle hauteur et la nouvelle largeur telles que définies par les constantes que vous avez déjà créées précédemment. Enfin, définissez l'option d'alignement des angles sur "true" en transmettant le troisième paramètre pour éviter tout problème d'alignement lors du redimensionnement. Le résultat de ce redimensionnement est stocké dans une variable appelée resizedTensorFrame.

Notez que ce redimensionnement primitif étire l'image, car la taille de l'image de votre webcam est de 640 x 480 pixels et que le modèle a besoin d'une image carrée de 224 x 224 pixels.

Pour les besoins de cette démonstration, cela devrait fonctionner correctement. Cependant, une fois cet atelier de programmation terminé, vous pouvez essayer de recadrer un carré à partir de cette image pour obtenir des résultats encore meilleurs pour tout système de production que vous créerez par la suite.

Ensuite, normalisez les données de l'image. Les données d'image sont toujours comprises entre 0 et 255 lorsque vous utilisez tf.browser.frompixels(). Il vous suffit donc de diviser le TensorFrame redimensionné par 255 pour vous assurer que toutes les valeurs sont comprises entre 0 et 1, ce qui correspond aux entrées attendues par le modèle MobileNet.

Enfin, dans la section tf.tidy() du code, transmettez ce Tensor normalisé dans le modèle chargé en appelant mobilenet.predict(), auquel vous transmettez la version développée de normalizedTensorFrame à l'aide de expandDims() pour qu'il s'agisse d'un lot de 1, car le modèle attend un lot d'entrées à traiter.

Une fois le résultat renvoyé, vous pouvez immédiatement appeler squeeze() sur ce résultat renvoyé pour l'écraser sur un Tensor unidimensionnel, que vous renvoyer ensuite et attribuer à la variable imageFeatures qui capture le résultat à partir de tf.tidy().

Maintenant que vous disposez des imageFeatures du modèle MobileNet, vous pouvez les enregistrer en les transférant dans le tableau trainingDataInputs que vous avez défini précédemment.

Vous pouvez également enregistrer ce que cette entrée représente en envoyant la valeur gatherDataState actuelle au tableau trainingDataOutputs.

Notez que la variable gatherDataState aurait été définie sur l'ID numérique de la classe actuelle pour laquelle vous enregistrez les données lorsque l'utilisateur a cliqué sur le bouton dans la fonction gatherDataForClass() définie précédemment.

À ce stade, vous pouvez également augmenter le nombre d'exemples disponibles pour une classe donnée. Pour ce faire, vérifiez d'abord si l'index du tableau examplesCount a déjà été initialisé ou non. S'il n'est pas défini, définissez-le sur 0 pour initialiser le compteur pour l'ID numérique d'une classe donnée. Vous pouvez ensuite incrémenter le examplesCount pour le gatherDataState actuel.

À présent, mettez à jour le texte de l'élément STATUS sur la page Web afin d'afficher les décomptes actuels de chaque classe au fur et à mesure de la capture. Pour ce faire, parcourez le tableau CLASS_NAMES et imprimez le nom lisible combiné au nombre de données au même index dans examplesCount.

Enfin, appelez window.requestAnimationFrame() avec dataGatherLoop transmis en tant que paramètre pour appeler à nouveau cette fonction de manière récursive. L'échantillonnage des images de la vidéo se poursuivra jusqu'à ce que l'élément mouseup du bouton soit détecté et que gatherDataState soit défini sur STOP_DATA_GATHER,. La boucle de collecte des données se terminera alors.

Si vous exécutez votre code maintenant, vous devriez pouvoir cliquer sur le bouton d'activation de l'appareil photo, attendre que la webcam se charge, puis cliquer de manière prolongée sur chacun des boutons de collecte de données afin de recueillir des exemples pour chaque classe de données. Ici, vous me voyez recueillir des données pour mon téléphone mobile et ma main, respectivement.

541051644a45131f.gif

Le texte d'état devrait être mis à jour, car il stocke tous les Tensors en mémoire, comme illustré dans la capture d'écran ci-dessus.

14. Entraîner et prédire

L'étape suivante consiste à implémenter le code de votre fonction trainAndPredict() actuellement vide, qui est l'endroit où a lieu l'apprentissage par transfert. Examinons le code:

script.js

async function trainAndPredict() {
  predict = false;
  tf.util.shuffleCombo(trainingDataInputs, trainingDataOutputs);
  let outputsAsTensor = tf.tensor1d(trainingDataOutputs, 'int32');
  let oneHotOutputs = tf.oneHot(outputsAsTensor, CLASS_NAMES.length);
  let inputsAsTensor = tf.stack(trainingDataInputs);
  
  let results = await model.fit(inputsAsTensor, oneHotOutputs, {shuffle: true, batchSize: 5, epochs: 10, 
      callbacks: {onEpochEnd: logProgress} });
  
  outputsAsTensor.dispose();
  oneHotOutputs.dispose();
  inputsAsTensor.dispose();
  predict = true;
  predictLoop();
}

function logProgress(epoch, logs) {
  console.log('Data for epoch ' + epoch, logs);
}

Tout d'abord, assurez-vous d'arrêter la mise en œuvre des prédictions actuelles en définissant predict sur false.

Ensuite, brassez vos tableaux d'entrée et de sortie à l'aide de tf.util.shuffleCombo() pour vous assurer que l'ordre ne cause pas de problèmes d'entraînement.

Convertissez votre tableau de sortie, trainingDataOutputs,, en un Tensor1d de type int32, afin de pouvoir l'utiliser dans un encodage one-hot. Il est stocké dans une variable nommée outputsAsTensor.

Utilisez la fonction tf.oneHot() avec cette variable outputsAsTensor, ainsi que le nombre maximal de classes à encoder, qui correspond simplement à CLASS_NAMES.length. Vos sorties à encodage one-hot sont désormais stockées dans un nouveau Tensor appelé oneHotOutputs.

Notez que trainingDataInputs est actuellement un tableau de Tensors enregistrés. Si vous souhaitez les utiliser pour l'entraînement, vous devrez convertir le tableau de Tensors en un Tensor bidimensionnel standard.

Pour ce faire, la bibliothèque TensorFlow.js intègre une fonction performante appelée tf.stack().

qui prend un tableau de Tensors et les empile ensemble pour produire un Tensor de dimension plus élevée en tant que sortie. Dans ce cas, un Tensor 2D est renvoyé. Il s'agit d'un lot d'entrées unidimensionnelles, chacune d'une longueur de 1 024, contenant les caractéristiques enregistrées, ce dont vous avez besoin pour l'entraînement.

Ensuite, await model.fit() pour entraîner la tête du modèle personnalisé. Ici, vous transmettez votre variable inputsAsTensor avec oneHotOutputs pour représenter les données d'entraînement à utiliser respectivement pour les exemples d'entrées et de sorties cibles. Dans l'objet de configuration du 3e paramètre, définissez shuffle sur true, utilisez batchSize de 5, avec epochs défini sur 10, puis spécifiez callback pour onEpochEnd dans la fonction logProgress que vous définirez prochainement.

Enfin, vous pouvez supprimer les Tensors créés, car le modèle est maintenant entraîné. Vous pouvez ensuite redéfinir predict sur true pour permettre à nouveau les prédictions, puis appeler la fonction predictLoop() pour commencer à prédire des images de webcam en direct.

Vous pouvez également définir la fonction logProcess() pour consigner l'état de l'entraînement, qui est utilisé dans model.fit() ci-dessus et qui imprime les résultats dans la console après chaque série d'entraînement.

Vous y êtes presque ! Il est temps d'ajouter la fonction predictLoop() pour effectuer des prédictions.

Boucle de prédiction centrale

Ici, vous allez implémenter la boucle de prédiction principale qui échantillonne les images d'une webcam et prédit en continu le contenu de chaque image, avec des résultats en temps réel dans le navigateur.

Vérifions le code:

script.js

function predictLoop() {
  if (predict) {
    tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO).div(255);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor,[MOBILE_NET_INPUT_HEIGHT, 
          MOBILE_NET_INPUT_WIDTH], true);

      let imageFeatures = mobilenet.predict(resizedTensorFrame.expandDims());
      let prediction = model.predict(imageFeatures).squeeze();
      let highestIndex = prediction.argMax().arraySync();
      let predictionArray = prediction.arraySync();

      STATUS.innerText = 'Prediction: ' + CLASS_NAMES[highestIndex] + ' with ' + Math.floor(predictionArray[highestIndex] * 100) + '% confidence';
    });

    window.requestAnimationFrame(predictLoop);
  }
}

Tout d'abord, vérifiez que predict a la valeur "true", afin que les prédictions ne soient effectuées qu'une fois le modèle entraîné et disponible à l'utilisation.

Vous pouvez ensuite obtenir les caractéristiques de l'image actuelle, comme vous l'avez fait dans la fonction dataGatherLoop(). Pour faire simple, vous allez récupérer un cadre à la webcam à l'aide de tf.browser.from pixels(), le normaliser, le redimensionner pour obtenir une taille de 224 x 224 pixels, puis transmettre ces données via le modèle MobileNet pour obtenir les caractéristiques d'image ainsi obtenues.

Toutefois, vous pouvez désormais utiliser la tête du modèle que vous venez d'entraîner pour effectuer une prédiction en transmettant le résultat imageFeatures que vous venez de trouver via la fonction predict() du modèle entraîné. Vous pouvez ensuite presser le Tensor obtenu pour le rendre à nouveau à une dimension et l'affecter à une variable appelée prediction.

Avec ce prediction, vous pouvez trouver l'index ayant la valeur la plus élevée à l'aide de argMax(), puis convertir ce Tensor obtenu en un tableau à l'aide de arraySync() pour obtenir les données sous-jacentes en JavaScript afin de découvrir la position de l'élément ayant la valeur la plus élevée. Cette valeur est stockée dans la variable appelée highestIndex.

Vous pouvez également obtenir les scores de confiance de prédiction réels de la même manière en appelant arraySync() directement sur le Tensor prediction.

Vous disposez maintenant de tout ce dont vous avez besoin pour mettre à jour le texte STATUS avec les données prediction. Pour obtenir la chaîne lisible par l'humain pour la classe, vous pouvez simplement rechercher highestIndex dans le tableau CLASS_NAMES, puis récupérer la valeur de confiance à partir de predictionArray. Pour le rendre plus lisible en tant que pourcentage, il suffit de multiplier le résultat par 100, puis d'obtenir le résultat math.floor().

Enfin, vous pouvez utiliser window.requestAnimationFrame() pour appeler à nouveau predictionLoop() une fois que tout est prêt, afin d'obtenir la classification en temps réel de votre flux vidéo. Ce processus se poursuit jusqu'à ce que predict soit défini sur false si vous choisissez d'entraîner un nouveau modèle avec de nouvelles données.

Ce qui vous amène à la dernière pièce du puzzle. Implémenter le bouton de réinitialisation

15. Implémenter le bouton de réinitialisation

C'est presque terminé ! La dernière pièce du puzzle consiste à implémenter un bouton de réinitialisation pour recommencer. Vous trouverez ci-dessous le code de la fonction reset() actuellement vide. Modifiez-le comme suit:

script.js

/**
 * Purge data and start over. Note this does not dispose of the loaded 
 * MobileNet model and MLP head tensors as you will need to reuse 
 * them to train a new model.
 **/
function reset() {
  predict = false;
  examplesCount.length = 0;
  for (let i = 0; i < trainingDataInputs.length; i++) {
    trainingDataInputs[i].dispose();
  }
  trainingDataInputs.length = 0;
  trainingDataOutputs.length = 0;
  STATUS.innerText = 'No data collected';
  
  console.log('Tensors in memory: ' + tf.memory().numTensors);
}

Commencez par arrêter les boucles de prédiction en cours d'exécution en définissant predict sur false. Ensuite, supprimez tout le contenu du tableau examplesCount en définissant sa longueur sur 0, ce qui est un moyen pratique d'effacer tout le contenu d'un tableau.

Examinez maintenant tous les trainingDataInputs actuellement enregistrés et assurez-vous d'utiliser dispose() pour chaque Tensor qu'il contient afin de libérer à nouveau de la mémoire, car les Tensors ne sont pas nettoyés par le récupérateur de mémoire JavaScript.

Une fois cela fait, vous pouvez définir la longueur des tableaux trainingDataInputs et trainingDataOutputs en toute sécurité pour les effacer également.

Enfin, définissez le texte STATUS sur un élément sensible, puis imprimez les Tensors laissés en mémoire afin de vérifier l'intégrité.

Notez qu'il reste quelques centaines de Tensors encore en mémoire, car le modèle MobileNet et le perceptron multicouche que vous avez défini ne sont pas supprimés. Vous devrez les réutiliser avec de nouvelles données d'entraînement si vous décidez d'effectuer un nouvel entraînement après cette réinitialisation.

16. Essayons.

Il est temps de tester votre propre version de Teachable Machine !

Accédez à l'aperçu en direct, activez la webcam, rassemblez au moins 30 échantillons pour la classe 1 d'un objet de votre salle, puis faites de même pour la classe 2 avec un autre objet, cliquez sur "Entraînement" et consultez le journal de la console pour voir la progression. L'entraînement devrait être assez rapide:

bf1ac3cc5b15740.gif

Une fois l'entraînement effectué, montrez les objets à l'appareil photo pour obtenir des prédictions en direct qui seront imprimées dans la zone de texte d'état en haut de la page Web. Si vous rencontrez des difficultés, vérifiez mon code opérationnel terminé pour vérifier si vous n'avez pas effectué de copie.

17. Félicitations

Félicitations ! Vous venez de terminer votre tout premier exemple d'apprentissage par transfert avec TensorFlow.js en direct dans le navigateur.

Essayez, testez-le sur différents objets. Vous remarquerez peut-être que certaines choses sont plus difficiles à reconnaître que d'autres, surtout si elles sont semblables à d'autres. Vous devrez peut-être ajouter d'autres classes ou données d'entraînement pour pouvoir les distinguer.

Résumé

Dans cet atelier de programmation, vous avez appris à:

  1. En quoi consiste l'apprentissage par transfert et ses avantages par rapport à l'entraînement d'un modèle complet.
  2. Obtenir des modèles à réutiliser à partir de TensorFlow Hub
  3. Configurer une application Web adaptée à l'apprentissage par transfert
  4. Charger et utiliser un modèle de base pour générer des caractéristiques d'image
  5. Comment entraîner une nouvelle tête de prédiction capable de reconnaître des objets personnalisés à partir d'images de webcam
  6. Comment utiliser les modèles qui en résultent pour classer des données en temps réel

Et ensuite ?

Maintenant que vous disposez d'une base de travail, quelles idées créatives pouvez-vous proposer pour étendre le code récurrent de ce modèle de machine learning à un cas d'utilisation réel sur lequel vous êtes susceptible de travailler ? Peut-être pourriez-vous révolutionner le secteur dans lequel vous travaillez actuellement pour aider les membres de votre entreprise à entraîner des modèles afin de classer les choses qui sont importantes dans leur travail quotidien ? Les possibilités sont innombrables.

Pour aller plus loin, vous pouvez suivre ce cours complet sans frais, qui vous montre comment combiner les deux modèles que vous utilisez actuellement dans cet atelier de programmation en un seul modèle pour plus d'efficacité.

Si vous souhaitez en savoir plus sur la théorie sous-jacente de l'application Teachable Machine d'origine, consultez ce tutoriel.

Partagez vos créations

Vous pouvez facilement étendre ce que vous avez créé aujourd'hui à d'autres cas d'utilisation créatifs. Nous vous encourageons à sortir des sentiers battus et à continuer à pirater.

N'oubliez pas de nous taguer sur les réseaux sociaux avec le hashtag #MadeWithTFJS. Votre projet sera peut-être présenté sur notre blog TensorFlow, voire dans des événements à venir. Nous serions ravis de découvrir vos créations.

Sites Web à consulter