1. 事前準備
在本程式碼研究室中,您將更新先前「開始使用行動裝置文字分類」程式碼研究室中建構的應用程式。
必要條件
- 本程式碼研究室是為機器學習新手的資深開發人員所設計。
- 程式碼研究室是序列課程的一部分。如果您尚未完成「建構基本訊息風格應用程式」或「建構垃圾留言機器學習模型」,請先停下來完成這兩項課程。
您將 [建構或學習] 的項目
- 您將瞭解如何將自訂模型整合至先前步驟中建立的應用程式。
軟硬體需求
- Android Studio 或 iOS 版 CocoaPods
2. 開啟現有的 Android 應用程式
您可以按照程式碼研究室 1 的步驟取得程式碼,也可以複製這個存放區,然後從 TextClassificationStep1
載入應用程式。
git clone https://github.com/googlecodelabs/odml-pathways
您可以在 TextClassificationOnMobile->Android
路徑中找到這個值。
已完成的程式碼也可做為 TextClassificationStep2
使用。
開啟後,您就可以繼續進行步驟 2。
3. 匯入模型檔案和中繼資料
在「建構留言垃圾內容機器學習模型」程式碼研究室中,您已建立 .TFLITE 模型。
您應該已下載模型檔案。如果您沒有這個模型,可以前往這個程式碼研究室的存放區取得,模型則可在這裡取得。
建立素材資源目錄,將其新增至專案。
- 使用專案導覽工具,確認已選取頂端的「Android」Android。
- 在「app」資料夾上按一下滑鼠右鍵。依序選取「New」 >「Directory」。
- 在「New Directory」對話方塊中,選取「src/main/assets」。
您會看到應用程式中現在有新的「assets」資料夾。
- 按一下「assets」的滑鼠右鍵。
- 隨即開啟的選單會顯示「在 Finder 中顯示」 (Mac 使用者)。選取該項目。(Windows 會顯示「在檔案總管中顯示」,而 Ubuntu 會顯示「在檔案中顯示」)。
Finder 會啟動並顯示檔案位置 (Windows 上的「檔案總管」,Linux 上的「Files」)。
- 將
labels.txt
、model.tflite
和vocab
檔案複製到這個目錄。
- 返回 Android Studio,您會在「assets」資料夾中看到這些檔案。
4. 更新 build.gradle 以使用 TensorFlow Lite
如要使用 TensorFlow Lite 和支援 TensorFlow Lite 的 TensorFlow Lite 工作程式庫,您必須更新 build.gradle
檔案。
Android 專案通常會有多個,因此請務必找出 app 層級。在 Android 檢視畫面的專案瀏覽器中,前往「Gradle Scripts」專區尋找該檔案。正確的專案會加上 .app 標籤,如下所示:
您需要對這個檔案進行兩項變更。第一個位於底部的「依附元件」部分。為 TensorFlow Lite 工作程式庫新增文字 implementation
,如下所示:
implementation 'org.tensorflow:tensorflow-lite-task-text:0.1.0'
自本文撰寫以來,版本編號可能已變更,因此請務必前往 https://www.tensorflow.org/lite/inference_with_metadata/task_library/nl_classifier 查看最新版本。
工作程式庫也需要至少 21 以上的 SDK 版本。在 android
> default config
中找到這項設定,並將其變更為 21:
現在您已擁有所有依附元件,可以開始編寫程式了!
5. 新增輔助類別
如要區分應用程式使用該模型的推論邏輯,請建立另一個類別來處理模型推論。將此稱為「helper」類別。
- 在
MainActivity
程式碼所在的套件名稱上按一下滑鼠右鍵。 - 依序選取「New」(新增) >「Package」(套件)。
- 畫面中央會顯示一個對話方塊,要求您輸入套件名稱。並將其加到目前的套件名稱結尾處。(這裡稱為「輔助程式」)。
- 完成後,請在專案總管中按一下「helpers」資料夾的滑鼠右鍵。
- 依序選取「New」>「Java Class」,並將其命名為
TextClassificationClient
。您將在下一個步驟中編輯檔案。
TextClassificationClient
輔助類別會如下所示 (但您的套件名稱可能不同)。
package com.google.devrel.textclassificationstep1.helpers;
public class TextClassificationClient {
}
- 使用下列程式碼更新檔案:
package com.google.devrel.textclassificationstep2.helpers;
import android.content.Context;
import android.util.Log;
import java.io.IOException;
import java.util.List;
import org.tensorflow.lite.support.label.Category;
import org.tensorflow.lite.task.text.nlclassifier.NLClassifier;
public class TextClassificationClient {
private static final String MODEL_PATH = "model.tflite";
private static final String TAG = "CommentSpam";
private final Context context;
NLClassifier classifier;
public TextClassificationClient(Context context) {
this.context = context;
}
public void load() {
try {
classifier = NLClassifier.createFromFile(context, MODEL_PATH);
} catch (IOException e) {
Log.e(TAG, e.getMessage());
}
}
public void unload() {
classifier.close();
classifier = null;
}
public List<Category> classify(String text) {
List<Category> apiResults = classifier.classify(text);
return apiResults;
}
}
這個類別會為 TensorFlow Lite 轉譯器提供包裝函式,載入模型,並抽象化管理應用程式與模型之間資料交換的複雜性。
在 load()
方法中,系統會從模型路徑將新的 NLClassifier
類型例項化。模型路徑就是模型的名稱 model.tflite
。NLClassifier
類型是文字工作階段程式庫的一部分,可協助您將字串轉換為符記,並使用正確的序列長度將其傳遞至模型,然後剖析結果。
(如需進一步瞭解這些項目,請參閱「建立評論垃圾內容機器學習模型」一文)。
分類會在分類方法中執行,也就是傳遞字串,且傳回 List
。當您使用機器學習模型分類內容,並想判斷字串是否為垃圾內容時,系統通常會傳回所有答案,並指派機率。舉例來說,如果您傳送的訊息看起來像垃圾內容,系統會回傳 2 個答案清單,一個是垃圾內容的機率,另一個是垃圾內容以外的機率。垃圾內容/非垃圾內容是類別,因此傳回的 List
會包含這些機率。您稍後就會剖析出這些內容。
有了輔助類別後,請返回 MainActivity
並更新該類別,以便使用輔助類別來分類文字。下一個步驟將會用到!
6. 分類文字
在 MainActivity
中,您首先要匯入剛才建立的輔助程式!
- 在
MainActivity.kt
頂端,與其他匯入項目一起新增:
import com.google.devrel.textclassificationstep2.helpers.TextClassificationClient
import org.tensorflow.lite.support.label.Category
- 接下來,請建立輔助程式。在
onCreate
中,於setContentView
行後方立即新增這些行,以便例項化及載入輔助程式類別:
val client = TextClassificationClient(applicationContext)
client.load()
目前,按鈕的 onClickListener
應如下所示:
btnSendText.setOnClickListener {
var toSend:String = txtInput.text.toString()
txtOutput.text = toSend
}
- 更新為以下內容:
btnSendText.setOnClickListener {
var toSend:String = txtInput.text.toString()
var results:List<Category> = client.classify(toSend)
val score = results[1].score
if(score>0.8){
txtOutput.text = "Your message was detected as spam with a score of " + score.toString() + " and not sent!"
} else {
txtOutput.text = "Message sent! \nSpam score was:" + score.toString()
}
txtInput.text.clear()
}
這會將功能從只輸出使用者輸入內容,改為先將輸入內容分類。
- 透過這行程式碼,您可以取得使用者輸入的字串,並將其傳遞至模型,取得結果:
var results:List<Category> = client.classify(toSend)
只有 2 個類別:False
和 True
. (TensorFlow 會依字母順序排列,因此 False 為項目 0,True 值則為項目 1)。
- 如要取得值為
True
的機率分數,您可以查看 results[1].score,如下所示:
val score = results[1].score
- 選取門檻值 (本例中為 0.8);此時,當 True 類別的分數超過門檻值 (0.8),即表示郵件為垃圾郵件。否則,這不是垃圾訊息,可以安全傳送:
if(score>0.8){
txtOutput.text = "Your message was detected as spam with a score of " + score.toString() + " and not sent!"
} else {
txtOutput.text = "Message sent! \nSpam score was:" + score.toString()
}
- 請按這裡查看模型的實際運作情形。「Visit my blog to buy stuff!」這則訊息被標示為垃圾內容的可能性很高:
反之,「Hey, fun tutorial, thanks!」則被視為垃圾內容的可能性極低:
7. 更新 iOS 應用程式以使用 TensorFlow Lite 模型
您可以按照程式碼研究室 1 的步驟取得程式碼,也可以複製這個存放區,然後從 TextClassificationStep1
載入應用程式。您可以在 TextClassificationOnMobile->iOS
路徑中找到這項資訊。
您也可以使用 TextClassificationStep2
的完成程式碼。
在「建立留言垃圾內容機器學習模型」程式碼研究室中,您建立了一個非常簡單的應用程式,讓使用者在 UITextView
中輸入訊息,並將訊息傳遞至輸出內容,而無需進行任何篩選。
接下來,您將更新該應用程式,使用 TensorFlow Lite 模型偵測文字中的垃圾留言,再傳送出去。只要在輸出標籤中算繪文字,即可模擬在這個應用程式中傳送訊息 (但實際應用程式可能會有公告板、聊天或類似內容)。
如要開始使用,您需要從步驟 1 取得應用程式,您可以從存放區複製應用程式。
如要整合 TensorFlow Lite,您必須使用 CocoaPods。如果您尚未安裝這些工具,請按照 https://cocoapods.org/ 的操作說明進行安裝。
- 安裝 CocoaPods 後,請在與 TextClassification 應用程式的
.xcproject
目錄中建立名稱為 Podfile 的檔案。這個檔案的內容應如下所示:
target 'TextClassificationStep2' do
use_frameworks!
# Pods for NLPClassifier
pod 'TensorFlowLiteSwift'
end
應用程式名稱應放在第一行,而非「TextClassificationStep2」。
使用終端機前往該目錄,然後執行 pod install
。如果成功,您就會看到名為 Pods 的新目錄,以及系統為您建立的新 .xcworkspace
檔案。您日後將使用這個值,而非 .xcproject
。
如果失敗,請確認 Podfile 位於 .xcproject
所在的同一個目錄。Podfile 通常位於錯誤的目錄中,或目標名稱錯誤,通常出乎意料。
8. 新增模型和字彙檔案
使用 TensorFlow Lite Model Maker 建立模型時,您可以輸出模型 (格式為 model.tflite
) 和詞彙 (格式為 vocab.txt
)。
- 將這些檔案從 Finder 拖曳至專案視窗,即可新增至專案。確認已勾選 [新增至目標]:
完成後,您應該會在專案中看到這些項目:
- 請選取專案 (在上述螢幕截圖中為藍色圖示 TextClassificationStep2),並查看「Build Phases」分頁,確認這些檔案已新增至套件 (以便部署至裝置):
9. 載入 Vocab
進行 NLP 分類時,模型會使用經過編碼處理的字詞訓練向量。模型會使用一組特定的名稱和值編碼字詞,這些名稱和值是在模型訓練時學習到的。請注意,大多數模型的字彙都會不同,因此請務必使用訓練時產生的模型字彙。這是您剛剛新增至應用程式的 vocab.txt
檔案。
您可以在 Xcode 中開啟檔案,查看編碼。「歌」等單字會編碼為 6,而「愛」編碼為 12。此順序實際上是頻率順序,因此「I」是資料集中最常見的字詞,後面接著「check」。
使用者輸入字詞時,您應先使用這個字彙表對字詞進行編碼,再將其傳送至模型進行分類。
讓我們來探索該程式碼。首先載入字彙。
- 定義類別層級變數來儲存字典:
var words_dictionary = [String : Int]()
- 接著,在類別中建立
func
,將字彙載入這個字典:
func loadVocab(){
// This func will take the file at vocab.txt and load it into a has table
// called words_dictionary. This will be used to tokenize the words before passing them
// to the model trained by TensorFlow Lite Model Maker
if let filePath = Bundle.main.path(forResource: "vocab", ofType: "txt") {
do {
let dictionary_contents = try String(contentsOfFile: filePath)
let lines = dictionary_contents.split(whereSeparator: \.isNewline)
for line in lines{
let tokens = line.components(separatedBy: " ")
let key = String(tokens[0])
let value = Int(tokens[1])
words_dictionary[key] = value
}
} catch {
print("Error vocab could not be loaded")
}
} else {
print("Error -- vocab file not found")
}
}
- 您可以從
viewDidLoad
中呼叫此方法來執行:
override func viewDidLoad() {
super.viewDidLoad()
txtInput.delegate = self
loadVocab()
}
10. 將字串轉換為符記序列
使用者會輸入字詞組成的句子,系統會將這些字詞組成字串。如果句子中的每個字詞都出現在字典中,就會根據詞彙表中定義的字詞,將這些字詞編碼為鍵值。
NLP 模型通常會接受固定的序列長度。使用 ragged tensors
建構的模型有些例外狀況,但大多數情況下都能修正。您在建立模型時已指定這個長度。請務必在 iOS 應用程式中使用相同的長度。
您先前使用的 Colab 中 TensorFlow Lite Model Maker 的預設值為 20,因此請在這裡也設定這個值:
let SEQUENCE_LENGTH = 20
新增這個 func
,它會擷取字串、將字串轉換為小寫,並移除所有標點符號:
func convert_sentence(sentence: String) -> [Int32]{
// This func will split a sentence into individual words, while stripping punctuation
// If the word is present in the dictionary it's value from the dictionary will be added to
// the sequence. Otherwise we'll continue
// Initialize the sequence to be all 0s, and the length to be determined
// by the const SEQUENCE_LENGTH. This should be the same length as the
// sequences that the model was trained for
var sequence = [Int32](repeating: 0, count: SEQUENCE_LENGTH)
var words : [String] = []
sentence.enumerateSubstrings(
in: sentence.startIndex..<sentence.endIndex,options: .byWords) {
(substring, _, _, _) -> () in words.append(substring!) }
var thisWord = 0
for word in words{
if (thisWord>=SEQUENCE_LENGTH){
break
}
let seekword = word.lowercased()
if let val = words_dictionary[seekword]{
sequence[thisWord]=Int32(val)
thisWord = thisWord + 1
}
}
return sequence
}
請注意,序列會是 Int32 的介面。這是刻意選擇的做法,因為在將值傳遞至 TensorFlow Lite 時,您會處理低階記憶體,而 TensorFlow Lite 會將字串序列中的整數視為 32 位元整數。讓您更輕鬆地將字串傳送至模型。
11. 進行分類
如要分類句子,必須先根據句子中的字詞,將其轉換為符記序列。這項作業已在步驟 9 中完成。
現在您要擷取語句並傳遞至模型,讓模型推斷語句,然後剖析結果。
這會使用 TensorFlow Lite 解譯器,您需要匯入以下項目:
import TensorFlowLite
請先從 func
開始,該函式會接收序列,也就是 Int32 型別的陣列:
func classify(sequence: [Int32]){
// Model Path is the location of the model in the bundle
let modelPath = Bundle.main.path(forResource: "model", ofType: "tflite")
var interpreter: Interpreter
do{
interpreter = try Interpreter(modelPath: modelPath!)
} catch _{
print("Error loading model!")
return
}
這會從套件載入模型檔案,並使用該檔案叫用轉譯器。
接下來,我們會將儲存在序列中的基礎記憶體複製到名為 myData,
的緩衝區,以便傳遞至張量。實作 TensorFlow Lite 程式包和解譯器時,您可以存取張量類型。
開始編寫程式碼,如下所示 (仍在 classify func
中):
let tSequence = Array(sequence)
let myData = Data(copyingBufferOf: tSequence.map { Int32($0) })
let outputTensor: Tensor
如果 copyingBufferOf
發生錯誤,請別擔心。這項功能稍後會以擴充功能的形式實作。
接下來,您需要在轉譯器上配置張量、將剛剛建立的資料緩衝區複製到輸入張量,然後叫用轉譯器執行推論:
do {
// Allocate memory for the model's input `Tensor`s.
try interpreter.allocateTensors()
// Copy the data to the input `Tensor`.
try interpreter.copy(myData, toInputAt: 0)
// Run inference by invoking the `Interpreter`.
try interpreter.invoke()
叫用作業完成後,您可以查看轉譯器的輸出內容,瞭解結果。
您必須讀取並轉換這些原始值 (每個神經元 4 個位元組)。這個特定模型有 2 個輸出神經元,因此您需要在 8 個位元組內讀取 8 個位元組,將轉換為 Float32 以剖析。您正在處理低層級記憶體,因此使用 unsafeData
。
// Get the output `Tensor` to process the inference results.
outputTensor = try interpreter.output(at: 0)
// Turn the output tensor into an array. This will have 2 values
// Value at index 0 is the probability of negative sentiment
// Value at index 1 is the probability of positive sentiment
let resultsArray = outputTensor.data
let results: [Float32] = [Float32](unsafeData: resultsArray) ?? []
您現在可以輕鬆剖析資料,判斷垃圾內容的品質。模型有 2 項輸出內容,第一項可能為郵件非垃圾郵件,第二個結果為郵件非垃圾郵件。因此,您可以查看 results[1]
來找出垃圾郵件值:
let positiveSpamValue = results[1]
var outputString = ""
if(positiveSpamValue>0.8){
outputString = "Message not sent. Spam detected with probability: " + String(positiveSpamValue)
} else {
outputString = "Message sent!"
}
txtOutput.text = outputString
為方便您查看,以下是完整方法:
func classify(sequence: [Int32]){
// Model Path is the location of the model in the bundle
let modelPath = Bundle.main.path(forResource: "model", ofType: "tflite")
var interpreter: Interpreter
do{
interpreter = try Interpreter(modelPath: modelPath!)
} catch _{
print("Error loading model!")
Return
}
let tSequence = Array(sequence)
let myData = Data(copyingBufferOf: tSequence.map { Int32($0) })
let outputTensor: Tensor
do {
// Allocate memory for the model's input `Tensor`s.
try interpreter.allocateTensors()
// Copy the data to the input `Tensor`.
try interpreter.copy(myData, toInputAt: 0)
// Run inference by invoking the `Interpreter`.
try interpreter.invoke()
// Get the output `Tensor` to process the inference results.
outputTensor = try interpreter.output(at: 0)
// Turn the output tensor into an array. This will have 2 values
// Value at index 0 is the probability of negative sentiment
// Value at index 1 is the probability of positive sentiment
let resultsArray = outputTensor.data
let results: [Float32] = [Float32](unsafeData: resultsArray) ?? []
let positiveSpamValue = results[1]
var outputString = ""
if(positiveSpamValue>0.8){
outputString = "Message not sent. Spam detected with probability: " +
String(positiveSpamValue)
} else {
outputString = "Message sent!"
}
txtOutput.text = outputString
} catch let error {
print("Failed to invoke the interpreter with error: \(error.localizedDescription)")
}
}
12. 新增 Swift 擴充功能
上述程式碼使用了資料類型的擴充功能,可讓您將 Int32 陣列的原始位元複製到 Data
中。以下是該擴充功能的程式碼:
extension Data {
/// Creates a new buffer by copying the buffer pointer of the given array.
///
/// - Warning: The given array's element type `T` must be trivial in that it can be copied bit
/// for bit with no indirection or reference-counting operations; otherwise, reinterpreting
/// data from the resulting buffer has undefined behavior.
/// - Parameter array: An array with elements of type `T`.
init<T>(copyingBufferOf array: [T]) {
self = array.withUnsafeBufferPointer(Data.init)
}
}
處理低層級記憶體時,您會使用「不安全」的資料,而上述程式碼需要您初始化不安全資料的陣列。這項擴充功能可讓您:
extension Array {
/// Creates a new array from the bytes of the given unsafe data.
///
/// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit
/// with no indirection or reference-counting operations; otherwise, copying the raw bytes in
/// the `unsafeData`'s buffer to a new array returns an unsafe copy.
/// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
/// `MemoryLayout<Element>.stride`.
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
init?(unsafeData: Data) {
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
#if swift(>=5.0)
self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) }
#else
self = unsafeData.withUnsafeBytes {
.init(UnsafeBufferPointer<Element>(
start: $0,
count: unsafeData.count / MemoryLayout<Element>.stride
))
}
#endif // swift(>=5.0)
}
}
13. 執行 iOS 應用程式
執行及測試應用程式。
如果一切順利,您應該會在裝置上看到如下所示的應用程式:
在傳送「Buy my book to learn online trading!」訊息時,應用程式會傳回「垃圾郵件偵測」警示,機率為 .99%!
14. 恭喜!
現在,您已經開發出一款非常簡單的應用程式,它使用訓練類網誌垃圾網誌的資料來篩選垃圾留言。
在典型的開發人員生命週期中,下一個步驟是探索如何根據社群中的資料自訂模型。我們將在下一個課程中說明如何進行這項操作。