1. 准备工作
在此 Codelab 中,您将更新在上一个“移动文本分类入门”Codelab 中构建的应用。
前提条件
- 此 Codelab 面向刚接触机器学习的资深开发者。
- 此 Codelab 是有序衔接课程的一部分。如果您尚未构建基本即时通讯样式应用或构建垃圾评论机器学习模型,请立即停止并执行此操作。
您将 [构建或学习]的内容
- 您将学习如何将自定义模型集成到上一步中创建的应用。
所需物品
- Android Studio 或 iOS 版 CocoaPods
2. 打开现有 Android 应用
您可以通过遵循 Codelab 1 来获取代码,也可以克隆此代码库并从 TextClassificationStep1
加载应用。
git clone https://github.com/googlecodelabs/odml-pathways
您可以在 TextClassificationOnMobile->Android
路径中找到它。
finished 代码也以 TextClassificationStep2
的形式提供给您。
打开帐号后,您可以继续进行第 2 步。
3.导入模型文件和元数据
在“构建垃圾评论机器学习模型” Codelab 中,您创建了一个 .TFLITE 模型。
您应该已经下载了模型文件。如果没有,您可以从此 Codelab 的代码库中获取;要获取模型,请点击此处。
通过创建资源目录将其添加到项目中。
- 使用项目导航器,确保在顶部选择了 Android。
- 右键点击 app 文件夹。依次选择新建 > 目录。
- 在 New Directory 对话框中,选择 src/main/assets。
您会看到应用中新增了 assets 文件夹。
- 右键点击 assets.
- 在打开的菜单中,您会看到(在 Mac 上)在 Finder 中显示。选择它。(在 Windows 上,系统会显示在资源管理器中显示;而在 Ubuntu 上,则会显示在文件中显示)。
系统会启动访达以显示文件位置(Windows 上为文件资源管理器,Linux 上为文件)。
- 将
labels.txt
、model.tflite
和vocab
文件复制到此目录。
- 返回 Android Studio,您将在 assets 文件夹中看到这些素材资源。
4.更新您的 build.gradle 以使用 TensorFlow Lite
如需使用 TensorFlow Lite 以及支持它的 TensorFlow Lite 任务库,您需要更新 build.gradle
文件。
Android 项目通常有多个级别,因此请务必找到应用级别 1。在 Android 视图中的项目资源管理器中,找到 Gradle Scripts 部分。正确的请求将带有 .app 标签,如下所示:
您需要对此文件进行两次更改。第一种位于底部的 dependencies 部分中。为 TensorFlow Lite 任务库添加文本 implementation
,如下所示:
implementation 'org.tensorflow:tensorflow-lite-task-text:0.1.0'
版本号可能已更改,因此请务必查看 https://www.tensorflow.org/lite/inference_with_metadata/task_library/nl_classifier 获取最新版本。
任务库还需要最低 SDK 版本 21。在 android
> default config
中找到此设置,并将其更改为 21:
现在,您已拥有所有依赖项,接下来就可以开始编码了!
5. 添加帮助程序类
为了将推理逻辑(即您的应用使用模型的位置)与用户界面分离开,请再创建一个类来处理模型推理。将此类称为“helper”类。
- 右键点击
MainActivity
代码所在的软件包名称。 - 依次选择 New > Package。
- 屏幕中央会显示一个对话框,要求您输入软件包名称。将它添加到当前软件包名称的末尾。(此处称为帮助程序)。
- 完成此操作后,在 Project Explorer 中右键点击 helpers 文件夹。
- 依次选择 New > Java Class,并将其命名为
TextClassificationClient
。您将在下一步中编辑该文件。
您的 TextClassificationClient
辅助类将如下所示(尽管您的软件包名称可能不同)。
package com.google.devrel.textclassificationstep1.helpers;
public class TextClassificationClient {
}
- 使用以下代码更新文件:
package com.google.devrel.textclassificationstep2.helpers;
import android.content.Context;
import android.util.Log;
import java.io.IOException;
import java.util.List;
import org.tensorflow.lite.support.label.Category;
import org.tensorflow.lite.task.text.nlclassifier.NLClassifier;
public class TextClassificationClient {
private static final String MODEL_PATH = "model.tflite";
private static final String TAG = "CommentSpam";
private final Context context;
NLClassifier classifier;
public TextClassificationClient(Context context) {
this.context = context;
}
public void load() {
try {
classifier = NLClassifier.createFromFile(context, MODEL_PATH);
} catch (IOException e) {
Log.e(TAG, e.getMessage());
}
}
public void unload() {
classifier.close();
classifier = null;
}
public List<Category> classify(String text) {
List<Category> apiResults = classifier.classify(text);
return apiResults;
}
}
此类将为 TensorFlow Lite 解释器提供一个封装容器,用于加载模型并抽象化您的应用与模型之间的数据交换的复杂性。
在 load()
方法中,它会从模型路径实例化一个新的 NLClassifier
类型。模型路径只是模型名称 model.tflite
。NLClassifier
类型是文本任务库的一部分,它可以帮助您将字符串转换为令牌、使用正确的序列长度,将其传递给模型并解析结果。
(如需了解详情,请再次参阅“垃圾评论”机器学习模型。)
分类在分类方法中执行,您可以在其中传递一个字符串,然后返回 List
。使用机器学习模型对内容进行分类时,如果您想判断某个字符串是否是垃圾内容,系统通常会返回所有具有概率的答案。例如,如果您向其传递似乎是垃圾邮件的邮件,则会收到包含 2 个答复的列表;一个是垃圾邮件的概率,另一个不是垃圾邮件的概率。“垃圾邮件/非垃圾邮件”是类别,因此返回的 List
将包含这些概率。稍后您将进行解析。
现在您已经有了帮助程序类,请返回到 MainActivity
并进行更新,以便用它对您的文本进行分类。您将在下一步中看到这一点!
6.对文本进行分类
在 MainActivity
中,您首先要导入刚刚创建的帮助程序!
- 在
MainActivity.kt
的顶部,与其他导入项一起添加:
import com.google.devrel.textclassificationstep2.helpers.TextClassificationClient
import org.tensorflow.lite.support.label.Category
- 接下来,您需要加载帮助程序。在
onCreate
中的setContentView
代码行后面,添加以下几行代码以实例化并加载帮助程序类:
val client = TextClassificationClient(applicationContext)
client.load()
目前,您的按钮的 onClickListener
应如下所示:
btnSendText.setOnClickListener {
var toSend:String = txtInput.text.toString()
txtOutput.text = toSend
}
- 请按如下所示进行更新:
btnSendText.setOnClickListener {
var toSend:String = txtInput.text.toString()
var results:List<Category> = client.classify(toSend)
val score = results[1].score
if(score>0.8){
txtOutput.text = "Your message was detected as spam with a score of " + score.toString() + " and not sent!"
} else {
txtOutput.text = "Message sent! \nSpam score was:" + score.toString()
}
txtInput.text.clear()
}
这会将功能从仅输出用户输入更改为首先分类。
- 以下代码行会将用户输入的字符串传递给模型,从而获得结果:
var results:List<Category> = client.classify(toSend)
只有 2 个类别,分别为 False
和 True
(TensorFlow 按字母顺序对其进行排序,因此 False 将是项 0,True 将是项 1。)
- 如需获取值为
True
的概率的得分,您可以查看 results[1].score,如下所示:
val score = results[1].score
- 选择了阈值(在本例中为 0.8),您说如果 True 类别的得分高于阈值 (0.8),则邮件为垃圾邮件。否则,这不是垃圾邮件,且邮件可以安全发送:
if(score>0.8){
txtOutput.text = "Your message was detected as spam with a score of " + score.toString() + " and not sent!"
} else {
txtOutput.text = "Message sent! \nSpam score was:" + score.toString()
}
- 在此处查看该模型的实际运用。“访问我的博客以买东西!”消息已标记为疑似垃圾内容的可能性:
反过来说:“谢谢,有趣的教程,谢谢!”被认为是垃圾内容的可能性极低:
7. 更新您的 iOS 应用以使用 TensorFlow Lite 模型
您可以通过遵循 Codelab 1 来获取代码,也可以克隆此代码库并从 TextClassificationStep1
加载应用。您可以在 TextClassificationOnMobile->iOS
路径中找到它。
finished 代码也以 TextClassificationStep2
的形式提供给您。
在“构建垃圾评论机器学习模型”Codelab 中,您创建了一个非常简单的应用,允许用户在 UITextView
中输入消息,并未经任何过滤将其传递到输出。
现在,您将更新该应用,以便在发送之前使用 TensorFlow Lite 模型检测文本中的垃圾评论。只需通过在输出标签中渲染文本来模拟此应用中的发送(但真实应用可能具有公告板、聊天或类似内容)。
要开始使用,您需要从第 1 步获取应用,该应用可从代码库中克隆。
如需整合 TensorFlow Lite,您需要使用 CocoaPods。如果您还没有安装这些扩展程序,可以按照 https://cocoapods.org/ 上的说明操作。
- 安装 CocoaPods 后,在 TextClassification 应用的
.xcproject
目录下创建一个名为 Podfile 的文件。此文件的内容应如下所示:
target 'TextClassificationStep2' do
use_frameworks!
# Pods for NLPClassifier
pod 'TensorFlowLiteSwift'
end
应用的名称应位于第一行,而不是“TextClassificationStep2”。
使用终端导航到该目录并运行 pod install
。如果成功,您将获得一个名为 Pod 的新目录,并且系统将为您创建新的 .xcworkspace
文件。您稍后将使用该变量,而非 .xcproject
。
如果下载失败,请确保 Podfile 与 .xcproject
位于同一目录中。错误目录或错误目标名称中的 podfile 通常是罪魁祸首!
8. 添加模型和 Vocab 文件
使用 TensorFlow Lite Model Maker 创建模型时,您可以输出模型(如 model.tflite
)和词汇表(如 vocab.txt
)。
- 将它们从 Finder 拖放到您的项目窗口中,即可将它们添加到您的项目中。确保已选中添加到目标:
完成后,您应该在项目中看到它们:
- 仔细检查您的项目是否已添加到 bundle 中(以便将其部署到设备),方法是选择您的项目(在上面的屏幕截图中,它是蓝色图标 TextClassificationStep2),然后查看 Build Phases 标签页:
9. 加载词汇表
进行 NLP 分类时,系统会使用编码向量的字词训练模型。该模型使用在模型训练时学习的一组特定名称和值对字词进行编码。请注意,大多数模型的词汇都不同,因此请务必在训练时为模型使用词汇表。这是您刚刚添加到应用的 vocab.txt
文件。
您可以在 Xcode 中打开该文件,以查看编码。比如,“song”会编码为 6,将“love”编码为 12。该顺序实际上是频率顺序,因此,“I”是该数据集中最常见的字词,其次是“check”。
当用户输入字词时,您需要先使用此词汇对其进行编码,然后再将其发送到要分类的模型。
我们来了解一下该代码。首先加载词汇表。
- 定义用于存储字典的类级变量:
var words_dictionary = [String : Int]()
- 然后,在类中创建一个
func
,以将词汇表加载到此字典中:
func loadVocab(){
// This func will take the file at vocab.txt and load it into a has table
// called words_dictionary. This will be used to tokenize the words before passing them
// to the model trained by TensorFlow Lite Model Maker
if let filePath = Bundle.main.path(forResource: "vocab", ofType: "txt") {
do {
let dictionary_contents = try String(contentsOfFile: filePath)
let lines = dictionary_contents.split(whereSeparator: \.isNewline)
for line in lines{
let tokens = line.components(separatedBy: " ")
let key = String(tokens[0])
let value = Int(tokens[1])
words_dictionary[key] = value
}
} catch {
print("Error vocab could not be loaded")
}
} else {
print("Error -- vocab file not found")
}
}
- 您可以通过在
viewDidLoad
内调用它来运行它:
override func viewDidLoad() {
super.viewDidLoad()
txtInput.delegate = self
loadVocab()
}
10. 将字符串转换为令牌序列
您的用户会以句子的形式输入字词,这将成为字符串。字典中的每个句子(如字典中存在)将被编码为词汇表中定义的单词的键值。
NLP 模型通常接受固定的序列长度。使用 ragged tensors
构建的模型也有例外情况,但在大多数情况下,您会发现此问题已得到修复。您在创建模型时指定了该长度。请务必在 iOS 应用中使用相同的长度。
您之前使用的 Colab for TensorFlow Lite Model Maker 的默认值为 20,因此也在此处进行设置:
let SEQUENCE_LENGTH = 20
添加此 func
,它会获取字符串、将其转换为小写形式并去除所有标点符号:
func convert_sentence(sentence: String) -> [Int32]{
// This func will split a sentence into individual words, while stripping punctuation
// If the word is present in the dictionary it's value from the dictionary will be added to
// the sequence. Otherwise we'll continue
// Initialize the sequence to be all 0s, and the length to be determined
// by the const SEQUENCE_LENGTH. This should be the same length as the
// sequences that the model was trained for
var sequence = [Int32](repeating: 0, count: SEQUENCE_LENGTH)
var words : [String] = []
sentence.enumerateSubstrings(
in: sentence.startIndex..<sentence.endIndex,options: .byWords) {
(substring, _, _, _) -> () in words.append(substring!) }
var thisWord = 0
for word in words{
if (thisWord>=SEQUENCE_LENGTH){
break
}
let seekword = word.lowercased()
if let val = words_dictionary[seekword]{
sequence[thisWord]=Int32(val)
thisWord = thisWord + 1
}
}
return sequence
}
请注意,序列将是 Int32。这是故意选择的,因为在向 TensorFlow Lite 传递值时,您会处理低级内存,并且 TensorFlow Lite 会将字符串序列中的整数视为 32 位整数。这使得将字符串传递给模型时,您的生活(稍微)变轻松。
11. 进行分类
如需对句子进行分类,必须先根据句子中的字词将它转换为一系列词条。此操作将在第 9 步中完成。
现在,您将获取句子并将其传递给模型,让模型对句子进行推断,并解析结果。
这将使用 TensorFlow Lite 解释器,您需要导入:
import TensorFlowLite
以接受序列(一个 Int32 类型的数组)的 func
开头:
func classify(sequence: [Int32]){
// Model Path is the location of the model in the bundle
let modelPath = Bundle.main.path(forResource: "model", ofType: "tflite")
var interpreter: Interpreter
do{
interpreter = try Interpreter(modelPath: modelPath!)
} catch _{
print("Error loading model!")
return
}
这将从 bundle 中加载模型文件,并使用该文件调用解释器。
下一步是将存储在序列中的底层内存复制到名为 myData,
的缓冲区中,以便将其传递给张量。在实现 TensorFlow Lite pod 和解释器时,您可以访问 Tensor Type。
启动如下代码(仍在 func
分类中):
let tSequence = Array(sequence)
let myData = Data(copyingBufferOf: tSequence.map { Int32($0) })
let outputTensor: Tensor
如果您在 copyingBufferOf
时遇到错误,请不要担心。稍后,将将其作为扩展程序实现。
现在该在解释器上分配张量了,将您刚刚创建的数据缓冲区复制到输入张量,然后调用解释器来进行推理:
do {
// Allocate memory for the model's input `Tensor`s.
try interpreter.allocateTensors()
// Copy the data to the input `Tensor`.
try interpreter.copy(myData, toInputAt: 0)
// Run inference by invoking the `Interpreter`.
try interpreter.invoke()
调用完成后,您可以查看解释器的输出以查看结果。
这些是原始值(每个神经元 4 字节),您必须读取并转换这些值。由于此特定模型具有 2 个输出神经元,因此您需要读取 8 个字节,这些文件将转换为 Float32 进行解析。您正在处理的是低级别内存,因此需设置 unsafeData
。
// Get the output `Tensor` to process the inference results.
outputTensor = try interpreter.output(at: 0)
// Turn the output tensor into an array. This will have 2 values
// Value at index 0 is the probability of negative sentiment
// Value at index 1 is the probability of positive sentiment
let resultsArray = outputTensor.data
let results: [Float32] = [Float32](unsafeData: resultsArray) ?? []
现在,解析数据相对容易,从而确定网络垃圾质量。模型有 2 项输出,第一项表示此邮件不是垃圾邮件的概率,第二项表示其为垃圾邮件的概率。因此,您可以查看 results[1]
来查找垃圾邮件值:
let positiveSpamValue = results[1]
var outputString = ""
if(positiveSpamValue>0.8){
outputString = "Message not sent. Spam detected with probability: " + String(positiveSpamValue)
} else {
outputString = "Message sent!"
}
txtOutput.text = outputString
为方便起见,下面给出了完整方法:
func classify(sequence: [Int32]){
// Model Path is the location of the model in the bundle
let modelPath = Bundle.main.path(forResource: "model", ofType: "tflite")
var interpreter: Interpreter
do{
interpreter = try Interpreter(modelPath: modelPath!)
} catch _{
print("Error loading model!")
Return
}
let tSequence = Array(sequence)
let myData = Data(copyingBufferOf: tSequence.map { Int32($0) })
let outputTensor: Tensor
do {
// Allocate memory for the model's input `Tensor`s.
try interpreter.allocateTensors()
// Copy the data to the input `Tensor`.
try interpreter.copy(myData, toInputAt: 0)
// Run inference by invoking the `Interpreter`.
try interpreter.invoke()
// Get the output `Tensor` to process the inference results.
outputTensor = try interpreter.output(at: 0)
// Turn the output tensor into an array. This will have 2 values
// Value at index 0 is the probability of negative sentiment
// Value at index 1 is the probability of positive sentiment
let resultsArray = outputTensor.data
let results: [Float32] = [Float32](unsafeData: resultsArray) ?? []
let positiveSpamValue = results[1]
var outputString = ""
if(positiveSpamValue>0.8){
outputString = "Message not sent. Spam detected with probability: " +
String(positiveSpamValue)
} else {
outputString = "Message sent!"
}
txtOutput.text = outputString
} catch let error {
print("Failed to invoke the interpreter with error: \(error.localizedDescription)")
}
}
12. 添加 Swift 扩展程序
上述代码使用了数据类型的扩展,以便您将 Int32 数组的原始位复制到 Data
中。以下是该扩展程序的代码:
extension Data {
/// Creates a new buffer by copying the buffer pointer of the given array.
///
/// - Warning: The given array's element type `T` must be trivial in that it can be copied bit
/// for bit with no indirection or reference-counting operations; otherwise, reinterpreting
/// data from the resulting buffer has undefined behavior.
/// - Parameter array: An array with elements of type `T`.
init<T>(copyingBufferOf array: [T]) {
self = array.withUnsafeBufferPointer(Data.init)
}
}
处理低级别内存时,您需要使用“不安全”数据,并且上述代码需要初始化不安全的数据数组。该扩展程序使实现以下目标:
extension Array {
/// Creates a new array from the bytes of the given unsafe data.
///
/// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit
/// with no indirection or reference-counting operations; otherwise, copying the raw bytes in
/// the `unsafeData`'s buffer to a new array returns an unsafe copy.
/// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
/// `MemoryLayout<Element>.stride`.
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
init?(unsafeData: Data) {
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
#if swift(>=5.0)
self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) }
#else
self = unsafeData.withUnsafeBytes {
.init(UnsafeBufferPointer<Element>(
start: $0,
count: unsafeData.count / MemoryLayout<Element>.stride
))
}
#endif // swift(>=5.0)
}
}
13. 运行 iOS 应用
运行并测试应用。
如果一切顺利,您应在设备上看到如下应用:
下方显示“购买我的图书,学习在线交易!”发送,那么应用发送回垃圾内容的概率为 0 .99%!
14. 恭喜!
您现在已经创建了一个非常简单的应用,该应用使用根据垃圾博客数据训练过的模型来过滤垃圾评论的文本。
典型的开发者生命周期中的下一步是探索根据您所在社区中的数据自定义模型的过程。您将在下一个衔接课程活动中看到具体操作方式。