TensorFlow.js: Stwórz własną maszynę "Teachable Machine" korzystanie z nauczania transferu z użyciem TensorFlow.js

1. Zanim zaczniesz

W ciągu ostatnich kilku lat wykorzystanie modelu TensorFlow.js zyskuje na popularności, a wielu programistów JavaScriptu poszukuje teraz możliwości wykorzystania dotychczasowych, najnowocześniejszych modeli i trenowania ich do pracy z danymi niestandardowymi, które są unikalne dla ich branży. Wykorzystanie istniejącego modelu (często nazywanego modelem podstawowym) i wykorzystanie go w podobnej, ale innej domenie, to tzw. uczenie się przenoszenia.

Przenoszenie nauki ma wiele zalet w porównaniu z rozpoczynaniem od całkowicie pustego modelu. Możesz wykorzystać wiedzę zdobytą podczas wcześniejszego wytrenowanego modelu i wymagać mniejszej liczby przykładów nowego elementu, który chcesz sklasyfikować. Trenowanie jest też często znacznie szybsze ze względu na konieczność ponownego wytrenowania tylko kilku ostatnich warstw architektury modelu, a nie całej sieci. Z tego powodu uczenie się transferów sprawdza się bardzo dobrze w środowisku przeglądarki, w którym zasoby mogą się różnić w zależności od urządzenia, na którym są wykonywane, ale mają też bezpośredni dostęp do czujników, by łatwo zbierać dane.

Dzięki temu ćwiczeniu w Codelabs dowiesz się, jak stworzyć aplikację internetową, tworząc aplikację internetową, Teachable Machine” witryny. Umożliwia stworzenie funkcjonalnej aplikacji internetowej, za pomocą której każdy użytkownik może rozpoznać obiekt niestandardowy na podstawie zaledwie kilku przykładowych obrazów z kamery internetowej. Zawartość witryny jest ograniczona do minimum, aby można było skupić się na aspektach uczenia maszynowego określonych w tym ćwiczeniu z programowania. Podobnie jak w przypadku oryginalnej witryny Teachable Machine, masz jednak mnóstwo możliwości, by wykorzystać istniejące doświadczenie z programistami stron internetowych, aby poprawić UX.

Wymagania wstępne

To ćwiczenie w Codelabs jest przeznaczone dla programistów stron internetowych, którzy znają wstępnie utworzone modele TensorFlow.js i podstawowe korzystanie z interfejsu API, a także chcą zacząć korzystać z nauki przenoszenia w TensorFlow.js.

  • W tym module zakładamy podstawową znajomość języków TensorFlow.js, HTML5, CSS i JavaScript.

Jeśli nie masz doświadczenia z Tensflow.js, rozważ skorzystanie z bezpłatnego szkolenia „zero do głównych”, które zakłada, że nie ma żadnej wiedzy na temat uczenia maszynowego ani TensorFlow.js. W mniejszej kolejności przedstawi wszystko, co trzeba wiedzieć.

Czego się nauczysz

  • Czym jest TensorFlow.js i dlaczego warto go użyć w swojej następnej aplikacji internetowej.
  • Jak utworzyć uproszczoną stronę HTML/CSS /JS, która powiela interfejs użytkownika Teachable Machine.
  • Jak za pomocą TensorFlow.js wczytać już wytrenowany model podstawowy, a w szczególności MobileNet, do wygenerowania funkcji obrazów, które można wykorzystać w nauczaniu transferowym.
  • Jak zebrać z kamery internetowej użytkownika dane z różnych klas, które chcesz rozpoznać.
  • Jak utworzyć i zdefiniować wielowarstwowy perceptron, który wykorzystuje cechy obrazu i uczy się klasyfikować nowe obiekty przy jego użyciu.

Czas na hakowanie...

Czego potrzebujesz

  • Preferowane jest konto Glitch.com. Możesz też skorzystać z środowiska internetowego, w którym możesz swobodnie edytować i uruchamiać aplikację.

2. Co to jest TensorFlow.js?

54e81d02971f53e8.png

TensorFlow.js to biblioteka systemów uczących się typu open source, która umożliwia uruchamianie JavaScriptu wszędzie. Opiera się on na pierwotnej bibliotece TensorFlow napisanej w języku Python i ma na celu odtworzenie tego środowiska programistycznego oraz zestawu interfejsów API w ekosystemie JavaScriptu.

Gdzie można korzystać z tej funkcji?

JavaScript jest łatwy w obsłudze, dlatego możesz teraz z łatwością pisać w jednym języku i wykonywać uczenie maszynowe na wszystkich poniższych platformach:

  • Po stronie klienta w przeglądarce używamy wbudowanego JavaScriptu
  • Po stronie serwera, a nawet urządzenia IoT, takie jak Raspberry Pi, korzystające z Node.js
  • Aplikacje komputerowe wykorzystujące technologię Electron
  • natywne aplikacje mobilne wykorzystujące komponent React Native,

TensorFlow.js obsługuje też wiele backendów w każdym z tych środowisk (rzeczywistych środowisk sprzętowych, w których można go uruchamiać, np. CPU lub WebGL. „backend” w tym kontekście nie oznacza to środowiska po stronie serwera – backend do wykonania może być na przykład po stronie klienta w WebGL), żeby zapewnić zgodność i jednocześnie zapewnić szybkie działanie. Obecnie TensorFlow.js obsługuje:

  • Wykonanie WebGL na karcie graficznej urządzenia (GPU) – to najszybszy sposób na uruchamianie większych modeli (o rozmiarach powyżej 3 MB) z akceleracją GPU.
  • Wykonywanie narzędzi Web Assembly (WASM) na CPU – w celu poprawy wydajności procesora na różnych urządzeniach, na przykład na telefonach komórkowych starszej generacji. Sprawdza się to lepiej w mniejszych modelach (poniżej 3 MB), które w WASM działają szybciej na procesorach niż w WebGL ze względu na wymagania związane z przesyłaniem materiałów do procesora graficznego.
  • Wykonanie procesora – środowisko zastępcze nie powinno być dostępne. To najwolniejszy z trzech, ale zawsze gotowy.

Uwaga: możesz wymusić stosowanie jednego z tych backendów, jeśli wiesz, na którym urządzeniu będzie wykonywane działanie. Jeśli nie określisz tego, możesz po prostu pozwolić TensorFlow.js zdecydować za Ciebie.

Supermocy po stronie klienta

Uruchomienie kodu TensorFlow.js w przeglądarce na komputerze klienckim może przynieść szereg korzyści, które warto rozważyć.

Prywatność

Możesz trenować i klasyfikować dane na komputerze klienckim bez konieczności wysyłania ich na serwer internetowy firmy zewnętrznej. W niektórych przypadkach może to być wymagane w celu zachowania zgodności z przepisami obowiązującymi w danym kraju, np. RODO, lub w przypadku przetwarzania danych, które użytkownik chce zachować na swoim komputerze, a nie wysyłać do innych firm.

Szybkość

Dzięki temu, że nie trzeba wysyłać danych na serwer zdalny, wnioskowanie (czynność klasyfikowania danych) może być szybsze. Co więcej, po przyznaniu dostępu przez użytkownika będziesz mieć bezpośredni dostęp do czujników urządzenia, takich jak aparat, mikrofon, GPS, akcelerometr i inne.

Zasięg i skala

Wystarczy jedno kliknięcie, aby każdy na całym świecie mógł kliknąć wysłany przez Ciebie link, otworzyć stronę internetową w przeglądarce i wykorzystać to, co udało Ci się osiągnąć. Użycie systemu uczącego się nie wymaga skomplikowanej konfiguracji systemu Linux po stronie serwera ze sterownikami CUDA i nie tylko.

Koszt

Brak serwerów oznacza, że jedyną rzeczą, za którą musisz zapłacić, jest sieć CDN do przechowywania plików HTML, CSS, JS i modeli. Koszt sieci CDN jest znacznie wyższy niż w przypadku serwera (potencjalnie z podłączoną kartą graficzną) działającym przez całą dobę.

Funkcje po stronie serwera

Poniższe funkcje są dostępne dzięki wdrożeniu TensorFlow.js w Node.js.

Pełna obsługa CUDA

Aby włączyć akcelerację karty graficznej, po stronie serwera musisz zainstalować sterowniki NVIDIA CUDA, aby umożliwić TensorFlow współpracę z kartą graficzną (inaczej niż w przeglądarce, która używa WebGL – nie trzeba instalować). Pełna obsługa technologii CUDA pozwala jednak w pełni wykorzystać możliwości karty graficznej niższego poziomu, co pozwala skrócić czas trenowania i wnioskowania. Wydajność jest porównywalna z implementacją TensorFlow w języku Python, ponieważ obie korzystają z tego samego backendu w C++.

Rozmiar modelu

Aby tworzyć najbardziej zaawansowane modele na podstawie badań, możesz pracować z bardzo dużymi modelami, nawet o wielkości gigabajtów. Tych modeli nie można obecnie uruchamiać w przeglądarce ze względu na ograniczenia wykorzystania pamięci przez poszczególne karty przeglądarki. Aby móc uruchamiać te większe modele, możesz użyć środowiska Node.js na własnym serwerze ze specyfikacjami sprzętowymi, które są wymagane do wydajnego działania takiego modelu.

IOT

Node.js jest obsługiwany na popularnych komputerach jednopłytkowych, takich jak Raspberry Pi, co z kolei oznacza, że na takich urządzeniach możesz uruchamiać modele TensorFlow.js.

Szybkość

Node.js jest napisany w języku JavaScript, co oznacza, że korzysta z kompilacji w odpowiednim momencie. Oznacza to, że gdy korzystasz z Node.js, możesz często zauważyć wzrost wydajności, ponieważ jest on optymalizowany w czasie działania, a zwłaszcza w przypadku wstępnego przetwarzania danych. Świetnym przykładem tego jest to studium przypadku, które pokazuje, jak firma Hugging Face wykorzystała środowisko Node.js, aby dwukrotnie zwiększyć wydajność modelu przetwarzania języka naturalnego.

Znasz już podstawy środowiska TensorFlow.js, wiesz, gdzie może być uruchamiany i jakie są jego zalety, więc możesz teraz zacząć robić z nim przydatne rzeczy.

3. Przenieś naukę

Czym dokładnie jest transfer learning?

Przekazywanie wiedzy oznacza wykorzystanie zdobytej wiedzy do nauczenia się czegoś innego, ale czegoś podobnego.

Ludzie ciągle to robią. W głowie kryje się wiele niezliczonych doświadczeń, które mogą pomóc Ci rozpoznać nowe rzeczy. Weźmy na przykład tę wierzbę:

e28070392cd4afb9.png

W zależności od Twojego położenia na świecie możesz jeszcze nie widzieć tego rodzaju drzewa.

Jeśli jednak zapytam, czy na nowym zdjęciu poniżej są jakieś wierzby, pewnie można je szybko rozpoznać, mimo że są pod innym kątem i nieco inne niż te, które pokazałem.

d9073a0d5df27222.png

Masz już w mózgu cały szereg neuronów, które potrafią rozpoznawać obiekty przypominające drzewa, oraz inne neurony, które radzą sobie dobrze w znajdowaniu długich linii prostych. Możesz ponownie wykorzystać tę wiedzę do szybkiego sklasyfikowania wierzby, czyli przypominającego drzewa obiektu z wieloma długimi, prostymi gałęziami.

Podobnie, jeśli masz model systemów uczących się, który został już wytrenowany w domenie (np. przez rozpoznawanie obrazów), możesz go użyć do wykonania innego, ale powiązanego zadania.

To samo można zrobić przy użyciu zaawansowanego modelu, takiego jak MobileNet, czyli bardzo popularny model badawczy, który potrafi rozpoznawać obraz na 1000 różnych typów obiektów. Trenowano na ogromnym zbiorze danych o nazwie ImageNet, który zawiera miliony obrazów oznaczonych etykietami – od psów po samochody.

Na tej animacji widać ogromną liczbę warstw tego modelu MobileNet V1:

7d4e1e35c1a89715.gif

Podczas trenowania ten model nauczył się, jak wyodrębniać typowe cechy, które mają znaczenie dla wszystkich tych 1000 obiektów, a wiele z funkcji niższych poziomów używanych do identyfikowania takich obiektów może być przydatnych do wykrywania nowych obiektów, których wcześniej nie widział. W końcu wszystko jest w ostatecznym rozrachunku tylko kombinacją linii, tekstur i kształtów.

Przyjrzyjmy się tradycyjnej architekturze konwolucyjnej sieci neuronowej (CNN) (podobnej do MobileNet) i zobaczmy, w jaki sposób uczenie się transferowe może wykorzystać tę wyszkoloną sieć do nauczenia się czegoś nowego. Poniższy obraz przedstawia typową architekturę modelu CNN, która w tym przypadku została wytrenowana do rozpoznawania odręcznych cyfr z zakresu od 0 do 9:

baf4e3d434576106.png

Gdyby można było oddzielić już wytrenowane warstwy niższego poziomu istniejącego wytrenowanego modelu w ten sposób, jak pokazano po lewej stronie, od warstw klasyfikacji widocznych po prawej stronie (nazywanej też warstwą klasyfikacji modelu), można użyć warstw niższego poziomu do wygenerowania cech wyjściowych dowolnego obrazu na podstawie pierwotnych danych, na których został wytrenowany. Oto ta sama sieć z usuniętym nagłówkiem klasyfikacji:

369a8a9041c6917d.png

Zakładając, że nowa funkcja, którą próbuje rozpoznać, może również wykorzystać funkcje wyjściowe omówione przez poprzedni model, jest spora szansa, że będą mogły zostać użyte ponownie do nowych celów.

Na powyższym diagramie ten hipotetyczny model został wytrenowany na cyfrach, więc być może informacje o cyfrach można zastosować również do liter, np. a, b i c.

Teraz możesz dodać nowy nagłówek klasyfikacji, który próbuje przewidzieć a, b lub c, jak w poniższym przykładzie:

db97e5e60ae73bbd.png

W tym przypadku warstwy niższego poziomu są zablokowane i nie zostały wytrenowane. Tylko nowa główka klasyfikacji zaktualizuje się, aby uczyć się na podstawie funkcji udostępnianych z wytrenowanego wstępnie wytrenowanego modelu po lewej stronie.

Działanie tego systemu określa się jako uczenie się transferowe, a Teachable Machine działa w tle.

Widać też, że trenowanie wielowarstwowego punktu widzenia tylko na samym końcu sieci sprawia, że trenuje się znacznie szybciej niż przy trenowaniu całej sieci od zera.

Jak możesz jednak trafić na poszczególne części modelu? Aby dowiedzieć się więcej, przejdź do następnej sekcji.

4. TensorFlow Hub – modele podstawowe

Znajdź odpowiedni model podstawowy do użycia

W przypadku bardziej zaawansowanych i popularnych modeli badawczych, takich jak MobileNet, możesz przejść do centrum TensorFlow, a następnie odfiltrować modele odpowiednie dla TensorFlow.js, które wykorzystują architekturę MobileNet v3, aby znaleźć wyniki takie jak te poniżej:

c5dc1420c6238c14.png

Pamiętaj, że niektóre z tych wyników są typu „klasyfikacja obrazów” (szczegóły w lewym górnym rogu każdego wyniku karty modelu), a inne są typu „wektor cech obrazu”.

Wyniki wektorów cech obrazu to zasadniczo wstępnie rozdzielone wersje MobileNet, których można użyć do pobrania wektorów cech obrazu zamiast ostatecznej klasyfikacji.

Tego typu modele są często nazywane „modelami podstawowymi”, Możesz go później wykorzystać do uczenia się przenoszenia w sposób opisany w poprzedniej sekcji, dodając nowy nagłówek klasyfikacji i trenując go z wykorzystaniem własnych danych.

Kolejnym krokiem jest sprawdzenie, w jakim formacie TensorFlow.js jest udostępniany model podstawowy. Jeśli otworzysz stronę jednego z modeli MobileNet v3 korzystających z wektorów funkcji, zobaczysz w dokumentacji JS, że ma ona postać modelu grafowego opartego na przykładowym fragmencie kodu w dokumentacji, w której użyto tf.loadGraphModel().

f97d903d2e46924b.png

Pamiętaj też, że jeśli znajdziesz model w formacie warstw, a nie wykresu, możesz wybrać, które warstwy zostaną zablokowane, a które nie będą zablokowane podczas trenowania. Może to być bardzo przydatne podczas tworzenia modelu dla nowego zadania, często nazywanego „modelem przeniesienia”. Na razie jednak w tym samouczku będziesz używać domyślnego typu modelu grafu, w ramach którego wdrażana jest większość modeli centrum TF. Aby dowiedzieć się więcej o pracy z modelami warstw, zapoznaj się z kursem „zero do bohatera” TensorFlow.js.

Zalety systemów uczących się

Jakie są zalety korzystania z systemu uczącego się zamiast trenowania całej architektury modelu od zera?

Po pierwsze, kluczową zaletą zastosowania metody uczenia się transferu jest czas trenowania, ponieważ masz już wytrenowany model podstawowy.

Po drugie, możesz dzięki temu widzieć znacznie mniej przykładów nowych elementów, które próbujesz sklasyfikować ze względu na już przeprowadzone szkolenie.

To świetne rozwiązanie, gdy masz ograniczony czas i zasoby na zebranie przykładowych danych, które chcesz sklasyfikować, i potrzebujesz szybkiego stworzenia prototypu, zanim zbierzesz więcej danych do trenowania, aby ulepszyć jego działanie.

Biorąc pod uwagę zapotrzebowanie na mniej danych i szybkość trenowania mniejszej sieci, uczenie się transferów nie intensywnie korzysta z zasobów. Dzięki temu doskonale nadaje się do działania w środowisku przeglądarek, a pełne trenowanie modelu zajmuje zaledwie kilkadziesiąt sekund na nowoczesnym komputerze, a nie godziny, dni czy tygodnie.

Świetnie! Wiesz już, czym jest Transfer Learning. Czas utworzyć własną wersję Teachable Machine. Zaczynajmy!

5. Skonfiguruj kod

Czego potrzebujesz

  • Nowoczesna przeglądarka.
  • Podstawowa znajomość języka HTML, CSS i JavaScript oraz Narzędzi deweloperskich w Chrome (wyświetlanie danych wyjściowych konsoli).

Rozpocznij kodowanie

Dla Glitch.com lub Codepen.io utworzono standardowe szablony, od których można zacząć. Wystarczy jedno kliknięcie, aby skopiować dowolny szablon jako stan podstawowy w tym module kodu.

W przypadku glitcha kliknij przycisk „Remiksuj to”, aby utworzyć rozwidlenie i utworzyć nowy zestaw plików, które możesz edytować.

Możesz też w Codepen kliknąć fork" w prawym dolnym rogu ekranu.

Ten prosty szkielet zawiera następujące pliki:

  • Strona HTML (index.html)
  • Arkusz stylów (style.css)
  • Plik do pisania kodu JavaScript (script.js)

Dla Twojej wygody do biblioteki TensorFlow.js dodaliśmy do pliku HTML importowany plik. Wygląda on następująco:

index.html

<!-- Import TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js" type="text/javascript"></script>

Alternatywnie: użyj preferowanego edytora stron internetowych lub pracuj lokalnie

Jeśli chcesz pobrać kod i pracować lokalnie lub w innym edytorze online, po prostu utwórz 3 podane wyżej pliki w tym samym katalogu, a następnie skopiuj i wklej do każdego z nich kod z naszego błędu.

6. Szablon HTML aplikacji

Od czego zacząć?

Wszystkie prototypy wymagają podstawowej struktury HTML, na której możesz renderować swoje wyniki. Skonfiguruj go teraz. Dodajesz:

  • Tytuł strony.
  • Jakiś tekst opisowy.
  • Akapit stanu.
  • Film z kameralnym obrazem z kamery internetowej.
  • Kilka przycisków do uruchamiania kamery, zbierania danych i resetowania interfejsu.
  • Importowanie plików TensorFlow.js i JS, które zakodujesz później.

Otwórz plik index.html i wklej na nim istniejący kod z podanymi niżej informacjami, aby skonfigurować powyższe funkcje:

index.html

<!DOCTYPE html>
<html lang="en">
  <head>
    <title>Transfer Learning - TensorFlow.js</title>
    <meta charset="utf-8">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <!-- Import the webpage's stylesheet -->
    <link rel="stylesheet" href="/style.css">
  </head>  
  <body>
    <h1>Make your own "Teachable Machine" using Transfer Learning with MobileNet v3 in TensorFlow.js using saved graph model from TFHub.</h1>
    
    <p id="status">Awaiting TF.js load</p>
    
    <video id="webcam" autoplay muted></video>
    
    <button id="enableCam">Enable Webcam</button>
    <button class="dataCollector" data-1hot="0" data-name="Class 1">Gather Class 1 Data</button>
    <button class="dataCollector" data-1hot="1" data-name="Class 2">Gather Class 2 Data</button>
    <button id="train">Train &amp; Predict!</button>
    <button id="reset">Reset</button>

    <!-- Import TensorFlow.js library -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.11.0/dist/tf.min.js" type="text/javascript"></script>

    <!-- Import the page's JavaScript to do some stuff -->
    <script type="module" src="/script.js"></script>
  </body>
</html>

Dokonaj podziału

Przeanalizujmy fragment kodu HTML powyżej, aby podkreślić kilka kluczowych dodanych przez Ciebie elementów.

  • Dodano tag <h1> dotyczący tytułu strony wraz z tagiem <p> z identyfikatorem „status”, który służy do drukowania informacji, gdy korzystasz z różnych części systemu do wyświetlania wyników.
  • Dodano element <video> z identyfikatorem „webcam”, na który będziesz później renderować transmisję z kamery internetowej.
  • Dodano 5 elementów <button>. Pierwsza, o identyfikatorze „enableCam”, włącza aparat. Następne dwa przyciski mają klasę „dataCollector”, który pozwala zebrać przykładowe zdjęcia obiektów, które chcesz rozpoznać. Kod napisany później zostanie tak zaprojektowany, aby można było dodać dowolną liczbę tych przycisków, a przyciski te będą działać zgodnie z oczekiwaniami.

Zwróć uwagę, że przyciski te mają również specjalny atrybut zdefiniowany przez użytkownika o nazwie data-1hot, którego wartość całkowita zaczyna się od 0 dla pierwszej klasy. Jest to indeks numeryczny, którego należy używać do przedstawienia danych określonej klasy. Indeks będzie używany do poprawnego kodowania klas wyjściowych za pomocą liczbowej reprezentacji zamiast ciągu znaków, ponieważ modele ML mogą działać tylko z liczbami.

Istnieje też atrybut data-name zawierający zrozumiałą dla człowieka nazwę, której chcesz użyć dla danej klasy. Dzięki temu możesz nadać użytkownikowi bardziej zrozumiałą nazwę zamiast wartości indeksu liczbowego pochodzącego z 1 kodowania „gorąco”.

Dodatkowo przycisk trenowania i resetowania umożliwia rozpoczęcie procesu trenowania po zebraniu danych lub zresetowanie aplikacji.

  • Dodałeś też 2 importy typu <script>. Jedna na potrzeby TensorFlow.js, a druga dla Script.js, którą niedługo zdefiniujesz.

7. Dodaj styl

Wartości domyślne elementu

Dodaj style do dodanych przed chwilą elementów HTML, aby mieć pewność, że będą się prawidłowo renderować. Oto kilka stylów, które zostały dodane do prawidłowych pozycji i rozmiaru elementów. Nic nadzwyczajnego. Możesz dodać te informacje później, aby jeszcze bardziej poprawić wygodę użytkowników, tak jak w filmie o systemach uczących się.

style.css

body {
  font-family: helvetica, arial, sans-serif;
  margin: 2em;
}

h1 {
  font-style: italic;
  color: #FF6F00;
}


video {
  clear: both;
  display: block;
  margin: 10px;
  background: #000000;
  width: 640px;
  height: 480px;
}

button {
  padding: 10px;
  float: left;
  margin: 5px 3px 5px 10px;
}

.removed {
  display: none;
}

#status {
  font-size:150%;
}

Świetnie. To wszystko. Jeśli wyświetlisz teraz podgląd danych wyjściowych, powinny one wyglądać mniej więcej tak:

81909685d7566dcb.png

8. JavaScript: stałe kluczowe i detektory

Definiowanie stałych kluczowych

Najpierw dodaj kluczowe stałe, których będziesz używać w aplikacji. Zacznij od zastąpienia zawartości script.js tymi stałymi:

script.js

const STATUS = document.getElementById('status');
const VIDEO = document.getElementById('webcam');
const ENABLE_CAM_BUTTON = document.getElementById('enableCam');
const RESET_BUTTON = document.getElementById('reset');
const TRAIN_BUTTON = document.getElementById('train');
const MOBILE_NET_INPUT_WIDTH = 224;
const MOBILE_NET_INPUT_HEIGHT = 224;
const STOP_DATA_GATHER = -1;
const CLASS_NAMES = [];

Zobaczmy, do czego służą:

  • Komponent STATUS po prostu zawiera odwołanie do tagu akapitu, w którym będziesz zapisywać zmiany stanu.
  • VIDEO zawiera odwołanie do elementu wideo HTML, który będzie renderować obraz z kamery internetowej.
  • ENABLE_CAM_BUTTON, RESET_BUTTON i TRAIN_BUTTON pobierają odwołania DOM do wszystkich kluczowych przycisków na stronie HTML.
  • MOBILE_NET_INPUT_WIDTH i MOBILE_NET_INPUT_HEIGHT określają odpowiednio oczekiwaną szerokość i wysokość modelu MobileNet. Przechowywanie tej wartości na stałe u góry pliku w ten sposób. Jeśli zdecydujesz się później użyć innej wersji, łatwiej będzie zaktualizować wartości raz, zamiast zastępować je w wielu różnych miejscach.
  • STOP_DATA_GATHER ma wartość – 1. Zapisuje wartość stanu, dzięki czemu wiesz, kiedy użytkownik przestał kliknąć przycisk, aby zebrać dane z kamery internetowej. Nadanie numerowi bardziej rozpoznawalnej nazwy sprawia, że później kod jest bardziej czytelny.
  • CLASS_NAMES działa jako wyszukiwanie i zawiera czytelne dla człowieka nazwy możliwych prognoz klas. Ta tablica zostanie uzupełniona później.

Skoro masz już odwołania do kluczowych elementów, nadszedł czas na powiązanie z nimi kilku detektorów zdarzeń.

Dodawanie detektorów kluczowych zdarzeń

Zacznij od dodania modułów obsługi zdarzeń kliknięcia do przycisków kluczy w następujący sposób:

script.js

ENABLE_CAM_BUTTON.addEventListener('click', enableCam);
TRAIN_BUTTON.addEventListener('click', trainAndPredict);
RESET_BUTTON.addEventListener('click', reset);


function enableCam() {
  // TODO: Fill this out later in the codelab!
}


function trainAndPredict() {
  // TODO: Fill this out later in the codelab!
}


function reset() {
  // TODO: Fill this out later in the codelab!
}

ENABLE_CAM_BUTTON – po kliknięciu wywołuje funkcję allowCam.

TRAIN_BUTTON – po kliknięciu wywołuje metodę TrainAndPredict.

RESET_BUTTON – połączenia są resetowane po kliknięciu.

W tej sekcji znajdziesz wszystkie przyciski z klasą „dataCollector” za pomocą funkcji document.querySelectorAll(). Zwraca ono tablicę elementów znalezionych z dokumentu, które pasują do tych elementów:

script.js

let dataCollectorButtons = document.querySelectorAll('button.dataCollector');
for (let i = 0; i < dataCollectorButtons.length; i++) {
  dataCollectorButtons[i].addEventListener('mousedown', gatherDataForClass);
  dataCollectorButtons[i].addEventListener('mouseup', gatherDataForClass);
  // Populate the human readable names for classes.
  CLASS_NAMES.push(dataCollectorButtons[i].getAttribute('data-name'));
}


function gatherDataForClass() {
  // TODO: Fill this out later in the codelab!
}

Wyjaśnienie kodu:

Następnie przejrzyj znalezione przyciski i przypisz do każdego z nich 2 detektory zdarzeń. jedno na słowo „mysz”, a drugie – „mysz”. Dzięki temu możesz rejestrować próbki, dopóki naciśniesz przycisk, co jest przydatne przy zbieraniu danych.

Oba zdarzenia wywołują funkcję gatherDataForClass, którą określisz później.

W tym momencie możesz też przekazać do tablicy CLASS_NAMES znalezione nazwy klas znalezione w postaci czytelnej dla człowieka z atrybutu data-name atrybutu przycisku HTML.

Następnie dodaj zmienne do przechowywania kluczowych rzeczy, które będą potrzebne później.

script.js

let mobilenet = undefined;
let gatherDataState = STOP_DATA_GATHER;
let videoPlaying = false;
let trainingDataInputs = [];
let trainingDataOutputs = [];
let examplesCount = [];
let predict = false;

Przyjrzyjmy się im.

Pierwsza zawiera zmienną mobilenet do przechowywania wczytanego modelu mobilenet. Początkowo nie określono ustawienia.

Dalej masz zmienną o nazwie gatherDataState. Jeśli atrybut „dataCollector” przycisk ten zmienia się na jeden identyfikator tego przycisku, zgodnie z definicją w kodzie HTML, dzięki czemu wiesz, jakiej klasy danych są zbierane. Początkowo ma wartość STOP_DATA_GATHER, dzięki czemu pętla zbierania danych, którą zapiszesz później, nie będzie zbierać żadnych danych, gdy nie naciśniesz żadnych przycisków.

videoPlaying śledzi, czy transmisja z kamery internetowej została załadowana i odtwarzana oraz czy jest gotowa do użycia. Początkowo ustawienie ma wartość false, ponieważ kamera internetowa nie jest włączona, dopóki nie naciśniesz przycisku ENABLE_CAM_BUTTON..

Następnie zdefiniuj 2 tablice: trainingDataInputs i trainingDataOutputs. Przechowują one wartości zebranych danych treningowych po kliknięciu zbioru danych „dataCollector” przycisków funkcji wejściowych wygenerowanych odpowiednio przez model podstawowy MobileNet i próbkowanej klasy wyjściowej.

Następnie zdefiniowano jedną ostateczną tablicę, examplesCount,, aby śledzić liczbę przykładów zawartych w poszczególnych klasach po rozpoczęciu ich dodawania.

Masz też zmienną o nazwie predict, która steruje pętlą prognoz. Początkowo jest ustawiona wartość false. Dopóki nie zostanie ustawiona na true, prognozy nie będą działać.

Po zdefiniowaniu wszystkich kluczowych zmiennych możemy przejść do załadowanego wstępnie modelu podstawowego MobileNet w wersji 3, który zamiast klasyfikacji zawiera wektory cech obrazu.

9. Wczytaj model podstawowy MobileNet

Najpierw zdefiniuj nową funkcję o nazwie loadMobileNetFeatureModel, jak pokazano poniżej. Musi to być funkcja asynchroniczna, ponieważ wczytywanie modelu jest asynchroniczne:

script.js

/**
 * Loads the MobileNet model and warms it up so ready for use.
 **/
async function loadMobileNetFeatureModel() {
  const URL = 
    'https://tfhub.dev/google/tfjs-model/imagenet/mobilenet_v3_small_100_224/feature_vector/5/default/1';
  
  mobilenet = await tf.loadGraphModel(URL, {fromTFHub: true});
  STATUS.innerText = 'MobileNet v3 loaded successfully!';
  
  // Warm up the model by passing zeros through it once.
  tf.tidy(function () {
    let answer = mobilenet.predict(tf.zeros([1, MOBILE_NET_INPUT_HEIGHT, MOBILE_NET_INPUT_WIDTH, 3]));
    console.log(answer.shape);
  });
}

// Call the function immediately to start loading.
loadMobileNetFeatureModel();

W tym kodzie definiujesz URL, w którym znajduje się model do wczytania, znajdujący się w dokumentacji TFHub.

Następnie możesz wczytać model za pomocą funkcji await tf.loadGraphModel(), pamiętając o ustawieniu właściwości specjalnej fromTFHub na true podczas wczytywania modelu z tej witryny Google. Jest to specjalny przypadek tylko w przypadku używania modeli hostowanych w Centrum TF, w których trzeba ustawić tę dodatkową właściwość.

Po zakończeniu wczytywania możesz ustawić dla elementu STATUS innerText komunikat, aby zobaczyć, czy został załadowany poprawnie, i możesz zacząć zbierać dane.

Teraz musisz tylko rozgrzać model. Przy większych modelach, takich jak ten, po pierwszym uruchomieniu modelu konfiguracja może zająć trochę czasu. Pomaga więc przekazywać zera w modelu, aby uniknąć czekania w przyszłości, w którym kluczowe znaczenie może mieć czas.

Możesz użyć funkcji tf.zeros() opakowanej w element tf.tidy(), aby mieć pewność, że tensory są prawidłowo rozmieszczone, rozmiar wsadu wynosi 1 oraz prawidłową wysokość i szerokość, które zostały określone w stałych na początku. Na koniec musisz określić kanały kolorów, czyli w tym przypadku 3 kanały, ponieważ model oczekuje obrazów RGB.

Następnie zapisz wynikowy kształt tensora zwróconego za pomocą funkcji answer.shape(), aby lepiej zrozumieć rozmiar obrazu generowanego przez ten model.

Po zdefiniowaniu tej funkcji możesz ją wywołać natychmiast, aby rozpocząć pobieranie modelu podczas wczytywania strony.

Jeśli wyświetlasz podgląd na żywo teraz, po kilku chwilach tekst stanu zmieni się z „Oczekuje na wczytanie pliku TF.js”. na „MobileNet v3 załadowano poprawnie!” jak pokazano poniżej. Zanim przejdziesz dalej, upewnij się, że wszystko działa.

a28b734e190afff.png

Możesz też sprawdzić w danych wyjściowych konsoli rozmiar wydrukowanych funkcji wyjściowych generowanych przez ten model. Po przeciągnięciu zer w modelu MobileNet zobaczysz wydrukowany kształt [1, 1024]. Pierwszy element to rozmiar wsadu, który wynosi 1 – jak widać, zwraca on 1024 funkcje, które można następnie wykorzystać do klasyfikowania nowych obiektów.

10. Zdefiniuj nową głowę modelu

Teraz nadszedł czas na zdefiniowanie głowicy modelu, czyli w zasadzie minimalnej wielowarstwowej perceptronu.

script.js

let model = tf.sequential();
model.add(tf.layers.dense({inputShape: [1024], units: 128, activation: 'relu'}));
model.add(tf.layers.dense({units: CLASS_NAMES.length, activation: 'softmax'}));

model.summary();

// Compile the model with the defined optimizer and specify a loss function to use.
model.compile({
  // Adam changes the learning rate over time which is useful.
  optimizer: 'adam',
  // Use the correct loss function. If 2 classes of data, must use binaryCrossentropy.
  // Else categoricalCrossentropy is used if more than 2 classes.
  loss: (CLASS_NAMES.length === 2) ? 'binaryCrossentropy': 'categoricalCrossentropy', 
  // As this is a classification problem you can record accuracy in the logs too!
  metrics: ['accuracy']  
});

Przeanalizujmy ten kod. Zacznij od zdefiniowania modelu tf.sekwencyjnego, do którego chcesz dodać warstwy modelu.

Następnie dodaj gęstą warstwę jako warstwę wejściową do tego modelu. Dane wejściowe tego typu mają kształt 1024, ponieważ dane wyjściowe funkcji MobileNet w wersji 3 są w tym rozmiarze. Udało Ci się to odkryć w poprzednim kroku po zapoznaniu się z modelem. Ta warstwa zawiera 128 neuronów, które korzystają z funkcji aktywacji ReLU.

Jeśli dopiero zaczynasz korzystać z funkcji aktywacyjnych i warstw modelu, zastanów się nad ukończeniem szkolenia opisanego na początku tych warsztatów, aby zrozumieć działanie tych właściwości za kulisami.

Kolejna warstwa, którą należy dodać, to warstwa wyjściowa. Liczba neuronów powinna być równa liczbie klas, które próbujesz przewidzieć. W tym celu użyj parametru CLASS_NAMES.length, aby określić, ile klas planujesz sklasyfikować – odpowiada liczbie przycisków zbierania danych dostępnych w interfejsie. Jest to problem z klasyfikacją, dlatego w tej warstwie wyjściowej należy użyć aktywacji softmax, której trzeba użyć podczas tworzenia modelu do rozwiązywania problemów z klasyfikacją, a nie do regresji.

Teraz wydrukuj dokument model.summary(), aby wydrukować w konsoli przegląd nowo zdefiniowanego modelu.

Na koniec skompiluj model, aby był gotowy do trenowania. Tutaj optymalizator jest ustawiony na adam, a strata będzie wynosić binaryCrossentropy, jeśli CLASS_NAMES.length ma wartość 2, albo użyje categoricalCrossentropy, jeśli istnieją co najmniej 3 klasy do sklasyfikowania. Żądane są także wskaźniki dokładności, które można później monitorować w logach na potrzeby debugowania.

W konsoli powinno być widoczne coś takiego:

22eaf32286fea4bb.png

Pamiętaj, że ma on ponad 130 tys. parametrów z możliwością trenowania. Ponieważ jest to prosta, gęsta warstwa zwykłych neuronów, trenowanie zajmie dość szybko.

Po zakończeniu projektu możesz spróbować zmienić liczbę neuronów w pierwszej warstwie, aby sprawdzić, jak nisko udało Ci się go osiągnąć, a jednocześnie nadal uzyskiwać dobre wyniki. W przypadku uczenia maszynowego prowadzona jest często metoda prób i błędów w celu znalezienia optymalnych wartości parametrów, które pozwalają osiągnąć najlepszy kompromis między wykorzystaniem zasobów a szybkością.

11. Włącz kamerę internetową

Czas dopracować zdefiniowaną wcześniej funkcję enableCam(). Dodaj nową funkcję o nazwie hasGetUserMedia(), jak pokazano poniżej, a następnie zastąp zawartość wcześniej zdefiniowanej funkcji enableCam() odpowiednim kodem poniżej.

script.js

function hasGetUserMedia() {
  return !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
}

function enableCam() {
  if (hasGetUserMedia()) {
    // getUsermedia parameters.
    const constraints = {
      video: true,
      width: 640, 
      height: 480 
    };

    // Activate the webcam stream.
    navigator.mediaDevices.getUserMedia(constraints).then(function(stream) {
      VIDEO.srcObject = stream;
      VIDEO.addEventListener('loadeddata', function() {
        videoPlaying = true;
        ENABLE_CAM_BUTTON.classList.add('removed');
      });
    });
  } else {
    console.warn('getUserMedia() is not supported by your browser');
  }
}

Najpierw utwórz funkcję o nazwie hasGetUserMedia(), aby sprawdzić, czy przeglądarka obsługuje getUserMedia(), sprawdzając istnienie kluczowych właściwości interfejsów API przeglądarki.

W funkcji enableCam() użyj zdefiniowanej wcześniej funkcji hasGetUserMedia(), aby sprawdzić, czy jest obsługiwana. Jeśli tak nie jest, wydrukuj ostrzeżenie w konsoli.

Jeśli ją obsługuje, określ pewne ograniczenia dla wywołania getUserMedia(), na przykład chcesz, by miał on tylko strumień wideo, i określ, że width filmu ma rozmiar 640 piks., a height480. Dlaczego? Większy obraz nie ma sensu, ponieważ w celu umieszczenia go w modelu MobileNet trzeba go zmienić do rozmiaru 224 na 224 piksele. Możesz też zaoszczędzić trochę zasobów obliczeniowych, prosząc o mniejszą rozdzielczość. Większość aparatów obsługuje rozdzielczość tego rozmiaru.

Następnie zadzwoń do navigator.mediaDevices.getUserMedia(), używając podanego wyżej constraints, a następnie poczekaj na zwrócenie kodu stream. Po zwróceniu obiektu stream element VIDEO może odtworzyć element stream, ustawiając go jako wartość srcObject.

Dodaj do elementu VIDEO element eventListener, aby wiedzieć, kiedy stream został załadowany i odtwarza się prawidłowo.

Po załadowaniu pary można ustawić videoPlaying na „true” i usunąć ENABLE_CAM_BUTTON, aby zapobiec ponownemu kliknięciu go, ustawiając jego klasę na „removed”.

Teraz uruchom kod, kliknij przycisk Włącz kamerę i zezwól na dostęp do kamery. Jeśli robisz to po raz pierwszy, element wideo na stronie powinien być renderowany w taki sposób:

b378eb1affa9b883.png

Teraz trzeba dodać funkcję obsługi kliknięć przycisku dataCollector.

12. Moduł obsługi zdarzeń przycisku do zbierania danych

Teraz nadszedł czas na uzupełnienie obecnie pustej funkcji o nazwie gatherDataForClass().. Oto co zostało przez Ciebie przypisane jako funkcja obsługi zdarzeń w przypadku przycisków dataCollector na początku ćwiczeń z programowania.

script.js

/**
 * Handle Data Gather for button mouseup/mousedown.
 **/
function gatherDataForClass() {
  let classNumber = parseInt(this.getAttribute('data-1hot'));
  gatherDataState = (gatherDataState === STOP_DATA_GATHER) ? classNumber : STOP_DATA_GATHER;
  dataGatherLoop();
}

Najpierw sprawdź atrybut data-1hot aktualnie klikniętego przycisku, wywołując this.getAttribute() z nazwą atrybutu (w tym przypadku data-1hot jako parametr). Jest to ciąg znaków, możesz więc przy użyciu funkcji parseInt() rzutować go na liczbę całkowitą i przypisać ten wynik do zmiennej o nazwie classNumber.

Następnie ustaw odpowiednio zmienną gatherDataState. Jeśli bieżąca wartość gatherDataState ma wartość STOP_DATA_GATHER (która ustawiana jest jako -1), oznacza to, że obecnie nie gromadzisz żadnych danych i zostało uruchomione zdarzenie mousedown. Ustaw gatherDataState na znaleziony przed chwilą classNumber.

Inaczej oznacza to, że obecnie zbierasz dane, a wywołane zdarzenie było zdarzeniem mouseup, i chcesz przestać zbierać dane dotyczące tej klasy. Po prostu ustaw go z powrotem na STOP_DATA_GATHER, aby zakończyć pętlę zbierania danych, którą niedługo zdefiniujesz.

Na koniec rozpocznij rozmowę z funkcją dataGatherLoop(),, która faktycznie rejestruje dane zajęć.

13. Zbieranie danych

Teraz zdefiniuj funkcję dataGatherLoop(). Ta funkcja odpowiada za pobieranie obrazów z filmu z kamery internetowej, przekazywanie ich przez model MobileNet i rejestrowanie danych wyjściowych tego modelu (wektory 1024 cech).

Następnie zapisuje je razem z identyfikatorem gatherDataState wciśniętego przycisku, dzięki czemu wiadomo, jakiej klasy reprezentują te dane.

Przyjrzyjmy się im.

script.js

function dataGatherLoop() {
  if (videoPlaying && gatherDataState !== STOP_DATA_GATHER) {
    let imageFeatures = tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [MOBILE_NET_INPUT_HEIGHT, 
          MOBILE_NET_INPUT_WIDTH], true);
      let normalizedTensorFrame = resizedTensorFrame.div(255);
      return mobilenet.predict(normalizedTensorFrame.expandDims()).squeeze();
    });

    trainingDataInputs.push(imageFeatures);
    trainingDataOutputs.push(gatherDataState);
    
    // Intialize array index element if currently undefined.
    if (examplesCount[gatherDataState] === undefined) {
      examplesCount[gatherDataState] = 0;
    }
    examplesCount[gatherDataState]++;

    STATUS.innerText = '';
    for (let n = 0; n < CLASS_NAMES.length; n++) {
      STATUS.innerText += CLASS_NAMES[n] + ' data count: ' + examplesCount[n] + '. ';
    }
    window.requestAnimationFrame(dataGatherLoop);
  }
}

Wykonywanie tej funkcji będzie kontynuowane tylko wtedy, gdy zasada videoPlaying ma wartość Prawda, co oznacza, że kamera internetowa jest aktywna, a gatherDataState nie jest równa STOP_DATA_GATHER, a przycisk gromadzenia danych klas jest w tym momencie naciśnięty.

Następnie umieść kod w polu tf.tidy(), aby usunąć wszystkie utworzone tensory w podanym niżej kodzie. Wynik tego wykonania kodu w usłudze tf.tidy() jest przechowywany w zmiennej o nazwie imageFeatures.

Możesz teraz zarejestrować klatkę z kamery internetowej VIDEO za pomocą aplikacji tf.browser.fromPixels(). Wynikowy tensor zawierający dane obrazu jest przechowywany w zmiennej o nazwie videoFrameAsTensor.

Następnie zmień rozmiar zmiennej videoFrameAsTensor, tak aby miała kształt właściwy dla danych wejściowych modelu MobileNet. Użyj wywołania tf.image.resizeBilinear() z tensorem, którego kształt chcesz zmienić, jako pierwszy parametr, a następnie kształtem określającym nową wysokość i szerokość zdefiniowanymi wcześniej przez stałe. Na koniec ustaw wyrównanie rogów na wartość „true” (prawda), przekazując trzeci parametr, aby uniknąć problemów z wyrównaniem przy zmianie rozmiaru. Wynik tej zmiany rozmiaru jest przechowywany w zmiennej o nazwie resizedTensorFrame.

Pamiętaj, że ta podstawowa zmiana rozmiaru powoduje rozciąganie obrazu, ponieważ obraz z kamery internetowej ma wymiary 640 x 480 pikseli, a model wymaga kwadratowego obrazu o wymiarach 224 x 224 piksele.

Na potrzeby tej wersji demonstracyjnej to rozwiązanie powinno działać prawidłowo. Jednak po ukończeniu tego ćwiczenia w programie możesz spróbować przyciąć kwadrat z tego obrazu, aby uzyskać jeszcze lepsze wyniki w każdym systemie produkcyjnym, który utworzysz później.

Następnie znormalizuj dane obrazu. Podczas korzystania z funkcji tf.browser.frompixels() dane obrazu zawsze mieszczą się w zakresie od 0 do 255. Można więc po prostu podzielić rozmiar interfejsu TensorFrame przez 255, aby wszystkie wartości mieściły się w przedziale od 0 do 1, co model MobileNet oczekuje jako danych wejściowych.

Na koniec w sekcji tf.tidy() kodu przekaż ten znormalizowany tensor przez wczytany model, wywołując metodę mobilenet.predict(), do której przekazujesz rozszerzoną wersję normalizedTensorFrame przy użyciu funkcji expandDims(), tak aby była to grupa 1, ponieważ model oczekuje grupy danych wejściowych do przetworzenia.

Po zwróceniu wyniku możesz od razu wywołać funkcję squeeze() przy tym zwróconym wyniku, aby zredukować go do tensora 1D, który następnie zwrócisz i przypisz do zmiennej imageFeatures, która przechwytuje wynik funkcji tf.tidy().

Teraz, gdy masz już imageFeatures z modelu MobileNet, możesz je zarejestrować, umieszczając je w zdefiniowanej wcześniej tablicy trainingDataInputs.

Możesz również rejestrować to, co oznaczają dane wejściowe, przesyłając też bieżącą wartość gatherDataState do tablicy trainingDataOutputs.

Pamiętaj, że zmienna gatherDataState zostałaby ustawiona na identyfikator numeryczny bieżącej klasy, którego dotyczy rejestrowane dane, gdy użytkownik kliknął przycisk w zdefiniowanej wcześniej funkcji gatherDataForClass().

W tym momencie możesz też zwiększyć liczbę przykładów dla danej klasy. Aby to zrobić, sprawdź najpierw, czy indeks w tablicy examplesCount został już zainicjowany. Jeśli wartość nie jest określona, ustaw ją na 0, aby zainicjować licznik dla liczbowego identyfikatora klasy, a następnie możesz zwiększyć examplesCount dla bieżącej wartości gatherDataState.

Teraz zaktualizuj tekst elementu STATUS na stronie internetowej, aby wyświetlać bieżące liczby dla poszczególnych klas w miarę ich rejestrowania. Aby to zrobić, zapętlaj tablicę CLASS_NAMES i wydrukuj czytelną dla człowieka nazwę w połączeniu z liczbą danych z tym samym indeksem w examplesCount.

Na koniec wywołaj funkcję window.requestAnimationFrame() z parametrem dataGatherLoop przekazanym jako parametr, aby ponownie wywoływać tę funkcję cyklicznie. Klatki z filmu będą kontynuowane do momentu wykrycia parametru mouseup przycisku, a wartość gatherDataState zostanie ustawiona na STOP_DATA_GATHER,, co oznacza zakończenie pętli zbierania danych.

Po uruchomieniu kodu możesz kliknąć przycisk Włącz kamerę, poczekać na wczytanie kamery internetowej, a następnie kliknąć i przytrzymać każdy z przycisków zbierania danych, aby zebrać przykłady dla poszczególnych klas danych. Tutaj widzimy dane zebrane przeze mnie dla telefonu komórkowego i dłoni.

541051644a45131f.gif

Tekst stanu powinien zostać zaktualizowany, ponieważ przechowuje wszystkie tensory w pamięci, tak jak na zrzucie ekranu powyżej.

14. Wytrenuj i utwórz prognozę

Następnym krokiem jest implementacja kodu dla obecnie pustej funkcji trainAndPredict(), w której odbywa się proces uczenia się transferu. Spójrzmy na kod:

script.js

async function trainAndPredict() {
  predict = false;
  tf.util.shuffleCombo(trainingDataInputs, trainingDataOutputs);
  let outputsAsTensor = tf.tensor1d(trainingDataOutputs, 'int32');
  let oneHotOutputs = tf.oneHot(outputsAsTensor, CLASS_NAMES.length);
  let inputsAsTensor = tf.stack(trainingDataInputs);
  
  let results = await model.fit(inputsAsTensor, oneHotOutputs, {shuffle: true, batchSize: 5, epochs: 10, 
      callbacks: {onEpochEnd: logProgress} });
  
  outputsAsTensor.dispose();
  oneHotOutputs.dispose();
  inputsAsTensor.dispose();
  predict = true;
  predictLoop();
}

function logProgress(epoch, logs) {
  console.log('Data for epoch ' + epoch, logs);
}

Przede wszystkim zakończ działanie bieżących prognoz, ustawiając predict na false.

Następnie przetasuj tablice wejściowe i wyjściowe przy użyciu interfejsu tf.util.shuffleCombo(), aby mieć pewność, że kolejność nie powoduje problemów podczas trenowania.

Przekonwertuj tablicę wyjściową (trainingDataOutputs,) na format tensor1d typu int32, aby był gotowy do użycia w jednym kodowaniu z gorącą. Jest on przechowywany w zmiennej o nazwie outputsAsTensor.

Użyj funkcji tf.oneHot() z tą zmienną outputsAsTensor oraz maksymalną liczbą klas do zakodowania – to tylko CLASS_NAMES.length. Twoje dane wyjściowe zakodowane na gorąco są teraz przechowywane w nowym tensorze o nazwie oneHotOutputs.

Pamiętaj, że obecnie trainingDataInputs jest tablicą zarejestrowanych tensorów. Aby ich używać do trenowania, trzeba przekonwertować tablicę tensorów na zwykły tensor 2D.

W tym celu w bibliotece TensorFlow.js znajduje się doskonała funkcja o nazwie tf.stack(),

który obejmuje tablicę tensorów i składa je razem, aby uzyskać dane wyjściowe o wyższym tensorze. W tym przypadku zwracany jest Tensor 2D, czyli zbiór 1 wymiarowych danych wejściowych, składających się z pojedynczych danych wejściowych o długości 1024 pikseli zawierających zarejestrowane cechy, co jest potrzebne do trenowania.

Następnie await model.fit(), aby wytrenować nagłówek modelu niestandardowego. W tym miejscu przekazujesz zmienną inputsAsTensor wraz z elementem oneHotOutputs reprezentującym dane treningowe i służą do wykorzystania odpowiednio przykładowych danych wejściowych i docelowych danych wyjściowych. W obiekcie konfiguracji trzeciego parametru ustaw shuffle na true, użyj wartości batchSize z 5 z wartością epochs ustawioną na 10, a następnie podaj callback dla onEpochEnd dla funkcji logProgress, którą niedługo zdefiniujesz.

Na koniec możesz usunąć utworzone tensory w trakcie trenowania modelu. Następnie możesz ustawić predict z powrotem na true, aby umożliwić powtórne prognozowanie, a następnie wywołać funkcję predictLoop(), aby rozpocząć przewidywanie obrazów z kamery internetowej na żywo.

Możesz też zdefiniować funkcję logProcess() do rejestrowania stanu trenowania, który jest używany w zasadzie model.fit() powyżej i wyświetla wyniki w konsoli po każdej rundzie trenowania.

Prawie Ci się udało Czas dodać funkcję predictLoop() do prognozowania.

Pętla prognozy podstawowej

W tym przypadku wdrażana jest główna pętla prognozowania, która wykorzystuje próbki klatek z kamery internetowej i stale przewiduje, co znajduje się w każdej klatce, na podstawie wyników w czasie rzeczywistym w przeglądarce.

Sprawdźmy kod:

script.js

function predictLoop() {
  if (predict) {
    tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO).div(255);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor,[MOBILE_NET_INPUT_HEIGHT, 
          MOBILE_NET_INPUT_WIDTH], true);

      let imageFeatures = mobilenet.predict(resizedTensorFrame.expandDims());
      let prediction = model.predict(imageFeatures).squeeze();
      let highestIndex = prediction.argMax().arraySync();
      let predictionArray = prediction.arraySync();

      STATUS.innerText = 'Prediction: ' + CLASS_NAMES[highestIndex] + ' with ' + Math.floor(predictionArray[highestIndex] * 100) + '% confidence';
    });

    window.requestAnimationFrame(predictLoop);
  }
}

Najpierw sprawdź, czy predict ma wartość prawda, aby prognozy były generowane dopiero wtedy, gdy model jest wytrenowany i jest dostępny do użycia.

Następnie możesz pobrać funkcje bieżącego obrazu, tak jak w przypadku funkcji dataGatherLoop(). Krótko mówiąc, rejestrujesz ramkę z kamery internetowej za pomocą programu tf.browser.from pixels(), normalizujesz ją, zmieniasz rozmiar do 224 x 224 piksele, a następnie przekazujesz tę informację za pomocą modelu MobileNet, aby uzyskać wymagane funkcje obrazu.

Teraz jednak możesz użyć nowo wytrenowanego modelu nagłówka, aby wykonać prognozę, przekazując wynikowy wynik imageFeatures znaleziony przed chwilą za pomocą funkcji predict() wytrenowanego modelu. Następnie możesz ściągnąć uzyskany tensor, aby znów miał 1 wymiar, i przypisać go do zmiennej o nazwie prediction.

Funkcja prediction pozwala znaleźć indeks o najwyższej wartości za pomocą parametru argMax(), a następnie przekonwertować tensor wynikowy na tablicę przy użyciu funkcji arraySync(), aby uzyskać dane bazowe w kodzie JavaScript i odkryć pozycję najbardziej wartościowego elementu. Ta wartość jest przechowywana w zmiennej o nazwie highestIndex.

Możesz też w ten sam sposób uzyskać rzeczywiste wskaźniki ufności prognozy, wywołując bezpośrednio funkcję arraySync() w tensorze prediction.

Masz już wszystko, czego potrzebujesz, aby zaktualizować tekst STATUS przy użyciu danych z: prediction. Aby uzyskać zrozumiały dla człowieka ciąg tekstowy klasy, wystarczy wyszukać highestIndex w tablicy CLASS_NAMES, a następnie pobrać wartość ufności z predictionArray. Aby była bardziej czytelna w procentach, pomnóż wynik przez 100 i math.floor().

Za pomocą window.requestAnimationFrame() możesz też ponownie wywołać funkcję predictionLoop(), gdy wszystko będzie gotowe, aby uzyskać klasyfikację w czasie rzeczywistym w strumieniu wideo. Potrwa to do momentu, gdy zasada predict będzie miała wartość false, jeśli zdecydujesz się wytrenować nowy model z użyciem nowych danych.

To prowadzi do ostatniego fragmentu łamigłówki. Wdrażanie przycisku resetowania.

15. Wdróż przycisk resetowania

Prawie gotowe. Ostatnim elementem łamigłówki jest umieszczenie przycisku resetowania, aby zacząć od początku. Poniżej znajduje się kod Twojej obecnie pustej funkcji reset(). Zaktualizuj go w ten sposób:

script.js

/**
 * Purge data and start over. Note this does not dispose of the loaded 
 * MobileNet model and MLP head tensors as you will need to reuse 
 * them to train a new model.
 **/
function reset() {
  predict = false;
  examplesCount.length = 0;
  for (let i = 0; i < trainingDataInputs.length; i++) {
    trainingDataInputs[i].dispose();
  }
  trainingDataInputs.length = 0;
  trainingDataOutputs.length = 0;
  STATUS.innerText = 'No data collected';
  
  console.log('Tensors in memory: ' + tf.memory().numTensors);
}

Najpierw zatrzymaj wszystkie uruchomione pętle prognozowania, ustawiając predict na false. Następnie usuń całą zawartość tablicy examplesCount, ustawiając jej długość na 0, co jest użytecznym sposobem na usunięcie całej zawartości tablicy.

Teraz przejrzyj wszystkie aktualnie zarejestrowane trainingDataInputs i upewnij się, że masz dispose() każdego tensora, który się w nim znajduje, aby ponownie zwolnić pamięć, ponieważ Tensory nie są czyszczone przez funkcję czyszczenia pamięci JavaScript.

Następnie możesz bezpiecznie ustawić długość tablicy na 0 w tablicach trainingDataInputs i trainingDataOutputs, aby je też usunąć.

Na koniec nadaj tekstowi STATUS coś rozsądnego i wydrukuj tensory pozostawione w pamięci w ramach kontroli poprawności.

Pamiętaj, że w pamięci pozostanie kilkaset tensorów, ponieważ nie pozbywa się ani modelu MobileNet, ani wielowarstwowego perceptrona. Jeśli zdecydujesz się na kolejny trenowanie po zresetowaniu, będzie trzeba użyć ich ponownie z nowymi danymi treningowymi.

16. Wypróbuj

Czas przetestować własną wersję Teachable Machine.

Otwórz podgląd na żywo, włącz kamerę internetową, zbierz co najmniej 30 przykładów dla klasy 1 dla obiektu w sali, a następnie wykonaj to samo dla klasy 2 dla innego obiektu, kliknij Trenuj i sprawdź postęp w dzienniku konsoli. Powinien szybko się trenować:

bf1ac3cc5b15740.gif

Po wytrenowaniu pokaż obiekty kamery, aby uzyskać prognozy na żywo, które zostaną wydrukowane w obszarze tekstowym stanu w górnej części strony internetowej. W razie problemów sprawdź mój gotowy kod, aby zobaczyć, czy coś nie zostało skopiowane.

17. Gratulacje

Gratulacje! Właśnie udało Ci się ukończyć pierwszy przykład nauki z przenoszeniem przy użyciu TensorFlow.js na żywo w przeglądarce.

Wypróbuj go i przetestuj na różnych obiektach. Być może zauważysz, że niektóre rzeczy są trudniejsze do rozpoznania niż inne, zwłaszcza jeśli są podobne do innych. Aby je rozróżnić, konieczne może być dodanie większej liczby zajęć lub danych szkoleniowych.

Podsumowanie

Dzięki temu ćwiczeniu w Codelabs omówiliśmy:

  1. Czym jest uczenie się przenoszenia i jego zalety w porównaniu z trenowaniem pełnego modelu.
  2. Jak uzyskać modele do ponownego wykorzystania z TensorFlow Hub.
  3. Jak skonfigurować aplikację internetową do nauki przenoszenia.
  4. Jak wczytać model podstawowy i używać go do generowania funkcji obrazu.
  5. Jak wytrenować nową głowę przewidywania, która będzie rozpoznawać obiekty niestandardowe na podstawie zdjęć z kamery internetowej.
  6. Jak używać uzyskanych modeli do klasyfikowania danych w czasie rzeczywistym.

Co dalej?

Skoro masz już gotową bazę, możesz zacząć. Jakie pomysły możesz wpaść, aby wykorzystać stały model systemów uczących się w praktyce w rzeczywistym przypadku użycia, nad którym pracujesz? Może uda Ci się zrewolucjonizować branżę, w której obecnie pracujesz, i pomóc pracownikom w trenowaniu modeli do klasyfikowania rzeczy ważnych w ich codziennej pracy? Możliwości są nieograniczone.

Aby przejść dalej, możesz bezpłatnie wziąć udział w tym bezpłatnym kursie. W jego ramach dowiesz się, jak połączyć 2 modele używane obecnie w tym ćwiczeniu w programie, aby zwiększyć ich efektywność.

Jeśli chcesz dowiedzieć się więcej o teorii stojącej za oryginalną aplikacją do nauczania, przeczytaj ten samouczek.

Podziel się z nami tym, co udało Ci się stworzyć

Możesz z łatwością wykorzystać swoje dzisiejsze materiały również w innych kreatywnych celach, więc zachęcamy do kreatywnego myślenia i dalszego hakowania.

Pamiętaj, aby oznaczyć nas w mediach społecznościowych hashtagiem #MadeWithTFJS, aby mieć szansę na zaprezentowanie Twojego projektu na blogu TensorFlow lub nawet na przyszłych wydarzeniach. Chętnie zobaczymy, co stworzysz.

Strony, na których można płacić za zakupy