How to fine tune a LLM using Cloud Run Jobs

How to fine tune a LLM using Cloud Run Jobs

About this codelab

subjectLast updated Jun 3, 2025
account_circleWritten by a Googler

1. Introduction

Overview

In this codelab, you will use Cloud Run jobs to finetune a Gemma model, then serve the result on Cloud Run using vLLM.

For the purposes of this codelab, you will use a text-to-sql dataset, intended to make the LLM reply with a SQL query when asked a question in natural language.

What you'll learn

  • How to conduct fine tuning using Cloud Run Jobs GPU
  • How to serve a model using Cloud Run with vLLM
  • How to use Direct VPC configuration for a GPU Job for faster upload and serving of the model

2. Before you begin

Enable APIs

Before you can start using this codelab, enable the following APIs by running:

gcloud services enable run.googleapis.com \
    compute.googleapis.com \
    run.googleapis.com \
    cloudbuild.googleapis.com \
    secretmanager.googleapis.com \
    artifactregistry.googleapis.com

GPU Quota

Request a quota increase for a supported region. The quota is nvidia_l4_gpu_allocation_no_zonal_redundancy, under Cloud Run Admin API.

Note: If you are using a new project, it may take a few minutes between enabling the API and having the quotas appear in this page.

Hugging Face

This codelab uses a model hosted on Hugging Face. To get this model, request for the Hugging Face user access token with "Read" permission. You will reference this later as YOUR_HF_TOKEN.

You will also need to agree to the usage terms to use the model: https://huggingface.co/google/gemma-2b

3. Setup and Requirements

Set up the following resources:

  • IAM service account and associated IAM permissions,
  • Secret Manager secret to store your Hugging Face token,
  • Cloud Storage bucket to store your fine-tuned model, and
  • Artifact Registry repository to store the image you'll build to fine-tune your model.
  1. Set environment variables for this codelab. We pre-populated a number of variables for you. Specify your project ID, region, and Hugging Face token.
    export PROJECT_ID=<YOUR_PROJECT_ID>
    export REGION=<YOUR_REGION>
    export HF_TOKEN=<YOUR_HF_TOKEN>

    export AR_REPO=codelab-finetuning-jobs
    export IMAGE_NAME=finetune-to-gcs
    export JOB_NAME=finetuning-to-gcs-job
    export BUCKET_NAME=$PROJECT_ID-codelab-finetuning-jobs
    export SECRET_ID=HF_TOKEN
    export SERVICE_ACCOUNT="finetune-job-sa"
    export SERVICE_ACCOUNT_ADDRESS=$SERVICE_ACCOUNT@$PROJECT_ID.iam.gserviceaccount.com
  2. Create the service account by running this command:
    gcloud iam service-accounts create $SERVICE_ACCOUNT \
     
    --display-name="Service account for fine-tuning codelab"
  3. Use Secret Manager to store Hugging Face access token:
    gcloud secrets create $SECRET_ID \
         
    --replication-policy="automatic"

    printf $HF_TOKEN
    | gcloud secrets versions add $SECRET_ID --data-file=-
  4. Grant your service account the role of Secret Manager Secret Accessor:
    gcloud secrets add-iam-policy-binding $SECRET_ID \
     
    --member serviceAccount:$SERVICE_ACCOUNT_ADDRESS \
     
    --role='roles/secretmanager.secretAccessor'
  5. Create a bucket that will host your fine-tuned model:
    gcloud storage buckets create -l $REGION gs://$BUCKET_NAME
  6. Grant your service account access to the bucket:
    gcloud storage buckets add-iam-policy-binding gs://$BUCKET_NAME \
     
    --member=serviceAccount:$SERVICE_ACCOUNT_ADDRESS \
     
    --role=roles/storage.objectAdmin
  7. Create an Artifact Registry repository to store the container image:
    gcloud artifacts repositories create $AR_REPO \
       
    --repository-format=docker \
       
    --location=$REGION \
       
    --description="codelab for finetuning using CR jobs" \
       
    --project=$PROJECT_ID

4. Create the Cloud Run job image

In the next step, you'll create the code that does the following:

  • Imports the Gemma model from Hugging Face
  • Performs fine tuning on the model with the dataset from Hugging Face. The job uses single L4 GPU for fine tuning.
  • Uploads the fine-tuned model called new_model to your Cloud Storage bucket
  1. Create a directory for your fine tuning job code.
    mkdir codelab-finetuning-job
    cd codelab
    -finetuning-job
  2. Create a file called finetune.py
    # Copyright 2024 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #      http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.

    import os
    import torch
    from datasets import load_dataset
    from transformers import (
       
    AutoModelForCausalLM,
       
    AutoTokenizer,
       
    BitsAndBytesConfig,
       
    TrainingArguments,

    )
    from peft import LoraConfig, PeftModel

    from trl import SFTTrainer

    # Cloud Storage bucket to upload the model
    bucket_name = os.getenv("BUCKET_NAME", "YOUR_BUCKET_NAME")

    # The model that you want to train from the Hugging Face hub
    model_name = os.getenv("MODEL_NAME", "google/gemma-2b")

    # The instruction dataset to use
    dataset_name = "b-mc2/sql-create-context"

    # Fine-tuned model name
    new_model = os.getenv("NEW_MODEL", "gemma-2b-sql")

    ################################################################################
    # QLoRA parameters
    ################################################################################

    # LoRA attention dimension
    lora_r = int(os.getenv("LORA_R", "4"))

    # Alpha parameter for LoRA scaling
    lora_alpha = int(os.getenv("LORA_ALPHA", "8"))

    # Dropout probability for LoRA layers
    lora_dropout = 0.1

    ################################################################################
    # bitsandbytes parameters
    ################################################################################

    # Activate 4-bit precision base model loading
    use_4bit = True

    # Compute dtype for 4-bit base models
    bnb_4bit_compute_dtype = "float16"

    # Quantization type (fp4 or nf4)
    bnb_4bit_quant_type = "nf4"

    # Activate nested quantization for 4-bit base models (double quantization)
    use_nested_quant = False

    ################################################################################
    # TrainingArguments parameters
    ################################################################################

    # Output directory where the model predictions and checkpoints will be stored
    output_dir = "./results"

    # Number of training epochs
    num_train_epochs = 1

    # Enable fp16/bf16 training (set bf16 to True with an A100)
    fp16 = True
    bf16 = False

    # Batch size per GPU for training
    per_device_train_batch_size = int(os.getenv("TRAIN_BATCH_SIZE", "1"))

    # Batch size per GPU for evaluation
    per_device_eval_batch_size = int(os.getenv("EVAL_BATCH_SIZE", "2"))

    # Number of update steps to accumulate the gradients for
    gradient_accumulation_steps = int(os.getenv("GRADIENT_ACCUMULATION_STEPS", "1"))

    # Enable gradient checkpointing
    gradient_checkpointing = True

    # Maximum gradient normal (gradient clipping)
    max_grad_norm = 0.3

    # Initial learning rate (AdamW optimizer)
    learning_rate = 2e-4

    # Weight decay to apply to all layers except bias/LayerNorm weights
    weight_decay = 0.001

    # Optimizer to use
    optim = "paged_adamw_32bit"

    # Learning rate schedule
    lr_scheduler_type = "cosine"

    # Number of training steps (overrides num_train_epochs)
    max_steps = -1

    # Ratio of steps for a linear warmup (from 0 to learning rate)
    warmup_ratio = 0.03

    # Group sequences into batches with same length
    # Saves memory and speeds up training considerably
    group_by_length = True

    # Save checkpoint every X updates steps
    save_steps = 0

    # Log every X updates steps
    logging_steps = int(os.getenv("LOGGING_STEPS", "50"))

    ################################################################################
    # SFT parameters
    ################################################################################

    # Maximum sequence length to use
    max_seq_length = int(os.getenv("MAX_SEQ_LENGTH", "512"))

    # Pack multiple short examples in the same input sequence to increase efficiency
    packing = False

    # Load the entire model on the GPU 0
    device_map = {'':torch.cuda.current_device()}

    # Set limit to a positive number
    limit = int(os.getenv("DATASET_LIMIT", "5000"))

    dataset = load_dataset(dataset_name, split="train")
    if limit != -1:
       
    dataset = dataset.shuffle(seed=42).select(range(limit))


    def transform(data):
       
    question = data['question']
       
    context = data['context']
       
    answer = data['answer']
       
    template = "Question: {question}\nContext: {context}\nAnswer: {answer}"
       
    return {'text': template.format(question=question, context=context, answer=answer)}


    transformed = dataset.map(transform)

    # Load tokenizer and model with QLoRA configuration
    compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

    bnb_config = BitsAndBytesConfig(
       
    load_in_4bit=use_4bit,
       
    bnb_4bit_quant_type=bnb_4bit_quant_type,
       
    bnb_4bit_compute_dtype=compute_dtype,
       
    bnb_4bit_use_double_quant=use_nested_quant,
    )

    # Check GPU compatibility with bfloat16
    if compute_dtype == torch.float16 and use_4bit:
       
    major, _ = torch.cuda.get_device_capability()
       
    if major >= 8:
           
    print("=" * 80)
           
    print("Your GPU supports bfloat16")
           
    print("=" * 80)

    # Load base model
    model = AutoModelForCausalLM.from_pretrained(
       
    model_name,
       
    quantization_config=bnb_config,
       
    device_map=device_map,
       
    torch_dtype=torch.float16,
    )
    model.config.use_cache = False
    model.config.pretraining_tp = 1

    # Load LLaMA tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    # Load LoRA configuration
    peft_config = LoraConfig(
       
    lora_alpha=lora_alpha,
       
    lora_dropout=lora_dropout,
       
    r=lora_r,
       
    bias="none",
       
    task_type="CAUSAL_LM",
       
    target_modules=["q_proj", "v_proj"]
    )

    # Set training parameters
    training_arguments = TrainingArguments(
       
    output_dir=output_dir,
       
    num_train_epochs=num_train_epochs,
       
    per_device_train_batch_size=per_device_train_batch_size,
       
    gradient_accumulation_steps=gradient_accumulation_steps,
       
    optim=optim,
       
    save_steps=save_steps,
       
    logging_steps=logging_steps,
       
    learning_rate=learning_rate,
       
    weight_decay=weight_decay,
       
    fp16=fp16,
       
    bf16=bf16,
       
    max_grad_norm=max_grad_norm,
       
    max_steps=max_steps,
       
    warmup_ratio=warmup_ratio,
       
    group_by_length=group_by_length,
       
    lr_scheduler_type=lr_scheduler_type,
    )

    trainer = SFTTrainer(
       
    model=model,
       
    train_dataset=transformed,
       
    peft_config=peft_config,
       
    dataset_text_field="text",
       
    max_seq_length=max_seq_length,
       
    tokenizer=tokenizer,
       
    args=training_arguments,
       
    packing=packing,
    )

    trainer.train()

    trainer.model.save_pretrained(new_model)

    # Reload model in FP16 and merge it with LoRA weights
    base_model = AutoModelForCausalLM.from_pretrained(
       
    model_name,
       
    low_cpu_mem_usage=True,
       
    return_dict=True,
       
    torch_dtype=torch.float16,
       
    device_map=device_map,
    )
    model = PeftModel.from_pretrained(base_model, new_model)
    model = model.merge_and_unload()

    # push to Cloud Storage

    file_path_to_save_the_model = '/finetune/new_model'
    model.save_pretrained(file_path_to_save_the_model)
    tokenizer.save_pretrained(file_path_to_save_the_model)
  3. Create a requirements.txt file:
    accelerate==0.34.2
    bitsandbytes
    ==0.45.5
    datasets
    ==2.19.1
    transformers
    ==4.51.3
    peft
    ==0.11.1
    trl
    ==0.8.6
    torch
    ==2.3.0
  4. Create a Dockerfile:
    FROM nvidia/cuda:12.6.2-runtime-ubuntu22.04

    RUN apt-get update && \
        apt-get -y --no-install-recommends install python3-dev gcc python3-pip git && \
        rm -rf /var/lib/apt/lists/*

    COPY requirements.txt /requirements.txt

    RUN pip3 install -r requirements.txt --no-cache-dir

    COPY finetune.py /finetune.py

    ENV PYTHONUNBUFFERED 1

    CMD python3 /finetune.py --device cuda
  5. Build the container in your Artifact Registry repository:
    gcloud builds submit \
      --tag $REGION-docker.pkg.dev/$PROJECT_ID/$AR_REPO/$IMAGE_NAME \
      --region $REGION

5. Deploy and execute the job

In this step, you'll create the YAML configuration for your job with direct VPC egress for faster uploads to Google Cloud Storage.

Note that this file contains variables that you will update in a subsequent step.

  1. Create a file called finetune-job.yaml.tmpl:
    apiVersion: run.googleapis.com/v1
    kind: Job
    metadata:
      name: $JOB_NAME
      labels:
        cloud.googleapis.com/location: $REGION
      annotations:
        run.googleapis.com/launch-stage: ALPHA
    spec:
      template:
        metadata:
          annotations:
            run.googleapis.com/execution-environment: gen2
            run.googleapis.com/network-interfaces: '[{"network":"default","subnetwork":"default"}]'
        spec:
          parallelism: 1
          taskCount: 1
          template:
            spec:
              serviceAccountName: $SERVICE_ACCOUNT_ADDRESS
              containers:
              - name: $IMAGE_NAME
                image: $REGION-docker.pkg.dev/$PROJECT_ID/$AR_REPO/$IMAGE_NAME
                env:
                - name: MODEL_NAME
                  value: "google/gemma-2b"
                - name: NEW_MODEL
                  value: "gemma-2b-sql-finetuned"
                - name: BUCKET_NAME
                  value: "$BUCKET_NAME"
                - name: LORA_R
                  value: "8"
                - name: LORA_ALPHA
                  value: "16"
                - name: GRADIENT_ACCUMULATION_STEPS
                  value: "2"
                - name: DATASET_LIMIT
                  value: "1000"
                - name: LOGGING_STEPS
                  value: "5"
                - name: HF_TOKEN
                  valueFrom:
                    secretKeyRef:
                      key: 'latest'
                      name: HF_TOKEN
                resources:
                  limits:
                    cpu: 8000m
                    nvidia.com/gpu: '1'
                    memory: 32Gi
                volumeMounts:
                - mountPath: /finetune/new_model
                  name: finetuned_model
              volumes:
              - name: finetuned_model
                csi:
                  driver: gcsfuse.run.googleapis.com
                  readOnly: false
                  volumeAttributes:
                    bucketName: $BUCKET_NAME
              maxRetries: 3
              timeoutSeconds: '3600'
              nodeSelector:
                run.googleapis.com/accelerator: nvidia-l4
  2. Replace the variables in the YAML with your environment variables by running the following command:
    envsubst < finetune-job.yaml.tmpl > finetune-job.yaml
  3. Create the Cloud Run Job:
    gcloud alpha run jobs replace finetune-job.yaml
  4. Execute the job:
    gcloud alpha run jobs execute $JOB_NAME --region $REGION --async

The job will take around 10 minutes to complete. You can check on the status using the link provided in the output of the last command.

6. Use a Cloud Run service to serve your finetuned model with vLLM

In this step, you will deploy a Cloud Run service. This configuration uses direct VPC to access Cloud Storage bucket over private network for faster downloads.

Note that this file contains variables that you will update in a subsequent step.

  1. Create a service.yaml.tmpl file:
    apiVersion: serving.knative.dev/v1
    kind: Service
    metadata:
      name: serve-gemma-sql
      labels:
        cloud.googleapis.com/location: $REGION
      annotations:
        run.googleapis.com/launch-stage: BETA
        run.googleapis.com/ingress: all
        run.googleapis.com/ingress-status: all
    spec:
      template:
        metadata:
          labels:
          annotations:
            autoscaling.knative.dev/maxScale: '1'
            run.googleapis.com/cpu-throttling: 'false'
            run.googleapis.com/gpu-zonal-redundancy-disabled: 'true'
            run.googleapis.com/network-interfaces: '[{"network":"default","subnetwork":"default"}]'
        spec:
          containers:
          - name: serve-finetuned
            image: us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20250505_0916_RC00
            ports:
            - name: http1
              containerPort: 8000
            resources:
              limits:
                cpu: 8000m
                nvidia.com/gpu: '1'
                memory: 32Gi
            volumeMounts:
            - name: fuse
              mountPath: /finetune/new_model
            command: ["python3", "-m", "vllm.entrypoints.api_server"]
            args:
            - --model=/finetune/new_model
            - --tensor-parallel-size=1
            env:
            - name: MODEL_ID
              value: 'new_model'
            - name: HF_HUB_OFFLINE
              value: '1'
          volumes:
          - name: fuse
            csi:
              driver: gcsfuse.run.googleapis.com
              volumeAttributes:
                bucketName: $BUCKET_NAME
          nodeSelector:
            run.googleapis.com/accelerator: nvidia-l4
  2. Update the service.yaml file with your bucket name.
    envsubst < service.yaml.tmpl > service.yaml
  3. Deploy your Cloud Run Service:
    gcloud alpha run services replace service.yaml

7. Test your fine-tuned model

In this step, you will prompt your model to test the fine tuning.

  1. Get the service URL for your Cloud Run service:
    SERVICE_URL=$(gcloud run services describe serve-gemma-sql --platform managed --region $REGION --format 'value(status.url)')
  2. Create your prompt for your model.
    USER_PROMPT="Question: What are the first name and last name of all candidates? Context: CREATE TABLE candidates (candidate_id VARCHAR); CREATE TABLE people (first_name VARCHAR, last_name VARCHAR, person_id VARCHAR)"
  3. Call your service using CURL to prompt your model:
    curl -X POST $SERVICE_URL/generate \
      -H "Content-Type: application/json" \
      -H "Authorization: bearer $(gcloud auth print-identity-token)" \
      -d @- <<EOF
    {
        "prompt": "${USER_PROMPT}"
    }
    EOF

You should see a response similar to the following:

{"predictions":["Prompt:\nQuestion: What are the first name and last name of all candidates? Context: CREATE TABLE candidates (candidate_id VARCHAR); CREATE TABLE people (first_name VARCHAR, last_name VARCHAR, person_id VARCHAR)\nOutput:\n CREATE TABLE people_to_candidates (candidate_id VARCHAR, person_id VARCHAR) CREATE TABLE people_to_people (person_id VARCHAR, person_id VARCHAR) CREATE TABLE people_to_people_to_candidates (person_id VARCHAR, candidate_id"]}

8. Congratulations!

Congratulations for completing the codelab!

We recommend reviewing the Cloud Run documentation.

What we've covered

  • How to conduct fine tuning using Cloud Run Jobs GPU
  • How to serve a model using Cloud Run with vLLM
  • How to use Direct VPC configuration for a GPU Job for faster upload and serving of the model

9. Clean up

To avoid inadvertent charges, for example, if the Cloud Run services are inadvertently invoked more times than your monthly Cloud Run invokement allocation in the free tier, you can delete the Cloud Run service you created in Step 6.

To delete the Cloud Run service, go to the Cloud Run Cloud Console at https://console.cloud.google.com/run and delete the serve-gemma-sql service.

To delete the entire project, go to Manage Resources, select the project you created in Step 2, and choose Delete. If you delete the project, you'll need to change projects in your Cloud SDK. You can view the list of all available projects by running gcloud projects list.