1. Прежде чем начать
В этой лаборатории кода вы обновите приложение, созданное в предыдущей лаборатории кода «Начало работы с мобильной классификацией текста».
Предварительные условия
- Эта лаборатория кода предназначена для опытных разработчиков, плохо знакомых с машинным обучением.
- Codelab является частью секвенированного пути. Если вы еще не завершили создание базового приложения для обмена сообщениями или создание модели машинного обучения для спама в комментариях, остановитесь и сделайте это сейчас.
Что вы будете [построить или изучить]
- Вы узнаете, как интегрировать свою пользовательскую модель в свое приложение, созданное на предыдущих шагах.
Что вам понадобится
- Android Studio или CocoaPods для iOS
2. Откройте существующее приложение для Android.
Вы можете получить код для этого, следуя Codelab 1 или клонировав этот репозиторий и загрузив приложение из TextClassificationStep1
.
git clone https://github.com/googlecodelabs/odml-pathways
Вы можете найти это по пути TextClassificationOnMobile->Android
.
Готовый код также доступен вам как TextClassificationStep2
.
Как только он откроется, вы готовы перейти к шагу 2.
3. Импортируйте файл модели и метаданные.
В лаборатории кода «Создание модели машинного обучения для спама в комментариях» вы создали модель .TFLITE.
Вы должны были загрузить файл модели. Если у вас его нет, вы можете получить его из репозитория для этой кодовой лаборатории, а модель доступна здесь .
Добавьте его в свой проект, создав каталог ресурсов.
- Используя навигатор проекта, убедитесь, что вверху выбран Android .
- Щелкните правой кнопкой мыши папку приложения . Выберите Создать > Каталог.
- В диалоговом окне «Новый каталог» выберите src/main/assets .
Вы увидите, что в приложении теперь доступна новая папка с ресурсами .
- Щелкните правой кнопкой мыши активы.
- В открывшемся меню вы увидите (на Mac) «Показать в Finder» . Выберите его. (В Windows будет написано «Показать в проводнике» , в Ubuntu — « Показать в файлах» .)
Запустится Finder , чтобы показать расположение файлов ( Проводник в Windows, Файлы в Linux).
- Скопируйте файлы
labels.txt
,model.tflite
иvocab
в этот каталог.
- Вернитесь в Android Studio, и вы увидите их в папке ресурсов .
4. Обновите build.gradle, чтобы использовать TensorFlow Lite.
Чтобы использовать TensorFlow Lite и поддерживающие его библиотеки задач TensorFlow Lite, вам необходимо обновить файл build.gradle
.
Проекты Android часто имеют более одного, поэтому обязательно найдите первый уровень приложения . В обозревателе проектов в представлении Android найдите его в разделе «Скрипты Gradle» . Правильный вариант будет помечен расширением .app , как показано здесь:
Вам нужно будет внести два изменения в этот файл. Первый находится в разделе зависимостей внизу. Добавьте текстовую implementation
библиотеки задач TensorFlow Lite, например:
implementation 'org.tensorflow:tensorflow-lite-task-text:0.1.0'
Номер версии мог измениться с момента его написания, поэтому обязательно проверьте https://www.tensorflow.org/lite/inference_with_metadata/task_library/nl_classifier на наличие последней версии.
Для библиотек задач также требуется минимальная версия SDK 21. Найдите этот параметр в android
> default config
и измените его на 21:
Теперь у вас есть все зависимости, так что пора приступить к кодированию!
5. Добавьте вспомогательный класс
Чтобы отделить логику вывода , в которой ваше приложение использует модель, от пользовательского интерфейса, создайте еще один класс для обработки вывода модели. Назовите это «вспомогательным» классом.
- Щелкните правой кнопкой мыши имя пакета, в котором находится ваш код
MainActivity
. - Выберите Создать > Пакет .
- В центре экрана вы увидите диалоговое окно с просьбой ввести имя пакета. Добавьте его в конец текущего имени пакета. (Здесь это называется помощниками .)
- Как только это будет сделано, щелкните правой кнопкой мыши папку помощников в проводнике проекта.
- Выберите «Создать» > «Класс Java» и назовите его
TextClassificationClient
. Вы отредактируете файл на следующем шаге.
Ваш вспомогательный класс TextClassificationClient
будет выглядеть следующим образом (хотя имя вашего пакета может быть другим).
package com.google.devrel.textclassificationstep1.helpers;
public class TextClassificationClient {
}
- Обновите файл с помощью этого кода:
package com.google.devrel.textclassificationstep2.helpers;
import android.content.Context;
import android.util.Log;
import java.io.IOException;
import java.util.List;
import org.tensorflow.lite.support.label.Category;
import org.tensorflow.lite.task.text.nlclassifier.NLClassifier;
public class TextClassificationClient {
private static final String MODEL_PATH = "model.tflite";
private static final String TAG = "CommentSpam";
private final Context context;
NLClassifier classifier;
public TextClassificationClient(Context context) {
this.context = context;
}
public void load() {
try {
classifier = NLClassifier.createFromFile(context, MODEL_PATH);
} catch (IOException e) {
Log.e(TAG, e.getMessage());
}
}
public void unload() {
classifier.close();
classifier = null;
}
public List<Category> classify(String text) {
List<Category> apiResults = classifier.classify(text);
return apiResults;
}
}
Этот класс предоставит оболочку интерпретатору TensorFlow Lite, загружая модель и абстрагируя сложность управления обменом данными между вашим приложением и моделью.
В методе load()
он создаст экземпляр нового типа NLClassifier
из пути к модели. Путь к модели — это просто имя модели model.tflite
. Тип NLClassifier
является частью библиотеки текстовых задач и помогает вам преобразовывать строку в токены, используя правильную длину последовательности, передавая ее в модель и анализируя результаты.
(Подробнее об этом см. в разделе «Создание модели машинного обучения для спама в комментариях».)
Классификация выполняется в методе classify, где вы передаете ему строку, и он возвращает List
. При использовании моделей машинного обучения для классификации контента, в котором вы хотите определить, является ли строка спамом или нет, все ответы обычно возвращаются с назначенными вероятностями. Например, если вы передадите ему сообщение, похожее на спам, вы получите обратно список из двух ответов; один с вероятностью, что это спам, и один с вероятностью, что это не спам. Спам/Не спам являются категориями, поэтому возвращаемый List
будет содержать эти вероятности. Вы разберёте это позже.
Теперь, когда у вас есть вспомогательный класс, вернитесь к MainActivity
и обновите его, чтобы использовать его для классификации вашего текста. Вы увидите это на следующем шаге!
6. Классифицируйте текст
В вашей MainActivity
сначала вам нужно импортировать только что созданных вами помощников!
- В верхней части
MainActivity.kt
, наряду с остальными импортируемыми данными, добавьте:
import com.google.devrel.textclassificationstep2.helpers.TextClassificationClient
import org.tensorflow.lite.support.label.Category
- Далее вам нужно загрузить помощников. В
onCreate
сразу после строкиsetContentView
добавьте эти строки для создания экземпляра и загрузки вспомогательного класса:
val client = TextClassificationClient(applicationContext)
client.load()
На данный момент onClickListener
вашей кнопки должен выглядеть так:
btnSendText.setOnClickListener {
var toSend:String = txtInput.text.toString()
txtOutput.text = toSend
}
- Обновите его, чтобы он выглядел так:
btnSendText.setOnClickListener {
var toSend:String = txtInput.text.toString()
var results:List<Category> = client.classify(toSend)
val score = results[1].score
if(score>0.8){
txtOutput.text = "Your message was detected as spam with a score of " + score.toString() + " and not sent!"
} else {
txtOutput.text = "Message sent! \nSpam score was:" + score.toString()
}
txtInput.text.clear()
}
Это меняет функциональность с простого вывода введенных пользователем данных на их предварительную классификацию.
- С помощью этой строки вы возьмете введенную пользователем строку и передадите ее модели, получив результаты:
var results:List<Category> = client.classify(toSend)
Есть только 2 категории: False
и True
. (TensorFlow сортирует их в алфавитном порядке, поэтому False будет элементом 0, а True — элементом 1.)
- Чтобы получить оценку вероятности того, что значение равно
True
, вы можете посмотреть результаты[1].score следующим образом:
val score = results[1].score
- Выбрано пороговое значение (в данном случае 0,8), где вы говорите, что если оценка для категории «Истина» превышает пороговое значение (0,8), то сообщение является спамом. В противном случае это не спам и сообщение можно смело отправлять:
if(score>0.8){
txtOutput.text = "Your message was detected as spam with a score of " + score.toString() + " and not sent!"
} else {
txtOutput.text = "Message sent! \nSpam score was:" + score.toString()
}
- Посмотреть модель в действии можно здесь. Сообщение «Посетите мой блог, чтобы купить что-нибудь!» было помечено как спам с высокой вероятностью:
И наоборот: «Эй, интересный урок, спасибо!» считалось, что вероятность быть спамом очень низкая:
7. Обновите приложение iOS, чтобы использовать модель TensorFlow Lite.
Вы можете получить код для этого, следуя Codelab 1 или клонировав этот репозиторий и загрузив приложение из TextClassificationStep1
. Вы можете найти это по пути TextClassificationOnMobile->iOS
.
Готовый код также доступен вам как TextClassificationStep2
.
В кодовой лаборатории «Создание модели машинного обучения для спама в комментариях» вы создали очень простое приложение, которое позволяло пользователю вводить сообщение в UITextView
и передавать его на выходные данные без какой-либо фильтрации.
Теперь вы обновите это приложение, чтобы использовать модель TensorFlow Lite для обнаружения спама в комментариях в тексте перед отправкой. Просто смоделируйте отправку в этом приложении, визуализируя текст в выходной метке (но в реальном приложении может быть доска объявлений, чат или что-то подобное).
Для начала вам понадобится приложение из шага 1, которое вы можете клонировать из репозитория.
Чтобы включить TensorFlow Lite, вы будете использовать CocoaPods. Если они у вас еще не установлены, вы можете сделать это, следуя инструкциям на https://cocoapods.org/ .
- После установки CocoaPods создайте файл с именем Podfile в том же каталоге, что и
.xcproject
для приложения TextClassification. Содержимое этого файла должно выглядеть следующим образом:
target 'TextClassificationStep2' do
use_frameworks!
# Pods for NLPClassifier
pod 'TensorFlowLiteSwift'
end
Имя вашего приложения должно быть в первой строке вместо «TextClassificationStep2».
Используя терминал, перейдите в этот каталог и запустите pod install
. Если все пройдет успешно, у вас появится новый каталог под названием Pods и новый файл .xcworkspace
. В будущем вы будете использовать его вместо .xcproject
.
Если это не удалось, убедитесь, что у вас есть Podfile в том же каталоге, где находился .xcproject
. Обычно главными виновниками являются подфайл в неправильном каталоге или неправильное целевое имя!
8. Добавьте файлы модели и словаря.
Когда вы создали модель с помощью средства создания моделей TensorFlow Lite, вы смогли вывести модель (как model.tflite
) и словарь (как vocab.txt
).
- Добавьте их в свой проект, перетащив их из Finder в окно проекта. Убедитесь, что флажок «Добавить к целям» установлен:
Когда вы закончите, вы должны увидеть их в своем проекте:
- Дважды проверьте, что они добавлены в пакет (чтобы они были развернуты на устройстве), выбрав свой проект (на снимке экрана выше это синий значок TextClassificationStep2 ) и просмотрев вкладку «Фазы сборки» :
9. Загрузите словарь
При классификации НЛП модель обучается с помощью слов, закодированных в векторы. Модель кодирует слова с определенным набором имен и значений, которые изучаются в процессе обучения модели. Обратите внимание, что большинство моделей будут иметь разные словари, и вам важно использовать словарь для вашей модели, который был создан во время обучения. Это файл vocab.txt
, который вы только что добавили в свое приложение.
Вы можете открыть файл в Xcode, чтобы увидеть кодировки. Такие слова, как «песня», закодированы цифрой 6, а «любовь» — цифрой 12. На самом деле этот порядок соответствует частоте , поэтому «я» было самым распространенным словом в наборе данных, за которым следовало «проверка».
Когда ваш пользователь вводит слова, вы захотите закодировать их с помощью этого словаря, прежде чем отправлять их в модель для классификации.
Давайте исследуем этот код. Начните с загрузки словарного запаса.
- Определите переменную уровня класса для хранения словаря:
var words_dictionary = [String : Int]()
- Затем создайте в классе
func
для загрузки словаря в этот словарь:
func loadVocab(){
// This func will take the file at vocab.txt and load it into a has table
// called words_dictionary. This will be used to tokenize the words before passing them
// to the model trained by TensorFlow Lite Model Maker
if let filePath = Bundle.main.path(forResource: "vocab", ofType: "txt") {
do {
let dictionary_contents = try String(contentsOfFile: filePath)
let lines = dictionary_contents.split(whereSeparator: \.isNewline)
for line in lines{
let tokens = line.components(separatedBy: " ")
let key = String(tokens[0])
let value = Int(tokens[1])
words_dictionary[key] = value
}
} catch {
print("Error vocab could not be loaded")
}
} else {
print("Error -- vocab file not found")
}
}
- Вы можете запустить это, вызвав его из
viewDidLoad
:
override func viewDidLoad() {
super.viewDidLoad()
txtInput.delegate = self
loadVocab()
}
10. Превратите строку в последовательность токенов
Ваши пользователи будут вводить слова в виде предложения, которое станет строкой. Каждое слово в предложении, если оно присутствует в словаре, будет закодировано в ключевое значение слова, определенное в словаре.
Модель НЛП обычно принимает фиксированную длину последовательности. Есть исключения для моделей, построенных с использованием ragged tensors
, но по большей части вы увидите, что это исправлено. Когда вы создавали свою модель, вы указали эту длину. Убедитесь, что вы используете одинаковую длину в своем приложении для iOS.
По умолчанию в Colab для TensorFlow Lite Model Maker, который вы использовали ранее, было 20, поэтому установите его и здесь:
let SEQUENCE_LENGTH = 20
Добавьте эту func
, которая будет принимать строку, преобразовывать ее в нижний регистр и удалять все знаки препинания:
func convert_sentence(sentence: String) -> [Int32]{
// This func will split a sentence into individual words, while stripping punctuation
// If the word is present in the dictionary it's value from the dictionary will be added to
// the sequence. Otherwise we'll continue
// Initialize the sequence to be all 0s, and the length to be determined
// by the const SEQUENCE_LENGTH. This should be the same length as the
// sequences that the model was trained for
var sequence = [Int32](repeating: 0, count: SEQUENCE_LENGTH)
var words : [String] = []
sentence.enumerateSubstrings(
in: sentence.startIndex..<sentence.endIndex,options: .byWords) {
(substring, _, _, _) -> () in words.append(substring!) }
var thisWord = 0
for word in words{
if (thisWord>=SEQUENCE_LENGTH){
break
}
let seekword = word.lowercased()
if let val = words_dictionary[seekword]{
sequence[thisWord]=Int32(val)
thisWord = thisWord + 1
}
}
return sequence
}
Обратите внимание, что последовательность будет Int32. Это выбрано намеренно, поскольку когда дело доходит до передачи значений в TensorFlow Lite, вы будете иметь дело с низкоуровневой памятью, а TensorFlow Lite обрабатывает целые числа в последовательности строк как 32-битные целые числа. Это (немного) облегчит вам жизнь, когда дело дойдет до передачи строк в модель.
11. Проведите классификацию
Чтобы классифицировать предложение, его сначала необходимо преобразовать в последовательность токенов на основе слов в предложении. Это будет сделано на шаге 9.
Теперь вы возьмете предложение и передадите его модели, а модель сделает вывод на основе предложения и проанализирует результаты.
При этом будет использоваться интерпретатор TensorFlow Lite, который вам необходимо импортировать:
import TensorFlowLite
Начните с func
, которая принимает вашу последовательность, которая представляет собой массив типов Int32:
func classify(sequence: [Int32]){
// Model Path is the location of the model in the bundle
let modelPath = Bundle.main.path(forResource: "model", ofType: "tflite")
var interpreter: Interpreter
do{
interpreter = try Interpreter(modelPath: modelPath!)
} catch _{
print("Error loading model!")
return
}
Это загрузит файл модели из пакета и вызовет с его помощью интерпретатор.
Следующим шагом будет копирование базовой памяти, хранящейся в последовательности, в буфер под названием myData,
чтобы ее можно было передать тензору. При реализации модуля TensorFlow Lite, а также интерпретатора вы получили доступ к типу Tensor.
Запустите код следующим образом (все еще в func
классификации):
let tSequence = Array(sequence)
let myData = Data(copyingBufferOf: tSequence.map { Int32($0) })
let outputTensor: Tensor
Не волнуйтесь, если при copyingBufferOf
вы получите ошибку. Позже это будет реализовано как расширение.
Теперь пришло время выделить тензоры в интерпретаторе, скопировать только что созданный буфер данных во входной тензор, а затем вызвать интерпретатор для выполнения вывода:
do {
// Allocate memory for the model's input `Tensor`s.
try interpreter.allocateTensors()
// Copy the data to the input `Tensor`.
try interpreter.copy(myData, toInputAt: 0)
// Run inference by invoking the `Interpreter`.
try interpreter.invoke()
После завершения вызова вы можете просмотреть вывод интерпретатора, чтобы увидеть результаты.
Это будут необработанные значения (4 байта на нейрон), которые вам затем придется считать и преобразовать. Поскольку эта конкретная модель имеет 2 выходных нейрона, вам нужно будет прочитать 8 байтов, которые будут преобразованы в Float32 для анализа. Вы имеете дело с памятью низкого уровня, отсюда и unsafeData
.
// Get the output `Tensor` to process the inference results.
outputTensor = try interpreter.output(at: 0)
// Turn the output tensor into an array. This will have 2 values
// Value at index 0 is the probability of negative sentiment
// Value at index 1 is the probability of positive sentiment
let resultsArray = outputTensor.data
let results: [Float32] = [Float32](unsafeData: resultsArray) ?? []
Теперь сравнительно легко проанализировать данные, чтобы определить качество спама. Модель имеет 2 выхода: первый с вероятностью того, что сообщение не является спамом, второй с вероятностью того, что это спам. Итак, вы можете посмотреть results[1]
, чтобы определить значение спама:
let positiveSpamValue = results[1]
var outputString = ""
if(positiveSpamValue>0.8){
outputString = "Message not sent. Spam detected with probability: " + String(positiveSpamValue)
} else {
outputString = "Message sent!"
}
txtOutput.text = outputString
Для удобства приведем полный метод:
func classify(sequence: [Int32]){
// Model Path is the location of the model in the bundle
let modelPath = Bundle.main.path(forResource: "model", ofType: "tflite")
var interpreter: Interpreter
do{
interpreter = try Interpreter(modelPath: modelPath!)
} catch _{
print("Error loading model!")
Return
}
let tSequence = Array(sequence)
let myData = Data(copyingBufferOf: tSequence.map { Int32($0) })
let outputTensor: Tensor
do {
// Allocate memory for the model's input `Tensor`s.
try interpreter.allocateTensors()
// Copy the data to the input `Tensor`.
try interpreter.copy(myData, toInputAt: 0)
// Run inference by invoking the `Interpreter`.
try interpreter.invoke()
// Get the output `Tensor` to process the inference results.
outputTensor = try interpreter.output(at: 0)
// Turn the output tensor into an array. This will have 2 values
// Value at index 0 is the probability of negative sentiment
// Value at index 1 is the probability of positive sentiment
let resultsArray = outputTensor.data
let results: [Float32] = [Float32](unsafeData: resultsArray) ?? []
let positiveSpamValue = results[1]
var outputString = ""
if(positiveSpamValue>0.8){
outputString = "Message not sent. Spam detected with probability: " +
String(positiveSpamValue)
} else {
outputString = "Message sent!"
}
txtOutput.text = outputString
} catch let error {
print("Failed to invoke the interpreter with error: \(error.localizedDescription)")
}
}
12. Добавьте расширения Swift
В приведенном выше коде используется расширение типа Data, позволяющее копировать необработанные биты массива Int32 в Data
. Вот код этого расширения:
extension Data {
/// Creates a new buffer by copying the buffer pointer of the given array.
///
/// - Warning: The given array's element type `T` must be trivial in that it can be copied bit
/// for bit with no indirection or reference-counting operations; otherwise, reinterpreting
/// data from the resulting buffer has undefined behavior.
/// - Parameter array: An array with elements of type `T`.
init<T>(copyingBufferOf array: [T]) {
self = array.withUnsafeBufferPointer(Data.init)
}
}
При работе с памятью низкого уровня вы используете «небезопасные» данные, и приведенный выше код требует инициализации массива небезопасных данных. Это расширение делает это возможным:
extension Array {
/// Creates a new array from the bytes of the given unsafe data.
///
/// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit
/// with no indirection or reference-counting operations; otherwise, copying the raw bytes in
/// the `unsafeData`'s buffer to a new array returns an unsafe copy.
/// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
/// `MemoryLayout<Element>.stride`.
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
init?(unsafeData: Data) {
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
#if swift(>=5.0)
self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) }
#else
self = unsafeData.withUnsafeBytes {
.init(UnsafeBufferPointer<Element>(
start: $0,
count: unsafeData.count / MemoryLayout<Element>.stride
))
}
#endif // swift(>=5.0)
}
}
13. Запустите приложение iOS.
Запустите и протестируйте приложение.
Если все прошло хорошо, вы должны увидеть приложение на своем устройстве следующим образом:
Где сообщение «Купите мою книгу, чтобы научиться онлайн-трейдингу!» было отправлено, приложение отправляет обратно предупреждение об обнаружении спама с вероятностью 0,99%!
14. Поздравляем!
Теперь вы создали очень простое приложение, которое фильтрует текст на наличие спама в комментариях, используя модель, обученную на данных, используемых для рассылки спама в блогах.
Следующим шагом типичного жизненного цикла разработчика является изучение того, что потребуется для настройки модели на основе данных, найденных в вашем собственном сообществе. Вы увидите, как это сделать, в следующем упражнении.