1. Antes de comenzar
El uso del modelo de TensorFlow.js creció exponencialmente en los últimos años y muchos desarrolladores de JavaScript ahora buscan tomar modelos de vanguardia existentes y volver a entrenarlos para que trabajen con datos personalizados que son exclusivos de su industria. El aprendizaje por transferencia es el acto de tomar un modelo existente (a menudo denominado modelo base) y usarlo en un dominio similar, pero diferente.
El aprendizaje por transferencia tiene muchas ventajas frente a comenzar con un modelo completamente en blanco. Puedes reutilizar el conocimiento ya aprendido de un modelo entrenado anterior y necesitas menos ejemplos del nuevo elemento que deseas clasificar. Además, el entrenamiento suele ser mucho más rápido debido a que solo hay que volver a entrenar las últimas capas de la arquitectura del modelo en lugar de toda la red. Por este motivo, el aprendizaje por transferencia es muy adecuado para el entorno de navegador web, en el que los recursos pueden variar según el dispositivo de ejecución, pero también tiene acceso directo a los sensores para facilitar la adquisición de datos.
En este codelab, aprenderás a compilar una app web a partir de un lienzo en blanco recreando las populares " Teachable Machine” sitio web. El sitio web te permite crear una app web funcional que cualquier usuario puede usar para reconocer un objeto personalizado con solo algunas imágenes de ejemplo de su cámara web. El propósito del sitio web es mínimo para que puedas enfocarte en los aspectos relacionados con el aprendizaje automático de este codelab. Sin embargo, al igual que con el sitio web original de Teachable Machine, hay mucho alcance para aplicar tu experiencia de desarrollador web existente para mejorar la UX.
Requisitos previos
Este codelab está dirigido a desarrolladores web que están familiarizados con los modelos prediseñados de TensorFlow.js y con el uso básico de las APIs, y que desean comenzar a usar el aprendizaje por transferencia en TensorFlow.js.
- Para este lab, se presupone que tienes conocimientos básicos de TensorFlow.js, HTML5, CSS y JavaScript.
Si es la primera vez que usas TensorFlow.js, considera realizar este curso gratuito de cero a hero, que presupone que no tienes experiencia en aprendizaje automático ni TensorFlow.js, y te enseña todo lo que necesitas saber en pasos más pequeños.
Qué aprenderás
- Qué es TensorFlow.js y por qué deberías usarlo en tu próxima app web
- Cómo crear una página web simplificada de HTML/CSS /JS que replique la experiencia del usuario de Teachable Machine
- Cómo usar TensorFlow.js para cargar un modelo base previamente entrenado, específicamente MobileNet, para generar atributos de imágenes que se pueden usar en el aprendizaje por transferencia
- Cómo recopilar datos de la cámara web de un usuario para varias clases de datos que deseas reconocer.
- Cómo crear y definir un perceptrón multicapa que toma los atributos de la imagen y aprende a clasificar objetos nuevos con ellos
Vamos a hackear...
Requisitos
- Es preferible que utilices una cuenta de Glitch.com, o bien puedes usar un entorno de entrega web que te resulte cómodo editar y ejecutar por tu cuenta.
2. ¿Qué es TensorFlow.js?
TensorFlow.js es una biblioteca de aprendizaje automático de código abierto que se puede ejecutar en cualquier lugar que pueda JavaScript. Se basa en la biblioteca original de TensorFlow escrita en Python y tiene como objetivo recrear esta experiencia de desarrollador y el conjunto de APIs para el ecosistema de JavaScript.
¿Dónde se puede utilizar?
Dada la portabilidad de JavaScript, ahora puedes escribir en 1 lenguaje y realizar el aprendizaje automático en todas las siguientes plataformas con facilidad:
- Del lado del cliente en el navegador web con JavaScript convencional
- En el servidor y también en dispositivos de IoT como Raspberry Pi con Node.js
- Apps de escritorio que usan Electron
- Apps nativas para dispositivos móviles con React Native
TensorFlow.js también admite múltiples backends dentro de cada uno de estos entornos (los entornos reales basados en hardware que puede ejecutar dentro de él, como la CPU o WebGL). Un "backend" en este contexto no significa un entorno del servidor: el backend para la ejecución podría ser del cliente en WebGL, por ejemplo, para garantizar la compatibilidad y también mantener el funcionamiento rápido. En la actualidad, TensorFlow.js admite lo siguiente:
- Ejecución de WebGL en la tarjeta gráfica del dispositivo (GPU): Es la forma más rápida de ejecutar modelos más grandes (más de 3 MB de tamaño) con aceleración de GPU.
- Ejecución de Web Assembly (WASM) en la CPU: Para mejorar el rendimiento de la CPU en todos los dispositivos, incluidos, por ejemplo, los teléfonos celulares de generaciones anteriores. Esto es más adecuado para modelos más pequeños (menos de 3 MB de tamaño) que pueden ejecutarse más rápido en la CPU con WASM que con WebGL debido a la sobrecarga que implica subir contenido a un procesador de gráficos.
- Ejecución de CPU: El resguardo no debe estar disponible ninguno de los otros entornos. Esta es la más lenta de las tres, pero siempre está ahí para ti.
Nota: Puedes optar por forzar uno de estos backends si sabes en qué dispositivo lo ejecutarás o, si no lo especificas, puedes dejar que TensorFlow.js decida por ti.
Superpoderes del cliente
La ejecución de TensorFlow.js en el navegador web de la máquina cliente puede generar varios beneficios que vale la pena considerar.
Privacidad
Puedes entrenar y clasificar datos en la máquina cliente sin tener que enviarlos a un servidor web de terceros. En algunas ocasiones, esto puede ser un requisito para cumplir con las leyes locales, como el GDPR, o cuando se procesan datos que el usuario quiere conservar en su máquina y no se envían a un tercero.
Velocidad
Como no tienes que enviar datos a un servidor remoto, la inferencia (el acto de clasificar los datos) puede ser más rápida. Aún mejor, tienes acceso directo a los sensores del dispositivo, como la cámara, el micrófono, el GPS, el acelerómetro y otros, si el usuario te otorga el acceso.
Alcance y escala
Con un solo clic, cualquier persona en el mundo puede hacer clic en el vínculo que le envías, abrir la página web en su navegador y usar lo que has hecho. No se necesita una compleja configuración de Linux en el servidor con controladores CUDA y mucho más solo para usar el sistema de aprendizaje automático.
Costo
Sin servidores significa que lo único que debes pagar es una CDN para alojar tus archivos HTML, CSS, JS y de modelo. El costo de una CDN es mucho más económico que mantener un servidor (posiblemente con una tarjeta gráfica adjunta) en funcionamiento las 24 horas, todos los días.
Funciones del servidor
Aprovechar la implementación de Node.js de TensorFlow.js habilita las siguientes funciones.
Compatibilidad total con CUDA
En el lado del servidor, para acelerar la tarjeta gráfica, debes instalar los controladores CUDA de NVIDIA para permitir que TensorFlow funcione con la tarjeta gráfica (a diferencia del navegador que usa WebGL, no es necesario instalarlo). Sin embargo, gracias a la total compatibilidad con CUDA, puedes aprovechar al máximo las capacidades de nivel inferior de la tarjeta gráfica, lo que agiliza los tiempos de inferencia y entrenamiento. El rendimiento está a la par con la implementación de TensorFlow en Python, ya que ambas comparten el mismo backend de C++.
Tamaño del modelo
Para modelos de vanguardia de la investigación, es posible que se trabaje con modelos muy grandes, quizás de gigabytes. Actualmente, estos modelos no se pueden ejecutar en el navegador web debido a las limitaciones del uso de memoria por pestaña del navegador. Para ejecutar estos modelos más grandes, puedes usar Node.js en tu propio servidor con las especificaciones de hardware que necesitas para ejecutar un modelo de este tipo de manera eficiente.
IoT
Node.js es compatible con computadoras de placa única populares como Raspberry Pi, lo que, a su vez, significa que también puedes ejecutar modelos de TensorFlow.js en esos dispositivos.
Velocidad
Node.js está escrito en JavaScript, lo que significa que se beneficia de una compilación inmediata. Esto significa que, a menudo, puedes ver mejoras en el rendimiento cuando usas Node.js, ya que se optimizará en el entorno de ejecución, especialmente para cualquier procesamiento previo que realices. Un excelente ejemplo de esto se puede observar en este caso de éxito, en el que se muestra cómo Hugging Face usó Node.js para duplicar el rendimiento de su modelo de procesamiento de lenguaje natural.
Ahora que comprendes los conceptos básicos de TensorFlow.js, dónde se puede ejecutar y algunos de sus beneficios, comencemos a hacer cosas útiles con él.
3. Aprendizaje por transferencia
¿Qué es exactamente el aprendizaje por transferencia?
El aprendizaje por transferencia implica tomar conocimiento que ya se aprendió para ayudar a aprender algo diferente pero similar.
Los seres humanos hacemos esto todo el tiempo. Tienes toda una vida de experiencias en tu cerebro que puedes usar para ayudarte a reconocer cosas nuevas que nunca antes viste. Tomemos este sauce como ejemplo:
Según el lugar del mundo en el que te encuentres, es posible que nunca hayas visto este tipo de árboles.
Sin embargo, si te pido que me digas si hay sauces en la nueva imagen de abajo, probablemente puedas detectarlos bastante rápido, aunque estén en un ángulo diferente y sean ligeramente diferentes al original que te mostré.
Ya hay un grupo de neuronas en el cerebro que saben identificar objetos en forma de árbol y otras neuronas que son capaces de encontrar líneas rectas largas. Puedes reutilizar ese conocimiento para clasificar rápidamente un sauce, que es un objeto similar a un árbol que tiene muchas ramas verticales largas y rectas.
Del mismo modo, si tienes un modelo de aprendizaje automático que ya está entrenado en un dominio, como el reconocimiento de imágenes, puedes volver a usarlo para realizar una tarea diferente, pero relacionada.
Puedes hacer lo mismo con un modelo avanzado como MobileNet, un modelo de investigación muy popular que puede realizar reconocimiento de imágenes en 1,000 tipos de objetos diferentes. Desde perros hasta automóviles, se entrenó con un enorme conjunto de datos conocido como ImageNet que tiene millones de imágenes etiquetadas.
En esta animación, puedes ver la gran cantidad de capas que tiene en este modelo de MobileNet V1:
Durante su entrenamiento, este modelo aprendió a extraer atributos comunes importantes para todos esos 1,000 objetos, y muchos de los atributos de nivel inferior que usa para identificar tales objetos pueden ser útiles para detectar objetos nuevos que nunca había visto antes. Después de todo, todo es, en definitiva, solo una combinación de líneas, texturas y formas.
Veamos la arquitectura tradicional de la red neuronal convolucional (CNN) (similar a MobileNet) y veamos cómo el aprendizaje por transferencia puede aprovechar esta red entrenada para aprender algo nuevo. La siguiente imagen muestra la arquitectura típica de un modelo de una CNN que, en este caso, se entrenó para reconocer dígitos escritos a mano del 0 al 9:
Si pudieras separar las capas de nivel inferior previamente entrenadas de un modelo entrenado existente como el que se muestra a la izquierda, de las capas de clasificación cerca del final del modelo que se muestra a la derecha (a veces denominadas “cabeza de clasificación del modelo”), podrías usar las capas de nivel inferior para producir atributos de salida para cualquier imagen determinada en función de los datos originales con los que se entrenó. Esta es la misma red sin el encabezado de clasificación:
Suponiendo que lo nuevo que intentas reconocer también puede hacer uso de esos atributos de salida que el modelo anterior aprendió, entonces es muy probable que se puedan reutilizar para un nuevo propósito.
En el diagrama anterior, este modelo hipotético se entrenó con dígitos, por lo que tal vez lo que se aprendió sobre los dígitos también se pueda aplicar a letras como a, b y c.
Así que ahora podrías agregar un nuevo encabezado de clasificación que intente predecir a, b o c, como se muestra a continuación:
Aquí, las capas de nivel inferior se inmovilizan y no están entrenadas; solo el nuevo encabezado de clasificación se actualizará para aprender de los atributos proporcionados por el modelo recortado previamente entrenado que se encuentra a la izquierda.
El hecho de hacer esto se conoce como aprendizaje por transferencia y es lo que Teachable Machine hace en segundo plano.
También puedes ver que, con solo entrenar el perceptrón multicapa al final de la red, se entrena mucho más rápido que si tuvieras que entrenar toda la red desde cero.
Pero ¿cómo puedes controlar las subpartes de un modelo? Ve a la siguiente sección para averiguarlo.
4. TensorFlow Hub: Modelos base
Busca un modelo base adecuado para usar
Para obtener modelos de investigación más avanzados y populares, como MobileNet, puedes ir a TensorFlow Hub y filtrar modelos adecuados para TensorFlow.js que usen la arquitectura MobileNet v3 para encontrar resultados como los que se muestran aquí:
Ten en cuenta que algunos de estos resultados son del tipo “clasificación de imágenes” (detallados en la parte superior izquierda del resultado de cada tarjeta de modelo) y otras son del tipo “vector de atributos de imagen”.
Estos resultados de vectores de atributos de imagen son, básicamente, versiones recortadas de MobileNet que puedes usar para obtener los vectores de atributos de la imagen, en lugar de la clasificación final.
Los modelos como este suelen llamarse “modelos base”, que puedes usar para realizar el aprendizaje por transferencia de la misma manera que se muestra en la sección anterior, agregando un nuevo encabezado de clasificación y entrenando con tus propios datos.
Lo siguiente que debes verificar es, para un modelo base de interés determinado, con qué formato de TensorFlow.js se lanza el modelo. Si abres la página de uno de estos modelos de vector de atributos de MobileNet v3, puedes ver, en la documentación de JS, que tiene la forma de un modelo de gráfico basado en el fragmento de código de ejemplo en la documentación que usa tf.loadGraphModel()
.
También debes tener en cuenta que, si encuentras un modelo en el formato de capas en lugar del formato de gráfico, puedes elegir qué capas inmovilizar y cuáles desbloquear para el entrenamiento. Esto puede ser muy potente cuando se crea un modelo para una nueva tarea, lo que con frecuencia se conoce como el “modelo de transferencia”. Sin embargo, por ahora usarás el tipo de modelo de grafo predeterminado para este instructivo, con el que se implementan la mayoría de los modelos de TF Hub. Para obtener más información sobre cómo trabajar con modelos de capas, consulta el curso 00 to hero TensorFlow.js.
Ventajas del aprendizaje por transferencia
¿Cuáles son las ventajas de usar el aprendizaje por transferencia en lugar de entrenar toda la arquitectura del modelo desde cero?
En primer lugar, el tiempo de entrenamiento es una ventaja clave de usar un enfoque de aprendizaje por transferencia, dado que ya tienes un modelo de base entrenado en el que basarte.
En segundo lugar, puedes mostrar muchos menos ejemplos de lo nuevo que intentas clasificar debido al entrenamiento ya realizado.
Esto es realmente genial si tienes tiempo y recursos limitados para recopilar datos de ejemplo de lo que deseas clasificar y necesitas hacer un prototipo rápidamente antes de recopilar más datos de entrenamiento para hacerlo más sólido.
Dada la necesidad de menos datos y la velocidad de entrenamiento de una red más pequeña, el aprendizaje por transferencia usa menos recursos. Esto lo hace muy adecuado para el entorno del navegador, ya que tarda solo decenas de segundos en una máquina moderna en lugar de horas, días o semanas para completar el entrenamiento del modelo.
¡Muy bien! Ahora que conoces la esencia de lo que es el aprendizaje por transferencia, es momento de crear tu propia versión de Teachable Machine. Comencemos.
5. Prepárate para codificar
Requisitos
- Un navegador web moderno
- Conocimientos básicos de HTML, CSS, JavaScript y Herramientas para desarrolladores de Chrome (visualización del resultado de la consola)
Comencemos a programar
Se crearon plantillas plantillas para comenzar en Glitch.com o Codepen.io. Puedes clonar cualquiera de las plantillas como tu estado base para este codelab con solo un clic.
En Glitch, haz clic en el botón “remix this” para bifurcarlo y crear un nuevo conjunto de archivos que puedas editar.
Como alternativa, en Codepen, haz clic en" fork" en la parte inferior derecha de la pantalla.
Este esqueleto simple te proporciona los siguientes archivos:
- Página HTML (index.html)
- Hoja de estilo (style.css)
- Archivo para escribir el código JavaScript (script.js)
Para tu comodidad, hay una importación adicional en el archivo HTML de la biblioteca de TensorFlow.js. El aspecto resultante será el siguiente:
index.html
<!-- Import TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js" type="text/javascript"></script>
Alternativa: Usa tu editor web preferido o trabaja de manera local
Si quieres descargar el código y trabajar de forma local o en un editor en línea diferente, simplemente crea los 3 archivos nombrados anteriormente en el mismo directorio y copia y pega el código de nuestro código estándar de Glitch en cada uno de ellos.
6. Código estándar HTML de la app
¿Por dónde empiezo?
Todos los prototipos requieren un andamiaje HTML básico en el que puedes procesar tus hallazgos. Configúralo ahora. Vas a agregar lo siguiente:
- Es un título para la página.
- Un texto descriptivo.
- Un párrafo de estado.
- Un video para conservar el feed de la cámara web una vez que esté listo.
- Varios botones para iniciar la cámara, recopilar datos o restablecer la experiencia.
- Importaciones para archivos de TensorFlow.js y JS que codificarás más adelante.
Abre index.html
y pega el código existente con lo siguiente para configurar las funciones anteriores:
index.html
<!DOCTYPE html>
<html lang="en">
<head>
<title>Transfer Learning - TensorFlow.js</title>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1">
<!-- Import the webpage's stylesheet -->
<link rel="stylesheet" href="/style.css">
</head>
<body>
<h1>Make your own "Teachable Machine" using Transfer Learning with MobileNet v3 in TensorFlow.js using saved graph model from TFHub.</h1>
<p id="status">Awaiting TF.js load</p>
<video id="webcam" autoplay muted></video>
<button id="enableCam">Enable Webcam</button>
<button class="dataCollector" data-1hot="0" data-name="Class 1">Gather Class 1 Data</button>
<button class="dataCollector" data-1hot="1" data-name="Class 2">Gather Class 2 Data</button>
<button id="train">Train & Predict!</button>
<button id="reset">Reset</button>
<!-- Import TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.11.0/dist/tf.min.js" type="text/javascript"></script>
<!-- Import the page's JavaScript to do some stuff -->
<script type="module" src="/script.js"></script>
</body>
</html>
Desglosar
Analicemos parte del código HTML anterior para destacar algunos aspectos clave que agregaste.
- Agregaste una etiqueta
<h1>
al título de la página y una etiqueta<p>
con el ID "status". que es donde imprimirás la información, ya que usarás diferentes partes del sistema para ver los resultados. - Agregaste un elemento
<video>
con el ID “webcam”. y renderizarás la transmisión de tu cámara web más tarde. - Agregaste 5 elementos
<button>
. La primera, con el ID de "enableCam", habilita la cámara. Los siguientes dos botones tienen una clase de “dataCollector”, que te permite recopilar imágenes de ejemplo para los objetos que deseas reconocer. El código que escribas más adelante se diseñará de manera que puedas agregar cualquier cantidad de estos botones y funcionarán como se espera automáticamente.
Ten en cuenta que estos botones también tienen un atributo especial definido por el usuario llamado data-1hot, con un valor entero que comienza en 0 para la primera clase. Este es el índice numérico que usarás para representar los datos de una clase determinada. El índice se usará para codificar las clases de salida correctamente con una representación numérica en lugar de una string, ya que los modelos de AA solo pueden trabajar con números.
También hay un atributo de nombre de datos que contiene el nombre legible por humanos que deseas usar para esta clase, lo que te permite proporcionar un nombre más significativo al usuario en lugar de un valor de índice numérico de la codificación 1.
Por último, tienes un botón de entrenamiento y restablecimiento para iniciar el proceso de entrenamiento una vez que se hayan recopilado los datos, o para restablecer la app, respectivamente.
- También agregaste 2 importaciones
<script>
. Una para TensorFlow.js y la otra para script.js que definirás pronto.
7. Agr. estilo
Valores predeterminados de elementos
Agrega estilos a los elementos HTML que acabas de agregar para asegurarte de que se rendericen correctamente. Estos son algunos estilos que se agregaron a los elementos de posición y tamaño correctamente. Nada especial. Sin duda, podrías agregar algo más adelante para mejorar la UX, como vimos en el video de Teachable Machine.
style.css
body {
font-family: helvetica, arial, sans-serif;
margin: 2em;
}
h1 {
font-style: italic;
color: #FF6F00;
}
video {
clear: both;
display: block;
margin: 10px;
background: #000000;
width: 640px;
height: 480px;
}
button {
padding: 10px;
float: left;
margin: 5px 3px 5px 10px;
}
.removed {
display: none;
}
#status {
font-size:150%;
}
¡Genial! ¡Eso es todo lo que necesitas! Si obtienes una vista previa del resultado ahora mismo, debería verse de la siguiente manera:
8. JavaScript: Constantes de clave y objetos de escucha
Define las constantes de clave
Primero, agrega algunas constantes de clave que usarás en toda la app. Para comenzar, reemplaza el contenido de script.js
con estas constantes:
script.js
const STATUS = document.getElementById('status');
const VIDEO = document.getElementById('webcam');
const ENABLE_CAM_BUTTON = document.getElementById('enableCam');
const RESET_BUTTON = document.getElementById('reset');
const TRAIN_BUTTON = document.getElementById('train');
const MOBILE_NET_INPUT_WIDTH = 224;
const MOBILE_NET_INPUT_HEIGHT = 224;
const STOP_DATA_GATHER = -1;
const CLASS_NAMES = [];
Analicemos para qué sirven:
STATUS
simplemente contiene una referencia a la etiqueta de párrafo en la que escribirás actualizaciones de estado.VIDEO
contiene una referencia al elemento de video HTML que renderizará el feed de la cámara web.ENABLE_CAM_BUTTON
,RESET_BUTTON
yTRAIN_BUTTON
toman referencias del DOM a todos los botones clave de la página HTML.MOBILE_NET_INPUT_WIDTH
yMOBILE_NET_INPUT_HEIGHT
definen el ancho y la altura de entrada esperados del modelo de MobileNet, respectivamente. Al almacenar esto en una constante cerca de la parte superior del archivo como esta, si decides usar una versión diferente más adelante, es más fácil actualizar los valores una vez, en lugar de tener que reemplazarla en muchos lugares diferentes.STOP_DATA_GATHER
se establece en - 1. Esto almacena un valor de estado para que sepas cuándo el usuario dejó de hacer clic en un botón para recopilar datos del feed de la cámara web. Si le das a este número un nombre más significativo, el código será más legible más adelante.CLASS_NAMES
actúa como una búsqueda y contiene nombres legibles por humanos para las posibles predicciones de clase. Este array se propagará más adelante.
Ahora que tienes referencias a elementos clave, es momento de asociarlos con algunos objetos de escucha de eventos.
Cómo agregar objetos de escucha de eventos clave
Comienza por agregar controladores de eventos de clic a los botones de teclas como se muestra a continuación:
script.js
ENABLE_CAM_BUTTON.addEventListener('click', enableCam);
TRAIN_BUTTON.addEventListener('click', trainAndPredict);
RESET_BUTTON.addEventListener('click', reset);
function enableCam() {
// TODO: Fill this out later in the codelab!
}
function trainAndPredict() {
// TODO: Fill this out later in the codelab!
}
function reset() {
// TODO: Fill this out later in the codelab!
}
ENABLE_CAM_BUTTON
: Llama a la función enableCam cuando se hace clic en él.
TRAIN_BUTTON
: Llama a trainAndPredict cuando se hace clic en él.
RESET_BUTTON
: Las llamadas se restablecen cuando se hace clic en él.
Por último, en esta sección, puedes encontrar todos los botones con la clase "dataCollector" usando document.querySelectorAll()
. Esto devuelve un array de elementos encontrados en el documento que coinciden:
script.js
let dataCollectorButtons = document.querySelectorAll('button.dataCollector');
for (let i = 0; i < dataCollectorButtons.length; i++) {
dataCollectorButtons[i].addEventListener('mousedown', gatherDataForClass);
dataCollectorButtons[i].addEventListener('mouseup', gatherDataForClass);
// Populate the human readable names for classes.
CLASS_NAMES.push(dataCollectorButtons[i].getAttribute('data-name'));
}
function gatherDataForClass() {
// TODO: Fill this out later in the codelab!
}
Explicación del código:
Luego, iteras a través de los botones encontrados y asocias 2 objetos de escucha de eventos a cada uno. Uno para "mousedown" y otro para "mouseup". Esto te permite seguir grabando muestras mientras se presiona el botón, lo cual es útil para la recopilación de datos.
Ambos eventos llaman a una función gatherDataForClass
que definirás más adelante.
En este punto, también puedes enviar los nombres de clase legibles por humanos del atributo data-name del botón HTML al array CLASS_NAMES
.
A continuación, agrega algunas variables para almacenar elementos clave que se usarán más adelante.
script.js
let mobilenet = undefined;
let gatherDataState = STOP_DATA_GATHER;
let videoPlaying = false;
let trainingDataInputs = [];
let trainingDataOutputs = [];
let examplesCount = [];
let predict = false;
Veamos cuáles son.
Primero, tienes una variable mobilenet
para almacenar el modelo de mobilenet cargado. Inicialmente, configúralo como indefinido.
A continuación, tienes una variable llamada gatherDataState
. Si un recopilador de datos botón, se convertirá en el ID activo de ese botón, como se define en el código HTML, para que sepas qué clase de datos estás recopilando en ese momento. Inicialmente, se establece en STOP_DATA_GATHER
para que el bucle de recopilación de datos que escribas más tarde no recopile datos cuando no se presione ningún botón.
videoPlaying
realiza un seguimiento de si la transmisión de la cámara web se cargó y se está reproduciendo correctamente, y si está disponible para usarse. Inicialmente, esto se establece en false
, ya que la cámara web no estará encendida hasta que presiones ENABLE_CAM_BUTTON.
.
A continuación, define 2 arrays: trainingDataInputs
y trainingDataOutputs
. Estos almacenan los valores de datos de entrenamiento recopilados, cuando haces clic en el botón “dataCollector” para los atributos de entrada generados por el modelo base de MobileNet y la clase de salida muestreada, respectivamente.
Luego, se define un array final, examplesCount,
, para hacer un seguimiento de la cantidad de ejemplos que contiene cada clase una vez que comienzas a agregarlos.
Por último, tienes una variable llamada predict
que controla el bucle de predicción. Se establece inicialmente en false
. No se podrán realizar predicciones hasta que la configuración se establezca como true
más adelante.
Ahora que se definieron todas las variables clave, carguemos el modelo base MobileNet v3 previamente cortado que proporciona vectores de atributos de imágenes en lugar de clasificaciones.
9. Carga el modelo base de MobileNet
Primero, define una nueva función llamada loadMobileNetFeatureModel
, como se muestra a continuación. Debe ser una función asíncrona, ya que el acto de cargar un modelo es asíncrono:
script.js
/**
* Loads the MobileNet model and warms it up so ready for use.
**/
async function loadMobileNetFeatureModel() {
const URL =
'https://tfhub.dev/google/tfjs-model/imagenet/mobilenet_v3_small_100_224/feature_vector/5/default/1';
mobilenet = await tf.loadGraphModel(URL, {fromTFHub: true});
STATUS.innerText = 'MobileNet v3 loaded successfully!';
// Warm up the model by passing zeros through it once.
tf.tidy(function () {
let answer = mobilenet.predict(tf.zeros([1, MOBILE_NET_INPUT_HEIGHT, MOBILE_NET_INPUT_WIDTH, 3]));
console.log(answer.shape);
});
}
// Call the function immediately to start loading.
loadMobileNetFeatureModel();
En este código, defines el URL
en el que se ubica el modelo que se cargará en la documentación de TFHub.
Luego, puedes cargar el modelo con await tf.loadGraphModel()
. Recuerda establecer la propiedad especial fromTFHub
en true
mientras cargas un modelo desde este sitio web de Google. Este es un caso especial solo para los modelos alojados en TF Hub en los que se debe configurar esta propiedad adicional.
Una vez que se complete la carga, puedes configurar el innerText
del elemento STATUS
con un mensaje para que puedas ver visualmente que se cargó de forma correcta y que tengas todo listo para comenzar a recopilar datos.
Lo único que queda por hacer ahora es preparar el modelo. Con modelos más grandes como este, la primera vez que usas el modelo, puede tomar un momento configurar todo. Por lo tanto, ayuda pasar ceros por el modelo para evitar cualquier espera en el futuro donde el tiempo pueda ser más crítico.
Puedes usar tf.zeros()
unido a un tf.tidy()
para asegurarte de que los tensores se eliminen correctamente, con un tamaño de lote de 1 y la altura y el ancho correctos que definiste en tus constantes al comienzo. Por último, también especificas los canales de color, que en este caso es 3, ya que el modelo espera imágenes RGB.
A continuación, registra la forma resultante del tensor que se muestra con answer.shape()
para ayudarte a comprender el tamaño de los atributos de la imagen que produce este modelo.
Después de definir esta función, puedes llamarla de inmediato para iniciar la descarga del modelo en la carga de la página.
Si puedes acceder a la vista previa en vivo ahora mismo, después de unos momentos, verás que el texto del estado cambia de "Awaiting TF.js load" (Esperando carga de TF.js). para que se convierta en un mensaje de “MobileNet v3” como se muestra a continuación. Asegúrate de que funcione antes de continuar.
También puedes verificar el resultado de la consola para ver el tamaño de impresión de los atributos de salida que produce este modelo. Después de ejecutar ceros a través del modelo de MobileNet, verás impresa la forma de [1, 1024]
. El primer elemento es solo el tamaño de lote de 1, y puedes ver que en realidad muestra 1,024 atributos que luego se pueden usar para ayudarte a clasificar objetos nuevos.
10. Define el encabezado del modelo nuevo
Ahora es el momento de definir el encabezado del modelo, que es, en esencia, un perceptrón multicapa mínimo.
script.js
let model = tf.sequential();
model.add(tf.layers.dense({inputShape: [1024], units: 128, activation: 'relu'}));
model.add(tf.layers.dense({units: CLASS_NAMES.length, activation: 'softmax'}));
model.summary();
// Compile the model with the defined optimizer and specify a loss function to use.
model.compile({
// Adam changes the learning rate over time which is useful.
optimizer: 'adam',
// Use the correct loss function. If 2 classes of data, must use binaryCrossentropy.
// Else categoricalCrossentropy is used if more than 2 classes.
loss: (CLASS_NAMES.length === 2) ? 'binaryCrossentropy': 'categoricalCrossentropy',
// As this is a classification problem you can record accuracy in the logs too!
metrics: ['accuracy']
});
Analicemos este código. Comenzarás definiendo un modelo tf.secuencial al que agregarás capas del modelo.
A continuación, agrega una capa densa como capa de entrada a este modelo. Tiene una forma de entrada de 1024
, ya que los resultados de las funciones de MobileNet v3 son de este tamaño. Descubriste esto en el paso anterior después de pasar unos por el modelo. Esta capa tiene 128 neuronas que usan la función de activación ReLU.
Si no tienes experiencia con las funciones de activación y las capas de modelos, considera realizar el curso detallado al comienzo de este taller para comprender qué hacen estas propiedades en segundo plano.
La siguiente capa que se debe agregar es la capa de salida. El número de neuronas debe ser igual al número de clases que intentas predecir. Para ello, puedes usar CLASS_NAMES.length
a fin de encontrar la cantidad de clases que planeas clasificar, lo que equivale a la cantidad de botones de recopilación de datos que se encuentran en la interfaz de usuario. Como este es un problema de clasificación, debes usar la activación softmax
en esta capa de salida, que debe usarse cuando intentas crear un modelo para resolver problemas de clasificación en lugar de regresión.
Ahora imprime un model.summary()
para imprimir la descripción general del modelo recién definido en la consola.
Por último, compila el modelo para que esté listo para entrenarse. Aquí, el optimizador se establece en adam
, y la pérdida será binaryCrossentropy
si CLASS_NAMES.length
es igual a 2
, o usará categoricalCrossentropy
si hay 3 o más clases para clasificar. También se solicitan métricas de precisión para que puedan supervisarse en los registros más adelante con fines de depuración.
En la consola, deberías ver algo como esto:
Ten en cuenta que tiene más de 130,000 parámetros entrenables. Pero como se trata de una capa simple y densa de neuronas regulares, se entrenará bastante rápido.
Una vez que se complete el proyecto, puedes probar cambiar la cantidad de neuronas en la primera capa para ver qué tan bajo puedes llegar y, al mismo tiempo, obtener un rendimiento decente. A menudo, con el aprendizaje automático, se necesita cierto nivel de ensayo y error para encontrar valores de parámetros óptimos que ofrezcan la mejor relación entre el uso de recursos y la velocidad.
11. Habilitar la cámara web
Ahora es el momento de completar la función enableCam()
que definiste antes. Agrega una nueva función llamada hasGetUserMedia()
como se muestra a continuación y, luego, reemplaza el contenido de la función enableCam()
definida anteriormente por el código correspondiente a continuación.
script.js
function hasGetUserMedia() {
return !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
}
function enableCam() {
if (hasGetUserMedia()) {
// getUsermedia parameters.
const constraints = {
video: true,
width: 640,
height: 480
};
// Activate the webcam stream.
navigator.mediaDevices.getUserMedia(constraints).then(function(stream) {
VIDEO.srcObject = stream;
VIDEO.addEventListener('loadeddata', function() {
videoPlaying = true;
ENABLE_CAM_BUTTON.classList.add('removed');
});
});
} else {
console.warn('getUserMedia() is not supported by your browser');
}
}
Primero, crea una función llamada hasGetUserMedia()
para verificar si el navegador admite getUserMedia()
. Para ello, comprueba la existencia de propiedades clave de las APIs del navegador.
En la función enableCam()
, usa la función hasGetUserMedia()
que acabas de definir antes para verificar si es compatible. De lo contrario, imprime una advertencia en la consola.
Si lo admite, define algunas restricciones para tu llamada a getUserMedia()
, por ejemplo, solo quieres transmitir video por Internet y elegir que width
del video tenga un tamaño de 640
píxeles, y height
tenga 480
píxeles. ¿Por qué? Bueno, no tiene mucho sentido hacer que un video se vea más grande, ya que se debería cambiar el tamaño a 224 por 224 píxeles para introducirlo en el modelo MobileNet. También puedes ahorrar algunos recursos de procesamiento si solicitas una resolución más pequeña. La mayoría de las cámaras admiten una resolución de este tamaño.
A continuación, llama a navigator.mediaDevices.getUserMedia()
con el constraints
detallado arriba y espera a que se muestre stream
. Una vez que se muestre el stream
, podrás hacer que tu elemento VIDEO
reproduzca el stream
configurándolo como su valor srcObject
.
También debes agregar un eventListener al elemento VIDEO
para saber cuándo se cargó el stream
y cuándo se reproduce correctamente.
Una vez que se cargue el flujo, puedes establecer videoPlaying
como verdadero y quitar el ENABLE_CAM_BUTTON
para evitar que se vuelva a hacer clic en él. Para ello, establece su clase en "removed
".
Ahora ejecuta tu código, haz clic en el botón para habilitar la cámara y permite el acceso a la cámara web. Si es la primera vez que haces esto, deberías verte renderizado en el elemento de video en la página, como se muestra a continuación:
Muy bien, ahora es el momento de agregar una función para controlar los clics en el botón dataCollector
.
12. Controlador de eventos del botón de recopilación de datos
Ahora es el momento de completar la función vacía actual llamada gatherDataForClass().
. Esto es lo que asignaste como función de controlador de eventos para los botones dataCollector
al comienzo del codelab.
script.js
/**
* Handle Data Gather for button mouseup/mousedown.
**/
function gatherDataForClass() {
let classNumber = parseInt(this.getAttribute('data-1hot'));
gatherDataState = (gatherDataState === STOP_DATA_GATHER) ? classNumber : STOP_DATA_GATHER;
dataGatherLoop();
}
Primero, verifica el atributo data-1hot
en el botón en el que se hace clic. Para ello, llama a this.getAttribute()
con el nombre del atributo, en este caso, data-1hot
como el parámetro. Como es una cadena, puedes usar parseInt()
para convertirla en un número entero y asignar este resultado a una variable llamada classNumber.
A continuación, configura la variable gatherDataState
según corresponda. Si el gatherDataState
actual es igual a STOP_DATA_GATHER
(que configuraste como -1), significa que no estás recopilando datos en este momento y que se activó un evento mousedown
. Configura gatherDataState
para que sea el classNumber
que acabas de encontrar.
De lo contrario, significa que estás recopilando datos y el evento que se activó fue un evento mouseup
, y ahora deseas dejar de recopilar datos para esa clase. Solo vuelve a establecerlo en el estado STOP_DATA_GATHER
para finalizar el bucle de recopilación de datos que definirás en breve.
Por último, inicia la llamada a dataGatherLoop(),
, que realiza la grabación de los datos de la clase.
13. Recopilación de datos
Ahora, define la función dataGatherLoop()
. Esta función se encarga de muestrear imágenes del video de la cámara web, pasarlas por el modelo MobileNet y capturar las salidas de ese modelo (los vectores de atributos 1024).
Luego, los almacena junto con el ID de gatherDataState
del botón que se está presionando en ese momento para que sepas qué clase representan estos datos.
Veamos cómo hacerlo paso a paso.
script.js
function dataGatherLoop() {
if (videoPlaying && gatherDataState !== STOP_DATA_GATHER) {
let imageFeatures = tf.tidy(function() {
let videoFrameAsTensor = tf.browser.fromPixels(VIDEO);
let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [MOBILE_NET_INPUT_HEIGHT,
MOBILE_NET_INPUT_WIDTH], true);
let normalizedTensorFrame = resizedTensorFrame.div(255);
return mobilenet.predict(normalizedTensorFrame.expandDims()).squeeze();
});
trainingDataInputs.push(imageFeatures);
trainingDataOutputs.push(gatherDataState);
// Intialize array index element if currently undefined.
if (examplesCount[gatherDataState] === undefined) {
examplesCount[gatherDataState] = 0;
}
examplesCount[gatherDataState]++;
STATUS.innerText = '';
for (let n = 0; n < CLASS_NAMES.length; n++) {
STATUS.innerText += CLASS_NAMES[n] + ' data count: ' + examplesCount[n] + '. ';
}
window.requestAnimationFrame(dataGatherLoop);
}
}
Solo continuarás la ejecución de esta función si videoPlaying
es verdadero, lo que significa que la cámara web está activa, gatherDataState
no es igual a STOP_DATA_GATHER
y se está presionando un botón para recopilar datos de la clase.
A continuación, une tu código en un tf.tidy()
para eliminar cualquier tensor creado en el código que sigue. El resultado de esta ejecución de código tf.tidy()
se almacena en una variable llamada imageFeatures
.
Ahora puedes tomar un marco de la cámara web VIDEO
utilizando tf.browser.fromPixels()
. El tensor resultante que contiene los datos de la imagen se almacena en una variable llamada videoFrameAsTensor
.
A continuación, cambia el tamaño de la variable videoFrameAsTensor
de modo que tenga la forma correcta para la entrada del modelo MobileNet. Usa una llamada tf.image.resizeBilinear()
con el tensor al que quieras cambiar la forma como primer parámetro y, luego, una forma que defina la altura y el ancho nuevos, como lo definen las constantes que ya creaste antes. Por último, pasa el tercer parámetro para establecer la alineación de las esquinas como verdadera y, de esa forma, evitar problemas de alineación cuando se cambie el tamaño. El resultado de este cambio de tamaño se almacena en una variable llamada resizedTensorFrame
.
Ten en cuenta que este primitivo cambio de tamaño estira la imagen, ya que la imagen de tu cámara web tiene un tamaño de 640 por 480 píxeles y el modelo necesita una imagen cuadrada de 224 por 224 píxeles.
A los fines de esta demostración, esto debería funcionar bien. Sin embargo, una vez que completes este codelab, es posible que quieras recortar un cuadrado de esta imagen para obtener mejores resultados en cualquier sistema de producción que crees más adelante.
A continuación, normaliza los datos de la imagen. Los datos de imagen siempre están en el rango de 0 a 255 cuando se usa tf.browser.frompixels()
, por lo que puedes dividir simplemente renamedTensorFrame por 255 para asegurarte de que todos los valores estén entre 0 y 1, que es lo que el modelo MobileNet espera como entradas.
Por último, en la sección tf.tidy()
del código, envía este tensor normalizado a través del modelo cargado llamando a mobilenet.predict()
, al que pasas la versión expandida de normalizedTensorFrame
con expandDims()
de modo que sea un lote de 1, ya que el modelo espera un lote de entradas para el procesamiento.
Una vez que se muestre el resultado, podrás llamar de inmediato a squeeze()
en el resultado que se muestra para reducirlo a un tensor 1D, que luego mostrarás y asignarás a la variable imageFeatures
que captura el resultado de tf.tidy()
.
Ahora que tienes los imageFeatures
del modelo MobileNet, puedes enviarlos enviándolos al array trainingDataInputs
que definiste antes para registrarlos.
También puedes enviar el gatherDataState
actual al array trainingDataOutputs
para registrar lo que representa esta entrada.
Ten en cuenta que la variable gatherDataState
se habría establecido en el ID numérico de la clase actual del que estás registrando datos cuando se hizo clic en el botón en la función gatherDataForClass()
definida con anterioridad.
En este punto, también puedes aumentar la cantidad de ejemplos que tienes para una clase determinada. Para ello, primero verifica si el índice dentro del array examplesCount
se inicializó antes o no. Si no está definido, configúralo en 0 para inicializar el contador del ID numérico de una clase determinada y, luego, puedes incrementar el examplesCount
para el gatherDataState
actual.
Ahora, actualiza el texto del elemento STATUS
en la página web para mostrar los recuentos actuales de cada clase a medida que se capturan. Para ello, realiza un bucle a lo largo del array CLASS_NAMES
y, luego, imprime el nombre legible por humanos combinado con el recuento de datos en el mismo índice en examplesCount
.
Por último, llama a window.requestAnimationFrame()
con dataGatherLoop
pasado como parámetro para volver a llamar de forma recursiva a esta función. Se seguirán muestreando fotogramas del video hasta que se detecte el mouseup
del botón y gatherDataState
se establecerá en STOP_DATA_GATHER,
, momento en el que finalizará el bucle de recopilación de datos.
Si ejecutas tu código ahora, deberías poder hacer clic en el botón para habilitar la cámara, esperar que la cámara web se cargue y, luego, hacer clic y mantener presionado cada uno de los botones de recopilación de datos para recopilar ejemplos para cada clase de datos. Aquí se ve cómo recopilo datos del teléfono celular y de la mano, respectivamente.
Deberías ver el texto de estado actualizado a medida que almacena todos los tensores en la memoria, como se muestra en la captura de pantalla anterior.
14. Entrenar y predecir
El siguiente paso es implementar código para la función trainAndPredict()
que está vacía actualmente, que es donde se lleva a cabo el aprendizaje por transferencia. Echemos un vistazo al código:
script.js
async function trainAndPredict() {
predict = false;
tf.util.shuffleCombo(trainingDataInputs, trainingDataOutputs);
let outputsAsTensor = tf.tensor1d(trainingDataOutputs, 'int32');
let oneHotOutputs = tf.oneHot(outputsAsTensor, CLASS_NAMES.length);
let inputsAsTensor = tf.stack(trainingDataInputs);
let results = await model.fit(inputsAsTensor, oneHotOutputs, {shuffle: true, batchSize: 5, epochs: 10,
callbacks: {onEpochEnd: logProgress} });
outputsAsTensor.dispose();
oneHotOutputs.dispose();
inputsAsTensor.dispose();
predict = true;
predictLoop();
}
function logProgress(epoch, logs) {
console.log('Data for epoch ' + epoch, logs);
}
Primero, asegúrate de detener las predicciones actuales. Para ello, establece predict
en false
.
A continuación, reorganiza tus arrays de entrada y salida con tf.util.shuffleCombo()
para asegurarte de que el orden no cause problemas en el entrenamiento.
Convierte tu array de salida, trainingDataOutputs,
, en un tensor1d de tipo int32 para que esté listo para usarse en una codificación one-hot. Esto se almacena en una variable llamada outputsAsTensor
.
Usa la función tf.oneHot()
con esta variable outputsAsTensor
junto con la cantidad máxima de clases para codificar, que es solo el CLASS_NAMES.length
. Las salidas con codificación directa ahora se almacenan en un nuevo tensor llamado oneHotOutputs
.
Ten en cuenta que, actualmente, trainingDataInputs
es un array de tensores grabados. Si desea usarlos para el entrenamiento, deberá convertir el array de tensores con el fin de que se convierta en un tensor 2D normal.
Para hacer eso, hay una gran función dentro de la biblioteca de TensorFlow.js llamada tf.stack()
,
que toma un array de tensores y los apila para producir un tensor de mayor dimensión como salida. En este caso, se muestra un tensor 2D, que es un lote de entradas 1 dimensional con una longitud de 1,024 m y que contiene los atributos registrados, que es lo que necesitas para el entrenamiento.
A continuación, utiliza await model.fit()
para entrenar el encabezado del modelo personalizado. Aquí, pasas la variable inputsAsTensor
junto con el oneHotOutputs
para representar los datos de entrenamiento que se usarán en las entradas de ejemplo y las salidas objetivo, respectivamente. En el objeto de configuración del tercer parámetro, establece shuffle
en true
, usa batchSize
de 5
, con epochs
establecido en 10
y, luego, especifica un callback
para onEpochEnd
en la función logProgress
que definirás en breve.
Por último, puedes deshacerte de los tensores creados a medida que el modelo está entrenado. Luego, puedes volver a establecer predict
en true
para permitir que se vuelvan a realizar las predicciones y, luego, llamar a la función predictLoop()
para comenzar a predecir imágenes de cámaras web en vivo.
También puedes definir la función logProcess()
para registrar el estado del entrenamiento, que se usa en el ejemplo anterior de model.fit()
y que imprime los resultados en la consola después de cada ronda de entrenamiento.
Falta muy poco. Es hora de agregar la función predictLoop()
para hacer predicciones.
Bucle de predicción principal
Aquí implementas el bucle de predicción principal que toma muestras de fotogramas de una cámara web y predice continuamente el contenido de cada fotograma con resultados en tiempo real en el navegador.
Comprobemos el código:
script.js
function predictLoop() {
if (predict) {
tf.tidy(function() {
let videoFrameAsTensor = tf.browser.fromPixels(VIDEO).div(255);
let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor,[MOBILE_NET_INPUT_HEIGHT,
MOBILE_NET_INPUT_WIDTH], true);
let imageFeatures = mobilenet.predict(resizedTensorFrame.expandDims());
let prediction = model.predict(imageFeatures).squeeze();
let highestIndex = prediction.argMax().arraySync();
let predictionArray = prediction.arraySync();
STATUS.innerText = 'Prediction: ' + CLASS_NAMES[highestIndex] + ' with ' + Math.floor(predictionArray[highestIndex] * 100) + '% confidence';
});
window.requestAnimationFrame(predictLoop);
}
}
Primero, verifica que predict
sea verdadero, de modo que las predicciones solo se realicen después de que un modelo esté entrenado y esté disponible para su uso.
A continuación, puedes obtener las funciones de imagen para la imagen actual como lo hiciste en la función dataGatherLoop()
. Básicamente, tomas un fotograma de la cámara web con tf.browser.from pixels()
, lo normalizas, lo cambias de tamaño para que tenga un tamaño de 224 por 224 píxeles y, luego, pasas esos datos por el modelo de MobileNet para obtener las características de imagen resultantes.
Sin embargo, ahora puedes usar tu encabezado de modelo recién entrenado para realizar una predicción. Para ello, pasa el imageFeatures
resultante que se acaba de encontrar a través de la función predict()
del modelo entrenado. Luego, puedes apretar el tensor resultante para volver a 1 dimensión y asignarlo a una variable llamada prediction
.
Con este prediction
, puedes encontrar el índice que tiene el valor más alto usando argMax()
y, luego, convertir este tensor resultante en un array con arraySync()
para obtener los datos subyacentes en JavaScript y descubrir la posición del elemento de mayor valor. Este valor se almacena en la variable llamada highestIndex
.
También puedes obtener las puntuaciones reales de confianza de la predicción de la misma manera si llamas directamente a arraySync()
en el tensor prediction
.
Ahora tienes todo lo que necesitas para actualizar el texto STATUS
con los datos de prediction
. Con el fin de obtener una cadena legible para la clase, puedes buscar highestIndex
en el array CLASS_NAMES
y, luego, tomar el valor de confianza de predictionArray
. Para que sea más legible como un porcentaje, solo multiplica el resultado por 100 y math.floor()
.
Por último, puedes usar window.requestAnimationFrame()
para volver a llamar a predictionLoop()
una vez que esté todo listo, de modo que obtengas una clasificación en tiempo real en tu transmisión de video por Internet. Esto continúa hasta que predict
se establece en false
si eliges entrenar un modelo nuevo con datos nuevos.
Lo que te lleva a la pieza final del rompecabezas. Implementa el botón de restablecimiento.
15. Implementa el botón de restablecimiento
Falta poco. La última pieza del rompecabezas consiste en implementar un botón de restablecimiento para empezar de nuevo. A continuación, se muestra el código de la función reset()
vacía. Actualízala de la siguiente manera:
script.js
/**
* Purge data and start over. Note this does not dispose of the loaded
* MobileNet model and MLP head tensors as you will need to reuse
* them to train a new model.
**/
function reset() {
predict = false;
examplesCount.length = 0;
for (let i = 0; i < trainingDataInputs.length; i++) {
trainingDataInputs[i].dispose();
}
trainingDataInputs.length = 0;
trainingDataOutputs.length = 0;
STATUS.innerText = 'No data collected';
console.log('Tensors in memory: ' + tf.memory().numTensors);
}
Primero, detén cualquier bucle de predicción en ejecución configurando predict
como false
. A continuación, borra todo el contenido del array examplesCount
estableciendo su longitud en 0, que es una forma práctica de borrar todo el contenido de un array.
Ahora revisa todos los trainingDataInputs
grabados actuales y asegúrate de que dispose()
de cada tensor contenido en él para liberar memoria nuevamente, ya que el recolector de elementos no utilizados de JavaScript no limpia los tensores.
Una vez que hayas terminado, puedes establecer la longitud del array en 0 de forma segura en los arrays trainingDataInputs
y trainingDataOutputs
para borrarlos también.
Por último, configura el texto STATUS
como algo sensible y, luego, imprime los tensores que queden en la memoria como una verificación de estado.
Ten en cuenta que quedarán algunos cientos de tensores aún en la memoria, ya que no se eliminan el modelo MobileNet ni el perceptrón multicapa que definiste. Deberás volver a usarlos con datos de entrenamiento nuevos si decides volver a entrenar después de este restablecimiento.
16. Probémoslo
Es hora de probar tu propia versión de Teachable Machine.
Dirígete a la vista previa en vivo, habilita la cámara web, reúne al menos 30 muestras de algún objeto de tu aula para la clase 1 y, luego, haz lo mismo con otro objeto de la clase 2, haz clic en Entrenar y consulta el registro de la consola para ver el progreso. Debería entrenarse bastante rápido:
Una vez entrenado, muestra los objetos a la cámara para obtener predicciones en vivo que se imprimirán en el área de texto de estado de la página web, cerca de la parte superior. Si tienes problemas, revisa el código que funcionaba completo para ver si te faltó copiar algo.
17. Felicitaciones
¡Felicitaciones! Acabas de completar tu primer ejemplo de aprendizaje por transferencia con TensorFlow.js en vivo en el navegador.
Realiza la prueba en una variedad de objetos. Es posible que notes que algunas cosas son más difíciles de reconocer que otras, en especial si son similares a otras. Es posible que debas agregar más clases o datos de entrenamiento para poder diferenciarlos.
Resumen
En este codelab, aprendiste lo siguiente:
- Qué es el aprendizaje por transferencia y sus ventajas en comparación con entrenar un modelo completo
- Cómo obtener modelos para reutilizar desde TensorFlow Hub.
- Cómo configurar una app web adecuada para el aprendizaje por transferencia
- Cómo cargar y usar un modelo base para generar atributos de imagen
- Cómo entrenar un nuevo cabezal de predicción que pueda reconocer objetos personalizados a partir de imágenes de cámaras web
- Cómo usar los modelos resultantes para clasificar datos en tiempo real
¿Qué sigue?
Ahora que tienes una base de trabajo desde la cual comenzar, ¿qué ideas creativas se te ocurren para extender este modelo de aprendizaje automático estándar para un caso de uso del mundo real en el que estás trabajando? ¿Quizá podrías revolucionar la industria en la que trabajas actualmente para ayudar a las personas de tu empresa a entrenar modelos para clasificar elementos que son importantes en su trabajo diario? Las posibilidades son infinitas.
Para avanzar, considera realizar este curso completo de forma gratuita, en el que se muestra cómo combinar los 2 modelos que tienes actualmente en este codelab en 1 solo para aumentar la eficiencia.
Además, si quieres más información sobre la teoría detrás de la aplicación original de Teachable Machine, consulta este instructivo.
Comparte tus creaciones con nosotros
También puedes extender fácilmente lo que creaste hoy para otros casos de uso creativos, y te recomendamos que pienses de forma innovadora y sigas hackeando.
Recuerda etiquetarnos en redes sociales con el hashtag #MadeWithTFJS para tener la oportunidad de que tu proyecto se destaque en nuestro blog de TensorFlow o incluso en eventos futuros. Nos encantaría ver tus productos.
Sitios web que puedes revisar
- Sitio web oficial de TensorFlow.js
- Modelos prediseñados de TensorFlow.js
- API de TensorFlow.js
- TensorFlow.js Show & Cuéntales: Inspírate y descubre las creaciones de otros usuarios.