TensorFlow.js: สร้าง "Teachable Machine" โดยใช้การเรียนรู้การโอนด้วย TensorFlow.js

1. ก่อนเริ่มต้น

การใช้โมเดล TensorFlow.js เติบโตขึ้นแบบทวีคูณในช่วง 2-3 ปีที่ผ่านมาและตอนนี้นักพัฒนาซอฟต์แวร์ JavaScript จำนวนมากกำลังมองหาโมเดลที่จะนำโมเดลล้ำสมัยที่มีอยู่มาใช้และฝึกให้โมเดลนี้ทำงานกับข้อมูลที่กำหนดเองซึ่งเป็นเอกลักษณ์เฉพาะในอุตสาหกรรมของตน การนำโมเดลที่มีอยู่ (มักเรียกว่าโมเดลพื้นฐาน) มาใช้ในโดเมนที่คล้ายกันแต่ต่างออกไปเรียกว่า "การเรียนรู้การโอน"

การเรียนรู้แบบโอนมีข้อดีมากกว่าการเริ่มจากโมเดลเปล่าทั้งหมด คุณนำความรู้ที่ได้เรียนรู้จากโมเดลที่ฝึกไปแล้วมาใช้ใหม่ได้โดยไม่ต้องขอตัวอย่างรายการใหม่ที่ต้องการแยกประเภท นอกจากนี้ การฝึกมักรวดเร็วกว่ามากเนื่องจากต้องฝึกเพียง 2-3 เลเยอร์สุดท้ายของสถาปัตยกรรมโมเดลอีกครั้ง แทนที่จะเป็นทั้งเครือข่าย ด้วยเหตุนี้ การเรียนรู้การโอนจึงเหมาะอย่างยิ่งสำหรับสภาพแวดล้อมของเว็บเบราว์เซอร์ที่ทรัพยากรอาจแตกต่างกันไปตามอุปกรณ์ที่ใช้งาน แต่ก็เข้าถึงเซ็นเซอร์ได้โดยตรงเพื่อให้รับข้อมูลได้ง่าย

Codelab นี้จะแสดงวิธีสร้างเว็บแอปจากผืนผ้าใบเปล่า เพื่อสร้างประสบการณ์ Teachable Machine" เว็บไซต์ของคุณ เว็บไซต์นี้จะช่วยให้คุณสร้างเว็บแอปที่ทำงานซึ่งผู้ใช้ทุกคนสามารถใช้เพื่อจดจำออบเจ็กต์ที่กำหนดเองด้วยรูปภาพตัวอย่างจากเว็บแคม 2-3 ภาพ เราตั้งใจให้เว็บไซต์มีขนาดเล็กที่สุดเพื่อให้คุณมุ่งเน้นด้านแมชชีนเลิร์นนิงของ Codelab นี้ได้ อย่างไรก็ตาม เรามีขอบเขตมากมายที่จะนำไปใช้กับนักพัฒนาเว็บที่มีอยู่เพื่อปรับปรุง UX ได้เช่นเดียวกับเว็บไซต์เดิมของ Teachable Machine

ข้อกำหนดเบื้องต้น

Codelab นี้เขียนขึ้นสำหรับนักพัฒนาเว็บที่ค่อนข้างคุ้นเคยกับโมเดลสำเร็จรูปของ TensorFlow.js และการใช้งาน API พื้นฐาน และผู้ที่ต้องการเริ่มต้นใช้งานการโอนการเรียนรู้ inTensorFlow.js

  • เราถือว่ามีความคุ้นเคยกับ TensorFlow.js, HTML5, CSS และ JavaScript ระดับพื้นฐานสำหรับห้องทดลองนี้

หากคุณเพิ่งเริ่มใช้ Tensorflow.js ให้ลองเรียนหลักสูตร Zero to hero แบบไม่เสียค่าใช้จ่ายนี้ก่อน ซึ่งจะถือว่าไม่มีพื้นฐานมาจากแมชชีนเลิร์นนิงหรือ TensorFlow.js และสอนทุกเรื่องที่คุณจำเป็นต้องทราบในขั้นตอนย่อย

สิ่งที่คุณจะได้เรียนรู้

  • TensorFlow.js คืออะไรและเหตุผลที่คุณควรใช้ในเว็บแอปถัดไป
  • วิธีสร้างหน้าเว็บ HTML/CSS /JS แบบง่ายที่จำลองประสบการณ์ของผู้ใช้ Teachable Machine
  • วิธีใช้ TensorFlow.js เพื่อโหลดโมเดลฐานที่ฝึกไว้แล้วล่วงหน้า โดยเฉพาะ MobileNet เพื่อสร้างฟีเจอร์รูปภาพที่ใช้ในการเรียนรู้การโอนได้
  • วิธีรวบรวมข้อมูลจากเว็บแคมของผู้ใช้สำหรับข้อมูลหลายคลาสที่คุณต้องการจดจำ
  • วิธีสร้างและกำหนด Perceptron หลายชั้นที่ใช้ฟีเจอร์รูปภาพและเรียนรู้วิธีแยกประเภทวัตถุใหม่โดยใช้วัตถุเหล่านั้น

มาเริ่มแฮ็กกันเลย...

สิ่งที่คุณต้องมี

  • เราขอแนะนำให้ศึกษาบัญชี Glitch.com ควบคู่กันไป หรือคุณสามารถใช้สภาพแวดล้อมการให้บริการบนเว็บที่คุณสะดวกในการแก้ไขและเรียกใช้ได้

2. TensorFlow.js คืออะไร

54e81d02971f53e8.png

TensorFlow.js เป็นไลบรารีแมชชีนเลิร์นนิงโอเพนซอร์สที่เรียกใช้ JavaScript ได้ทุกที่ API ดังกล่าวอิงตามไลบรารี TensorFlow เดิมที่เขียนด้วย Python และมีเป้าหมายที่จะสร้างประสบการณ์สำหรับนักพัฒนาซอฟต์แวร์และชุด API ใหม่สำหรับระบบนิเวศของ JavaScript

ใช้ได้ที่ไหนบ้าง

เนื่องจาก JavaScript สามารถถ่ายโอนได้ จึงช่วยให้คุณเขียนในภาษาเดียวและดำเนินการแมชชีนเลิร์นนิงในแพลตฟอร์มทั้งหมดต่อไปนี้ได้อย่างง่ายดาย

  • ฝั่งไคลเอ็นต์ในเว็บเบราว์เซอร์ที่ใช้ JavaScript แบบวานิลลา
  • ฝั่งเซิร์ฟเวอร์และอุปกรณ์ IoT อย่าง Raspberry Pi ที่ใช้ Node.js
  • แอปบนเดสก์ท็อปที่ใช้ Electron
  • แอปที่มาพร้อมเครื่องที่ใช้ React Native

TensorFlow.js ยังสนับสนุนแบ็กเอนด์หลายชุดภายในแต่ละสภาพแวดล้อมเหล่านี้ (สภาพแวดล้อมจริงที่ใช้ฮาร์ดแวร์ซึ่งสามารถทำงานภายในได้ เช่น CPU หรือ WebGL เป็นต้น "แบ็กเอนด์" ในบริบทนี้ไม่ได้หมายถึงสภาพแวดล้อมฝั่งเซิร์ฟเวอร์ ตัวอย่างเช่น แบ็กเอนด์สำหรับการดำเนินการอาจเป็นฝั่งไคลเอ็นต์ใน WebGL) เพื่อรองรับความเข้ากันได้และทำให้ทุกอย่างทำงานได้อย่างรวดเร็ว ปัจจุบัน TensorFlow.js รองรับ

  • การดำเนินการของ WebGL บนการ์ดแสดงผลของอุปกรณ์ (GPU) - นี่เป็นวิธีที่เร็วที่สุดในการใช้งานโมเดลที่มีขนาดใหญ่ขึ้น (มีขนาดใหญ่กว่า 3 MB) ด้วยการเร่งความเร็ว GPU
  • การดำเนินการ Web Assembly (WASM) บน CPU เพื่อปรับปรุงประสิทธิภาพของ CPU ในอุปกรณ์ต่างๆ เช่น โทรศัพท์มือถือรุ่นเก่า วิธีนี้เหมาะกับโมเดลขนาดเล็ก (มีขนาดเล็กกว่า 3 MB) ซึ่งทำงานบน CPU ด้วย WASM ได้เร็วกว่า WebGL เนื่องจากค่าใช้จ่ายในการอัปโหลดเนื้อหาไปยังโปรเซสเซอร์กราฟิก
  • การดำเนินการของ CPU - ตัวเลือกสำรองต้องไม่มีสภาพแวดล้อมอื่นๆ พร้อมใช้งาน นี่คือรายการที่ช้าที่สุดใน 3 อันดับ แต่เราพร้อมช่วยเหลือคุณเสมอ

หมายเหตุ: คุณเลือกที่จะบังคับใช้หนึ่งในแบ็กเอนด์เหล่านี้ได้หากทราบว่าจะใช้อุปกรณ์ใด หรือจะให้ TensorFlow.js ตัดสินใจแทนก็ได้หากไม่ระบุ

พลังพิเศษฝั่งไคลเอ็นต์

การเรียกใช้ TensorFlow.js ในเว็บเบราว์เซอร์บนเครื่องไคลเอ็นต์อาจก่อให้เกิดประโยชน์มากมายที่ควรค่าแก่การพิจารณา

ความเป็นส่วนตัว

คุณจะฝึกและแยกประเภทข้อมูลในเครื่องไคลเอ็นต์ได้โดยไม่ต้องส่งข้อมูลไปยังเว็บเซิร์ฟเวอร์ของบุคคลที่สาม บางครั้งระบบอาจกำหนดให้ต้องปฏิบัติตามข้อกำหนดเพื่อปฏิบัติตามกฎหมายท้องถิ่น เช่น GDPR หรือเมื่อประมวลผลข้อมูลที่ผู้ใช้อาจต้องการเก็บไว้ในเครื่องและไม่ส่งไปยังบุคคลที่สาม

ความเร็ว

เนื่องจากคุณไม่ต้องส่งข้อมูลไปยังเซิร์ฟเวอร์ระยะไกล การอนุมาน (การแยกประเภทข้อมูล) จึงทำงานได้เร็วขึ้น ยิ่งไปกว่านั้น คุณยังเข้าถึงเซ็นเซอร์ของอุปกรณ์ได้โดยตรง เช่น กล้อง, ไมโครโฟน, GPS, ตัวตรวจวัดความเร่ง และอื่นๆ หากผู้ใช้ให้สิทธิ์เข้าถึงแก่คุณ

การเข้าถึงและการปรับขนาด

เพียงคลิกเดียว ทุกคนในโลกก็สามารถคลิกลิงก์ที่คุณส่งไป เปิดหน้าเว็บในเบราว์เซอร์ของตน และใช้สิ่งที่คุณสร้างขึ้นมาได้ ไม่ต้องมีการตั้งค่า Linux ฝั่งเซิร์ฟเวอร์ที่ซับซ้อนด้วยไดรเวอร์ CUDA และอื่นๆ อีกมากมายเพื่อใช้ระบบแมชชีนเลิร์นนิงเพียงอย่างเดียว

ค่าใช้จ่าย

หากไม่มีเซิร์ฟเวอร์ สิ่งเดียวที่คุณต้องเสียค่าใช้จ่ายคือ CDN เพื่อโฮสต์ไฟล์ HTML, CSS, JS และไฟล์โมเดล ค่าใช้จ่ายของ CDN ถูกกว่าการเก็บเซิร์ฟเวอร์ (โดยอาจมีการ์ดแสดงผลต่ออยู่) ที่ทำงานตลอด 24 ชั่วโมงอย่างมาก

ฟีเจอร์ฝั่งเซิร์ฟเวอร์

การใช้ประโยชน์จากการติดตั้งใช้งาน Node.js ของ TensorFlow.js จะเป็นการเปิดใช้ฟีเจอร์ต่อไปนี้

การรองรับ CUDA อย่างเต็มรูปแบบ

สำหรับฝั่งเซิร์ฟเวอร์ สำหรับการเร่งการแสดงผลการ์ดกราฟิก คุณต้องติดตั้งไดรเวอร์ NVIDIA CUDA เพื่อให้ TensorFlow ทำงานกับการ์ดแสดงผล (ต่างจากเบราว์เซอร์ที่ใช้ WebGL ไม่ต้องติดตั้ง) อย่างไรก็ตาม ด้วยการสนับสนุน CUDA เต็มรูปแบบ คุณจะสามารถใช้ประโยชน์จากความสามารถระดับต่ำกว่าของการ์ดแสดงผลได้อย่างเต็มที่ ทำให้การฝึกใช้เวลาและการอนุมานรวดเร็วขึ้นได้ ประสิทธิภาพการทำงานจะเทียบเท่ากับการติดตั้งใช้งาน Python TensorFlow เนื่องจากทั้งคู่ใช้แบ็กเอนด์ C++ เดียวกัน

ขนาดรุ่น

สำหรับโมเดลล้ำสมัยจากการวิจัย คุณอาจทำงานกับโมเดลขนาดใหญ่มากซึ่งอาจมีขนาดกิกะไบต์ โมเดลเหล่านี้ไม่สามารถทำงานในเว็บเบราว์เซอร์ได้ในขณะนี้ เนื่องจากมีข้อจำกัดในการใช้หน่วยความจำต่อแท็บของเบราว์เซอร์ หากต้องการใช้งานโมเดลที่มีขนาดใหญ่ขึ้น คุณสามารถใช้ Node.js บนเซิร์ฟเวอร์ของคุณเองที่มีข้อมูลจำเพาะของฮาร์ดแวร์ที่คุณจำเป็นต้องใช้เพื่อเรียกใช้โมเดลดังกล่าวอย่างมีประสิทธิภาพ

IOT

Node.js ได้รับการสนับสนุนบนคอมพิวเตอร์บอร์ดเดี่ยวที่ได้รับความนิยมอย่าง Raspberry Pi ซึ่งก็หมายความว่าคุณสามารถเรียกใช้โมเดล TensorFlow.js บนอุปกรณ์ดังกล่าวได้เช่นกัน

ความเร็ว

Node.js เขียนขึ้นด้วย JavaScript ซึ่งหมายความว่าจะได้รับประโยชน์จากการรวบรวมข้อมูลในเวลาเท่านั้น ซึ่งหมายความว่าคุณอาจเห็นประสิทธิภาพเพิ่มขึ้นเมื่อใช้ Node.js เนื่องจาก Node.js จะได้รับการเพิ่มประสิทธิภาพขณะรันไทม์ โดยเฉพาะสำหรับการประมวลผลล่วงหน้าที่คุณอาจกำลังทำอยู่ ดูตัวอย่างที่ยอดเยี่ยมได้ในกรณีศึกษานี้ ซึ่งแสดงให้เห็นว่า Hugging Face ใช้ Node.js ในการเพิ่มประสิทธิภาพ 2 เท่าสำหรับโมเดลการประมวลผลภาษาธรรมชาติอย่างไร

ตอนนี้คุณเข้าใจข้อมูลพื้นฐานของ TensorFlow.js ซึ่งทำงานได้รวมถึงข้อดีบางประการแล้ว มาเริ่มทำสิ่งที่เป็นประโยชน์ด้วยกันเลย

3. ถ่ายทอดการเรียนรู้

การโอนการเรียนรู้คืออะไร

การถ่ายทอดการเรียนรู้เกี่ยวข้องกับการนำความรู้ที่ได้เรียนรู้ไปใช้ในการเรียนรู้สิ่งที่ต่างออกไป

มนุษย์เราทำสิ่งนี้ตลอดเวลา คุณมีประสบการณ์ที่จะเก็บอยู่ในสมองมาตลอดชีวิต ซึ่งนำมาใช้ช่วยให้จดจำสิ่งใหม่ๆ ที่ไม่เคยเห็นมาก่อนได้ ลองดูต้นหลิวนี้

e28070392cd4afb9.png

ซึ่งคุณอาจไม่เคยเห็นต้นไม้ประเภทนี้มาก่อน ทั้งนี้ขึ้นอยู่กับว่าคุณอยู่ที่ไหนในโลก

อย่างไรก็ตาม ถ้าฉันขอให้คุณบอกฉันว่ามีต้นหลิวในภาพใหม่ด้านล่างบ้างไหม คุณอาจจะสังเกตเห็นได้ค่อนข้างเร็ว แม้ว่าจะอยู่คนละมุมกัน และต่างกับต้นฉบับที่ฉันแสดงให้ดูเล็กน้อยก็ตาม

d9073a0d5df27222.png

คุณมีเซลล์ประสาทในสมองที่รู้วิธีระบุวัตถุที่คล้ายต้นไม้ และเซลล์ประสาทอื่นๆ ที่หาเส้นตรงยาวๆ ได้ดีแล้ว คุณสามารถนำความรู้ดังกล่าวมาใช้ใหม่เพื่อจำแนกต้นหลิว ซึ่งเป็นวัตถุคล้ายต้นไม้ที่มีกิ่งก้านยาวตรงหลายท่อน

ในทำนองเดียวกัน หากมีโมเดลแมชชีนเลิร์นนิงที่ได้รับการฝึกในโดเมนแล้ว เช่น การจดจำรูปภาพ คุณก็นำโมเดลดังกล่าวมาใช้ซ้ำเพื่อทำงานอื่นที่เกี่ยวข้องได้

คุณสามารถทำเช่นเดียวกันนี้ได้กับโมเดลขั้นสูงอย่าง MobileNet ซึ่งเป็นโมเดลการวิจัยที่ได้รับความนิยมอย่างมากที่สามารถจดจำรูปภาพบนวัตถุประเภทต่างๆ กว่า 1, 000 ประเภท ตั้งแต่สุนัขไปจนถึงรถยนต์ เทคโนโลยีได้รับการฝึกด้วยชุดข้อมูลขนาดใหญ่ที่เรียกว่า ImageNet ซึ่งมีรูปภาพที่ติดป้ายกำกับหลายล้านภาพ

ในภาพเคลื่อนไหวนี้ คุณจะเห็นเลเยอร์จำนวนมากในโมเดล MobileNet V1 นี้

7d4e1e35c1a89715.gif

ในระหว่างการฝึก โมเดลนี้ได้เรียนรู้วิธีแยกฟีเจอร์ทั่วไปที่สำคัญสำหรับวัตถุทั้งหมด 1,000 ชิ้นเหล่านั้น และฟีเจอร์ระดับล่างอีกหลายรายการที่ใช้ระบุวัตถุดังกล่าวอาจเป็นประโยชน์ในการตรวจหาวัตถุใหม่ๆ ที่ไม่เคยเห็นมาก่อนด้วยเช่นกัน ท้ายที่สุดแล้ว ทุกอย่างก็เป็นเพียงการผสมเส้น พื้นผิว และรูปทรงเท่านั้น

ลองมาดูสถาปัตยกรรมโครงข่ายระบบประสาทเทียมแบบ Convolutional (CNN) แบบดั้งเดิม (คล้ายกับ MobileNet) และดูว่าการเรียนรู้แบบโอนจะใช้ประโยชน์จากเครือข่ายที่ได้รับการฝึกอบรมนี้เพื่อเรียนรู้สิ่งใหม่ๆ ได้อย่างไร รูปภาพด้านล่างแสดงสถาปัตยกรรมโมเดลทั่วไปของ CNN ซึ่งในกรณีนี้ได้รับการฝึกให้รู้จำตัวเลขที่เขียนด้วยลายมือตั้งแต่ 0 ถึง 9

baf4e3d434576106.png

หากคุณสามารถแยกเลเยอร์ระดับล่างที่ได้รับการฝึกล่วงหน้าของโมเดลที่มีอยู่ฝึกแล้วดังเช่นที่แสดงทางด้านซ้าย จากเลเยอร์การแยกประเภทบริเวณใกล้สุดของโมเดลที่แสดงทางด้านขวา (บางครั้งเรียกว่าส่วนหัวการจัดประเภทของโมเดล) คุณสามารถใช้เลเยอร์ระดับล่างเพื่อสร้างฟีเจอร์เอาต์พุตสำหรับรูปภาพหนึ่งๆ โดยอิงตามข้อมูลต้นฉบับที่ฝึกโมเดลดังกล่าว เครือข่ายเดียวกันซึ่งนำส่วนหัวการแยกประเภทออกมีดังนี้

369a8a9041c6917d.png

สมมติว่าสิ่งใหม่ที่คุณพยายามจดจำสามารถใช้ประโยชน์จากคุณลักษณะเอาต์พุตที่โมเดลก่อนหน้านี้เรียนรู้มาได้ ก็มีโอกาสสูงที่จะสามารถนำข้อมูลเหล่านั้นมาใช้กับวัตถุประสงค์ใหม่ได้

ในแผนภาพด้านบน แบบจำลองสมมตินี้ได้รับการฝึกด้วยตัวเลข ดังนั้นสิ่งที่ได้เรียนรู้เกี่ยวกับตัวเลขยังสามารถประยุกต์ใช้กับตัวอักษร เช่น a, b และ c ได้ด้วย

ตอนนี้คุณจึงสามารถเพิ่มส่วนหัวการแยกประเภทใหม่ที่พยายามคาดการณ์ a, b หรือ c แทน ดังที่แสดงด้านล่าง

db97e5e60ae73bbd.png

เลเยอร์ระดับล่างจะถูกตรึงไว้และไม่ได้รับการฝึก เฉพาะส่วนหัวการแยกประเภทใหม่เท่านั้นที่จะอัปเดตตัวเองเพื่อเรียนรู้จากฟีเจอร์ที่มีให้จากโมเดลที่แยกส่วนก่อนการฝึกทางด้านซ้าย

กระบวนการนี้เรียกว่า "การเรียนรู้แบบโอน" และเป็นสิ่งที่ Teachable Machine ทำอยู่เบื้องหลัง

นอกจากนี้ คุณยังจะเห็นได้ว่าเพียงต้องฝึกการรับรู้แบบหลายเลเยอร์ในตอนท้ายของเครือข่าย ก็ทำให้ฝึกได้เร็วกว่าที่ต้องฝึกทั้งเครือข่ายตั้งแต่ต้น

แต่คุณจะทำความเข้าใจเกี่ยวกับส่วนย่อยของโมเดลได้อย่างไร ไปที่ส่วนถัดไปเพื่อดูคำตอบ

4. TensorFlow Hub - โมเดลฐาน

ค้นหาโมเดลฐานที่เหมาะสมที่จะใช้

สำหรับโมเดลการวิจัยขั้นสูงและได้รับความนิยมมากขึ้น เช่น MobileNet ให้ไปที่ TensorFlow Hub จากนั้นกรองโมเดลที่เหมาะสำหรับ TensorFlow.js ที่ใช้สถาปัตยกรรม MobileNet v3 เพื่อดูผลลัพธ์แบบเดียวกับที่แสดงที่นี่

c5dc1420c6238c14.png

โปรดทราบว่าผลการค้นหาบางส่วนจัดอยู่ในประเภท "การจำแนกประเภทรูปภาพ" (รายละเอียดอยู่ที่ด้านซ้ายบนของผลลัพธ์การ์ดโมเดลแต่ละรายการ) และรายการอื่นๆ เป็นประเภท "เวกเตอร์ฟีเจอร์ภาพ"

ผลลัพธ์เวกเตอร์ของคุณลักษณะรูปภาพเหล่านี้เป็น MobileNet เวอร์ชันก่อนการตัดทอนที่คุณสามารถใช้เพื่อรับเวกเตอร์คุณลักษณะภาพแทนการแยกประเภทขั้นสุดท้ายได้

โมเดลแบบนี้มักจะเรียกว่า "โมเดลฐาน" ซึ่งคุณจะใช้เพื่อดำเนินการเรียนรู้การโอนในลักษณะเดียวกับที่แสดงในส่วนก่อนหน้านี้ได้โดยการเพิ่มส่วนหัวการแยกประเภทใหม่และฝึกด้วยข้อมูลของคุณเอง

สิ่งต่อไปที่ต้องตรวจสอบคือโมเดลฐานที่สนใจซึ่งเผยแพร่โมเดล TensorFlow.js ในรูปแบบดังกล่าว หากคุณเปิดหน้าเว็บสำหรับโมเดล MobileNet v3 ของฟีเจอร์ใดๆ แบบเวกเตอร์เหล่านี้ คุณจะดูได้จากเอกสาร JS ว่าโมเดลดังกล่าวอยู่ในรูปแบบกราฟตามตัวอย่างข้อมูลโค้ดในเอกสารประกอบที่ใช้ tf.loadGraphModel()

f97d903d2e46924b.png

และควรทราบด้วยว่า ถ้าคุณพบโมเดลในรูปแบบเลเยอร์แทนที่จะเป็นรูปแบบกราฟ คุณสามารถเลือกได้ว่าจะให้เลเยอร์ใดตรึงหรือตรึงเลเยอร์ใดสำหรับการฝึก ซึ่งจะมีประโยชน์มากเมื่อสร้างโมเดลสำหรับงานใหม่ ซึ่งมักเรียกว่า "รูปแบบการโอน" ในตอนนี้ คุณจะใช้ประเภทโมเดลกราฟเริ่มต้นในบทแนะนำนี้ ซึ่งโมเดล TF Hub ส่วนใหญ่จะใช้งาน หากต้องการดูข้อมูลเพิ่มเติมเกี่ยวกับการใช้งานโมเดลเลเยอร์ ให้ดูหลักสูตร TensorFlow.js จาก Zero to Hero

ข้อดีของการเรียนรู้การโอน

ประโยชน์ของการใช้การเรียนรู้แบบโอนแทนการฝึกสถาปัตยกรรมโมเดลทั้งหมดตั้งแต่ต้นคืออะไร

ข้อแรก เวลาฝึกอบรมเป็นข้อดีที่สำคัญของการใช้แนวทางการเรียนรู้แบบโอน เนื่องจากคุณมีโมเดลพื้นฐานที่ผ่านการฝึกอบรมเพื่อต่อยอดแล้ว

ประการที่ 2 คุณสามารถหลีกเลี่ยงการแสดงตัวอย่างของสิ่งใหม่ที่คุณพยายามแยกประเภทได้น้อยลงเนื่องจากการฝึกที่ได้ไปแล้ว

ซึ่งมีประโยชน์มากหากคุณมีเวลาและทรัพยากรจำกัดในการรวบรวมข้อมูลตัวอย่างของสิ่งที่คุณต้องการจัดประเภท และจำเป็นต้องสร้างต้นแบบอย่างรวดเร็วก่อนที่จะรวบรวมข้อมูลการฝึกเพิ่มเติมเพื่อทำให้มีประสิทธิภาพมากขึ้น

เนื่องจากจำเป็นต้องใช้ข้อมูลน้อยลงและความเร็วในการฝึกเครือข่ายขนาดเล็ก ระบบจึงใช้ทรัพยากรน้อยลงเพื่อการเรียนรู้แบบโอน วิธีนี้เหมาะสำหรับสภาพแวดล้อมของเบราว์เซอร์ โดยใช้เวลาเพียง 10 วินาทีในเครื่องรุ่นใหม่ แทนที่จะเป็นหลายชั่วโมง วัน หรือสัปดาห์สำหรับการฝึกโมเดลเต็มรูปแบบ

ใช่แล้ว! ตอนนี้เมื่อรู้แก่นแท้ของ Transfer Learning แล้ว ก็ถึงเวลาสร้าง Teachable Machine ในเวอร์ชันของคุณเอง มาเริ่มกันเลย

5. เตรียมเขียนโค้ด

สิ่งที่คุณต้องมี

มาเขียนโค้ดกัน

เทมเพลตแบบ Boilerplate ที่จะเริ่มต้นนั้นสร้างขึ้นสำหรับ Glitch.com หรือ Codepen.io คุณสามารถโคลนเทมเพลตใดเทมเพลตหนึ่งเป็นสถานะฐานสำหรับห้องทดลองโค้ดนี้ได้ด้วยการคลิกเพียงครั้งเดียว

ใน Glitch ให้คลิกปุ่ม "รีมิกซ์นี้" เพื่อแยกและสร้างชุดไฟล์ใหม่ที่แก้ไขได้

นอกจากนี้ ใน Codepen ให้คลิก" fork" ที่ด้านล่างขวาของหน้าจอ

โครงสร้างง่ายๆ นี้จะมีไฟล์ต่อไปนี้ให้คุณ

  • หน้า HTML (index.html)
  • สไตล์ชีต (style.css)
  • ไฟล์สำหรับเขียนโค้ด JavaScript (script.js)

เพื่อความสะดวก จึงมีการนำเข้าเพิ่มเข้ามาในไฟล์ HTML สำหรับไลบรารี TensorFlow.js ซึ่งมีลักษณะดังนี้

index.html

<!-- Import TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js" type="text/javascript"></script>

ทางเลือก: ใช้โปรแกรมแก้ไขเว็บที่ต้องการหรือทำงานในเครื่อง

หากคุณต้องการดาวน์โหลดโค้ดและทำงานภายในเครื่อง หรือในตัวแก้ไขออนไลน์อื่นๆ ก็เพียงแค่สร้างไฟล์ 3 ไฟล์ที่มีชื่อข้างต้นในไดเรกทอรีเดียวกัน แล้วคัดลอกและวางโค้ดจากต้นแบบ Glitch ลงในแต่ละไฟล์

6. เอกสารต้นแบบสำหรับ HTML ของแอป

ฉันต้องไปที่ไหน

ต้นแบบทั้งหมดจำเป็นต้องมีโครงสร้าง HTML พื้นฐานที่คุณสามารถแสดงผลผลลัพธ์ของคุณได้ ตั้งค่าตอนนี้เลย คุณกำลังจะเพิ่ม:

  • ชื่อสำหรับหน้า
  • ข้อความอธิบาย
  • ย่อหน้าสถานะ
  • วิดีโอสำหรับเก็บฟีดเว็บแคมเมื่อพร้อมแล้ว
  • มีปุ่มหลายปุ่มในการเริ่มใช้กล้อง รวบรวมข้อมูล หรือรีเซ็ตประสบการณ์
  • การนำเข้าไฟล์ TensorFlow.js และ JS ที่คุณจะเขียนโค้ดภายหลัง

เปิด index.html แล้ววางทับโค้ดที่มีอยู่พร้อมด้วยรายการต่อไปนี้เพื่อตั้งค่าฟีเจอร์ข้างต้น

index.html

<!DOCTYPE html>
<html lang="en">
  <head>
    <title>Transfer Learning - TensorFlow.js</title>
    <meta charset="utf-8">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <!-- Import the webpage's stylesheet -->
    <link rel="stylesheet" href="/style.css">
  </head>  
  <body>
    <h1>Make your own "Teachable Machine" using Transfer Learning with MobileNet v3 in TensorFlow.js using saved graph model from TFHub.</h1>
    
    <p id="status">Awaiting TF.js load</p>
    
    <video id="webcam" autoplay muted></video>
    
    <button id="enableCam">Enable Webcam</button>
    <button class="dataCollector" data-1hot="0" data-name="Class 1">Gather Class 1 Data</button>
    <button class="dataCollector" data-1hot="1" data-name="Class 2">Gather Class 2 Data</button>
    <button id="train">Train &amp; Predict!</button>
    <button id="reset">Reset</button>

    <!-- Import TensorFlow.js library -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.11.0/dist/tf.min.js" type="text/javascript"></script>

    <!-- Import the page's JavaScript to do some stuff -->
    <script type="module" src="/script.js"></script>
  </body>
</html>

ดูรายละเอียด

มาแบ่งโค้ด HTML ข้างต้นบางข้อเพื่อเน้นสิ่งสำคัญที่คุณได้เพิ่มเข้าไป

  • คุณเพิ่มแท็ก <h1> สำหรับชื่อหน้า พร้อมด้วยแท็ก <p> ที่มีรหัส "สถานะ" ซึ่งเป็นที่ที่คุณจะพิมพ์ข้อมูลลงไป ในขณะที่ใช้ส่วนต่างๆ ของระบบเพื่อดูเอาต์พุต
  • คุณเพิ่มองค์ประกอบ <video> ที่มีรหัส "เว็บแคม" ซึ่งคุณจะแสดงผลสตรีมเว็บแคมของคุณในภายหลัง
  • คุณเพิ่มองค์ประกอบ <button> 5 รายการแล้ว องค์ประกอบแรกที่มีรหัส "enableCam" เปิดใช้กล้อง ปุ่ม 2 ปุ่มถัดไปมีคลาสเป็น "dataCollector" ซึ่งให้คุณรวบรวมรูปภาพตัวอย่างของวัตถุที่คุณต้องการจดจำ โค้ดที่คุณเขียนภายหลังได้รับการออกแบบเพื่อให้คุณสามารถเพิ่มปุ่มจำนวนเท่าใดก็ได้ และปุ่มเหล่านั้นจะทำงานตามที่ต้องการโดยอัตโนมัติ

โปรดทราบว่าปุ่มเหล่านี้ยังมีแอตทริบิวต์พิเศษที่กำหนดโดยผู้ใช้ที่เรียกว่า data-1hot โดยมีค่าจำนวนเต็มเริ่มต้นที่ 0 สำหรับคลาสแรก นี่คือดัชนีตัวเลขที่คุณจะใช้แสดงข้อมูลของชั้นเรียน ระบบจะใช้ดัชนีเพื่อเข้ารหัสคลาสเอาต์พุตอย่างถูกต้องด้วยการแสดงตัวเลขแทนสตริง เนื่องจากโมเดล ML ทำงานได้กับตัวเลขเท่านั้น

นอกจากนี้ ยังมีแอตทริบิวต์ data-name ซึ่งมีชื่อที่มนุษย์อ่านได้ที่คุณต้องการใช้สำหรับคลาสนี้ ซึ่งจะช่วยให้คุณระบุชื่อที่มีความหมายมากขึ้นสำหรับผู้ใช้แทนค่าดัชนีตัวเลขจากการเข้ารหัส Hot 1 ได้

สุดท้าย คุณมีปุ่มฝึกและรีเซ็ตเพื่อเริ่มต้นกระบวนการฝึกเมื่อรวบรวมข้อมูลแล้ว หรือรีเซ็ตแอปตามลำดับ

  • คุณเพิ่มการนำเข้า <script> อีก 2 รายการแล้ว ลิงก์หนึ่งสำหรับ TensorFlow.js และอีกแท็กสำหรับ Script.js ที่คุณจะกำหนดในไม่ช้า

7. เพิ่มสไตล์

ค่าเริ่มต้นขององค์ประกอบ

เพิ่มสไตล์สำหรับองค์ประกอบ HTML ที่คุณเพิ่งเพิ่มเพื่อให้แน่ใจว่าจะแสดงอย่างถูกต้อง สไตล์ที่เพิ่มไปยังองค์ประกอบการวางตำแหน่งและขนาดอย่างถูกต้องมีดังนี้ ไม่มีอะไรพิเศษเกินไป คุณเพิ่มวิธีนี้ในภายหลังเพื่อสร้าง UX ที่ดียิ่งขึ้นได้ อย่างที่เห็นในวิดีโอแมชชีนที่สอนได้

style.css

body {
  font-family: helvetica, arial, sans-serif;
  margin: 2em;
}

h1 {
  font-style: italic;
  color: #FF6F00;
}


video {
  clear: both;
  display: block;
  margin: 10px;
  background: #000000;
  width: 640px;
  height: 480px;
}

button {
  padding: 10px;
  float: left;
  margin: 5px 3px 5px 10px;
}

.removed {
  display: none;
}

#status {
  font-size:150%;
}

เยี่ยม! เท่านี้ก็เรียบร้อย หากคุณดูตัวอย่างเอาต์พุตในขณะนี้ โค้ดควรมีลักษณะดังนี้

81909685d7566dcb.png

8. JavaScript: ค่าคงที่หลักและ Listener

ระบุค่าคงที่หลัก

ก่อนอื่นให้ใส่ค่าคงที่หลักที่คุณจะใช้ทั่วทั้งแอป เริ่มต้นด้วยการแทนที่เนื้อหาของ script.js ด้วยค่าคงที่ต่อไปนี้

script.js

const STATUS = document.getElementById('status');
const VIDEO = document.getElementById('webcam');
const ENABLE_CAM_BUTTON = document.getElementById('enableCam');
const RESET_BUTTON = document.getElementById('reset');
const TRAIN_BUTTON = document.getElementById('train');
const MOBILE_NET_INPUT_WIDTH = 224;
const MOBILE_NET_INPUT_HEIGHT = 224;
const STOP_DATA_GATHER = -1;
const CLASS_NAMES = [];

มาดูรายละเอียดกันว่าสิ่งเหล่านี้มีไว้เพื่ออะไร

  • STATUS จะใช้เพียงการอ้างอิงแท็กย่อหน้าที่คุณจะเขียนอัปเดตสถานะ
  • VIDEO มีการอ้างอิงองค์ประกอบวิดีโอ HTML ที่จะแสดงฟีดเว็บแคม
  • ENABLE_CAM_BUTTON, RESET_BUTTON และ TRAIN_BUTTON จะใช้การอ้างอิง DOM ไปยังปุ่มคีย์ทั้งหมดจากหน้า HTML
  • MOBILE_NET_INPUT_WIDTH และ MOBILE_NET_INPUT_HEIGHT กำหนดความกว้างและความสูงของอินพุตที่คาดไว้ของโมเดล MobileNet ตามลำดับ ด้วยการเก็บไว้ในค่าคงที่ใกล้กับด้านบนของไฟล์เช่นนี้ หากคุณตัดสินใจที่จะใช้เวอร์ชันอื่นในภายหลัง ก็จะช่วยให้อัปเดตค่าเพียงครั้งเดียวได้ง่ายขึ้น แทนที่จะต้องแทนที่ด้วยหลายๆ ที่
  • ตั้งค่า STOP_DATA_GATHER เป็น - 1 การดำเนินการนี้จะจัดเก็บค่าสถานะเพื่อให้คุณทราบเมื่อผู้ใช้หยุดคลิกปุ่มเพื่อรวบรวมข้อมูลจากฟีดเว็บแคม การตั้งชื่อให้ตัวเลขนี้มีความหมายมากขึ้นจะทำให้อ่านรหัสได้ง่ายขึ้นในภายหลัง
  • CLASS_NAMES ทำหน้าที่เป็นการค้นหาและเก็บชื่อที่มนุษย์อ่านได้สำหรับการคาดคะเนคลาสที่เป็นไปได้ ระบบจะป้อนข้อมูลอาร์เรย์นี้ในภายหลัง

โอเค เมื่อคุณมีการอ้างอิงไปยังองค์ประกอบหลักแล้ว ก็ถึงเวลาเชื่อมโยง Listener เหตุการณ์กับองค์ประกอบเหล่านั้น

เพิ่ม Listener เหตุการณ์สําคัญ

เริ่มต้นด้วยการเพิ่มเครื่องจัดการเหตุการณ์การคลิกลงในปุ่มคีย์ดังที่แสดงด้านล่างนี้

script.js

ENABLE_CAM_BUTTON.addEventListener('click', enableCam);
TRAIN_BUTTON.addEventListener('click', trainAndPredict);
RESET_BUTTON.addEventListener('click', reset);


function enableCam() {
  // TODO: Fill this out later in the codelab!
}


function trainAndPredict() {
  // TODO: Fill this out later in the codelab!
}


function reset() {
  // TODO: Fill this out later in the codelab!
}

ENABLE_CAM_BUTTON - เรียกใช้ฟังก์ชัน enabledCam เมื่อคลิก

TRAIN_BUTTON - เรียกใช้ trialAndPredict เมื่อคลิก

RESET_BUTTON - การโทรจะรีเซ็ตเมื่อคลิก

สุดท้ายในส่วนนี้ คุณจะเห็นปุ่มทั้งหมดที่มีคลาสเป็น "dataCollector" ด้วย document.querySelectorAll() การดำเนินการนี้จะแสดงอาร์เรย์ขององค์ประกอบที่พบจากเอกสารที่ตรงกับรายการต่อไปนี้

script.js

let dataCollectorButtons = document.querySelectorAll('button.dataCollector');
for (let i = 0; i < dataCollectorButtons.length; i++) {
  dataCollectorButtons[i].addEventListener('mousedown', gatherDataForClass);
  dataCollectorButtons[i].addEventListener('mouseup', gatherDataForClass);
  // Populate the human readable names for classes.
  CLASS_NAMES.push(dataCollectorButtons[i].getAttribute('data-name'));
}


function gatherDataForClass() {
  // TODO: Fill this out later in the codelab!
}

คำอธิบายโค้ด

จากนั้นทำซ้ำผ่านปุ่มที่พบและเชื่อมโยง Listener เหตุการณ์ 2 รายการกับแต่ละปุ่ม อันหนึ่งสำหรับ "mousedown" และอีกอันสำหรับ "mouseup" วิธีนี้จะช่วยให้คุณบันทึกตัวอย่างได้ตราบใดที่กดปุ่ม ซึ่งมีประโยชน์ในการเก็บรวบรวมข้อมูล

ทั้ง 2 เหตุการณ์จะเรียกใช้ฟังก์ชัน gatherDataForClass ที่คุณจะกําหนดในภายหลัง

ในจุดนี้ คุณยังพุชชื่อคลาสที่มนุษย์อ่านเจอได้จากแอตทริบิวต์ data-name ของปุ่ม HTML ไปยังอาร์เรย์ CLASS_NAMES ได้ด้วย

จากนั้นเพิ่มตัวแปรเพื่อจัดเก็บสิ่งสำคัญที่จะใช้ในภายหลัง

script.js

let mobilenet = undefined;
let gatherDataState = STOP_DATA_GATHER;
let videoPlaying = false;
let trainingDataInputs = [];
let trainingDataOutputs = [];
let examplesCount = [];
let predict = false;

ลองมาดูกันที่จุดเหล่านั้น

ประการแรก คุณมีตัวแปร mobilenet เพื่อจัดเก็บรูปแบบ Mobilenet ที่โหลด โดยเริ่มแรกตั้งค่านี้เป็น "ไม่ระบุ"

ถัดไปคือตัวแปรชื่อ gatherDataState หาก "เครื่องมือเก็บข้อมูล" เมื่อกดปุ่ม การเปลี่ยนแปลงนี้จะเป็น Hot ID 1 ของปุ่มนั้นแทน ตามที่กำหนดใน HTML เพื่อที่คุณจะได้ทราบว่ากำลังเก็บข้อมูลในระดับใดอยู่ ในตอนแรก การตั้งค่านี้เป็น STOP_DATA_GATHER เพื่อให้ข้อมูลที่รวบรวมวนซ้ำที่คุณเขียนในภายหลังจะไม่รวบรวมข้อมูลใดๆ เมื่อไม่มีปุ่มกด

videoPlaying ติดตามว่าสตรีมจากเว็บแคมโหลดและเล่นสำเร็จหรือไม่ และพร้อมให้ใช้งานหรือไม่ ในตอนแรก การตั้งค่านี้จะตั้งเป็น false เนื่องจากเว็บแคมจะไม่เปิดจนกว่าคุณจะกด ENABLE_CAM_BUTTON.

ถัดไป ให้กำหนดอาร์เรย์ 2 รายการ คือ trainingDataInputs และ trainingDataOutputs รายการเหล่านี้จะจัดเก็บค่าข้อมูลการฝึกที่รวบรวมไว้ เมื่อคุณคลิก "เครื่องมือเก็บข้อมูล" สำหรับฟีเจอร์อินพุตที่สร้างโดยโมเดลฐานของ MobileNet และคลาสเอาต์พุตสุ่มตัวอย่างตามลำดับ

ระบบจะกำหนดอาร์เรย์สุดท้าย 1 รายการ examplesCount, เพื่อติดตามจำนวนตัวอย่างที่มีอยู่สำหรับแต่ละคลาสเมื่อคุณเริ่มเพิ่มอาร์เรย์

ขั้นสุดท้าย คุณมีตัวแปรชื่อ predict ซึ่งควบคุมลูปการคาดการณ์ของคุณ ค่าเริ่มต้นนี้ตั้งไว้เป็น false ไม่สามารถคาดการณ์ได้จนกว่าจะมีการตั้งค่านี้เป็น true ในภายหลัง

ตอนนี้เราได้กำหนดตัวแปรหลักทั้งหมดแล้ว เรามาโหลดโมเดลฐานของ MobileNet v3 ที่มีการสับเปลี่ยนซึ่งแสดงเวกเตอร์ฟีเจอร์ของรูปภาพแทนการแยกประเภทกัน

9. โหลดโมเดลฐานของ MobileNet

ขั้นแรก ให้กำหนดฟังก์ชันใหม่ที่เรียกว่า loadMobileNetFeatureModel ดังที่แสดงด้านล่าง นี่ต้องเป็นฟังก์ชันแบบไม่พร้อมกันเนื่องจากการโหลดโมเดลเป็นแบบไม่พร้อมกัน

script.js

/**
 * Loads the MobileNet model and warms it up so ready for use.
 **/
async function loadMobileNetFeatureModel() {
  const URL = 
    'https://tfhub.dev/google/tfjs-model/imagenet/mobilenet_v3_small_100_224/feature_vector/5/default/1';
  
  mobilenet = await tf.loadGraphModel(URL, {fromTFHub: true});
  STATUS.innerText = 'MobileNet v3 loaded successfully!';
  
  // Warm up the model by passing zeros through it once.
  tf.tidy(function () {
    let answer = mobilenet.predict(tf.zeros([1, MOBILE_NET_INPUT_HEIGHT, MOBILE_NET_INPUT_WIDTH, 3]));
    console.log(answer.shape);
  });
}

// Call the function immediately to start loading.
loadMobileNetFeatureModel();

ในโค้ดนี้ คุณกำหนด URL ซึ่งโมเดลที่จะโหลดนั้นอยู่จากเอกสารประกอบของ TFHub

จากนั้นคุณสามารถโหลดโมเดลโดยใช้ await tf.loadGraphModel() รวมถึงตั้งค่าพร็อพเพอร์ตี้พิเศษ fromTFHub เป็น true ขณะโหลดโมเดลจากเว็บไซต์ Google นี้ นี่เป็นกรณีพิเศษเฉพาะเมื่อใช้โมเดลที่โฮสต์ใน TF Hub ที่ต้องตั้งค่าพร็อพเพอร์ตี้เพิ่มเติมนี้

เมื่อโหลดเสร็จแล้ว คุณสามารถตั้งค่า innerText ขององค์ประกอบ STATUS ด้วยข้อความเพื่อให้มองเห็นได้ว่าข้อความโหลดอย่างถูกต้อง และคุณพร้อมที่จะเริ่มรวบรวมข้อมูลแล้ว

สิ่งเดียวที่ต้องทำในตอนนี้คือเตรียมโมเดลให้พร้อม ด้วยโมเดลขนาดใหญ่เช่นนี้ ในครั้งแรกที่คุณใช้โมเดล อาจใช้เวลาสักครู่เพื่อตั้งค่าทุกอย่าง ดังนั้นการส่งเลข 0 ผ่านโมเดลจึงช่วยหลีกเลี่ยงการรอในอนาคตซึ่งช่วงเวลาที่สำคัญกว่า

คุณสามารถใช้ tf.zeros() ที่รวมไว้ใน tf.tidy() เพื่อให้แน่ใจว่ามีการกำจัด Tensor อย่างถูกต้องด้วยขนาดกลุ่มเป็น 1 รวมถึงความสูงและความกว้างที่ถูกต้องตามที่คุณกำหนดในค่าคงที่ตั้งแต่เริ่มต้น สุดท้าย คุณต้องระบุช่องสี ซึ่งในกรณีนี้คือ 3 ตามที่โมเดลต้องการรูปภาพ RGB

ถัดไป ให้บันทึกรูปร่างผลลัพธ์ของ tensor ที่แสดงผลโดยใช้ answer.shape() เพื่อช่วยให้คุณเข้าใจขนาดของฟีเจอร์รูปภาพที่โมเดลนี้สร้างขึ้น

หลังจากกำหนดฟังก์ชันนี้แล้ว คุณสามารถเรียกใช้ได้ทันทีเพื่อเริ่มการดาวน์โหลดโมเดลในการโหลดหน้าเว็บ

หากดูตัวอย่างสดในขณะนี้ คุณจะเห็นข้อความสถานะเปลี่ยนจาก "กำลังรอโหลด TF.js" หลังจากนั้นสักครู่ ให้กลายเป็น "โหลด MobileNet v3 สำเร็จแล้ว!" ดังที่แสดงด้านล่าง โปรดตรวจสอบว่าวิธีนี้ได้ผลก่อนดำเนินการต่อ

a28b734e190afff.png

คุณสามารถตรวจสอบเอาต์พุตของคอนโซลเพื่อดูขนาดที่พิมพ์ของฟีเจอร์เอาต์พุตที่โมเดลสร้างขึ้นได้ หลังจากเรียกใช้เลขศูนย์ในโมเดล MobileNet คุณจะเห็นรูปร่าง [1, 1024] ที่พิมพ์ออกมา รายการแรกมีขนาดกลุ่มเท่ากับ 1 ซึ่งคุณจะเห็นได้ว่าจริงๆ แล้วแสดงฟีเจอร์ 1, 024 อย่างที่นำไปใช้แยกประเภทออบเจ็กต์ใหม่ได้

10. กำหนดส่วนหัวของโมเดลใหม่

ตอนนี้ได้เวลากำหนดหัวโมเดลของคุณ ซึ่งโดยพื้นฐานแล้วก็คือ Perceptron หลายชั้นที่มีขนาดเล็กมาก

script.js

let model = tf.sequential();
model.add(tf.layers.dense({inputShape: [1024], units: 128, activation: 'relu'}));
model.add(tf.layers.dense({units: CLASS_NAMES.length, activation: 'softmax'}));

model.summary();

// Compile the model with the defined optimizer and specify a loss function to use.
model.compile({
  // Adam changes the learning rate over time which is useful.
  optimizer: 'adam',
  // Use the correct loss function. If 2 classes of data, must use binaryCrossentropy.
  // Else categoricalCrossentropy is used if more than 2 classes.
  loss: (CLASS_NAMES.length === 2) ? 'binaryCrossentropy': 'categoricalCrossentropy', 
  // As this is a classification problem you can record accuracy in the logs too!
  metrics: ['accuracy']  
});

ลองมาดูโค้ดนี้อย่างละเอียดกัน คุณเริ่มต้นโดยการกำหนดโมเดล tf.sอันดับ ที่จะเพิ่มเลเยอร์โมเดลลงไป

ต่อไป ให้เพิ่มเลเยอร์ที่หนาแน่นเป็นเลเยอร์อินพุตในโมเดลนี้ ซึ่งมีรูปแบบอินพุตเป็น 1024 เนื่องจากเอาต์พุตจากฟีเจอร์ MobileNet v3 มีขนาดเท่านี้ คุณค้นพบขั้นตอนนี้ได้ในขั้นตอนก่อนหน้าหลังจากที่ส่งผ่านโมเดล ชั้นนี้มีเซลล์ประสาท 128 เซลล์ที่ใช้ฟังก์ชันการเปิดใช้งาน ReLU

หากคุณยังไม่คุ้นเคยกับฟังก์ชันการเปิดใช้งานและเลเยอร์โมเดล โปรดลองศึกษาหลักสูตรนี้โดยละเอียดเมื่อเริ่มต้นเวิร์กช็อปนี้ เพื่อทำความเข้าใจว่าพร็อพเพอร์ตี้เหล่านี้ทำหน้าที่อะไรในเบื้องหลัง

เลเยอร์ถัดไปที่จะเพิ่มคือเลเยอร์เอาต์พุต จำนวนเซลล์ประสาทควรเท่ากับจำนวนคลาสที่คุณพยายามคาดการณ์ ในการทำเช่นนี้ คุณสามารถใช้ CLASS_NAMES.length เพื่อดูจำนวนชั้นเรียนที่คุณต้องการจัดประเภท ซึ่งเท่ากับจำนวนปุ่มรวบรวมข้อมูลที่พบในอินเทอร์เฟซผู้ใช้ เนื่องจากนี่เป็นปัญหาการจัดประเภท คุณใช้การเปิดใช้งาน softmax ในเลเยอร์เอาต์พุตนี้ ซึ่งต้องใช้เมื่อพยายามสร้างโมเดลเพื่อแก้ปัญหาการจัดประเภทแทนการถดถอย

ตอนนี้ให้พิมพ์ model.summary() เพื่อพิมพ์ภาพรวมของโมเดลที่กำหนดใหม่ไปยังคอนโซล

สุดท้าย ให้คอมไพล์โมเดลเพื่อให้พร้อมสำหรับการฝึก ในที่นี้ เครื่องมือเพิ่มประสิทธิภาพได้รับการตั้งค่าเป็น adam และการสูญเสียจะเป็น binaryCrossentropy หาก CLASS_NAMES.length เท่ากับ 2 หรือจะใช้ categoricalCrossentropy หากมีคลาสที่จะแยกประเภทตั้งแต่ 3 คลาสขึ้นไป นอกจากนี้ยังมีการขอเมตริกความแม่นยำด้วย เพื่อให้สามารถตรวจสอบในบันทึกในภายหลังเพื่อจุดประสงค์ในการแก้ไขข้อบกพร่อง

คุณจะเห็นข้อมูลต่อไปนี้ในคอนโซล

22eaf32286fea4bb.png

โปรดทราบว่าพารามิเตอร์นี้มีพารามิเตอร์ที่ฝึกได้กว่า 130,000 รายการ แต่เนื่องจากนี่เป็นชั้นเซลล์ประสาทปกติที่หนาแน่นง่ายๆ มันจะฝึกได้อย่างรวดเร็ว

สำหรับกิจกรรมที่ต้องทำเมื่อทำโครงงานเสร็จ คุณอาจลองเปลี่ยนจำนวนเซลล์ประสาทในชั้นแรก เพื่อดูว่าจะอยู่ในระดับต่ำแค่ไหนโดยที่ยังมีประสิทธิภาพดีอยู่ แมชชีนเลิร์นนิงมักมีการลองผิดลองถูกในระดับหนึ่งเพื่อหาค่าพารามิเตอร์ที่เหมาะสมที่สุดเพื่อให้คุณได้รับตัวเลือกที่ดีที่สุดระหว่างการใช้งานทรัพยากรกับความเร็ว

11. เปิดใช้เว็บแคม

ได้เวลาที่จะแสดงฟังก์ชัน enableCam() ที่คุณกำหนดไว้ก่อนหน้านี้แล้ว เพิ่มฟังก์ชันใหม่ชื่อ hasGetUserMedia() ดังที่แสดงด้านล่าง แล้วแทนที่เนื้อหาของฟังก์ชัน enableCam() ที่กำหนดไว้ก่อนหน้านี้ด้วยโค้ดที่เกี่ยวข้องด้านล่าง

script.js

function hasGetUserMedia() {
  return !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
}

function enableCam() {
  if (hasGetUserMedia()) {
    // getUsermedia parameters.
    const constraints = {
      video: true,
      width: 640, 
      height: 480 
    };

    // Activate the webcam stream.
    navigator.mediaDevices.getUserMedia(constraints).then(function(stream) {
      VIDEO.srcObject = stream;
      VIDEO.addEventListener('loadeddata', function() {
        videoPlaying = true;
        ENABLE_CAM_BUTTON.classList.add('removed');
      });
    });
  } else {
    console.warn('getUserMedia() is not supported by your browser');
  }
}

ก่อนอื่น ให้สร้างฟังก์ชันชื่อ hasGetUserMedia() เพื่อตรวจสอบว่าเบราว์เซอร์รองรับ getUserMedia() หรือไม่ โดยตรวจสอบการมีอยู่ของพร็อพเพอร์ตี้ API ที่สำคัญของเบราว์เซอร์

ในฟังก์ชัน enableCam() ให้ใช้ฟังก์ชัน hasGetUserMedia() ที่คุณเพิ่งกำหนดไว้ข้างต้นเพื่อตรวจสอบว่าระบบรองรับหรือไม่ หากไม่เป็นเช่นนั้น ให้พิมพ์คำเตือนไปยังคอนโซล

หากระบบรองรับ ให้กำหนดข้อจำกัดบางประการสำหรับการโทร getUserMedia() เช่น คุณต้องการสตรีมวิดีโอเท่านั้น และต้องการให้ width ของวิดีโอมีขนาด 640 พิกเซลและ height เป็น 480 พิกเซล เหตุผล เอาล่ะ ไม่มีประเด็นอะไรมากที่จะทำให้วิดีโอมีขนาดใหญ่กว่านี้ เนื่องจากจะต้องปรับขนาดเป็น 224 x 224 พิกเซลเพื่อใส่ลงไปในโมเดล MobileNet คุณอาจประหยัดทรัพยากรในการประมวลผลได้บางส่วนด้วยการขอความละเอียดที่น้อยลง กล้องส่วนใหญ่รองรับความละเอียดขนาดนี้

จากนั้น ให้โทรไปที่ navigator.mediaDevices.getUserMedia() พร้อมกับแจ้ง constraints ตามรายละเอียดข้างต้น แล้วรอให้ระบบส่งคืน stream เมื่อมีการส่งคืน stream แล้ว คุณสามารถให้องค์ประกอบ VIDEO เล่น stream ได้โดยตั้งค่าองค์ประกอบนั้นเป็นค่า srcObject

คุณควรเพิ่ม eventListener ในองค์ประกอบ VIDEO ด้วยเพื่อให้ทราบเมื่อ stream โหลดและเล่นสำเร็จ

เมื่อ Steam โหลดแล้ว คุณสามารถตั้ง videoPlaying เป็น "จริง" และนำ ENABLE_CAM_BUTTON ออกเพื่อป้องกันไม่ให้เกิดการคลิกซ้ำโดยตั้งค่าคลาสเป็น "removed"

ในตอนนี้ให้เรียกใช้โค้ด คลิกปุ่ม "เปิดใช้กล้อง" และอนุญาตการเข้าถึงเว็บแคม หากนี่เป็นครั้งแรกของคุณ คุณควรจะเห็นว่าตัวเองแสดงผลในองค์ประกอบวิดีโอบนหน้าเว็บดังที่ปรากฏ:

b378eb1affa9b883.png

เอาล่ะ ตอนนี้ก็ถึงเวลาเพิ่มฟังก์ชันเพื่อจัดการการคลิกปุ่ม dataCollector แล้ว

12. เครื่องจัดการเหตุการณ์ปุ่มการเก็บรวบรวมข้อมูล

ได้เวลากรอกข้อมูลลงในฟังก์ชันว่างปัจจุบันที่เรียกว่า gatherDataForClass(). แล้ว นี่คือสิ่งที่คุณกำหนดให้เป็นฟังก์ชันเครื่องจัดการกิจกรรมสำหรับปุ่ม dataCollector ที่จุดเริ่มต้นของ Codelab

script.js

/**
 * Handle Data Gather for button mouseup/mousedown.
 **/
function gatherDataForClass() {
  let classNumber = parseInt(this.getAttribute('data-1hot'));
  gatherDataState = (gatherDataState === STOP_DATA_GATHER) ? classNumber : STOP_DATA_GATHER;
  dataGatherLoop();
}

ก่อนอื่น ให้ตรวจสอบแอตทริบิวต์ data-1hot บนปุ่มที่มีการคลิกอยู่ในปัจจุบัน โดยเรียก this.getAttribute() ด้วยชื่อของแอตทริบิวต์ ซึ่งในกรณีนี้คือ data-1hot เป็นพารามิเตอร์ เนื่องจากนี่คือสตริง คุณจะใช้ parseInt() เพื่อแคสต์เป็นจำนวนเต็มและกำหนดผลลัพธ์นี้ให้กับตัวแปรชื่อ classNumber. ได้

จากนั้นให้ตั้งค่าตัวแปร gatherDataState ตามที่เหมาะสม หาก gatherDataState ปัจจุบันเท่ากับ STOP_DATA_GATHER (ซึ่งคุณตั้งค่าเป็น -1) หมายความว่าคุณไม่ได้รวบรวมข้อมูลใดๆ และเป็นเหตุการณ์ mousedown ที่เริ่มทำงาน ตั้งค่า gatherDataState เป็นclassNumberที่คุณเพิ่งพบ

มิเช่นนั้น หมายความว่าคุณกำลังรวบรวมข้อมูลอยู่และเหตุการณ์ที่เริ่มทำงานคือเหตุการณ์ mouseup และตอนนี้ต้องการหยุดรวบรวมข้อมูลสำหรับชั้นเรียนนั้น เพียงตั้งค่ากลับไปเป็นสถานะ STOP_DATA_GATHER เพื่อสิ้นสุดลูปการรวบรวมข้อมูลที่คุณจะกำหนดเร็วๆ นี้

สุดท้าย เริ่มการโทรถึง dataGatherLoop(), ซึ่งจะเป็นผู้บันทึกข้อมูลของชั้นเรียน

13. การรวบรวมข้อมูล

จากนั้นกำหนดฟังก์ชัน dataGatherLoop() ฟังก์ชันนี้มีหน้าที่สุ่มตัวอย่างรูปภาพจากวิดีโอเว็บแคม ส่งผ่านโมเดล MobileNet และจับภาพเอาต์พุตของโมเดลนั้น (เวกเตอร์ฟีเจอร์ 1024)

จากนั้นจะจัดเก็บไว้พร้อมกับรหัส gatherDataState ของปุ่มที่กดอยู่เพื่อให้คุณทราบว่าข้อมูลนี้แสดงถึงคลาสใด

ลองมาดูกันทีละข้อ

script.js

function dataGatherLoop() {
  if (videoPlaying && gatherDataState !== STOP_DATA_GATHER) {
    let imageFeatures = tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [MOBILE_NET_INPUT_HEIGHT, 
          MOBILE_NET_INPUT_WIDTH], true);
      let normalizedTensorFrame = resizedTensorFrame.div(255);
      return mobilenet.predict(normalizedTensorFrame.expandDims()).squeeze();
    });

    trainingDataInputs.push(imageFeatures);
    trainingDataOutputs.push(gatherDataState);
    
    // Intialize array index element if currently undefined.
    if (examplesCount[gatherDataState] === undefined) {
      examplesCount[gatherDataState] = 0;
    }
    examplesCount[gatherDataState]++;

    STATUS.innerText = '';
    for (let n = 0; n < CLASS_NAMES.length; n++) {
      STATUS.innerText += CLASS_NAMES[n] + ' data count: ' + examplesCount[n] + '. ';
    }
    window.requestAnimationFrame(dataGatherLoop);
  }
}

คุณจะดำเนินการต่อไปของฟังก์ชันนี้ต่อเมื่อ videoPlaying เป็นจริงเท่านั้น ซึ่งหมายความว่าเว็บแคมทำงานอยู่และ gatherDataState ไม่เท่ากับ STOP_DATA_GATHER และปุ่มการรวบรวมข้อมูลของชั้นเรียนกำลังถูกกดอยู่

จากนั้นให้รวมโค้ดไว้ใน tf.tidy() เพื่อกำจัด Tensor ที่สร้างขึ้นในโค้ดที่ตามมา ผลลัพธ์ของการเรียกใช้โค้ด tf.tidy() นี้ได้รับการจัดเก็บไว้ในตัวแปรชื่อ imageFeatures

ตอนนี้คุณสามารถจับเฟรมของเว็บแคม VIDEO โดยใช้ tf.browser.fromPixels() Tensor ผลลัพธ์ที่มีข้อมูลรูปภาพจะได้รับการจัดเก็บไว้ในตัวแปรชื่อ videoFrameAsTensor

จากนั้น ปรับขนาดตัวแปร videoFrameAsTensor ให้มีรูปร่างที่ถูกต้องสำหรับอินพุตของโมเดล MobileNet ใช้การเรียก tf.image.resizeBilinear() กับ Tensor ที่คุณต้องการปรับรูปร่างเป็นพารามิเตอร์แรก จากนั้นเป็นรูปร่างที่กำหนดความสูงและความกว้างใหม่ตามที่กำหนดโดยค่าคงที่ที่คุณสร้างขึ้นก่อนหน้านี้ สุดท้าย ให้ตั้งค่าการจัดแนวมุมเป็น "จริง" โดยการส่งพารามิเตอร์ที่ 3 เพื่อหลีกเลี่ยงปัญหาการปรับแนวเมื่อปรับขนาด ผลลัพธ์ของการปรับขนาดนี้จัดเก็บไว้ในตัวแปรที่ชื่อ resizedTensorFrame

โปรดทราบว่าการปรับขนาดแบบเดิมนี้จะยืดรูปภาพ เนื่องจากภาพเว็บแคมของคุณมีขนาด 640 x 480 พิกเซล และโมเดลต้องการรูปภาพสี่เหลี่ยมจัตุรัสขนาด 224 x 224 พิกเซล

การสาธิตนี้น่าจะดำเนินไปได้ด้วยดี อย่างไรก็ตาม เมื่อเสร็จสิ้น Codelab แล้ว คุณอาจต้องการลองครอบตัดสี่เหลี่ยมจัตุรัสจากรูปภาพนี้แทน เพื่อให้ได้ผลลัพธ์ที่ดียิ่งขึ้นสำหรับระบบการผลิตที่คุณสามารถสร้างในภายหลัง

จากนั้น ปรับข้อมูลภาพให้เป็นมาตรฐาน ข้อมูลรูปภาพจะอยู่ในช่วง 0 ถึง 255 เสมอเมื่อใช้ tf.browser.frompixels() คุณเพียงแบ่งสัดส่วน TensorFrame ด้วย 255 เพื่อให้ค่าทั้งหมดอยู่ระหว่าง 0 ถึง 1 แทน ซึ่งเป็นสิ่งที่โมเดล MobileNet คาดว่าจะเป็นอินพุต

สุดท้าย ในส่วน tf.tidy() ของโค้ด ให้พุช tensor ที่ปรับให้เป็นมาตรฐานนี้ผ่านโมเดลที่โหลดโดยการเรียกใช้ mobilenet.predict() ซึ่งจะส่งเวอร์ชันขยายของ normalizedTensorFrame โดยใช้ expandDims() เพื่อให้เป็นชุด 1 ตามที่โมเดลคาดว่าจะได้รับชุดอินพุตสำหรับการประมวลผล

เมื่อผลลัพธ์กลับมา คุณจะเรียกใช้ squeeze() ได้ทันทีที่แสดงผลลัพธ์เพื่อบีบกลับเป็น Tensor แบบ 1 มิติ ซึ่งคุณจะแสดงผลและกำหนดให้กับตัวแปร imageFeatures ที่รวบรวมผลลัพธ์จาก tf.tidy()

ตอนนี้คุณมี imageFeatures จากโมเดล MobileNet แล้ว คุณสามารถบันทึกโมเดลเหล่านั้นโดยพุชไปยังอาร์เรย์ trainingDataInputs ที่คุณกำหนดไว้ก่อนหน้านี้

นอกจากนี้ คุณยังบันทึกสิ่งที่อินพุตนี้แสดงได้โดยการพุช gatherDataState ปัจจุบันไปยังอาร์เรย์ trainingDataOutputs ได้เช่นกัน

โปรดทราบว่าระบบจะตั้งค่าตัวแปร gatherDataState เป็นรหัสตัวเลขของชั้นเรียนปัจจุบันที่คุณกำลังบันทึกข้อมูลเมื่อมีการคลิกปุ่มในฟังก์ชัน gatherDataForClass() ที่กำหนดไว้ก่อนหน้านี้

ในจุดนี้ คุณยังเพิ่มจำนวนตัวอย่างที่มีสำหรับชั้นเรียนที่ระบุได้อีกด้วย โดยก่อนอื่น ให้ตรวจสอบว่าดัชนีภายในอาร์เรย์ examplesCount เริ่มต้นมาก่อนแล้วหรือไม่ หากไม่ได้กำหนด ให้กำหนดค่าเป็น 0 เพื่อเริ่มต้นตัวนับสำหรับรหัสตัวเลขของชั้นเรียนที่ระบุ จากนั้นคุณสามารถเพิ่ม examplesCount สำหรับ gatherDataState ปัจจุบัน

จากนั้นให้อัปเดตข้อความขององค์ประกอบ STATUS ในหน้าเว็บเพื่อแสดงจำนวนปัจจุบันของแต่ละชั้นเรียนเมื่อระบบเก็บข้อมูลแล้ว โดยให้วนซ้ำอาร์เรย์ CLASS_NAMES แล้วพิมพ์ชื่อที่มนุษย์อ่านได้รวมกับจำนวนข้อมูลที่ดัชนีเดียวกันใน examplesCount

สุดท้าย เรียกใช้ window.requestAnimationFrame() ด้วย dataGatherLoop ที่ส่งผ่านเป็นพารามิเตอร์ เพื่อเรียกใช้ฟังก์ชันนี้ซ้ำแบบวนซ้ำ การดำเนินการนี้จะใช้เฟรมตัวอย่างจากวิดีโอต่อไปจนกว่าจะตรวจพบ mouseup ของปุ่ม และตั้งค่า gatherDataState เป็น STOP_DATA_GATHER, ซึ่งเป็นจุดที่วนซ้ำการรวบรวมข้อมูลจะสิ้นสุดลง

หากคุณเรียกใช้โค้ดตอนนี้ คุณควรสามารถคลิกปุ่มเปิดใช้กล้อง รอให้เว็บแคมโหลด จากนั้นคลิกที่ปุ่มรวบรวมข้อมูลแต่ละปุ่มค้างไว้เพื่อรวบรวมตัวอย่างสำหรับข้อมูลแต่ละคลาส ตรงนี้คุณเห็นผมรวบรวมข้อมูลสำหรับโทรศัพท์มือถือและมือของฉันตามลำดับ

541051644a45131f.gif

คุณควรเห็นข้อความสถานะที่อัปเดตเนื่องจากจัดเก็บ Tensor ทั้งหมดไว้ในหน่วยความจำดังที่แสดงในการจับภาพหน้าจอด้านบน

14. ฝึกและคาดการณ์

ขั้นตอนถัดไปคือการติดตั้งใช้งานโค้ดสำหรับฟังก์ชัน trainAndPredict() ที่ว่างเปล่าในปัจจุบัน ซึ่งเป็นโอกาสในการเรียนรู้การโอน ลองมาดูโค้ดกัน

script.js

async function trainAndPredict() {
  predict = false;
  tf.util.shuffleCombo(trainingDataInputs, trainingDataOutputs);
  let outputsAsTensor = tf.tensor1d(trainingDataOutputs, 'int32');
  let oneHotOutputs = tf.oneHot(outputsAsTensor, CLASS_NAMES.length);
  let inputsAsTensor = tf.stack(trainingDataInputs);
  
  let results = await model.fit(inputsAsTensor, oneHotOutputs, {shuffle: true, batchSize: 5, epochs: 10, 
      callbacks: {onEpochEnd: logProgress} });
  
  outputsAsTensor.dispose();
  oneHotOutputs.dispose();
  inputsAsTensor.dispose();
  predict = true;
  predictLoop();
}

function logProgress(epoch, logs) {
  console.log('Data for epoch ' + epoch, logs);
}

ขั้นตอนแรก โปรดตรวจสอบว่าคุณหยุดไม่ให้การคาดการณ์ปัจจุบันเกิดขึ้นโดยการตั้งค่า predict เป็น false

ถัดไป ให้สับเปลี่ยนอาร์เรย์อินพุตและเอาต์พุตโดยใช้ tf.util.shuffleCombo() เพื่อให้แน่ใจว่าลำดับจะไม่ทำให้เกิดปัญหาในการฝึก

แปลงอาร์เรย์เอาต์พุต trainingDataOutputs, เป็น tensor1d ของประเภท int32 เพื่อให้พร้อมใช้งานในการเข้ารหัสแบบ Hot Ent32 ซึ่งจัดเก็บไว้ในตัวแปรชื่อ outputsAsTensor

ใช้ฟังก์ชัน tf.oneHot() กับตัวแปร outputsAsTensor นี้ร่วมกับจำนวนคลาสสูงสุดที่จะเข้ารหัส ซึ่งมีเพียง CLASS_NAMES.length เท่านั้น ตอนนี้เอาต์พุตที่เข้ารหัสแบบ Hot 1 รายการของคุณจัดเก็บไว้ใน Tensor ใหม่ที่ชื่อ oneHotOutputs แล้ว

โปรดทราบว่าปัจจุบัน trainingDataInputs เป็นอาร์เรย์ของ Tensor ที่บันทึกไว้ ในการใช้สิ่งเหล่านี้ในการฝึก คุณจะต้องแปลงอาร์เรย์ของ Tensor ให้เป็น Tensor แบบ 2 มิติปกติ

ซึ่งการทำเช่นนั้นจะมีฟังก์ชันที่ยอดเยี่ยมภายในไลบรารี TensorFlow.js ที่ชื่อ tf.stack()

ซึ่งจะนำอาร์เรย์ของ tensor มาซ้อนกันเพื่อสร้าง tensor มิติที่สูงขึ้นเป็นเอาต์พุต ในกรณีนี้ ระบบจะแสดงผล tensor 2D ซึ่งเป็นกลุ่มอินพุตมิติข้อมูล 1 รายการที่มีความยาว 1024 แต่ละรายการมีฟีเจอร์ที่บันทึกไว้ ซึ่งเป็นสิ่งที่คุณต้องการสําหรับการฝึก

ต่อไป await model.fit() เพื่อฝึกหัวโมเดลที่กำหนดเอง ในส่วนนี้ คุณจะส่งตัวแปร inputsAsTensor ไปพร้อมกับ oneHotOutputs เพื่อแสดงถึงข้อมูลการฝึกที่จะใช้เป็นตัวอย่างอินพุตและเอาต์พุตเป้าหมายตามลำดับ ในออบเจ็กต์การกำหนดค่าสำหรับพารามิเตอร์ที่ 3 ให้ตั้งค่า shuffle เป็น true ใช้ batchSize จาก 5 โดยตั้งค่า epochs เป็น 10 จากนั้นระบุ callback สำหรับ onEpochEnd เป็นฟังก์ชัน logProgress ที่คุณจะกำหนดในไม่ช้า

ขั้นตอนสุดท้าย คุณสามารถกำจัด Tensor ที่สร้างขึ้นเมื่อโมเดลได้รับการฝึกแล้ว จากนั้นคุณสามารถตั้งค่า predict กลับไปเป็น true เพื่ออนุญาตให้การคาดคะเนเกิดขึ้นอีกครั้ง จากนั้นเรียกใช้ฟังก์ชัน predictLoop() เพื่อเริ่มคาดการณ์รูปภาพเว็บแคมสด

นอกจากนี้คุณยังสามารถกำหนดฟังก์ชัน logProcess() เพื่อบันทึกสถานะของการฝึก ซึ่งจะใช้ใน model.fit() ด้านบนและแสดงผลไปยังคอนโซลหลังรอบการฝึกแต่ละครั้งได้อีกด้วย

เกือบถูกแล้ว ถึงเวลาเพิ่มฟังก์ชัน predictLoop() เพื่อทำการคาดการณ์แล้ว

ลูปการคาดการณ์หลัก

ที่นี่คุณสามารถใช้ลูปการคาดการณ์หลักที่สุ่มตัวอย่างเฟรมจากเว็บแคม และคาดการณ์สิ่งที่อยู่ในแต่ละเฟรมอย่างต่อเนื่องพร้อมผลลัพธ์แบบเรียลไทม์ในเบราว์เซอร์

มาตรวจสอบโค้ดกัน

script.js

function predictLoop() {
  if (predict) {
    tf.tidy(function() {
      let videoFrameAsTensor = tf.browser.fromPixels(VIDEO).div(255);
      let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor,[MOBILE_NET_INPUT_HEIGHT, 
          MOBILE_NET_INPUT_WIDTH], true);

      let imageFeatures = mobilenet.predict(resizedTensorFrame.expandDims());
      let prediction = model.predict(imageFeatures).squeeze();
      let highestIndex = prediction.argMax().arraySync();
      let predictionArray = prediction.arraySync();

      STATUS.innerText = 'Prediction: ' + CLASS_NAMES[highestIndex] + ' with ' + Math.floor(predictionArray[highestIndex] * 100) + '% confidence';
    });

    window.requestAnimationFrame(predictLoop);
  }
}

ก่อนอื่นให้ตรวจสอบว่า predict เป็นจริง เพื่อให้มีการคาดการณ์หลังจากโมเดลได้รับการฝึกและพร้อมใช้งานเท่านั้น

ถัดไป คุณจะรับฟีเจอร์รูปภาพของรูปภาพปัจจุบันได้แบบเดียวกับที่ทำในฟังก์ชัน dataGatherLoop() คุณจะนำเฟรมหนึ่งมาจากเว็บแคมโดยใช้ tf.browser.from pixels() ปรับให้เป็นมาตรฐาน ปรับขนาดเป็นขนาด 224 x 224 พิกเซล จากนั้นส่งข้อมูลดังกล่าวผ่านโมเดล MobileNet เพื่อให้ได้ฟีเจอร์รูปภาพที่เป็นผลลัพธ์

อย่างไรก็ตาม ตอนนี้คุณสามารถใช้หัวโมเดลที่เพิ่งฝึกเพื่อคาดการณ์จริงๆ ได้โดยการส่ง imageFeatures ผลลัพธ์ที่เพิ่งพบผ่านฟังก์ชัน predict() ของโมเดลที่ผ่านการฝึก จากนั้นคุณจะบีบ Tensor ที่ได้เพื่อทำให้เป็น 1 มิติอีกครั้ง แล้วกำหนดให้กับตัวแปรที่ชื่อ prediction ได้

ด้วย prediction นี้ คุณจะหาดัชนีที่มีค่าสูงสุดโดยใช้ argMax() จากนั้นแปลง tensor ที่ได้นี้เป็นอาร์เรย์โดยใช้ arraySync() เพื่อให้ได้ข้อมูลที่สําคัญใน JavaScript เพื่อค้นหาตําแหน่งขององค์ประกอบที่มีมูลค่าสูงสุด ค่านี้จัดเก็บไว้ในตัวแปรที่ชื่อ highestIndex

คุณยังรับคะแนนความเชื่อมั่นของการคาดการณ์จริงได้ในลักษณะเดียวกันโดยเรียกใช้ arraySync() ใน Tensor ของ prediction โดยตรง

ตอนนี้คุณมีทุกอย่างที่จำเป็นในการอัปเดตข้อความ STATUS ด้วยข้อมูล prediction หากต้องการรับสตริงที่มนุษย์อ่านได้สำหรับคลาส เพียงค้นหา highestIndex ในอาร์เรย์ CLASS_NAMES แล้วใช้ค่าความเชื่อมั่นจาก predictionArray หากต้องการอ่านเป็นเปอร์เซ็นต์ได้ง่ายขึ้น เพียงคูณผลลัพธ์ด้วย 100 และ math.floor()

สุดท้ายนี้ คุณสามารถใช้ window.requestAnimationFrame() เพื่อโทรหา predictionLoop() อีกครั้งเมื่อพร้อม เพื่อรับการแยกประเภทแบบเรียลไทม์ในสตรีมวิดีโอ การดําเนินการนี้จะดำเนินต่อไปจนกว่าจะมีการตั้งค่า predict เป็น false หากคุณเลือกฝึกโมเดลใหม่ด้วยข้อมูลใหม่

มาถึงการไขปริศนาแล้ว การใช้ปุ่มรีเซ็ต

15. ใช้ปุ่มรีเซ็ต

เกือบเสร็จแล้ว! ปริศนาชิ้นสุดท้ายคือการใช้ปุ่มรีเซ็ตเพื่อเริ่มต้นใหม่ โค้ดสำหรับฟังก์ชัน reset() ที่ว่างเปล่าในปัจจุบันของคุณคือ ให้อัปเดตโดยทำดังนี้

script.js

/**
 * Purge data and start over. Note this does not dispose of the loaded 
 * MobileNet model and MLP head tensors as you will need to reuse 
 * them to train a new model.
 **/
function reset() {
  predict = false;
  examplesCount.length = 0;
  for (let i = 0; i < trainingDataInputs.length; i++) {
    trainingDataInputs[i].dispose();
  }
  trainingDataInputs.length = 0;
  trainingDataOutputs.length = 0;
  STATUS.innerText = 'No data collected';
  
  console.log('Tensors in memory: ' + tf.memory().numTensors);
}

ก่อนอื่น ให้หยุดการวนซ้ำการคาดการณ์ที่ทำงานอยู่โดยตั้งค่า predict เป็น false ถัดไป ลบเนื้อหาทั้งหมดในอาร์เรย์ examplesCount โดยตั้งค่าความยาวเป็น 0 ซึ่งเป็นวิธีที่สะดวกในการล้างเนื้อหาทั้งหมดจากอาร์เรย์

ตอนนี้ให้ตรวจสอบ trainingDataInputs ที่บันทึกไว้ทั้งหมดในปัจจุบันและตรวจสอบว่าได้ dispose() ของแต่ละ Tensor ที่มีอยู่ภายในเพื่อเพิ่มหน่วยความจำอีกครั้ง เนื่องจาก Tensor ไม่ได้ทำความสะอาดโดยตัวรวบรวมขยะของ JavaScript

เมื่อเสร็จแล้ว คุณสามารถกำหนดความยาวของอาร์เรย์เป็น 0 ได้อย่างปลอดภัยทั้งในอาร์เรย์ trainingDataInputs และ trainingDataOutputs เพื่อล้างอาร์เรย์

สุดท้ายให้ตั้งค่าข้อความ STATUS ให้เป็นข้อความที่เหมาะสม แล้วพิมพ์ Tensor ที่เหลืออยู่ในความทรงจำเพื่อตรวจสอบความถูกต้อง

โปรดทราบว่าจะมี 200-000 Tensor ที่ยังหลงเหลืออยู่ในหน่วยความจำ เนื่องจากทั้งโมเดล MobileNet และ Perceptron แบบหลายเลเยอร์ที่คุณกำหนดไว้ไม่ได้ถูกกำจัดทิ้ง คุณจะต้องนำมาใช้ใหม่กับข้อมูลการฝึกใหม่หากตัดสินใจฝึกอีกครั้งหลังจากการรีเซ็ตนี้

16. มาลองใช้กันเลย

ได้เวลาทดสอบ Teachable Machine ในเวอร์ชันของคุณเองแล้ว

ไปยังการแสดงตัวอย่างแบบสด เปิดใช้เว็บแคม รวบรวมตัวอย่างอย่างน้อย 30 รายการสำหรับคลาส 1 สำหรับวัตถุบางอย่างในห้อง จากนั้นทำแบบเดียวกันนี้กับคลาส 2 สำหรับวัตถุอื่น คลิกฝึก และตรวจสอบบันทึกของคอนโซลเพื่อดูความคืบหน้า ระบบจะฝึกได้อย่างรวดเร็ว:

bf1ac3cc5b15740.gif

เมื่อฝึกแล้ว ให้ยื่นวัตถุไปที่กล้องเพื่อรับการคาดการณ์แบบเรียลไทม์ซึ่งจะพิมพ์ลงในพื้นที่ข้อความสถานะบนหน้าเว็บใกล้กับด้านบน หากพบปัญหา ให้ตรวจสอบโค้ดการทำงานที่เสร็จสมบูรณ์แล้วเพื่อดูว่าคุณพลาดการคัดลอกอะไรไปหรือไม่

17. ขอแสดงความยินดี

ยินดีด้วย คุณเพิ่งเสร็จสิ้นตัวอย่างการเรียนรู้การโอนครั้งแรกที่ใช้ TensorFlow.js แบบสดในเบราว์เซอร์

ลองใช้ ทดสอบกับวัตถุต่างๆ คุณอาจสังเกตเห็นว่าของบางอย่างจดจำได้ยากกว่ารายการอื่นๆ โดยเฉพาะสิ่งที่คล้ายกับสิ่งอื่น คุณอาจต้องเพิ่มคลาสหรือข้อมูลการฝึกอบรมเพื่อให้แยกความแตกต่างได้

สรุป

คุณได้เรียนรู้สิ่งต่อไปนี้ใน Codelab แล้ว

  1. การเรียนรู้แบบการโอนคืออะไร และข้อดีที่มากกว่าการฝึกโมเดลแบบเต็ม
  2. วิธีรับโมเดลเพื่อนำมาใช้ซ้ำจาก TensorFlow Hub
  3. วิธีตั้งค่าเว็บแอปที่เหมาะสำหรับการโอนการเรียนรู้
  4. วิธีโหลดและใช้รูปแบบฐานเพื่อสร้างฟีเจอร์รูปภาพ
  5. วิธีฝึกหัวการคาดการณ์ใหม่ที่สามารถจดจำออบเจ็กต์ที่กำหนดเองจากภาพเว็บแคม
  6. วิธีใช้แบบจำลองที่ได้เพื่อจัดหมวดหมู่ข้อมูลแบบเรียลไทม์

สิ่งที่ต้องทำต่อไป

ตอนนี้คุณมีฐานการทำงานที่จะเริ่มต้นแล้ว ไอเดียสร้างสรรค์ใดที่คุณคิดได้เพื่อต่อยอดต้นแบบของโมเดลแมชชีนเลิร์นนิงนี้สำหรับกรณีการใช้งานในชีวิตจริงที่คุณอาจกำลังทำอยู่ บางทีคุณอาจปฏิวัติวงการที่คุณทำงานอยู่เพื่อช่วยพนักงานในบริษัทฝึกโมเดลให้แยกประเภทสิ่งต่างๆ ที่สำคัญในการทำงานแต่ละวัน ทุกอย่างเป็นไปได้

หากต้องการก้าวต่อไป ให้ลองเข้าร่วมหลักสูตรนี้ฟรี ซึ่งแสดงวิธีรวม 2 โมเดลที่คุณมีอยู่ใน Codelab นี้เป็นโมเดลเดียวเพื่อประสิทธิภาพ

นอกจากนี้ หากคุณสงสัยเพิ่มเติมเกี่ยวกับทฤษฎีเบื้องหลังแอปพลิเคชันเดิมที่ใช้เครื่องที่สอนได้ โปรดดูบทแนะนำนี้

แชร์สิ่งที่คุณทำกับเรา

คุณสามารถต่อยอดสิ่งที่คุณทำในวันนี้ไปยังกรณีการใช้งานที่สร้างสรรค์อื่นๆ ได้ง่ายๆ เช่นกัน และเราขอแนะนำให้คุณคิดนอกกรอบและแฮ็กไปเรื่อยๆ

อย่าลืมแท็กเราบนโซเชียลมีเดียโดยใช้แฮชแท็ก #MadeWithTFJS เพื่อลุ้นโอกาสให้โปรเจ็กต์ของคุณปรากฏบนบล็อก TensorFlow หรือแม้แต่กิจกรรมในอนาคต เราอยากเห็นสิ่งที่คุณสร้าง

เว็บไซต์ที่น่าสนใจ