使用 TensorFlow Lite (Android) 构建和部署自定义对象检测模型

1. 准备工作

在此 Codelab 中,您将学习如何使用一组训练图片来训练一个TFLite 模型制作工具,然后使用以下代码将您的模型部署到 Android 应用TFLite 任务库。您将学习以下内容:

  • 构建一款 Android 应用,用于检测餐点图片中的成分。
  • 集成 TFLite 预训练对象检测模型,看看该模型能够检测出的限制。
  • 使用名为 salad 和 TFLite Model Maker 的自定义数据集训练自定义对象检测模型,以检测膳食的食材/组成部分。
  • 使用 TFLite Task 库将自定义模型部署到 Android 应用。

最后,您需要创建与下图类似的图片:

b9705235366ae162.png

前提条件

此 Codelab 面向希望获得机器学习相关经验的移动开发者。您应熟悉以下内容/操作:

  • 使用 Kotlin 和 Android Studio 进行 Android 开发
  • 基本 Python 语法

学习内容

  • 如何使用 TFLite Model Maker 训练自定义对象检测模型。
  • 如何使用 TFLite Task 库部署 TFLite 对象检测模型。

所需物品

  • 最新版 Android Studio(v4.2 及更高版本)
  • Android Studio 模拟器或实体 Android 设备
  • 示例代码
  • 使用 Kotlin 进行 Android 开发的基础知识

2. 对象检测

对象检测是一组计算机视觉任务,可以检测和定位数字图片中的对象。在给定图片或视频串流后,对象检测模型可以识别其中可能存在哪组已知对象,并提供有关对象在图像中所处位置的信息。

TensorFlow 提供已预先训练且针对移动设备进行了优化的模型,可以检测汽车、橘子等常见对象。您只需使用几行代码,便可将这些预训练模型集成到移动应用中。但是,您可能希望或需要检测更多独特或另类类别的对象。这需要收集您自己的训练图片,然后训练并部署您自己的对象检测模型。

TensorFlow Lite

TensorFlow Lite 是一个跨平台机器学习库,针对在边缘设备(包括 Android 和 iOS 移动设备)上运行机器学习模型进行了优化。

TensorFlow Lite 实际上是机器学习套件内运行机器学习模型的核心引擎。TensorFlow Lite 生态系统中的以下两个组件可助您轻松地在移动设备上训练和部署机器学习模型:

  • Model Maker 是一个 Python 库,只需几行代码即可让您轻松地使用自己的数据训练 TensorFlow Lite 模型,而无需具备机器学习专业知识。
  • Task 库是一个跨平台库,通过在您的移动应用中只需几行代码即可轻松部署 TensorFlow Lite 模型。

此 Codelab 重点介绍 TFLite。我们并未解释与 TFLite 和对象检测无关的概念和代码块,而是供您直接复制和粘贴。

3.进行设置

下载代码

点击下面的链接可下载本 Codelab 的所有代码:

解压下载的 ZIP 文件。此操作会解压缩一个根文件夹 (odml-pathways-main),其中包含您需要的所有资源。对于本 Codelab,您只需要 object-detection/codelab2/android 子目录中的源代码。

object-detection/codelab2/android 代码库中的 android 子目录包含两个目录:

  • android_studio_folder.pngstarter - 本 Codelab 的起始代码。
  • android_studio_folder.pngfinal - 完成后的示例应用的完整代码。

导入起始应用

首先,将入门应用导入 Android Studio。

  1. 打开 Android Studio,然后选择 Import Project(Gradle、Eclipse ADT 等)
  2. 打开您之前下载的源代码中的 starter 文件夹。

7c0f27882a2698ac.png

为确保所有依赖项都可供您的应用使用,您应该在导入过程完成后将项目与 gradle 文件同步。

  1. 从 Android Studio 工具栏中选择 Sync Project with Gradle Files ( b451ab2d04d835f9.png)。导入 starter/app/build.gradle

运行起始应用

现在,您已将项目导入 Android Studio,可以首次运行应用了。

通过 USB 将 Android 设备连接到计算机,或启动 Android Studio 模拟器,然后点击 Android Studio 工具栏中的 Run ( execute.png)。

4.了解起始应用

为了简化此 Codelab 并使其专注于机器学习开发工作,入门应用包含的一些样板代码会为您执行一些操作:

  • 它可以使用设备的相机拍摄照片。
  • 其中包含一些库存图片,您可以在 Android 模拟器上试用对象检测。
  • 它提供了一种在输入位图上绘制对象检测结果的便捷方法。

您将主要在应用框架中与以下方法交互:

  • fun runObjectDetection(bitmap: Bitmap)当您选择预设图片或拍照时,系统会调用此方法。bitmap 是对象检测的输入图片。在此 Codelab 的后面部分,您将向此方法添加对象检测代码。
  • data class DetectionResult(val boundingBoxes: Rect, val text: String)。这是一个数据类,表示要直观呈现的对象检测结果。boundingBoxes 是对象所在的矩形,text 是与对象的边界框一起显示的检测结果字符串。
  • fun drawDetectionResult(bitmap: Bitmap, detectionResults: List<DetectionResult>): Bitmap 此方法会在输入 bitmap 上绘制 detectionResults 的对象检测结果,并返回修改后的副本。

以下是 drawDetectionResult 实用程序方法的输出示例。

f6b1e6dad726e129.png

5. 添加设备端对象检测

现在,您将通过将可以检测常见对象的预训练 TFLite 模型集成到入门应用中来构建原型。

下载预训练的 TFLite 对象检测模型

您可以使用 TensorFlow Hub 中的多种对象检测器模型。在此 Codelab 中,您将下载 EfficientDet-Lite 对象检测模型,该模型使用 COCO 2017 数据集训练、针对 TFLite 进行了优化,并针对移动 CPU、GPU 的性能进行了优化和 Edge TPU。

接下来,使用 TFLite 任务库将预训练的 TFLite 模型集成到起始应用中。借助 TFLite 任务库,您可以轻松地将针对移动设备进行了优化的机器学习模型集成到移动应用中。它支持许多常见的机器学习用例,包括对象检测、图像分类和文本分类。只需几行代码即可加载 TFLite 模型并运行它。

将模型添加到起始应用

  1. 将您刚刚下载的模型复制到起始应用的 assets 文件夹中。您可以在 Android Studio 的 Project 导航面板中找到该文件夹。

c2609599b7d22641.png

  1. 将文件命名为 model.tflite

c83e9397177c4561.png

更新 Gradle 文件任务库依赖项

转到 app/build.gradle 文件,并将以下代码行添加到 dependencies 配置中:

implementation 'org.tensorflow:tensorflow-lite-task-vision:0.3.1'

将项目与 Gradle 文件同步

为了确保所有依赖项对您的应用都可用,此时应将项目与 gradle 文件同步。从 Android Studio 工具栏中选择 Sync Project with Gradle Files ( b451ab2d04d835f9.png)。

(如果此按钮已停用,请确保您仅导入 starter/app/build.gradle,而不是导入整个代码库)。

对图片设置并运行设备端对象检测

只需 3 个简单步骤即可加载和运行对象检测模型,包括 3 个 API:

  • 准备图片 / 数据流:TensorImage
  • 创建检测器对象:ObjectDetector
  • 连接上述 2 个对象:detect(image)

您可以在 MainActivity.kt 文件内的 runObjectDetection(bitmap: Bitmap) 函数中实现这些函数。

/**
* TFLite Object Detection Function
*/
private fun runObjectDetection(bitmap: Bitmap) {
    //TODO: Add object detection code here
}

目前,该函数是空的。请转到以下步骤来实现 TFLite 对象检测器。在此过程中,Android Studio 将提示您添加必要的导入:

  • org.tensorflow.lite.support.image.TensorImage
  • org.tensorflow.lite.task.vision.detector.ObjectDetector

创建图片对象

您将在此 Codelab 中使用的图片将来自设备端的相机,或者来自您在应用界面上选择的预设图片。输入图像会被解码为 Bitmap 格式并传递给 runObjectDetection 方法。

TFLite 提供一个简单的 API,用于从 Bitmap 创建 TensorImage。将以下代码添加到 runObjectDetection(bitmap:Bitmap) 的顶部:

// Step 1: create TFLite's TensorImage object
val image = TensorImage.fromBitmap(bitmap)

创建检测器实例

TFLite 任务库遵循构建器设计模式。将配置传递给构建器,然后从中获取检测器。有几个配置选项可用,包括用于调整对象检测器的灵敏度的选项:

  • 结果数上限(模型应检测的最大对象数)
  • 分数阈值(对象检测器应返回已检测对象的置信度)
  • 标签许可名单/拒绝名单(允许/拒绝预定义列表中的对象)

通过指定 TFLite 模型文件名和配置选项,初始化对象检测器实例:

// Step 2: Initialize the detector object
val options = ObjectDetector.ObjectDetectorOptions.builder()
    .setMaxResults(5)
    .setScoreThreshold(0.5f)
    .build()
val detector = ObjectDetector.createFromFileAndOptions(
    this, // the application context
    "model.tflite", // must be same as the filename in assets folder
    options
)

将图片馈送到检测器

将以下代码添加到 fun runObjectDetection(bitmap:Bitmap)。这会将您的图片提供给检测器。

// Step 3: feed given image to the model and print the detection result
val results = detector.detect(image)

完成后,检测器会返回 Detection 列表,其中每个列表都包含模型在图片中发现的对象的相关信息。各个对象的描述如下:

  • boundingBox:声明某个对象及其在图片中的位置的矩形
  • categories:该对象的类型以及模型对检测结果的信心。该模型会返回多个类别,最可靠的类别是第一个。
  • label:对象类别的名称。
  • classificationConfidence:介于 0.0 和 1.0 之间的浮点数,1.0 表示 100%

将以下代码添加到 fun runObjectDetection(bitmap:Bitmap)。这将调用一个方法,将对象检测结果输出到 Logcat。

// Step 4: Parse the detection result and show it
debugPrint(results)

然后,将此 debugPrint() 方法添加到 MainActivity 类:

private fun debugPrint(results : List<Detection>) {
    for ((i, obj) in results.withIndex()) {
        val box = obj.boundingBox

        Log.d(TAG, "Detected object: ${i} ")
        Log.d(TAG, "  boundingBox: (${box.left}, ${box.top}) - (${box.right},${box.bottom})")

        for ((j, category) in obj.categories.withIndex()) {
            Log.d(TAG, "    Label $j: ${category.label}")
            val confidence: Int = category.score.times(100).toInt()
            Log.d(TAG, "    Confidence: ${confidence}%")
        }
    }
}

现在,您的对象检测器已就绪!点击 Android Studio 工具栏中的 Run 图标 ( execute.png) 来编译并运行应用。当应用在设备上显示后,点按任意预设图像来启动对象检测器。然后查看 IDE 内的 Logcat 窗口*(* 16bd6ea224cf8cf1.png*)*,您应该会看到类似下图的内容:

D/TFLite-ODT: Detected object: 0
D/TFLite-ODT:   boundingBox: (0.0, 15.0) - (2223.0,1645.0)
D/TFLite-ODT:     Label 0: dining table
D/TFLite-ODT:     Confidence: 77%
D/TFLite-ODT: Detected object: 1
D/TFLite-ODT:   boundingBox: (702.0, 3.0) - (1234.0,797.0)
D/TFLite-ODT:     Label 0: cup
D/TFLite-ODT:     Confidence: 69%

这表示检测器检测到了 2 个对象。第一个是:

  • 对象位于 (0, 15) – (2223, 1645) 的矩形内
  • 标签为餐桌
  • 此模型确信第 1 个是餐桌 (77%)

从技术上讲,这是 TFLite Task 库正常运行所需的全部功能:您目前已获得所有权限! 恭喜

但在界面方面,您仍然处于起步阶段。现在,您必须通过对检测到的结果进行后处理,来利用界面上检测到的结果。

6.在输入图片上绘制检测结果

在前面的步骤中,您将检测结果输出到 logcat 中:简单快速。在此步骤中,您将利用在起始应用中为您实现的实用程序方法,以便:

  • 在图片上绘制边界框
  • 在边界框内绘制类别名称和置信度百分比
  1. debugPrint(results) 调用替换为以下代码段:
val resultToDisplay = results.map {
    // Get the top-1 category and craft the display text
    val category = it.categories.first()
    val text = "${category.label}, ${category.score.times(100).toInt()}%"

    // Create a data object to display the detection result
    DetectionResult(it.boundingBox, text)
}
// Draw the detection result on the bitmap and show it.
val imgWithResult = drawDetectionResult(bitmap, resultToDisplay)
runOnUiThread {
    inputImageView.setImageBitmap(imgWithResult)
}
  1. 现在,点击 Android Studio 工具栏中的 Run ( execute.png)。
  2. 应用加载后,点按其中一张预设图片即可查看检测结果。

想要使用自己的照片吗?点按拍照按钮,拍摄您周围物体的一些照片。

8b024362b15096a6.png

7. 训练自定义对象检测模型

在上一步中,您将预训练的 TFLite 对象检测模型集成到了 Android 应用中,发现它能够检测到示例图像中的常见对象(例如碗或餐桌)。但是,您的目标是检测图片中菜肴的成分,因此常规对象检测不适合您的用例。您希望使用包含要检测的成分的训练数据集训练自定义对象检测模型。

这是一个包含图片和标签的数据集,可用于练习训练您自己的自定义模型。它使用 Open Images Dataset V4 中的图片创建。

Colaboratory

接下来,前往 Google Colab 训练自定义模型。

训练自定义模型大约需要 30 分钟。

如果您很着急,可以在提供的数据集上下载我们为您预训练的模型,然后继续执行下一步。

8. 将自定义 TFLite 模型集成到 Android 应用中

现在,您已经训练了沙拉检测模型,可以集成该模型,并让您的应用从常见的对象检测器(具体而言,就是沙拉检测器)进行转换。

  1. 将 salad TFLite 模型复制到 assets 文件夹。将新模型命名为 salad.tflite

91e8d37c4f78eddb.png

  1. 打开 MainActivity.kt 文件并找到 ObjectDetector 初始化代码。
  2. 将 EfficientDet-Lite 模型 (model.tflite) 替换为沙拉模型 (salad.tflite)
val detector = ObjectDetector.createFromFileAndOptions(
    this, // the application context
    "salad.tflite", // must be same as the filename in assets folder
    options
)
  1. 点击 Android Studio 工具栏中的 Run ( execute.png) 可使用新模型重新运行应用。就是这样!此应用现在可以识别奶酪、沙拉、烘焙食品。

b9705235366ae162.png

9. 恭喜!

您已使用 TFLite 训练自定义模型并向应用添加对象检测功能。您只需完成此步骤便可开始投放广告!

所学内容

  • 如何在 TensorFlow Hub 中查找预训练的 TFLite 对象检测模型
  • 如何使用 TFLite Task 库将异议检测模型集成到 Android 应用中
  • 如何使用 TFLite Model Maker 训练自定义对象检测模型

后续步骤

  • 使用 Firebase 增强 TFLite 模型部署
  • 收集训练数据以训练您自己的模型
  • 在您自己的 Android 应用中应用对象检测

了解详情