Compila una app para Android de clasificación de dígitos escritos a mano con MediaPipe Tasks

1. Introducción

¿Qué es MediaPipe?

Las MediaPipe Solutions te permiten aplicar soluciones de aprendizaje automático (AA) en tus apps. Proporciona un framework para configurar canalizaciones de procesamiento compiladas previamente que generan resultados inmediatos, atractivos y útiles para los usuarios. Incluso puedes personalizar estas soluciones con MediaPipe Model Maker para actualizar los modelos predeterminados.

La clasificación de imágenes es una de las muchas tareas de visión de AA que ofrece MediaPipe Solutions. MediaPipe Tasks está disponible para Android, iOS, Python (incluida la Raspberry Pi!) y la Web.

En este codelab, comenzarás con una app para Android que te permite dibujar dígitos numéricos en la pantalla. Luego, agregarás funcionalidad que clasifique esos dígitos dibujados como un valor único del 0 al 9.

Qué aprenderás

  • Cómo incorporar una tarea de clasificación de imágenes en una app para Android con MediaPipe Tasks.

Requisitos

  • Una versión instalada de Android Studio (este codelab se escribió y probó con Android Studio Giraffe).
  • Un emulador o dispositivo Android para ejecutar la app
  • Conocimientos básicos sobre el desarrollo de Android (esto no es "Hello World", pero no es demasiado lejano).

2. Cómo agregar tareas de MediaPipe a la app para Android

Descarga la app de partida para Android

Este codelab comenzará con una muestra prediseñada que te permite dibujar en la pantalla. Puedes encontrar esa app de inicio en el repositorio oficial de MediaPipe Samples aquí. Clona el repo o descarga el archivo ZIP haciendo clic en Código > Descargar ZIP.

Importa la app a Android Studio

  1. Abre Android Studio.
  2. En la pantalla Welcome to Android Studio, selecciona Open en la esquina superior derecha.

a0b5b070b802e4ea.png

  1. Navega a la ubicación donde clonaste o descargaste el repositorio y abre el codelabs/digitclassifier/android/start directorio.
  2. Verifica que todo se haya abierto correctamente haciendo clic en la flecha verde run ( 7e15a9c9e1620fe7.png) en la parte superior derecha de Android Studio.
  3. Deberías ver que la app se abre con una pantalla negra en la que puedes dibujar, así como el botón Borrar para restablecer esa pantalla. Si bien puedes dibujar en esa pantalla, no hace mucho más, por lo que comenzaremos a corregir eso ahora.

11a0f6fe021fdc92.jpeg

Modelo

Cuando ejecutes la app por primera vez, es posible que notes que se descarga y almacena un archivo llamado mnist.tflite en el directorio assets de la app. Para simplificar, ya tomamos un modelo conocido, MNIST, que clasifica dígitos, y lo agregamos a la app mediante la secuencia de comandos download_models.gradle en el proyecto. Si decides entrenar tu propio modelo personalizado, como uno para letras escritas a mano, debes quitar el archivo download_models.gradle, borrar la referencia en el archivo build.gradle de nivel de tu aplicación y cambiar el nombre del modelo más adelante en el código (específicamente en el archivo DigitClassifierHelper.kt).

Cómo actualizar build.gradle

Antes de comenzar a usar MediaPipe Tasks, debes importar la biblioteca.

  1. Abre el archivo build.gradle ubicado en el módulo app y, luego, desplázate hacia abajo hasta el bloque dependencias.
  2. Deberías ver un comentario en la parte inferior de ese bloque que dice // STEP 1 Dependency Import.
  3. Reemplaza esa línea con la siguiente implementación.
implementation("com.google.mediapipe:tasks-vision:latest.release")
  1. Haz clic en el botón Sync Now que aparece en el banner en la parte superior de Android Studio para descargar esta dependencia.

3. Crea un clasificador de dígitos de MediaPipe Tasks

Para el siguiente paso, completarás una clase que se encargará del trabajo pesado para tu clasificación de aprendizaje automático. Abre DigitClassifierHelper.kt para comenzar.

  1. Busca el comentario en la parte superior de la clase que dice // PASO 2 Crear objeto de escucha.
  2. Reemplaza esa línea con el siguiente código. Esto creará un objeto de escucha que se usará para pasar los resultados de la clase DigitClassifierHelper a cualquier lugar en el que se detecten esos resultados (en este caso, será tu clase DigitCanvasFragment, pero llegaremos allí pronto).
// STEP 2 Create listener

interface DigitClassifierListener {
    fun onError(error: String)
    fun onResults(
        results: ImageClassifierResult,
        inferenceTime: Long
    )
}
  1. También deberás aceptar un DigitClassifierListener como parámetro opcional para la clase:
class DigitClassifierHelper(
    val context: Context,
    val digitClassifierListener: DigitClassifierListener?
) {
  1. Bajando a la línea que dice // PASO 3 definir clasificador, agrega la siguiente línea para crear un marcador de posición para el ImageClassifier que se usará en esta app:

// PASO 3 define el clasificador

private var digitClassifier: ImageClassifier? = null
  1. Agrega la siguiente función en la que ves el comentario // PASO 4 configura el clasificador:
// STEP 4 set up classifier
private fun setupDigitClassifier() {

    val baseOptionsBuilder = BaseOptions.builder()
        .setModelAssetPath("mnist.tflite")

    // Describe additional options
    val optionsBuilder = ImageClassifierOptions.builder()
        .setRunningMode(RunningMode.IMAGE)
        .setBaseOptions(baseOptionsBuilder.build())

    try {
        digitClassifier =
            ImageClassifier.createFromOptions(
                context,
                optionsBuilder.build()
            )
    } catch (e: IllegalStateException) {
        digitClassifierListener?.onError(
            "Image classifier failed to initialize. See error logs for " +
                    "details"
        )
        Log.e(TAG, "MediaPipe failed to load model with error: " + e.message)
    }
}

Como hay algunas cosas que están sucediendo en la sección anterior, veamos partes más pequeñas para entender realmente lo que sucede.

val baseOptionsBuilder = BaseOptions.builder()
    .setModelAssetPath("mnist.tflite")

// Describe additional options
val optionsBuilder = ImageClassifierOptions.builder()
    .setRunningMode(RunningMode.IMAGE)
    .setBaseOptions(baseOptionsBuilder.build())

Este bloque definirá los parámetros que usa ImageClassifier. Esto incluye el modelo almacenado dentro de tu app (mnist.tflite) en BaseOptions y el RunningMode en ImageClassifierOptions, que en este caso es IMAGE, pero VIDEO y LIVE_STREAM son opciones adicionales disponibles. Otros parámetros disponibles son MaxResults, que limita al modelo para que muestre una cantidad máxima de resultados, y ScoreThreshold, que establece la confianza mínima que debe tener el modelo en un resultado antes de mostrarlo.

try {
    digitClassifier =
        ImageClassifier.createFromOptions(
            context,
            optionsBuilder.build()
        )
} catch (e: IllegalStateException) {
    digitClassifierListener?.onError(
        "Image classifier failed to initialize. See error logs for " +
                "details"
    )
    Log.e(TAG, "MediaPipe failed to load model with error: " + e.message)
}

Después de crear tus opciones de configuración, puedes crear tu nuevo ImageClassifier pasando un contexto y las opciones. Si algo sale mal con el proceso de inicialización, se mostrará un error a través de tu DigitClassifierListener.

  1. Como desearemos inicializar ImageClassifier antes de que se use, puedes agregar un bloque init para llamar a setupDigitClassifier().
init {
    setupDigitClassifier()
}
  1. Por último, desplázate hacia abajo hasta el comentario que dice // PASO 5 crea una función de clasificación y agrega el siguiente código. Esta función aceptará un Bitmap, que en este caso es el dígito dibujado, lo convertirá en un objeto MediaPipe Image (MPImage) y, luego, clasificará esa imagen con ImageClassifier, además de registrar el tiempo que tarda la inferencia, antes de mostrar esos resultados en DigitClassifierListener.
// STEP 5 create classify function
fun classify(image: Bitmap) {
    if (digitClassifier == null) {
        setupDigitClassifier()
    }

    // Convert the input Bitmap object to an MPImage object to run inference.
    // Rotating shouldn't be necessary because the text is being extracted from
    // a view that should always be correctly positioned.
    val mpImage = BitmapImageBuilder(image).build()

    // Inference time is the difference between the system time at the start and finish of the
    // process
    val startTime = SystemClock.uptimeMillis()

    // Run image classification using MediaPipe Image Classifier API
    digitClassifier?.classify(mpImage)?.also { classificationResults ->
        val inferenceTimeMs = SystemClock.uptimeMillis() - startTime
        digitClassifierListener?.onResults(classificationResults, inferenceTimeMs)
    }
}

Eso es todo con respecto al archivo auxiliar. En la siguiente sección, completarás los últimos pasos para comenzar a clasificar los números dibujados.

4. Ejecuta la inferencia con las tareas de MediaPipe

Para comenzar esta sección, abre la clase DigitCanvasFragment en Android Studio, que es donde se llevará a cabo todo el trabajo.

  1. En la parte inferior de este archivo, deberías ver un comentario que dice // PASO 6 Configura el objeto de escucha. Agregarás las funciones onResults() y onError() asociadas con el objeto de escucha aquí.
// STEP 6 Set up listener
override fun onError(error: String) {
    activity?.runOnUiThread {
        Toast.makeText(requireActivity(), error, Toast.LENGTH_SHORT).show()
        fragmentDigitCanvasBinding.tvResults.text = ""
    }
}

override fun onResults(
    results: ImageClassifierResult,
    inferenceTime: Long
) {
    activity?.runOnUiThread {
        fragmentDigitCanvasBinding.tvResults.text = results
            .classificationResult()
            .classifications().get(0)
            .categories().get(0)
            .categoryName()

        fragmentDigitCanvasBinding.tvInferenceTime.text = requireActivity()
            .getString(R.string.inference_time, inferenceTime.toString())
    }
}

onResults() es particularmente importante, ya que mostrará los resultados recibidos de ImageClassifier. Como esta devolución de llamada se activa desde un subproceso en segundo plano, también deberás ejecutar las actualizaciones de la IU en el subproceso de IU de Android.

  1. A medida que agregas nuevas funciones desde una interfaz en el paso anterior, también deberás agregar la declaración de implementación en la parte superior de la clase.
class DigitCanvasFragment : Fragment(), DigitClassifierHelper.DigitClassifierListener
  1. En la parte superior de la clase, deberías ver un comentario que dice // PASO 7a Inicializar clasificador. Aquí es donde colocarás la declaración del DigitClassifierHelper.
// STEP 7a Initialize classifier.
private lateinit var digitClassifierHelper: DigitClassifierHelper
  1. Desplázate hacia abajo a // PASO 7b Inicializar clasificador. Puedes inicializar digitClassifierHelper dentro de la función onViewCreated().
// STEP 7b Initialize classifier
// Initialize the digit classifier helper, which does all of the
// ML work. This uses the default values for the classifier.
digitClassifierHelper = DigitClassifierHelper(
    context = requireContext(), digitClassifierListener = this
)
  1. Para los últimos pasos, busca el comentario // PASO 8a*: clasificar* y agrega el siguiente código para llamar a una función nueva que agregarás en un momento. Este bloque de código activará la clasificación cuando levantes el dedo del área de trazado de la app.
// STEP 8a: classify
classifyDrawing()
  1. Por último, busca el comentario // STEP 8b classify para agregar la nueva función classifyDrawing(). Esto extraerá un mapa de bits del lienzo y, luego, lo pasará al DigitClassifierHelper para realizar una clasificación y recibir los resultados en la función de interfaz onResults().
// STEP 8b classify
private fun classifyDrawing() {
    val bitmap = fragmentDigitCanvasBinding.digitCanvas.getBitmap()
    digitClassifierHelper.classify(bitmap)
}

5. Implementa y prueba la app

Después de todo esto, deberías tener una app que funcione y pueda clasificar los dígitos dibujados en la pantalla. Implementa la app en un Android Emulator o en un dispositivo Android físico para probarla.

  1. Haz clic en Run ( 7e15a9c9e1620fe7.png) en la barra de herramientas de Android Studio para ejecutar la app.
  2. Dibuja un dígito en el panel de dibujo y comprueba si la app puede reconocerlo. Debe mostrar el dígito que el modelo cree que se dibujó, así como el tiempo que tardó en predecirlo.

7f37187f8f919638.gif

6. ¡Felicitaciones!

¡Lo lograste! En este codelab, aprendiste a agregar la clasificación de imágenes a una app para Android y, en particular, a clasificar dígitos dibujados a mano con el modelo MNIST.

Próximos pasos

  • Ahora que puedes clasificar dígitos, te recomendamos que entrenes tu propio modelo para clasificar letras dibujadas, animales o una cantidad infinita de otros elementos. Puedes encontrar la documentación para entrenar un modelo de clasificación de imágenes nuevo con MediaPipe Model Maker en la página developers.google.com/mediapipe.
  • Obtén información sobre otras tareas de MediaPipe disponibles para Android, como la detección de puntos de referencia facial, el reconocimiento de gestos y la clasificación de audio.

¡Esperamos con ansias todas las cosas geniales que hagas!