Prototype to Production: Distributed training on Vertex AI

1. Overview

In this lab, you'll use Vertex AI to run a distributed training job on Vertex AI Training using TensorFlow.

This lab is part of the Prototype to Production video series. Be sure to complete the previous labs before trying out this one. You can watch the accompanying video series to learn more:

.

What you learn

You'll learn how to:

  • Run distributed training on a single machine with multiple GPUs
  • Run distributed training across multiple machines

The total cost to run this lab on Google Cloud is about $2.

2. Intro to Vertex AI

This lab uses the newest AI product offering available on Google Cloud. Vertex AI integrates the ML offerings across Google Cloud into a seamless development experience. Previously, models trained with AutoML and custom models were accessible via separate services. The new offering combines both into a single API, along with other new products. You can also migrate existing projects to Vertex AI.

Vertex AI includes many different products to support end-to-end ML workflows. This lab will focus on the products highlighted below: Training and Workbench

Vertex product overview

3. Distributed training overview

If you have a single GPU, TensorFlow will use this accelerator to speed up model training with no extra work on your part. However, if you want to get an additional boost from using multiple GPUs, then you'll need to use tf.distribute, which is TensorFlow's module for running a computation across multiple devices.

The first section of this lab uses tf.distribute.MirroredStrategy, which you can add to your training applications with only a few code changes. This strategy creates a copy of the model on each GPU on your machine. The subsequent gradient updates will happen in a synchronous manner. This means that each GPU computes the forward and backward passes through the model on a different slice of the input data. The computed gradients from each of these slices are then aggregated across all of the GPUs and averaged in a process known as all-reduce. Model parameters are updated using these averaged gradients.

The optional section at the end of the lab uses tf.distribute.MultiWorkerMirroredStrategy, which is similar to MirroredStrategy except that it works acorss multiple machines. Each of these machines might also have multiple GPUs. Like, MirroredStrategy, MultiWorkerMirroredStrategy is a synchronous data parallelism strategy that you can use with only a few code changes. The main difference when moving from synchronous data parallelism on one machine to many is that the gradients at the end of each step now need to be synchronized across all GPUs in a machine and across all machines in the cluster.

You don't need to know the details to complete this lab, but if you want to learn more about how distributed training works in TensorFlow, check out the video below:

4. Set up your environment

Complete the steps in the Training custom models with Vertex AI lab to set up your environment.

5. Single machine, multi GPU training

You'll submit your distributed training job to Vertex AI by putting your training application code in a Docker container and pushing this container to Google Artifact Registry. Using this approach, you can train a model built with any framework.

To start, from the Launcher menu of the Workbench notebook that you created in the previous labs, open a terminal window.

Open terminal in notebook

Step 1: Write training code

Create a new directory called flowers-multi-gpu and cd into it:

mkdir flowers-multi-gpu
cd flowers-multi-gpu

Run the following to create a directory for the training code and a Python file where you'll add the code below.

mkdir trainer
touch trainer/task.py

You should now have the following in your flowers-multi-gpu/ directory:

+ trainer/
    + task.py

Next, open the task.py file you just created and copy the code below.

You'll need to replace {your-gcs-bucket} in BUCKET_ROOT with the Cloud Storage bucket where you stored the flowers dataset in Lab 1.

import tensorflow as tf
import numpy as np
import os

## Replace {your-gcs-bucket} !!
BUCKET_ROOT='/gcs/{your-gcs-bucket}'

# Define variables
NUM_CLASSES = 5
EPOCHS=10
BATCH_SIZE = 32

IMG_HEIGHT = 180
IMG_WIDTH = 180

DATA_DIR = f'{BUCKET_ROOT}/flower_photos'

def create_datasets(data_dir, batch_size):
  '''Creates train and validation datasets.'''

  train_dataset = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=batch_size)

  validation_dataset = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=batch_size)

  train_dataset = train_dataset.cache().shuffle(1000).prefetch(buffer_size=tf.data.AUTOTUNE)
  validation_dataset = validation_dataset.cache().prefetch(buffer_size=tf.data.AUTOTUNE)

  return train_dataset, validation_dataset


def create_model():
  '''Creates model.'''

  model = tf.keras.Sequential([
    tf.keras.layers.Resizing(IMG_HEIGHT, IMG_WIDTH),
    tf.keras.layers.Rescaling(1./255, input_shape=(IMG_HEIGHT, IMG_WIDTH, 3)),
    tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
  ])
  return model

def main():  

  # Create distribution strategy
  strategy = tf.distribute.MirroredStrategy()

  # Get data
  GLOBAL_BATCH_SIZE = BATCH_SIZE * strategy.num_replicas_in_sync
  train_dataset, validation_dataset = create_datasets(DATA_DIR, BATCH_SIZE)

  # Wrap model creation and compilation within scope of strategy
  with strategy.scope():
    model = create_model()
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                  metrics=['accuracy'])

  history = model.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=EPOCHS
  )

  model.save(f'{BUCKET_ROOT}/model_output')


if __name__ == "__main__":
    main()

Before you build the container, let's take a deeper look at the code. There are a few components that are specific to using distributed training.

  • In the main() function, the MirroredStrategy object is created. Next, you wrap the creation of your model variables within the strategy's scope. This step tells TensorFlow which variables should be mirrored across the GPUs.
  • The batch size is scaled up by the num_replicas_in_sync. Scaling the batch size is a best practice when using synchronous data parallelism strategies in TensorFlow. You can learn more here.

Step 2: Create a Dockerfile

To containerize your code, you'll need to create a Dockerfile. In the Dockerfile you'll include all the commands needed to run the image. It'll install all the necessary libraries and set up the entry point for the training code.

From your Terminal, create an empty Dockerfile in the root of your flowers directory:

touch Dockerfile

You should now have the following in your flowers-multi-gpu/ directory:

+ Dockerfile
+ trainer/
    + task.py

Open the Dockerfile and copy the following into it:

FROM gcr.io/deeplearning-platform-release/tf2-gpu.2-8

WORKDIR /

# Copies the trainer code to the docker image.
COPY trainer /trainer

# Sets up the entry point to invoke the trainer.
ENTRYPOINT ["python", "-m", "trainer.task"]

Step 3: Build the container

From your Terminal, run the following to define an env variable for your project, making sure to replace your-cloud-project with the ID of your project:

PROJECT_ID='your-cloud-project'

Create a repo in Artifact Registry. We'll use the repo we created in the first lab.

REPO_NAME='flower-app'

Define a variable with the URI of your container image in Artifact Registry:

IMAGE_URI=us-central1-docker.pkg.dev/$PROJECT_ID/$REPO_NAME/flower_image_distributed:single_machine

Configure docker

gcloud auth configure-docker \
    us-central1-docker.pkg.dev

Then, build the container by running the following from the root of your flowers-multi-gpu directory:

docker build ./ -t $IMAGE_URI

Lastly, push it to Artifact Registry:

docker push $IMAGE_URI

With the container pushed to Artifact Registry, you're now ready to kick off a the training job.

Step 4: Run the job with the SDK

In this section, you'll see how to configure and launch the distributed training job by using the Vertex AI Python SDK.

From the Launcher, create a TensorFlow 2 notebook.

new_notebook

Import the Vertex AI SDK.

from google.cloud import aiplatform

Then, define a CustomContainerTrainingJob.

You'll need to replace {PROJECT_ID} in the container_uri, and {YOUR_BUCKET} in staging_bucket.

job = aiplatform.CustomContainerTrainingJob(display_name='flowers-multi-gpu',
                                            container_uri='us-central1-docker.pkg.dev/{PROJECT_ID}/flower-app/flower_image_distributed:single_machine',
                                            staging_bucket='gs://{YOUR_BUCKET}')

Once the job is defined, you can run the job. You'll set the number of accelerators to be 2. If we only used 1 GPU, this would not be considered distributed training. Distributed training on a single machine is when you use 2 or more accelerators.

my_custom_job.run(replica_count=1,
                  machine_type='n1-standard-4',
                  accelerator_type='NVIDIA_TESLA_V100',
                  accelerator_count=2)

In the console, you'll be able to see the progress of your job.

multigpu_job

6. [Optional] Multi-worker training

Now that you've tried distributed training on a single machine with multiple GPUs, you can take your distributed training skills to the next level by training across multiple machines. To keep costs lower, we won't add any GPUs to those machines, but you could experiment by adding GPUs if you'd like to.

Open a new terminal window in your notebook instance:

Open terminal in notebook

Step 1: Write training code

Create a new directory called flowers-multi-machine and cd into it:

mkdir flowers-multi-machine
cd flowers-multi-machine

Run the following to create a directory for the training code and a Python file where you'll add the code below.

mkdir trainer
touch trainer/task.py

You should now have the following in your flowers-multi-machine/ directory:

+ trainer/
    + task.py

Next, open the task.py file you just created and copy the code below.

You'll need to replace {your-gcs-bucket} in BUCKET_ROOT with the Cloud Storage bucket where you stored the flowers dataset in Lab 1.

import tensorflow as tf
import numpy as np
import os

## Replace {your-gcs-bucket} !!
BUCKET_ROOT='/gcs/{your-gcs-bucket}'

# Define variables
NUM_CLASSES = 5
EPOCHS=10
BATCH_SIZE = 32

IMG_HEIGHT = 180
IMG_WIDTH = 180

DATA_DIR = f'{BUCKET_ROOT}/flower_photos'
SAVE_MODEL_DIR = f'{BUCKET_ROOT}/multi-machine-output'

def create_datasets(data_dir, batch_size):
  '''Creates train and validation datasets.'''

  train_dataset = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=batch_size)

  validation_dataset = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=batch_size)

  train_dataset = train_dataset.cache().shuffle(1000).prefetch(buffer_size=tf.data.AUTOTUNE)
  validation_dataset = validation_dataset.cache().prefetch(buffer_size=tf.data.AUTOTUNE)

  return train_dataset, validation_dataset


def create_model():
  '''Creates model.'''

  model = tf.keras.Sequential([
    tf.keras.layers.Resizing(IMG_HEIGHT, IMG_WIDTH),
    tf.keras.layers.Rescaling(1./255, input_shape=(IMG_HEIGHT, IMG_WIDTH, 3)),
    tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
  ])
  return model

def _is_chief(task_type, task_id):
  '''Helper function. Determines if machine is chief.'''

  return task_type == 'chief'


def _get_temp_dir(dirpath, task_id):
  '''Helper function. Gets temporary directory for saving model.'''

  base_dirpath = 'workertemp_' + str(task_id)
  temp_dir = os.path.join(dirpath, base_dirpath)
  tf.io.gfile.makedirs(temp_dir)
  return temp_dir


def write_filepath(filepath, task_type, task_id):
  '''Helper function. Gets filepath to save model.'''

  dirpath = os.path.dirname(filepath)
  base = os.path.basename(filepath)
  if not _is_chief(task_type, task_id):
    dirpath = _get_temp_dir(dirpath, task_id)
  return os.path.join(dirpath, base)

def main():
  # Create distribution strategy
  strategy = tf.distribute.MultiWorkerMirroredStrategy()

  # Get data
  GLOBAL_BATCH_SIZE = BATCH_SIZE * strategy.num_replicas_in_sync
  train_dataset, validation_dataset = create_datasets(DATA_DIR, BATCH_SIZE)

  # Wrap variable creation within strategy scope
  with strategy.scope():
    model = create_model()
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                  metrics=['accuracy'])

  history = model.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=EPOCHS
  )

  # Determine type and task of the machine from
  # the strategy cluster resolver
  task_type, task_id = (strategy.cluster_resolver.task_type,
                        strategy.cluster_resolver.task_id)

  # Based on the type and task, write to the desired model path
  write_model_path = write_filepath(SAVE_MODEL_DIR, task_type, task_id)
  model.save(write_model_path)

if __name__ == "__main__":
    main()

Before you build the container, let's take a deeper look at the code. There are a few components in the code that are necessary for your training application to work with MultiWorkerMirroredStrategy.

  • In the main() function, the MultiWorkerMirroredStrategy object is created. Next, you wrap the creation of your model variables within the strategy's scope. This crucial step tells TensorFlow which variables should be mirrored across the replicas.
  • The batch size is scaled up by the num_replicas_in_sync. Scaling the batch size is a best practice when using synchronous data parallelism strategies in TensorFlow.
  • Saving your model is slightly more complicated in the multi-worker case because the destination needs to be different for each of the workers. The chief worker will save to the desired model directory, while the other workers will save the model to temporary directories. It's important that these temporary directories are unique in order to prevent multiple workers from writing to the same location. Saving can contain collective operations, meaning that all workers must save and not just the chief. The functions _is_chief(), _get_temp_dir(), write_filepath(), as well as the main() function all include boilerplate code that help save the model.

Step 2: Create a Dockerfile

To containerize your code, you'll need to create a Dockerfile. In the Dockerfile you'll include all the commands needed to run the image. It'll install all the necessary libraries and set up the entry point for the training code.

From your Terminal, create an empty Dockerfile in the root of your flowers directory:

touch Dockerfile

You should now have the following in your flowers-multi-machine/ directory:

+ Dockerfile
+ trainer/
    + task.py

Open the Dockerfile and copy the following into it:

FROM gcr.io/deeplearning-platform-release/tf2-gpu.2-8

WORKDIR /

# Copies the trainer code to the docker image.
COPY trainer /trainer

# Sets up the entry point to invoke the trainer.
ENTRYPOINT ["python", "-m", "trainer.task"]

Step 3: Build the container

From your Terminal, run the following to define an env variable for your project, making sure to replace your-cloud-project with the ID of your project:

PROJECT_ID='your-cloud-project'

Create a repo in Artifact Registry. We'll use the repo we created in the first lab.

REPO_NAME='flower-app'

Define a variable with the URI of your container image in Google Artifact Registry:

IMAGE_URI=us-central1-docker.pkg.dev/$PROJECT_ID/$REPO_NAME/flower_image_distributed:multi_machine

Configure docker

gcloud auth configure-docker \
    us-central1-docker.pkg.dev

Then, build the container by running the following from the root of your flowers-multi-machine directory:

docker build ./ -t $IMAGE_URI

Lastly, push it to Artifact Registry:

docker push $IMAGE_URI

With the container pushed to Artifact Registry, you're now ready to kick off a the training job.

Step 4: Run the job with the SDK

In this section, you'll see how to configure and launch the distributed training job by using the Vertex AI Python SDK.

From the Launcher, create a TensorFlow 2 notebook.

new_notebook

Import the Vertex AI SDK.

from google.cloud import aiplatform

Then, define the worker_pool_specs.

Vertex AI provides 4 worker pools to cover the different types of machine tasks.

Worker pool 0 configures the Primary, chief, scheduler, or "master". In MultiWorkerMirroredStrategy, all machines are designated as workers, which are the physical machines on which the replicated computation is executed. In addition to each machine being a worker, there needs to be one worker that takes on some extra work such as saving checkpoints and writing summary files to TensorBoard. This machine is known as the chief. There is only ever one chief worker, so your worker count for Worker pool 0 will always be 1.

Worker pool 1 is where you configure the additional workers for your cluster.

The first dictionary in the worker_pool_specs list represents Worker pool 0, and the second dictionary represents Worker pool 1. In this sample, the two configs are identical. However, if you wanted to train across 3 machines, you would add additional workers to Worker pool 1 by setting the replica_count to 2. If you wanted to add GPUs, you'll need to add the arguments accelerator_type and accelerator_count to the machine_spec for both worker pools. Note that if you want to use GPUs with MultiWorkerMirroredStrategy, each machine in the cluster must have an identical number of GPUs. The job will fail otherwise.

You'll need to replace {PROJECT_ID} in the image_uri.

# The spec of the worker pools including machine type and Docker image
# Be sure to replace PROJECT_ID in the "image_uri" with your project.

worker_pool_specs=[
     {
        "replica_count": 1,
        "machine_spec": {
          "machine_type": "n1-standard-4",
        },
        "container_spec": {"image_uri": "us-central1-docker.pkg.dev/{PROJECT_ID}/flower-app/flower_image_distributed:multi_machine"}
      },
      {
        "replica_count": 1,
        "machine_spec": {
          "machine_type": "n1-standard-4",
        },
        "container_spec": {"image_uri": "us-central1-docker.pkg.dev/{PROJECT_ID}/flower-app/flower_image_distributed:multi_machine"}
      }
          ]

Next create and run a CustomJob, replacing {YOUR_BUCKET} in staging_bucket with a bucket in your project for staging.

my_custom_job = aiplatform.CustomJob(display_name='flowers-multi-worker',
                                     worker_pool_specs=worker_pool_specs,
                                     staging_bucket='gs://{YOUR_BUCKET}')

my_custom_job.run()

In the console, you'll be able to see the progress of your job.

multi_worker_job

🎉 Congratulations! 🎉

You've learned how to use Vertex AI to:

  • Run distributed training jobs with TensorFlow

To learn more about different parts of Vertex, check out the documentation.

7. Cleanup

Because we configured the notebook to time out after 60 idle minutes, we don't need to worry about shutting the instance down. If you would like to manually shut down the instance, click the Stop button on the Vertex AI Workbench section of the console. If you'd like to delete the notebook entirely, click the Delete button.

Stop instance

To delete the Storage Bucket, using the Navigation menu in your Cloud Console, browse to Storage, select your bucket, and click Delete:

Delete storage