使用 MediaPipe Tasks 构建手写数字分类器 Android 应用

1. 简介

什么是 MediaPipe?

借助 MediaPipe 解决方案,您可以将机器学习 (ML) 解决方案应用于自己的应用。它提供了一个框架,用于配置预构建的处理流水线,以便为用户提供即时、富有吸引力的实用输出。您甚至可以使用 MediaPipe Model Maker 自定义这些解决方案,以更新默认模型。

图像分类是 MediaPipe 解决方案提供的若干机器学习视觉任务之一。MediaPipe Tasks 适用于 Android、iOS、Python(包括 Raspberry Pi!)和 Web。

在此 Codelab 中,您将首先构建一个 Android 应用,该应用可让您在屏幕上画出数字,然后添加相应功能,将这些绘制的数字归类为单个值(范围为 0 到 9)。

学习内容

  • 如何使用 MediaPipe Tasks 在 Android 应用中整合图片分类任务。

所需条件

  • 已安装版本的 Android Studio(此 Codelab 是使用 Android Studio Giraffe 编写和测试的)。
  • 用于运行应用的 Android 设备或模拟器。
  • 具备 Android 开发方面的基础知识(这不是“Hello World”,但也很浅显易懂!)。

2. 将 MediaPipe Tasks 添加到 Android 应用

下载 Android 起始应用

此 Codelab 将从一个预制示例开始,允许您在屏幕上绘图。您可以点击此处,在官方 MediaPipe 示例代码库中找到启动应用。克隆代码库或依次点击“代码”>“下载”以下载 ZIP 文件下载 ZIP 文件。

将应用导入 Android Studio

  1. 打开 Android Studio。
  2. Welcome to Android Studio 屏幕中,选择右上角的 Open

a0b5b070b802e4ea.png

  1. 转到您克隆或下载代码库的位置,然后打开 codelabs/digitcategories/android/start 目录
  2. 点击 Android Studio 右上角的绿色 run 箭头 ( 7e15a9c9e1620fe7),验证所有内容是否都能正确打开
  3. 您应该会看到该应用打开,其中显示可在其中进行绘图的黑屏,以及用于重置该屏幕的清除按钮。虽然您可以在该屏幕上绘图,但它没有什么其他作用,所以我们现在就开始解决这个问题。

11a0f6fe021fdc92.jpeg

型号

首次运行应用时,您可能会注意到名为 mnist.tflite 的文件已下载并存储在您应用的 assets 目录中。为简单起见,我们已经采用一个已知模型 MNIST,该模型会对数字进行分类,并通过在项目中使用 download_models.gradle 脚本将其添加到应用中。如果您决定训练自己的自定义模型(例如用于手写字母的模型),则应移除 download_models.gradle 文件,删除应用级 build.gradle 文件中对它的引用,然后在代码中(特别是 DigitClassifierHelper.kt 文件中)更改模型名称。

更新 build.gradle

您需要先导入库,然后才能开始使用 MediaPipe Tasks。

  1. 打开位于 app 模块中的 build.gradle 文件,然后向下滚动到 dependencies 代码块。
  2. 您应该会在该代码块的底部看到一条内容为 // STEP 1 Dependency Import 的注释。
  3. 将该行替换为以下实现
implementation("com.google.mediapipe:tasks-vision:latest.release")
  1. 点击 Android Studio 顶部横幅中显示的 Sync Now 按钮,即可下载此依赖项。

3. 创建 MediaPipe Tasks 数字分类器帮助程序

在下一步中,您将填写一个类,该类会为机器学习分类完成繁重的工作。打开 DigitClassifierHelper.kt 我们开始吧!

  1. 在类顶部找到内容为 // STEP 2 Create listener 的注释
  2. 将该行替换为以下代码。这将创建一个监听器,用于从 DigitClassifierHelper 类将结果传递回监听这些结果的任何位置(在本例中,它是您的 DigitCanvasFragment 类,但我们很快会介绍它)
// STEP 2 Create listener

interface DigitClassifierListener {
    fun onError(error: String)
    fun onResults(
        results: ImageClassifierResult,
        inferenceTime: Long
    )
}
  1. 您还需要接受 DigitClassifierListener 作为类的可选参数:
class DigitClassifierHelper(
    val context: Context,
    val digitClassifierListener: DigitClassifierListener?
) {
  1. 找到 // STEP 3define Classification 这一行,添加以下行,为将用于此应用的 ImageClassifier 创建占位符:

// 第 3 步定义分类器

private var digitClassifier: ImageClassifier? = null
  1. 在显示 // STEP 4 set up Classification 注释的位置添加以下函数:
// STEP 4 set up classifier
private fun setupDigitClassifier() {

    val baseOptionsBuilder = BaseOptions.builder()
        .setModelAssetPath("mnist.tflite")

    // Describe additional options
    val optionsBuilder = ImageClassifierOptions.builder()
        .setRunningMode(RunningMode.IMAGE)
        .setBaseOptions(baseOptionsBuilder.build())

    try {
        digitClassifier =
            ImageClassifier.createFromOptions(
                context,
                optionsBuilder.build()
            )
    } catch (e: IllegalStateException) {
        digitClassifierListener?.onError(
            "Image classifier failed to initialize. See error logs for " +
                    "details"
        )
        Log.e(TAG, "MediaPipe failed to load model with error: " + e.message)
    }
}

上一部分发生了一些事情,让我们看看小部分内容,真正了解发生了什么。

val baseOptionsBuilder = BaseOptions.builder()
    .setModelAssetPath("mnist.tflite")

// Describe additional options
val optionsBuilder = ImageClassifierOptions.builder()
    .setRunningMode(RunningMode.IMAGE)
    .setBaseOptions(baseOptionsBuilder.build())

此代码块将定义 ImageClassifier 使用的参数。这包括存储在应用 (mnist.tflite) 中的 BaseOptions 和 RunningMode 下的 ImageClassifierOptions(在本例中为 IMAGE),但 VIDEO 和 LIVE_STREAM 是其他可用选项。其他可用参数包括 MaxResults 和 ScoreThreshold,前者限制模型返回最大数量的结果,后者用于设置模型在返回结果之前所需的最低置信度。

try {
    digitClassifier =
        ImageClassifier.createFromOptions(
            context,
            optionsBuilder.build()
        )
} catch (e: IllegalStateException) {
    digitClassifierListener?.onError(
        "Image classifier failed to initialize. See error logs for " +
                "details"
    )
    Log.e(TAG, "MediaPipe failed to load model with error: " + e.message)
}

创建配置选项后,您可以通过传入上下文和选项来创建新的 ImageClassifier。如果该初始化过程出现问题,系统会通过您的 DigitClassifierListener 返回一个错误。

  1. 由于我们需要先初始化 ImageClassifier,然后才能使用它,因此可以添加一个 init 块来调用 setupDigitClassifier()。
init {
    setupDigitClassifier()
}
  1. 最后,向下滚动到内容为 // STEP 5 create classify function 的注释,并添加以下代码。此函数将接受一个位图(在本例中是绘制的数字),将其转换为 MediaPipe Image 对象 (MPImage),然后使用 ImageClassifier 对该图像进行分类,并记录推理所用的时间,最后通过 DigitClassifierListener 返回这些结果。
// STEP 5 create classify function
fun classify(image: Bitmap) {
    if (digitClassifier == null) {
        setupDigitClassifier()
    }

    // Convert the input Bitmap object to an MPImage object to run inference.
    // Rotating shouldn't be necessary because the text is being extracted from
    // a view that should always be correctly positioned.
    val mpImage = BitmapImageBuilder(image).build()

    // Inference time is the difference between the system time at the start and finish of the
    // process
    val startTime = SystemClock.uptimeMillis()

    // Run image classification using MediaPipe Image Classifier API
    digitClassifier?.classify(mpImage)?.also { classificationResults ->
        val inferenceTimeMs = SystemClock.uptimeMillis() - startTime
        digitClassifierListener?.onResults(classificationResults, inferenceTimeMs)
    }
}

帮助程序文件到此结束!在下一部分中,您将填写最后几步,开始对绘制的数字进行分类。

4. 使用 MediaPipe Tasks 运行推断

您可以通过在 Android Studio 中打开 DigitCanvasFragment 类来开始此部分,所有工作都将在此完成。

  1. 在此文件的最底部,您应该会看到一条内容为 // STEP 6 Set up listener 的注释。您将在此处添加与监听器相关联的 onResults() 和 onError() 函数。
// STEP 6 Set up listener
override fun onError(error: String) {
    activity?.runOnUiThread {
        Toast.makeText(requireActivity(), error, Toast.LENGTH_SHORT).show()
        fragmentDigitCanvasBinding.tvResults.text = ""
    }
}

override fun onResults(
    results: ImageClassifierResult,
    inferenceTime: Long
) {
    activity?.runOnUiThread {
        fragmentDigitCanvasBinding.tvResults.text = results
            .classificationResult()
            .classifications().get(0)
            .categories().get(0)
            .categoryName()

        fragmentDigitCanvasBinding.tvInferenceTime.text = requireActivity()
            .getString(R.string.inference_time, inferenceTime.toString())
    }
}

onResults() 尤为重要,因为它会显示从 ImageClassifier 接收的结果。由于此回调是从后台线程触发的,因此您还需要在 Android 的界面线程上运行界面更新。

  1. 在上述步骤中通过接口添加新函数时,您还需要在类的顶部添加实现声明。
class DigitCanvasFragment : Fragment(), DigitClassifierHelper.DigitClassifierListener
  1. 在类的顶部,您应该会看到一条内容为 // STEP 7a Initialize Classification 的注释。您将在此处放置 DigitClassifierHelper 的声明。
// STEP 7a Initialize classifier.
private lateinit var digitClassifierHelper: DigitClassifierHelper
  1. 向下移至 // STEP 7b 初始化分类器,您可以在 onViewCreated() 函数中初始化 digitClassifierHelper。
// STEP 7b Initialize classifier
// Initialize the digit classifier helper, which does all of the
// ML work. This uses the default values for the classifier.
digitClassifierHelper = DigitClassifierHelper(
    context = requireContext(), digitClassifierListener = this
)
  1. 对于最后几步,请找到 // STEP 8a*: classify* 注释,然后添加以下代码以调用稍后将添加的新函数。当您从应用的绘图区域抬起手指时,此代码块就会触发分类。
// STEP 8a: classify
classifyDrawing()
  1. 最后,查找 // STEP 8b classify 注释以添加新的 classifyDrawing() 函数。这会从画布中提取一个位图,然后将其传递给 DigitClassifierHelper 以执行分类,以便在 onResults() 接口函数中接收结果。
// STEP 8b classify
private fun classifyDrawing() {
    val bitmap = fragmentDigitCanvasBinding.digitCanvas.getBitmap()
    digitClassifierHelper.classify(bitmap)
}

5. 部署并测试应用

完成以上所有操作后,您应该会得到一个可正常运行的应用,能够对屏幕上绘制的数字进行分类!接下来,将应用部署到 Android 模拟器或 Android 实体设备,以对其进行测试。

  1. 点击 Android Studio 工具栏中的“Run”(7e15a9c9e1620fe7) 以运行该应用。
  2. 在绘图板上画出任意数字,看看应用能否识别该数字。它应该同时显示模型认为已绘制的数字,以及预测该数字所需的时间。

7f37187f8f919638.gif

6. 恭喜!

大功告成!在此 Codelab 中,您学习了如何向 Android 应用添加图片分类,特别是如何使用 MNIST 模型对手绘数字进行分类。

后续步骤

  • 现在,你可以对数字进行分类了,你可能需要训练自己的模型来对绘制的字母、动物或无数其他项目进行分类。您可以在 developers.google.com/mediapipe 页面上找到有关使用 MediaPipe Model Maker 训练新图片分类模型的文档。
  • 了解适用于 Android 的其他 MediaPipe 任务,包括人脸特征点检测、手势识别和音频分类。

我们期待着您创作出的所有炫酷产品!