Add on-device Text Classification to your app with TensorFlow Lite and Firebase - iOS Codelab

582ceaeff9c3aca5.png

Welcome to the Text Classification with TensorFlow Lite and Firebase codelab. In this codelab you'll learn how to use TensorFlow Lite and Firebase to train and deploy a text classification model to your app. This codelab is based on this TensorFlow Lite example.

Text classification is the process of assigning tags or categories to text according to its content. It's one of the fundamental tasks in Natural Language Processing (NLP) with broad applications such as sentiment analysis, topic labeling, spam detection, and intent detection.

Sentiment analysis is the interpretation and classification of emotions (positive, negative and neutral) within text data using text analysis techniques. Sentiment analysis allows businesses to identify customer sentiment toward products, brands or services in online conversations and feedback.

This tutorial shows how to build a machine learning model for sentiment analysis, in particular classifying text as positive or negative. This is an example of binary—or two-class—classification, an important and widely applicable kind of machine learning problem.

What you'll learn

  • Train a TF Lite sentiment analysis model with TF Lite Model Maker
  • Deploy TF Lite models to Firebase ML and access them from your app
  • Track user feedback to measure model accuracy with Firebase Analytics
  • Profile model performance via Firebase Performance Monitoring
  • Select which one of multiple deployed models is loaded through Remote Config
  • Experiment with different models via Firebase A/B Testing

What you'll need

  • Xcode 11 (or higher)
  • CocoaPods 1.9.1 (or higher)

How will you use this tutorial?

Read it through only Read it and complete the exercises

How would rate your experience with building Android apps?

Novice Intermediate Proficient

Add Firebase to the project

  1. Go to the Firebase console.
  2. Select Create New Project and name your project "Firebase ML iOS Codelab".

Download the Code

Begin by cloning the sample project and running pod update in the project directory:

git clone https://github.com/FirebaseExtended/codelab-textclassification-ios.git
cd codelab-textclassification-ios
pod install --repo-update

If you don't have git installed, you can also download the sample project from its GitHub page or by clicking on this link. Once you've downloaded the project, run it in Xcode and play around with the text classification to get a feel for how it works.

Set up Firebase

Follow the documentation to create a new Firebase project. Once you've got your project, download your project's GoogleService-Info.plist file from Firebase console and drag it to the root of the Xcode project.

9efb62a92f27e939.png

Add Firebase to your Podfile and run pod install.

pod 'Firebase/MLCommon'
pod 'FirebaseMLModelInterpreter', '0.20.0'

In your AppDelegate's didFinishLaunchingWithOptions method, import Firebase at the top of the file

import Firebase

And add a call to configure Firebase.

FirebaseApp.configure()

Run the project again to make sure the app is configured correctly and does not crash on launch.

We will use TensorFlow Lite Model Maker to train a text classification model to predict sentiment of a given text.

This step is presented as a Python notebook that you can open in Google Colab.

Open in Colab

After finishing this step, you will have a TensorFlow Lite sentiment analysis model that is ready for deployment to a mobile app.

Deploying a model to Firebase ML is useful for two main reasons:

  1. We can keep the app install size small and only download the model if needed
  2. The model can be updated regularly and with a different release cycle than the entire app

The model can be deployed either via the console, or programmatically, using the Firebase Admin SDK. In this step we will deploy via the console.

First, open the Firebase Console and click on Machine Learning in the left navigation panel. Click ‘Get Started' if you are opening this first time. Then navigate to "Custom" and click on the "Add model" button.

When prompted, name the model sentiment_analysis and upload the file that you downloaded from Colab in the previous step.

3c3c50e6ef12b3b.png

Choosing when to download the remote model from Firebase into your app can be tricky since TFLite models can grow relatively large. Ideally we want to avoid loading the model immediately when the app launches, since if our model is used for only one feature and the user never uses that feature, we'll have downloaded a significant amount of data for no reason. We can also set download options such as only fetching models when connected to wifi. If you want to ensure that the model is available even without a network connection, it's important to also bundle it without the app as a backup.

For the sake of simplicity, we'll remove the default bundled model and always download a model from Firebase when the app starts for the first time. This way when running sentiment analysis you can be sure that the inference is running with the model provided from Firebase.

At the top of ModelDownloader.swift, import the Firebase module.

import Firebase

Then implement the following methods.

static func downloadModel(named name: String,
                          completion: @escaping (RemoteModel?, DownloadError?) -> Void) {
  guard FirebaseApp.app() != nil else {
    completion(nil, .firebaseNotInitialized)
    return
  }
  guard success == nil && failure == nil else {
    completion(nil, .downloadInProgress)
    return
  }

  let remoteModel = CustomRemoteModel(name: name)
  let conditions = ModelDownloadConditions(allowsCellularAccess: true,
                                           allowsBackgroundDownloading: true)

  success = NotificationCenter.default.addObserver(forName: .firebaseMLModelDownloadDidSucceed,
                                                   object: nil,
                                                   queue: nil) { (notification) in
    defer { success = nil; failure = nil }
    guard let userInfo = notification.userInfo,
        let model = userInfo[ModelDownloadUserInfoKey.remoteModel.rawValue] as? RemoteModel
    else {
      completion(nil, .downloadReturnedEmptyModel)
      return
    }
    guard model.name == name else {
      completion(nil, .downloadReturnedWrongModel)
      return
    }
    completion(model, nil)
  }
  failure = NotificationCenter.default.addObserver(forName: .firebaseMLModelDownloadDidFail,
                                                   object: nil,
                                                   queue: nil) { (notification) in
    defer { success = nil; failure = nil }
    guard let userInfo = notification.userInfo,
        let error = userInfo[ModelDownloadUserInfoKey.error.rawValue] as? Error
    else {
      completion(nil, .mlkitError(underlyingError: DownloadError.unknownError))
      return
    }
    completion(nil, .mlkitError(underlyingError: error))
  }
  ModelManager.modelManager().download(remoteModel, conditions: conditions)
}

// Attempts to fetch the model from disk, downloading the model if it does not already exist.
static func fetchModel(named name: String,
                       completion: @escaping (String?, DownloadError?) -> Void) {
  let remoteModel = CustomRemoteModel(name: name)
  if ModelManager.modelManager().isModelDownloaded(remoteModel) {
    ModelManager.modelManager().getLatestModelFilePath(remoteModel) { (path, error) in
      completion(path, error.map { DownloadError.mlkitError(underlyingError: $0) })
    }
  } else {
    downloadModel(named: name) { (model, error) in
      guard let model = model else {
        let underlyingError = error ?? DownloadError.unknownError
        let compositeError = DownloadError.mlkitError(underlyingError: underlyingError)
        completion(nil, compositeError)
        return
      }
      ModelManager.modelManager().getLatestModelFilePath(model) { (path, pathError) in
        completion(path, error.map { DownloadError.mlkitError(underlyingError: $0) })
      }
    }
  }
}

In ViewController.swift's viewDidLoad, replace the call to loadModel() with our new model download method.

// Download the model from Firebase
print("Fetching model...")
ModelDownloader.fetchModel(named: "sentiment_analysis") { (filePath, error) in
  guard let path = filePath else {
    if let error = error {
      print(error)
    }
    return
  }
  print("Model download complete")

  // TODO: Initialize an NLClassifier from the downloaded model
}

Re-run your app. After a few seconds, you should see a log in Xcode indicating the remote model has successfully downloaded. Try typing some text and confirm the behavior of the app has not changed.

Tensorflow Lite Task Library helps you integrate TensorFlow Lite models into your app with just a few lines of code. We will initialize a TFLNLClassifier instance using the TensorFlow Lite model downloaded from Firebase. Then we will use it to classify the text input from the app users and show the result on the UI.

Add the dependency

Go to the app's Podfile and add TensorFlow Lite Task Library (Text) in the app's dependencies. Make sure you add the dependency under the target 'TextClassification' declaration.

pod 'TensorFlowLiteTaskText', '~> 0.0.1-nightly'

Run pod install to install the new dependency.

Initialize a text classifier

Then we will load the sentiment analysis model downloaded from Firebase using the Task Library's NLClassifier.

ViewController.swift

Let's declare an TFLNLClassifier instance variable. At the top of the file, import the new dependency:

import TensorFlowLiteTaskText

Find this comment above the method we modified in the last step:

// TODO: Add an TFLNLClassifier property.

Replace the TODO with the following code:

private var classifier: TFLNLClassifier?

Initialize the textClassifier variable with the sentiment analysis model downloaded from Firebase. Find this comment we added in the last step:

// TODO: Initialize an NLClassifier from the downloaded model

Replace the TODO with the following code:

let options = TFLNLClassifierOptions()
classifier = TFLNLClassifier.nlClassifier(modelPath: path, options: options)

Classify text

Once the classifier instance has been set up, you can run sentiment analysis with a single method call.

ViewController.swift

In the classify(text:) method, find the TODO comment:

// TODO: Run sentiment analysis on the input text

Replace the comment with the following code:

guard let classifier = self.classifier else { return }

// Classify the text
let classifierResults = classifier.classify(text: text)

// Append the results to the list of results
let result = ClassificationResult(text: text, results: classifierResults)
results.append(result)

You have integrated the sentiment analysis model to the app, so let's test it. Connect your iOS device, and click Run ( execute.png) in the Xcode toolbar.

The app should be able to correctly predict the sentiment of the movie review that you enter.

582ceaeff9c3aca5.png

Besides hosting your TFLite models, Firebase provides several other features to power up your machine learning use cases:

  • Firebase Performance Monitoring to measure your model inference speed running on users' device.
  • Firebase Analytics to measure how good your model performs in production by measuring user reaction.
  • Firebase A/B Testing to test multiple versions of your model
  • Did you remember we trained two versions of our TFLite model earlier? A/B testing is a good way to find out which version performs better in production!

To learn more about how to leverage these features in your app, check out the codelabs below:

In this codelab, you learned how to train a sentiment analysis TFLite model and deploy it to your mobile app using Firebase. To learn more about TFLite and Firebase, take a look at other TFLite samples and the Firebase getting started guides.

What we've covered

  • TensorFlow Lite
  • Firebase ML

Next Steps

  • Measure your model inference speed with Firebase Performance Monitoring.
  • Deploy the model from Colab directly to Firebase via the Firebase ML Model Management API.
  • Add a mechanism to allow users to feedback on the prediction result, and use Firebase Analytics to track user feedback.
  • A/B test the Average Word Vector model and the MobileBERT model with Firebase A/B testing.

Learn More

Have a Question?

Report Issues