TensorFlow.js 전이 학습 이미지 분류기

이 Codelab에서는 강력하면서도 유연한 자바스크립트 머신러닝 라이브러리인 TensorFlow.js를 사용해 브라우저에서 바로 학습시킬 간단한 커스텀 이미지 분류기인 'Teachable Machine'을 빌드하는 방법을 알아봅니다. 먼저 브라우저에서 이미지 분류에 널리 사용되는 선행 학습된 모델(MobileNet)을 로드하고 실행합니다. 그런 다음 '전이 학습'이라는 기술을 사용합니다. 이 기술은 선행 학습된 MobileNet 모델로 학습을 부트스트랩하고 맞춤설정하여 애플리케이션에 맞게 학습할 수 있도록 합니다.

이 Codelab에서는 Teachable Machine 애플리케이션 배경 이론에 대해서는 다루지 않습니다. 관련 내용은 이 튜토리얼을 참조하세요.

학습 내용

  • 선행 학습된 MobileNet 모델을 로드하고 새 데이터에 대한 예측을 수행하는 방법
  • 웹캠을 통해 예측을 수행하는 방법
  • MobileNet의 중간 활성화를 사용하여 웹캠으로 바로 정의한 새 클래스 집합에서 전이 학습을 수행하는 방법

그럼 시작해 보겠습니다.

이 Codelab을 완료하려면 다음이 필요합니다.

  1. 최신 버전의 Chrome 또는 다른 최신 브라우저
  2. 머신에서 로컬로 실행되거나 Codepen 또는 Glitch 등을 통해 웹에서 실행되는 텍스트 편집기
  3. HTML, CSS, 자바스크립트, Chrome DevTools(또는 선호하는 브라우저의 개발 도구)에 대한 지식
  4. 신경망에 대한 대략적인 개념 이해. 소개 또는 복습이 필요하다면 3blue1brown의 동영상 또는 아시 크리슈난의 자바스크립트 딥 러닝 동영상을 확인하세요.

편집기에서 index.html을 열고 다음 내용을 추가합니다.

<html>
  <head>
    <!-- Load the latest version of TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script>
  </head>
  <body>
    <div id="console"></div>
    <!-- Add an image that we will use to test -->
    <img id="img" crossorigin src="https://i.imgur.com/JlUvsxa.jpg" width="227" height="227"/>
    <!-- Load index.js after the content of the page -->
    <script src="index.js"></script>
  </body>
</html>

다음으로 코드 편집기에서 index.js 파일을 열거나 만든 후 다음 코드를 추가합니다.

let net;

async function app() {
  console.log('Loading mobilenet..');

  // Load the model.
  net = await mobilenet.load();
  console.log('Successfully loaded model');

  // Make a prediction through the model on our image.
  const imgEl = document.getElementById('img');
  const result = await net.classify(imgEl);
  console.log(result);
}

app();

웹페이지를 실행하려면 웹브라우저에서 index.html을 열기만 하면 됩니다. Cloud Console을 사용하는 경우에는 미리보기 페이지를 새로고침해 보세요.

개 사진이 표시되고 개발자 도구의 자바스크립트 콘솔에는 MobileNet의 상위 예측이 표시됩니다. 모델을 다운로드하는 데는 다소 시간이 걸릴 수 있으니 잠시 기다려 주세요.

이미지가 올바르게 분류되었나요?

휴대전화에서도 작동한다는 점도 알아두시기 바랍니다.

이제 실시간으로 보다 상호작용이 가능하도록 만들어 보겠습니다. 웹캠을 통해 들어오는 이미지에 대한 예측을 하도록 웹캠을 설정해 보겠습니다.

먼저 웹캠 동영상 요소를 설정합니다. index.html 파일을 열고 <body> 섹션 안에 다음 줄을 추가한 다음 개 이미지를 로드하는 <img> 태그를 삭제합니다.

<video autoplay playsinline muted id="webcam" width="224" height="224"></video>

index.js 파일을 열고 파일의 맨 위에 webcamElement를 추가합니다.

const webcamElement = document.getElementById('webcam');

이제 이전에 추가한 app() 함수에서 이미지를 통한 예측을 삭제하고 대신 웹캠 요소를 통해 예측을 수행하는 무한 루프를 만들 수 있습니다.

async function app() {
  console.log('Loading mobilenet..');

  // Load the model.
  net = await mobilenet.load();
  console.log('Successfully loaded model');

  // Create an object from Tensorflow.js data API which could capture image
  // from the web camera as Tensor.
  const webcam = await tf.data.webcam(webcamElement);
  while (true) {
    const img = await webcam.capture();
    const result = await net.classify(img);

    document.getElementById('console').innerText = `
      prediction: ${result[0].className}\n
      probability: ${result[0].probability}
    `;
    // Dispose the tensor to release the memory.
    img.dispose();

    // Give some breathing room by waiting for the next animation frame to
    // fire.
    await tf.nextFrame();
  }
}

웹페이지에서 콘솔을 열면 웹캠에서 수집된 모든 프레임에 대한 확률과 함께 MobileNet 예측이 표시됩니다.

ImageNet 데이터 세트는 일반적으로 웹캠에 표시되는 이미지처럼 보이지 않을 수 있으므로 무의미한 결과로 보일 수 있습니다. 노트북 카메라 앞에서 개 사진이 있는 휴대전화를 들고 테스트해 보시기 바랍니다.

이제 보다 유용하게 만들어 볼까요? 웹캠에서 바로 커스텀 3클래스 객체 분류기를 만들어 보겠습니다. MobileNet을 통해 분류는 하겠지만 이번에는 특정 웹캠 이미지에 대해 모델의 내부 표현(활성화)을 이용해 분류합니다.

'K-Nearest Neighbors Classifier'라는 모듈을 사용해 웹캠 이미지(실제로는 MobileNet 활성화)를 다양한 카테고리(또는 '클래스')에 효과적으로 배치할 수 있으며, 사용자가 예측을 요청하는 경우 예측을 수행하는 대상과 가장 유사한 활성화를 가진 클래스를 간단하게 선택합니다.

index.html에서 <head> 태그의 가져오기 마지막에 KNN Classifier 가져오기를 추가합니다(MobileNet이 계속 필요하므로 가져오기를 삭제하지 마세요).

...
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script>
...

동영상 요소 아래 index.html의 각 버튼에 3개의 버튼을 추가합니다. 이 3개 버튼은 모델에 학습 이미지를 추가하는 데 사용됩니다.

...
<button id="class-a">Add A</button>
<button id="class-b">Add B</button>
<button id="class-c">Add C</button>
...

index.js 상단에서 분류기를 만듭니다.

const classifier = knnClassifier.create();

다음과 같이 앱 함수를 업데이트합니다.

async function app() {
  console.log('Loading mobilenet..');

  // Load the model.
  net = await mobilenet.load();
  console.log('Successfully loaded model');

  // Create an object from Tensorflow.js data API which could capture image
  // from the web camera as Tensor.
  const webcam = await tf.data.webcam(webcamElement);

  // Reads an image from the webcam and associates it with a specific class
  // index.
  const addExample = async classId => {
    // Capture an image from the web camera.
    const img = await webcam.capture();

    // Get the intermediate activation of MobileNet 'conv_preds' and pass that
    // to the KNN classifier.
    const activation = net.infer(img, true);

    // Pass the intermediate activation to the classifier.
    classifier.addExample(activation, classId);

    // Dispose the tensor to release the memory.
    img.dispose();
  };

  // When clicking a button, add an example for that class.
  document.getElementById('class-a').addEventListener('click', () => addExample(0));
  document.getElementById('class-b').addEventListener('click', () => addExample(1));
  document.getElementById('class-c').addEventListener('click', () => addExample(2));

  while (true) {
    if (classifier.getNumClasses() > 0) {
      const img = await webcam.capture();

      // Get the activation from mobilenet from the webcam.
      const activation = net.infer(img, 'conv_preds');
      // Get the most likely class and confidence from the classifier module.
      const result = await classifier.predictClass(activation);

      const classes = ['A', 'B', 'C'];
      document.getElementById('console').innerText = `
        prediction: ${classes[result.label]}\n
        probability: ${result.confidences[result.label]}
      `;

      // Dispose the tensor to release the memory.
      img.dispose();
    }

    await tf.nextFrame();
  }
}

이제 index.html 페이지를 로드할 때 일반 사물 또는 얼굴/신체 동작을 사용하여 세 가지 클래스 각각에 대해 이미지를 캡처할 수 있습니다. '추가' 버튼 중 하나를 클릭할 때마다 이미지 하나가 학습 예시로 추가됩니다. 이 작업을 수행하는 동안 모델은 들어오는 웹캠 이미지를 계속해서 예측하고 그 결과를 실시간으로 표시합니다.

이제 동작이 없는 다른 클래스를 추가해 보세요.

이 Codelab에서는 TensorFlow.js를 사용해 간단한 머신러닝 웹 애플리케이션을 구현했습니다. 웹캠에서 이미지 분류를 위해 선행 학습된 MobileNet 모델을 로드하고 사용했으며, 다음에는 3가지 커스텀 카테고리로 이미지를 분류하도록 모델을 맞춤설정했습니다.

더 많은 예시 및 코드가 포함된 데모가 제공되는 js.tensorflow.org를 방문해 애플리케이션에서 TensorFlow.js를 사용하는 방법에 대해 알아보세요.