使用 MediaPipe Tasks 构建手写数字分类器 Android 应用

1. 简介

什么是 MediaPipe?

借助 MediaPipe 解决方案,您可以将机器学习 (ML) 解决方案应用于您的应用。它提供了一个框架,可让您配置预构建的处理流水线,用于为用户提供即时的、有吸引力的有用输出。您甚至可以使用 MediaPipe Model Maker 自定义这些解决方案,以更新默认模型。

图像分类是 MediaPipe Solutions 提供的多项机器学习视觉任务之一。MediaPipe Tasks 适用于 Android、iOS、Python(包括 Raspberry Pi!)和 Web。

在此 Codelab 中,您将从一个 Android 应用开始,该应用可让您在屏幕上绘制数字,然后您将添加功能来将这些绘制的数字分类为 0 到 9 之间的单个值。

学习内容

  • 如何使用 MediaPipe Tasks 在 Android 应用中集成图像分类任务。

所需条件

  • 安装好的 Android Studio 版本(此 Codelab 是使用 Android Studio Giraffe 编写和测试的)。
  • 用于运行应用的 Android 设备或模拟器。
  • 具备 Android 开发基础知识(这不是“Hello World”,但也不算太难!)。

2. 向 Android 应用添加 MediaPipe 任务

下载 Android 起始应用

本 Codelab 将从一个预先创建的示例开始,该示例可让您在屏幕上绘制图形。您可以在官方 MediaPipe 示例代码库的此处找到该启动应用。点击“Code”(代码)>“Download ZIP”(下载 ZIP 文件),克隆代码库或下载 ZIP 文件。

将应用导入 Android Studio

  1. 打开 Android Studio。
  2. Welcome to Android Studio 屏幕中,选择右上角的 Open

a0b5b070b802e4ea.png

  1. 前往您克隆或下载代码库的位置,然后打开 codelabs/digitclassifier/android/start 目录
  2. 点击 Android Studio 右上角的绿色运行箭头 ( 7e15a9c9e1620fe7.png),验证所有内容是否已正确打开
  3. 您应该会看到应用打开后显示一个可供绘图的黑屏,以及用于重置该屏幕的清除按钮。虽然您可以在该界面上绘制,但它无法执行其他操作,因此我们现在就开始解决这个问题。

11a0f6fe021fdc92.jpeg

型号

首次运行应用时,您可能会注意到系统会下载并将名为 mnist.tflite 的文件存储在应用的 assets 目录中。为简单起见,我们已使用项目中的 download_models.gradle 脚本,将用于分类数字的已知模型 MNIST 添加到应用中。如果您决定训练自己的自定义模型(例如手写字母模型),则应移除 download_models.gradle 文件,删除应用级 build.gradle 文件中对它的引用,并在稍后的代码(具体而言,在 DigitClassifierHelper.kt 文件中)更改模型的名称。

更新 build.gradle

您需要先导入该库,然后才能开始使用 MediaPipe Tasks。

  1. 打开位于 app 模块中的 build.gradle 文件,然后向下滚动到 dependencies 代码块。
  2. 您应该会在该代码块底部看到一条注释,其中包含 // STEP 1 Dependency Import
  3. 将该行替换为以下实现
implementation("com.google.mediapipe:tasks-vision:latest.release")
  1. 点击 Android Studio 顶部横幅中显示的 Sync Now 按钮,下载此依赖项。

3. 创建 MediaPipe Tasks 数字分类器帮助程序

在下一步中,您将填写一个类,该类将为机器学习分类完成繁重工作。打开 DigitClassifierHelper.kt,开始吧!

  1. 找到类顶部的注释,其中包含 // STEP 2 Create listener
  2. 将该行替换为以下代码。这将创建一个监听器,用于将 DigitClassifierHelper 类中的结果传回到监听这些结果的任何位置(在本例中,它将是 DigitCanvasFragment 类,但我们很快就会介绍到)
// STEP 2 Create listener

interface DigitClassifierListener {
    fun onError(error: String)
    fun onResults(
        results: ImageClassifierResult,
        inferenceTime: Long
    )
}
  1. 您还需要接受 DigitClassifierListener 作为该类的可选参数:
class DigitClassifierHelper(
    val context: Context,
    val digitClassifierListener: DigitClassifierListener?
) {
  1. 向下滚动到显示 // STEP 3 define classifier 的代码行,然后添加以下代码行,为将用于此应用的 ImageClassifier 创建占位符:

// STEP 3 define classifier

private var digitClassifier: ImageClassifier? = null
  1. 在您看到注释 // STEP 4 set up classifier 的位置,添加以下函数:
// STEP 4 set up classifier
private fun setupDigitClassifier() {

    val baseOptionsBuilder = BaseOptions.builder()
        .setModelAssetPath("mnist.tflite")

    // Describe additional options
    val optionsBuilder = ImageClassifierOptions.builder()
        .setRunningMode(RunningMode.IMAGE)
        .setBaseOptions(baseOptionsBuilder.build())

    try {
        digitClassifier =
            ImageClassifier.createFromOptions(
                context,
                optionsBuilder.build()
            )
    } catch (e: IllegalStateException) {
        digitClassifierListener?.onError(
            "Image classifier failed to initialize. See error logs for " +
                    "details"
        )
        Log.e(TAG, "MediaPipe failed to load model with error: " + e.message)
    }
}

上述部分涉及到一些内容,因此我们来看看较小的部分,以便真正了解发生了什么。

val baseOptionsBuilder = BaseOptions.builder()
    .setModelAssetPath("mnist.tflite")

// Describe additional options
val optionsBuilder = ImageClassifierOptions.builder()
    .setRunningMode(RunningMode.IMAGE)
    .setBaseOptions(baseOptionsBuilder.build())

此代码块将定义 ImageClassifier 使用的参数。这包括 BaseOptions 下存储在应用内的模型 (mnist.tflite) 和 ImageClassifierOptions 下的 RunningMode,在本例中为 IMAGE,但 VIDEO 和 LIVE_STREAM 是其他可用选项。其他可用参数包括 MaxResults,用于限制模型返回的结果数上限,以及 ScoreThreshold,用于设置模型在返回结果之前对结果的最低置信度。

try {
    digitClassifier =
        ImageClassifier.createFromOptions(
            context,
            optionsBuilder.build()
        )
} catch (e: IllegalStateException) {
    digitClassifierListener?.onError(
        "Image classifier failed to initialize. See error logs for " +
                "details"
    )
    Log.e(TAG, "MediaPipe failed to load model with error: " + e.message)
}

创建配置选项后,您可以通过传入上下文和选项来创建新的 ImageClassifier。如果该初始化流程出现问题,系统会通过 DigitClassifierListener 返回错误。

  1. 由于我们希望在使用 ImageClassifier 之前对其进行初始化,因此您可以添加一个 init 块来调用 setupDigitClassifier()。
init {
    setupDigitClassifier()
}
  1. 最后,向下滚动到显示“// STEP 5 create classify function”的注释,然后添加以下代码。此函数将接受 Bitmap(在本例中为所绘制的数字),将其转换为 MediaPipe Image 对象 (MPImage),然后使用 ImageClassifier 对该图片进行分类,并记录推理所需的时间,最后通过 DigitClassifierListener 返回这些结果。
// STEP 5 create classify function
fun classify(image: Bitmap) {
    if (digitClassifier == null) {
        setupDigitClassifier()
    }

    // Convert the input Bitmap object to an MPImage object to run inference.
    // Rotating shouldn't be necessary because the text is being extracted from
    // a view that should always be correctly positioned.
    val mpImage = BitmapImageBuilder(image).build()

    // Inference time is the difference between the system time at the start and finish of the
    // process
    val startTime = SystemClock.uptimeMillis()

    // Run image classification using MediaPipe Image Classifier API
    digitClassifier?.classify(mpImage)?.also { classificationResults ->
        val inferenceTimeMs = SystemClock.uptimeMillis() - startTime
        digitClassifierListener?.onResults(classificationResults, inferenceTimeMs)
    }
}

辅助文件就到此为止!在下一部分中,您将完成最后的步骤,开始对所抽出的数字进行分类。

4. 使用 MediaPipe Tasks 运行推理

如需开始本部分,请在 Android Studio 中打开 DigitCanvasFragment 类,所有工作都将在此类中完成。

  1. 在此文件的底部,您应该会看到一条注释,内容为 // STEP 6 Set up listener。您将在此处添加与监听器关联的 onResults() 和 onError() 函数。
// STEP 6 Set up listener
override fun onError(error: String) {
    activity?.runOnUiThread {
        Toast.makeText(requireActivity(), error, Toast.LENGTH_SHORT).show()
        fragmentDigitCanvasBinding.tvResults.text = ""
    }
}

override fun onResults(
    results: ImageClassifierResult,
    inferenceTime: Long
) {
    activity?.runOnUiThread {
        fragmentDigitCanvasBinding.tvResults.text = results
            .classificationResult()
            .classifications().get(0)
            .categories().get(0)
            .categoryName()

        fragmentDigitCanvasBinding.tvInferenceTime.text = requireActivity()
            .getString(R.string.inference_time, inferenceTime.toString())
    }
}

onResults() 尤为重要,因为它会显示从 ImageClassifier 收到的结果。由于此回调是从后台线程触发的,因此您还需要在 Android 的界面线程上运行界面更新。

  1. 由于您在上一步中通过接口添加了新函数,因此您还需要在类顶部添加实现声明。
class DigitCanvasFragment : Fragment(), DigitClassifierHelper.DigitClassifierListener
  1. 在类顶部,您应该会看到一条注释,内容为 // STEP 7a Initialize classifier。您将在此处放置 DigitClassifierHelper 的声明。
// STEP 7a Initialize classifier.
private lateinit var digitClassifierHelper: DigitClassifierHelper
  1. 向下滚动到 // STEP 7b Initialize classifier(初始化分类器),您可以在 onViewCreated() 函数中初始化 digitClassifierHelper。
// STEP 7b Initialize classifier
// Initialize the digit classifier helper, which does all of the
// ML work. This uses the default values for the classifier.
digitClassifierHelper = DigitClassifierHelper(
    context = requireContext(), digitClassifierListener = this
)
  1. 在最后的步骤中,找到注释 // STEP 8a*: classify*,然后添加以下代码以调用您稍后要添加的新函数。当您从应用中的绘制区域移开手指时,此代码块将触发分类。
// STEP 8a: classify
classifyDrawing()
  1. 最后,找到注释 // STEP 8b classify,以添加新的 classifyDrawing() 函数。这将从画布中提取位图,然后将其传递给 DigitClassifierHelper 以执行分类,以便在 onResults() 接口函数中接收结果。
// STEP 8b classify
private fun classifyDrawing() {
    val bitmap = fragmentDigitCanvasBinding.digitCanvas.getBitmap()
    digitClassifierHelper.classify(bitmap)
}

5. 部署和测试应用

完成所有这些操作后,您应该会得到一个可以对屏幕上绘制的数字进行分类的应用!接下来,将应用部署到 Android 模拟器或 Android 实体设备以进行测试。

  1. 点击 Android Studio 工具栏中的“Run”(运行)图标 ( 7e15a9c9e1620fe7.png) 以运行应用。
  2. 在绘图板上画出任意数字,看看应用能否识别出来。它应显示模型认为所画的数字,以及预测该数字所花的时间。

7f37187f8f919638.gif

6. 恭喜!

大功告成!在此 Codelab 中,您学习了如何向 Android 应用添加图片分类功能,具体而言,就是如何使用 MNIST 模型对手写数字进行分类。

后续步骤

  • 现在,您已经可以对数字进行分类了,接下来您可能需要训练自己的模型来分类手写的字母、动物或无数其他内容。如需了解如何使用 MediaPipe Model Maker 训练新的图片分类模型,请访问 developers.google.com/mediapipe 页面。
  • 了解适用于 Android 的其他 MediaPipe Tasks,包括人脸特征点检测、手势识别和音频分类。

我们期待看到您创作的所有精彩内容!