1. Before you begin
In this codelab, you'll update the app you built in the previous Get started with mobile text classification codelabs.
Prerequisites
- This codelab has been designed for experienced developers new to machine learning.
- The codelab is part of a sequenced pathway. If you have not already completed Build a basic messaging style app or Build a comment spam machine learning model, please stop and do so now.
What you'll [build or learn]
- You'll learn how to integrate your custom model into your app, built in the previous steps.
What you'll need
- Android Studio, or CocoaPods for iOS
2. Open the Existing Android App
You can get the code for this by following Codelab 1, or by cloning this repo and loading the app from TextClassificationStep1
.
git clone https://github.com/googlecodelabs/odml-pathways
You can find this in the TextClassificationOnMobile->Android
path.
The finished code is also available for you as TextClassificationStep2
.
Once it's opened up, you're ready to move on to step 2.
3. Import the Model File and Metadata
In the Build a comment spam machine learning model codelab, you created a .TFLITE model.
You should have downloaded the model file. If you don't have it, you can get it from the repo for this codelab, and the model is available here.
Add it to your project by creating an assets directory.
- Using the project navigator, make sure Android is selected at the top.
- Right-click the app folder. Select New > Directory.
- In the New Directory dialog, select src/main/assets.
You'll see a new assets folder is now available in the app.
- Right click assets.
- On the menu that opens, you'll see (on mac) Reveal in Finder. Select it. (On Windows it will say Show in Explorer, on Ubuntu it will say Show in Files.)
Finder will launch to show the files location (File Explorer on Windows, Files on Linux).
- Copy the
labels.txt
,model.tflite
andvocab
files to this directory.
- Return to Android Studio, and you'll see them available in your assets folder.
4. Update your build.gradle to use TensorFlow Lite
To use TensorFlow Lite, and the TensorFlow Lite task libraries that support it, you'll need to update your build.gradle
file.
Android projects often have more than one, so be sure to find the app level one. In the project explorer in Android view, find it in your Gradle Scripts section. The correct one will be labelled with .app as shown here:
You'll need to make two changes to this file. The first is in the dependencies section at the bottom. Add a text implementation
for the TensorFlow Lite task library, like this:
implementation 'org.tensorflow:tensorflow-lite-task-text:0.1.0'
The version number may have changed since this was written, so be sure to check https://www.tensorflow.org/lite/inference_with_metadata/task_library/nl_classifier for the latest.
The task libraries also require a minimum SDK version of 21. Find this setting in the android
> default config
, and change it to 21:
You now have all your dependencies, so it's time to start coding!
5. Add a Helper Class
To separate the inference logic, where your app uses the model, from the user interface, create another class to handle the model inference. Call this a "helper" class.
- Right click the package name that your
MainActivity
code is in. - Select New > Package.
- You'll see a dialog in the center of the screen asking you to enter the package name. Add it at the end of the current package name. (Here, it's called helpers.)
- Once this is done, right click the helpers folder in project explorer.
- Select New > Java Class, and call it
TextClassificationClient
. You'll edit the file in the next step.
Your TextClassificationClient
helper class will look like this (although your package name may be different.)
package com.google.devrel.textclassificationstep1.helpers;
public class TextClassificationClient {
}
- Update the file with this code:
package com.google.devrel.textclassificationstep2.helpers;
import android.content.Context;
import android.util.Log;
import java.io.IOException;
import java.util.List;
import org.tensorflow.lite.support.label.Category;
import org.tensorflow.lite.task.text.nlclassifier.NLClassifier;
public class TextClassificationClient {
private static final String MODEL_PATH = "model.tflite";
private static final String TAG = "CommentSpam";
private final Context context;
NLClassifier classifier;
public TextClassificationClient(Context context) {
this.context = context;
}
public void load() {
try {
classifier = NLClassifier.createFromFile(context, MODEL_PATH);
} catch (IOException e) {
Log.e(TAG, e.getMessage());
}
}
public void unload() {
classifier.close();
classifier = null;
}
public List<Category> classify(String text) {
List<Category> apiResults = classifier.classify(text);
return apiResults;
}
}
This class will provide a wrapper to the TensorFlow Lite interpreter, loading the model and abstracting the complexity of managing the data interchange between your app and the model.
In the load()
method, it will instantiate a new NLClassifier
type from the model path. The model path is simply the name of the model, model.tflite
. The NLClassifier
type is part of the text tasks libraries, and it helps you by converting your string into tokens, using the correct sequence length, passing it to the model, and parsing the results.
(For more details on these, revisit Build a comment spam machine learning model.)
The classification is performed in the classify method, where you pass it a string, and it will return a List
. When using Machine Learning models to classify content where you want to determine if a string is spam or not, it's common for all answers to be returned, with assigned probabilities. For example, if you pass it a message that looks like spam, you'll get a list of 2 answers back; one with the probability that it is spam, and one with the probability that it isn't. The Spam/Not Spam are categories, so the List
returned will contain these probabilities. You'll parse that out later.
Now that you have the helper class, go back to your MainActivity
and update it to use this to classify your text. You'll see that in the next step!
6. Classify the Text
In your MainActivity
you'll first want to import the helpers that you just created!
- At the top of
MainActivity.kt
, along with the other imports, add:
import com.google.devrel.textclassificationstep2.helpers.TextClassificationClient
import org.tensorflow.lite.support.label.Category
- Next, you'll want to load the helpers. In
onCreate
, immediately after thesetContentView
line, add these lines to instantiate and load the helper class:
val client = TextClassificationClient(applicationContext)
client.load()
At the moment, your button's onClickListener
should look like this:
btnSendText.setOnClickListener {
var toSend:String = txtInput.text.toString()
txtOutput.text = toSend
}
- Update it to look like this:
btnSendText.setOnClickListener {
var toSend:String = txtInput.text.toString()
var results:List<Category> = client.classify(toSend)
val score = results[1].score
if(score>0.8){
txtOutput.text = "Your message was detected as spam with a score of " + score.toString() + " and not sent!"
} else {
txtOutput.text = "Message sent! \nSpam score was:" + score.toString()
}
txtInput.text.clear()
}
This changes the functionality from just outputting the user's input, to classifying it first.
- With this line, you'll take the string the user entered and pass it to the model, getting back results:
var results:List<Category> = client.classify(toSend)
There are only 2 categories, False
and True
. (TensorFlow sorts them alphabetically, so False will be item 0, and True will be item 1.)
- To get the score for the probability that the value is
True
, you can look at results[1].score like this:
val score = results[1].score
- Picked a threshold value (in this case 0.8), where you say that if the score for the True category is above the threshold value (0.8), then the message is spam. Otherwise, it isn't spam and the message is safe to send:
if(score>0.8){
txtOutput.text = "Your message was detected as spam with a score of " + score.toString() + " and not sent!"
} else {
txtOutput.text = "Message sent! \nSpam score was:" + score.toString()
}
- See the model in action here. The message "Visit my blog to buy stuff!" was flagged as a high likelihood for spam:
And conversely, "Hey, fun tutorial, thanks!" was seen to be a very low likelihood of being spam:
7. Update your iOS App to use the TensorFlow Lite Model
You can get the code for this by following Codelab 1, or by cloning this repo and loading the app from TextClassificationStep1
. You can find this in the TextClassificationOnMobile->iOS
path.
The finished code is also available for you as TextClassificationStep2
.
In the Build a comment spam machine learning model codelab, you created a very simple app that allowed the user to type a message into a UITextView
and have it passed through to an output without any filtering.
Now you'll update that app to use a TensorFlow Lite model to detect comment spam in the text prior to sending. Just simulate the sending in this app by rendering the text in an output label (but a real app might have a bulletin board, a chat, or something similar).
To get started, you'll need the app from step 1, which you can clone from the repo.
To incorporate TensorFlow Lite, you'll use CocoaPods. If you don't have these installed already, you can do so with the instructions at https://cocoapods.org/.
- Once you have CocoaPods installed, create a file with the name Podfile in the same directory as the
.xcproject
for the TextClassification app. The contents of this file should look like this:
target 'TextClassificationStep2' do
use_frameworks!
# Pods for NLPClassifier
pod 'TensorFlowLiteSwift'
end
The name of your app should be in the first line, instead of "TextClassificationStep2."
Using Terminal, navigate to that directory and run pod install
. If it's successful, you'll have a new directory called Pods, and a new .xcworkspace
file created for you. You'll use that in future instead of the .xcproject
.
If it failed, please make sure you have Podfile in the same directory where .xcproject
had been. The podfile in the wrong directory, or the wrong target name, are usually the main culprits!
8. Add the Model and Vocab Files
When you created the model with TensorFlow Lite Model maker, you were able to output the model (as model.tflite
) and the vocab (as vocab.txt
).
- Add them to your project by dragging and dropping them from Finder into your project window. Make sure add to targets is checked:
When you're done, you should see them in your project:
- Double-check that they are added to the bundle (so that they get deployed to a device) by selecting your project (in the above screenshot, it's the blue icon TextClassificationStep2), and looking at the Build Phases tab:
9. Load the Vocab
When doing NLP classification, the model is trained with words encoded into vectors. The model encodes words with a specific set of names and values that are learned as the model trains. Please note that most models will have different vocabularies, and it's important for you to use the vocab for your model that was generated at the time of training. This is the vocab.txt
file you just added to your app.
You can open the file in Xcode to see the encodings. Words like "song" are encoded to 6 and "love" to 12. The order is actually frequency order, so "I" was the most common word in the dataset, followed by "check."
When your user types in words, you'll want to encode them with this vocabulary prior to sending them to the model to be classified.
Let's explore that code. Start by loading the vocabulary.
- Define a class level variable to store the dictionary:
var words_dictionary = [String : Int]()
- Then create a
func
in the class to load the vocab into this dictionary:
func loadVocab(){
// This func will take the file at vocab.txt and load it into a has table
// called words_dictionary. This will be used to tokenize the words before passing them
// to the model trained by TensorFlow Lite Model Maker
if let filePath = Bundle.main.path(forResource: "vocab", ofType: "txt") {
do {
let dictionary_contents = try String(contentsOfFile: filePath)
let lines = dictionary_contents.split(whereSeparator: \.isNewline)
for line in lines{
let tokens = line.components(separatedBy: " ")
let key = String(tokens[0])
let value = Int(tokens[1])
words_dictionary[key] = value
}
} catch {
print("Error vocab could not be loaded")
}
} else {
print("Error -- vocab file not found")
}
}
- You can run this by calling it from within
viewDidLoad
:
override func viewDidLoad() {
super.viewDidLoad()
txtInput.delegate = self
loadVocab()
}
10. Turn a string into a sequence of tokens
Your users will type words in as a sentence which will become a string. Each word in the sentence, if present in the dictionary, will be encoded into the key value for the word as defined in the vocab.
An NLP model typically accepts a fixed sequence length. There are exceptions with models built using ragged tensors
, but for the most part you'll see it's fixed. When you created your model you specified this length. Be sure you use the same length in your iOS app.
The default in the Colab for TensorFlow Lite Model Maker you used earlier was 20, so set that up here too:
let SEQUENCE_LENGTH = 20
Add this func
which will take the string, convert it to lowercase, and strip out any punctuation:
func convert_sentence(sentence: String) -> [Int32]{
// This func will split a sentence into individual words, while stripping punctuation
// If the word is present in the dictionary it's value from the dictionary will be added to
// the sequence. Otherwise we'll continue
// Initialize the sequence to be all 0s, and the length to be determined
// by the const SEQUENCE_LENGTH. This should be the same length as the
// sequences that the model was trained for
var sequence = [Int32](repeating: 0, count: SEQUENCE_LENGTH)
var words : [String] = []
sentence.enumerateSubstrings(
in: sentence.startIndex..<sentence.endIndex,options: .byWords) {
(substring, _, _, _) -> () in words.append(substring!) }
var thisWord = 0
for word in words{
if (thisWord>=SEQUENCE_LENGTH){
break
}
let seekword = word.lowercased()
if let val = words_dictionary[seekword]{
sequence[thisWord]=Int32(val)
thisWord = thisWord + 1
}
}
return sequence
}
Note that the sequence will be Int32's. This is deliberately chosen because when it comes to passing values to TensorFlow Lite, you'll be dealing with low-level memory, and TensorFlow Lite treats the integers in a string sequence as 32-bit integers. This will make your life (a little) easier when it comes to passing strings to the model.
11. Do the Classification
To classify a sentence, it must first be converted into a sequence of tokens based on the words in the sentence. This will have been done in step 9.
You'll now take the sentence and pass it to the model, have the model do inference on the sentence, and parse the results.
This will use the TensorFlow Lite interpreter, which you'll need to import:
import TensorFlowLite
Start with a func
that takes in your sequence, which was an array of Int32 types:
func classify(sequence: [Int32]){
// Model Path is the location of the model in the bundle
let modelPath = Bundle.main.path(forResource: "model", ofType: "tflite")
var interpreter: Interpreter
do{
interpreter = try Interpreter(modelPath: modelPath!)
} catch _{
print("Error loading model!")
return
}
This will load the model file from the bundle, and invoke an interpreter with it.
The next step will be to copy the underlying memory stored in the sequence into a buffer called myData,
so it can be passed to a tensor. When implementing the TensorFlow Lite pod, as well as the interpreter, you got access to a Tensor Type.
Start the code like this (still in the classify func
.):
let tSequence = Array(sequence)
let myData = Data(copyingBufferOf: tSequence.map { Int32($0) })
let outputTensor: Tensor
Don't worry if you get an error on copyingBufferOf
. This will be implemented as an extension later.
Now it's time to allocate tensors on the interpreter, copy the data buffer you just created to the input tensor, and then invoke the interpreter to do the inference:
do {
// Allocate memory for the model's input `Tensor`s.
try interpreter.allocateTensors()
// Copy the data to the input `Tensor`.
try interpreter.copy(myData, toInputAt: 0)
// Run inference by invoking the `Interpreter`.
try interpreter.invoke()
Once the invocation is complete, you can look at the output of the interpreter to see the results.
These will be raw values (4 bytes per neuron) which you'll then have to read in and convert. As this particular model has 2 output neurons, you'll need to read in 8 bytes that will be converted into Float32's for parsing. You are dealing with low level memory, hence the unsafeData
.
// Get the output `Tensor` to process the inference results.
outputTensor = try interpreter.output(at: 0)
// Turn the output tensor into an array. This will have 2 values
// Value at index 0 is the probability of negative sentiment
// Value at index 1 is the probability of positive sentiment
let resultsArray = outputTensor.data
let results: [Float32] = [Float32](unsafeData: resultsArray) ?? []
Now it's relatively easy to parse the data to determine the spam quality. The model has 2 outputs, the first with the probability that the message is not spam, the second with the probability that it is. So you can look at results[1]
to find the spam value:
let positiveSpamValue = results[1]
var outputString = ""
if(positiveSpamValue>0.8){
outputString = "Message not sent. Spam detected with probability: " + String(positiveSpamValue)
} else {
outputString = "Message sent!"
}
txtOutput.text = outputString
For convenience, here's the full method:
func classify(sequence: [Int32]){
// Model Path is the location of the model in the bundle
let modelPath = Bundle.main.path(forResource: "model", ofType: "tflite")
var interpreter: Interpreter
do{
interpreter = try Interpreter(modelPath: modelPath!)
} catch _{
print("Error loading model!")
Return
}
let tSequence = Array(sequence)
let myData = Data(copyingBufferOf: tSequence.map { Int32($0) })
let outputTensor: Tensor
do {
// Allocate memory for the model's input `Tensor`s.
try interpreter.allocateTensors()
// Copy the data to the input `Tensor`.
try interpreter.copy(myData, toInputAt: 0)
// Run inference by invoking the `Interpreter`.
try interpreter.invoke()
// Get the output `Tensor` to process the inference results.
outputTensor = try interpreter.output(at: 0)
// Turn the output tensor into an array. This will have 2 values
// Value at index 0 is the probability of negative sentiment
// Value at index 1 is the probability of positive sentiment
let resultsArray = outputTensor.data
let results: [Float32] = [Float32](unsafeData: resultsArray) ?? []
let positiveSpamValue = results[1]
var outputString = ""
if(positiveSpamValue>0.8){
outputString = "Message not sent. Spam detected with probability: " +
String(positiveSpamValue)
} else {
outputString = "Message sent!"
}
txtOutput.text = outputString
} catch let error {
print("Failed to invoke the interpreter with error: \(error.localizedDescription)")
}
}
12. Add the Swift Extensions
The above code used an extension to the Data type to allow you to copy the raw bits of an Int32 array into a Data
. Here's the code for that extension:
extension Data {
/// Creates a new buffer by copying the buffer pointer of the given array.
///
/// - Warning: The given array's element type `T` must be trivial in that it can be copied bit
/// for bit with no indirection or reference-counting operations; otherwise, reinterpreting
/// data from the resulting buffer has undefined behavior.
/// - Parameter array: An array with elements of type `T`.
init<T>(copyingBufferOf array: [T]) {
self = array.withUnsafeBufferPointer(Data.init)
}
}
When dealing with low level memory, you use "unsafe" data, and the above code needs you to initialize an array of unsafe data. This extension makes that possible:
extension Array {
/// Creates a new array from the bytes of the given unsafe data.
///
/// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit
/// with no indirection or reference-counting operations; otherwise, copying the raw bytes in
/// the `unsafeData`'s buffer to a new array returns an unsafe copy.
/// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
/// `MemoryLayout<Element>.stride`.
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
init?(unsafeData: Data) {
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
#if swift(>=5.0)
self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) }
#else
self = unsafeData.withUnsafeBytes {
.init(UnsafeBufferPointer<Element>(
start: $0,
count: unsafeData.count / MemoryLayout<Element>.stride
))
}
#endif // swift(>=5.0)
}
}
13. Run the iOS App
Run and test the app.
If all went well, you should see the app on your device like this:
Where the message "Buy my book to learn online trading!" was sent, the app sends back a spam detected alert with a probability of .99%!
14. Congratulations!
You've now created a very simple app that filters text for comment spam using a model that was trained on data used to spam blogs.
The next step in the typical developer lifecycle is to then explore what it would take to customize the model based on data found in your own community. You'll see how to do that in the next pathway activity.