1. 准备工作
在此 Codelab 中,您将更新在上一个“移动文本分类入门”Codelab 中构建的应用。
前提条件
- 本 Codelab 专为刚接触机器学习的有经验的开发者而设计。
- 此 Codelab 是开发者在线课程的一部分。如果您尚未完成“构建基本的消息式应用”或“构建垃圾评论机器学习模型”,请立即停止并完成这些操作。
您将 [构建或学习]的内容
- 您将学习如何将自定义模型集成到之前步骤中构建的应用中。
所需条件
- Android Studio,或适用于 iOS 的 CocoaPods
2. 打开现有 Android 应用
您可以按照 Codelab 1 中的说明获取此代码,也可以克隆此代码库并从 TextClassificationStep1 加载该应用。
git clone https://github.com/googlecodelabs/odml-pathways
您可以在 TextClassificationOnMobile->Android 路径中找到此文件。
您还可以通过 TextClassificationStep2 获取完成的代码。
打开后,您就可以继续执行第 2 步了。
3. 导入模型文件和元数据
在“构建垃圾评论机器学习模型”Codelab 中,您创建了一个 .TFLITE 模型。
您应该已下载模型文件。如果您没有此模型,可以从本 Codelab 的代码库中获取,该模型可在此处找到。
通过创建资源目录将其添加到项目中。
- 使用项目导航器,确保在顶部选择了 Android。
- 右键点击 app 文件夹。依次选择 New > Directory。

- 在 New Directory 对话框中,选择 src/main/assets。

您会看到应用中现在提供了一个新的 assets 文件夹。

- 右键点击资源。
- 在随即打开的菜单中,您会看到(在 Mac 上)Reveal in Finder。选择相应选项。(在 Windows 上,此项操作为在资源管理器中显示,在 Ubuntu 上为在“文件”中显示。)

系统会启动 Finder 以显示文件位置(在 Windows 上为文件资源管理器,在 Linux 上为文件)。
- 将
labels.txt、model.tflite和vocab文件复制到此目录。

- 返回 Android Studio,您会看到这些文件已显示在 assets 文件夹中。

4. 更新 build.gradle 以使用 TensorFlow Lite
如需使用 TensorFlow Lite 及支持它的 TensorFlow Lite 任务库,您需要更新 build.gradle 文件。
Android 项目通常有多个 build.gradle 文件,因此请务必找到应用级文件。在 Android 视图的项目资源管理器中,在 Gradle Scripts 部分中找到该文件。正确的应用会带有 .app 标签,如下所示:

您需要对此文件进行两项更改。第一个位于底部的 dependencies 部分。为 TensorFlow Lite 任务库添加文本 implementation,如下所示:
implementation 'org.tensorflow:tensorflow-lite-task-text:0.1.0'
自撰写本文以来,版本号可能已发生变化,因此请务必访问 https://www.tensorflow.org/lite/inference_with_metadata/task_library/nl_classifier 查看最新版本。
任务库还要求最低 SDK 版本为 21。在 android > default config 中找到此设置,并将其更改为 21:

现在,您已拥有所有依赖项,可以开始编码了!
5. 添加辅助类
为了将推理逻辑(即您的应用使用模型的位置)与用户界面分离开,请创建另一个类来处理模型推理。我们称之为“辅助”类。
- 右键点击包含
MainActivity代码的软件包名称。 - 依次选择 New > Package。

- 您会在屏幕中央看到一个对话框,其中要求您输入软件包名称。将其添加到当前软件包名称的末尾。(此处称为 helpers)。

- 完成后,在项目资源管理器中右键点击 helpers 文件夹。
- 依次选择 New > Java Class,并将其命名为
TextClassificationClient。您将在下一步中修改该文件。
您的 TextClassificationClient 辅助类将如下所示(不过您的软件包名称可能有所不同)。
package com.google.devrel.textclassificationstep1.helpers;
public class TextClassificationClient {
}
- 使用以下代码更新文件:
package com.google.devrel.textclassificationstep2.helpers;
import android.content.Context;
import android.util.Log;
import java.io.IOException;
import java.util.List;
import org.tensorflow.lite.support.label.Category;
import org.tensorflow.lite.task.text.nlclassifier.NLClassifier;
public class TextClassificationClient {
private static final String MODEL_PATH = "model.tflite";
private static final String TAG = "CommentSpam";
private final Context context;
NLClassifier classifier;
public TextClassificationClient(Context context) {
this.context = context;
}
public void load() {
try {
classifier = NLClassifier.createFromFile(context, MODEL_PATH);
} catch (IOException e) {
Log.e(TAG, e.getMessage());
}
}
public void unload() {
classifier.close();
classifier = null;
}
public List<Category> classify(String text) {
List<Category> apiResults = classifier.classify(text);
return apiResults;
}
}
此类将为 TensorFlow Lite 解释器提供封装容器,用于加载模型并抽象化管理应用与模型之间数据交换的复杂性。
在 load() 方法中,它将从模型路径实例化一个新的 NLClassifier 类型。模型路径就是模型名称 model.tflite。NLClassifier 类型属于文本任务库,它通过以下方式为您提供帮助:将字符串转换为令牌、使用正确的序列长度、将其传递给模型,以及解析结果。
(如需详细了解这些内容,请重新访问“构建垃圾评论机器学习模型”一文。)
分类在 classify 方法中执行,您需要向该方法传递一个字符串,它会返回一个 List。使用机器学习模型对内容进行分类时,如果您想确定某个字符串是否为垃圾内容,通常会返回所有答案以及分配的概率。例如,如果您向其传递看起来像垃圾邮件的消息,您会收到 2 个答案的列表;一个答案是该消息为垃圾邮件的概率,另一个答案是该消息不是垃圾邮件的概率。“垃圾内容”/“非垃圾内容”是类别,因此返回的 List 将包含这些概率。您稍后会解析该内容。
现在您已拥有辅助类,请返回到 MainActivity 并更新它,以使用该类对文本进行分类。您将在下一步中看到这一点!
6. 对文本进行分类
在 MainActivity 中,您首先需要导入刚刚创建的辅助函数!
- 在
MainActivity.kt的顶部,与其他导入内容一起添加:
import com.google.devrel.textclassificationstep2.helpers.TextClassificationClient
import org.tensorflow.lite.support.label.Category
- 接下来,您需要加载辅助函数。在
onCreate中,紧跟在setContentView行之后添加以下代码行,以实例化并加载辅助类:
val client = TextClassificationClient(applicationContext)
client.load()
目前,按钮的 onClickListener 应如下所示:
btnSendText.setOnClickListener {
var toSend:String = txtInput.text.toString()
txtOutput.text = toSend
}
- 将其更新为如下所示:
btnSendText.setOnClickListener {
var toSend:String = txtInput.text.toString()
var results:List<Category> = client.classify(toSend)
val score = results[1].score
if(score>0.8){
txtOutput.text = "Your message was detected as spam with a score of " + score.toString() + " and not sent!"
} else {
txtOutput.text = "Message sent! \nSpam score was:" + score.toString()
}
txtInput.text.clear()
}
这会将功能从仅输出用户输入更改为先对用户输入进行分类。
- 通过这一行代码,您可以获取用户输入的字符串并将其传递给模型,从而获得结果:
var results:List<Category> = client.classify(toSend)
只有 2 个类别,即 False 和 True
。(TensorFlow 会按字母顺序对它们进行排序,因此 False 将是项 0,而 True 将是项 1。)
- 如需获取值为
True的概率得分,您可以查看 results[1].score,如下所示:
val score = results[1].score
- 选择一个阈值(在本例中为 0.8),如果 True 类别的得分高于该阈值 (0.8),则表示相应消息为垃圾邮件。否则,该邮件不是垃圾邮件,可以安全地发送:
if(score>0.8){
txtOutput.text = "Your message was detected as spam with a score of " + score.toString() + " and not sent!"
} else {
txtOutput.text = "Message sent! \nSpam score was:" + score.toString()
}
- 点击此处可查看该模型的实际应用。系统将“Visit my blog to buy stuff!”这条消息标记为极有可能是垃圾内容:

相反,“嘿,有趣的教程,谢谢!”被认为是垃圾邮件的可能性非常低:

7. 更新 iOS 应用以使用 TensorFlow Lite 模型
您可以按照 Codelab 1 中的说明获取此代码,也可以克隆此代码库并从 TextClassificationStep1 加载该应用。您可以在 TextClassificationOnMobile->iOS 路径中找到此文件。
您还可以通过 TextClassificationStep2 获取完成的代码。
在“构建垃圾评论机器学习模型”Codelab 中,您创建了一个非常简单的应用,该应用允许用户在 UITextView 中输入消息,然后将消息传递到输出,而无需进行任何过滤。
现在,您将更新该应用,使其使用 TensorFlow Lite 模型在发送之前检测文本中的垃圾评论。只需通过在输出标签中呈现文本来模拟此应用中的发送操作(但实际应用可能具有公告板、聊天功能或类似功能)。
首先,您需要第 1 步中的应用,您可以从代码库中克隆该应用。
如需纳入 TensorFlow Lite,您需要使用 CocoaPods。如果您尚未安装这些应用,可以按照 https://cocoapods.org/ 上的说明进行安装。
- 安装 CocoaPods 后,在 TextClassification 应用的
.xcproject所在的目录中创建一个名为 Podfile 的文件。此文件的内容应如下所示:
target 'TextClassificationStep2' do
use_frameworks!
# Pods for NLPClassifier
pod 'TensorFlowLiteSwift'
end
应用名称应位于第一行,而不是“TextClassificationStep2”。
使用终端,导航到该目录并运行 pod install。如果成功,您将获得一个名为 Pods 的新目录,并且系统会为您创建一个新的 .xcworkspace 文件。您日后将使用该值,而不是 .xcproject。
如果失败,请确保 Podfile 与 .xcproject 位于同一目录中。Podfile 位于错误的目录中,或者目标名称有误,通常是主要原因!
8. 添加模型和词汇文件
使用 TensorFlow Lite Model Maker 创建模型时,您可以输出模型(以 model.tflite 形式)和词汇(以 vocab.txt 形式)。
- 将它们从 Finder 拖放到项目窗口中,即可将其添加到项目中。确保选中添加到目标:

完成后,您应该会在项目中看到这些文件:

- 选择您的项目(在上面的屏幕截图中,它是蓝色图标 TextClassificationStep2),然后查看 Build Phases 标签页,以仔细检查它们是否已添加到软件包中(以便将它们部署到设备):

9. 加载词汇
在进行 NLP 分类时,模型会使用编码为向量的字词进行训练。模型会使用一组特定的名称和值对字词进行编码,这些名称和值会在模型训练时学习。请注意,大多数模型会有不同的词汇,因此您必须使用在训练时生成的模型词汇。这是您刚刚添加到应用中的 vocab.txt 文件。
您可以在 Xcode 中打开该文件,以查看编码。“song”等字词编码为 6,“love”编码为 12。实际上,顺序是按频次排序,因此“I”是数据集中最常见的字词,其次是“check”。
当用户输入文字时,您需要先使用此词汇表对文字进行编码,然后再将其发送给模型进行分类。
我们来探索一下该代码。首先加载词汇。
- 定义一个类级变量来存储字典:
var words_dictionary = [String : Int]()
- 然后,在类中创建一个
func,以将词汇加载到此字典中:
func loadVocab(){
// This func will take the file at vocab.txt and load it into a has table
// called words_dictionary. This will be used to tokenize the words before passing them
// to the model trained by TensorFlow Lite Model Maker
if let filePath = Bundle.main.path(forResource: "vocab", ofType: "txt") {
do {
let dictionary_contents = try String(contentsOfFile: filePath)
let lines = dictionary_contents.split(whereSeparator: \.isNewline)
for line in lines{
let tokens = line.components(separatedBy: " ")
let key = String(tokens[0])
let value = Int(tokens[1])
words_dictionary[key] = value
}
} catch {
print("Error vocab could not be loaded")
}
} else {
print("Error -- vocab file not found")
}
}
- 您可以通过从
viewDidLoad内调用此函数来运行它:
override func viewDidLoad() {
super.viewDidLoad()
txtInput.delegate = self
loadVocab()
}
10. 将字符串转换为一系列令牌
用户将以句子的形式输入文字,这些文字将成为一个字符串。如果句子中的每个字词都存在于字典中,则会根据词汇表中的定义将这些字词编码为相应字词的键值。
NLP 模型通常接受固定序列长度。使用 ragged tensors 构建的模型存在例外情况,但在大多数情况下,您会发现它已固定。您在创建模型时指定了此长度。请务必在 iOS 应用中使用相同的长度。
您之前使用的 TensorFlow Lite Model Maker 的 Colab 中的默认值为 20,因此在此处也进行相应设置:
let SEQUENCE_LENGTH = 20
添加以下 func,它将获取字符串、将其转换为小写并去除所有标点符号:
func convert_sentence(sentence: String) -> [Int32]{
// This func will split a sentence into individual words, while stripping punctuation
// If the word is present in the dictionary it's value from the dictionary will be added to
// the sequence. Otherwise we'll continue
// Initialize the sequence to be all 0s, and the length to be determined
// by the const SEQUENCE_LENGTH. This should be the same length as the
// sequences that the model was trained for
var sequence = [Int32](repeating: 0, count: SEQUENCE_LENGTH)
var words : [String] = []
sentence.enumerateSubstrings(
in: sentence.startIndex..<sentence.endIndex,options: .byWords) {
(substring, _, _, _) -> () in words.append(substring!) }
var thisWord = 0
for word in words{
if (thisWord>=SEQUENCE_LENGTH){
break
}
let seekword = word.lowercased()
if let val = words_dictionary[seekword]{
sequence[thisWord]=Int32(val)
thisWord = thisWord + 1
}
}
return sequence
}
请注意,序列将为 Int32。之所以特意选择此值,是因为在将值传递给 TensorFlow Lite 时,您将处理低级内存,而 TensorFlow Lite 会将字符串序列中的整数视为 32 位整数。这样,在向模型传递字符串时,您的工作会(稍微)轻松一些。
11. 执行分类
若要对句子进行分类,必须先根据句子中的字词将其转换为一系列令牌。此操作已在第 9 步中完成。
现在,您将获取该句子并将其传递给模型,让模型对该句子进行推理,然后解析结果。
这将使用 TensorFlow Lite 解释器,您需要导入该解释器:
import TensorFlowLite
首先,创建一个接受序列的 func,该序列是 Int32 类型的数组:
func classify(sequence: [Int32]){
// Model Path is the location of the model in the bundle
let modelPath = Bundle.main.path(forResource: "model", ofType: "tflite")
var interpreter: Interpreter
do{
interpreter = try Interpreter(modelPath: modelPath!)
} catch _{
print("Error loading model!")
return
}
这会从软件包中加载模型文件,并使用该文件调用解释器。
下一步是将序列中存储的底层内存复制到名为 myData, 的缓冲区中,以便将其传递给张量。在实现 TensorFlow Lite pod 以及解释器时,您获得了对张量类型的访问权限。
按如下方式启动代码(仍在 classify func 中):
let tSequence = Array(sequence)
let myData = Data(copyingBufferOf: tSequence.map { Int32($0) })
let outputTensor: Tensor
如果您在 copyingBufferOf 上收到错误消息,请不必担心。稍后将作为扩展程序实现。
现在,您可以开始在解释器上分配张量,将您刚刚创建的数据缓冲区复制到输入张量,然后调用解释器进行推理:
do {
// Allocate memory for the model's input `Tensor`s.
try interpreter.allocateTensors()
// Copy the data to the input `Tensor`.
try interpreter.copy(myData, toInputAt: 0)
// Run inference by invoking the `Interpreter`.
try interpreter.invoke()
调用完成后,您可以查看解释器的输出,了解结果。
这些将是原始值(每个神经元 4 个字节),您随后必须读取并转换这些值。由于此特定模型有 2 个输出神经元,因此您需要读入 8 个字节,这些字节将转换为 Float32 以进行解析。您正在处理低级内存,因此会显示 unsafeData。
// Get the output `Tensor` to process the inference results.
outputTensor = try interpreter.output(at: 0)
// Turn the output tensor into an array. This will have 2 values
// Value at index 0 is the probability of negative sentiment
// Value at index 1 is the probability of positive sentiment
let resultsArray = outputTensor.data
let results: [Float32] = [Float32](unsafeData: resultsArray) ?? []
现在,解析数据以确定垃圾内容的质量相对容易。该模型有 2 个输出,第一个输出表示相应消息不是垃圾邮件的概率,第二个输出表示相应消息是垃圾邮件的概率。因此,您可以查看 results[1] 来找到垃圾邮件值:
let positiveSpamValue = results[1]
var outputString = ""
if(positiveSpamValue>0.8){
outputString = "Message not sent. Spam detected with probability: " + String(positiveSpamValue)
} else {
outputString = "Message sent!"
}
txtOutput.text = outputString
为方便起见,以下是完整的方法:
func classify(sequence: [Int32]){
// Model Path is the location of the model in the bundle
let modelPath = Bundle.main.path(forResource: "model", ofType: "tflite")
var interpreter: Interpreter
do{
interpreter = try Interpreter(modelPath: modelPath!)
} catch _{
print("Error loading model!")
Return
}
let tSequence = Array(sequence)
let myData = Data(copyingBufferOf: tSequence.map { Int32($0) })
let outputTensor: Tensor
do {
// Allocate memory for the model's input `Tensor`s.
try interpreter.allocateTensors()
// Copy the data to the input `Tensor`.
try interpreter.copy(myData, toInputAt: 0)
// Run inference by invoking the `Interpreter`.
try interpreter.invoke()
// Get the output `Tensor` to process the inference results.
outputTensor = try interpreter.output(at: 0)
// Turn the output tensor into an array. This will have 2 values
// Value at index 0 is the probability of negative sentiment
// Value at index 1 is the probability of positive sentiment
let resultsArray = outputTensor.data
let results: [Float32] = [Float32](unsafeData: resultsArray) ?? []
let positiveSpamValue = results[1]
var outputString = ""
if(positiveSpamValue>0.8){
outputString = "Message not sent. Spam detected with probability: " +
String(positiveSpamValue)
} else {
outputString = "Message sent!"
}
txtOutput.text = outputString
} catch let error {
print("Failed to invoke the interpreter with error: \(error.localizedDescription)")
}
}
12. 添加 Swift 扩展程序
上述代码使用数据类型的扩展功能,以便您将 Int32 数组的原始位复制到 Data 中。相应扩展程序的代码如下:
extension Data {
/// Creates a new buffer by copying the buffer pointer of the given array.
///
/// - Warning: The given array's element type `T` must be trivial in that it can be copied bit
/// for bit with no indirection or reference-counting operations; otherwise, reinterpreting
/// data from the resulting buffer has undefined behavior.
/// - Parameter array: An array with elements of type `T`.
init<T>(copyingBufferOf array: [T]) {
self = array.withUnsafeBufferPointer(Data.init)
}
}
处理低级内存时,您会使用“不安全”数据,而上述代码需要您初始化一个不安全的数据数组。此扩展程序可实现此目的:
extension Array {
/// Creates a new array from the bytes of the given unsafe data.
///
/// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit
/// with no indirection or reference-counting operations; otherwise, copying the raw bytes in
/// the `unsafeData`'s buffer to a new array returns an unsafe copy.
/// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
/// `MemoryLayout<Element>.stride`.
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
init?(unsafeData: Data) {
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
#if swift(>=5.0)
self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) }
#else
self = unsafeData.withUnsafeBytes {
.init(UnsafeBufferPointer<Element>(
start: $0,
count: unsafeData.count / MemoryLayout<Element>.stride
))
}
#endif // swift(>=5.0)
}
}
13. 运行 iOS 应用
运行并测试应用。
如果一切顺利,您应该会在设备上看到如下所示的应用:

在发送“Buy my book to learn online trading!”消息的位置,应用会返回垃圾内容检测到的提醒,概率为 0 .99!
14. 恭喜!
您现在已创建了一个非常简单的应用,该应用使用在用于向博客发送垃圾内容的训练数据上训练的模型来过滤文本中的垃圾评论。
典型开发者生命周期的下一步是探索如何根据您自己社区中的数据自定义模型。您将在下一个开发者在线课程活动中了解具体操作方法。