Skip to content

Latest commit

 

History

History
259 lines (200 loc) · 10.7 KB

File metadata and controls

259 lines (200 loc) · 10.7 KB

Reinforcement Learning on Multi-Host TPUs

This tutorial provides step-by-step instructions for setting up the environment and training the Llama3.1 70B-IT model on the GSM8K math reasoning dataset using Pathways for orchestration on multi-host TPU-VMs, such as v5p-128.

We utilize two RL algorithms, implemented via the Tunix library, to enhance the model's reasoning capabilities:

  • Group Relative Policy Optimization (GRPO): GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group's performance to update the policy.

  • Group Sequence Policy Optimization (GSPO): GSPO is an RL algorithm that improves training efficiency and performance of LLMs by using sequence-level importance ratios and operations. GSPO defines the importance ratio based on sequence likelihood and performs sequence-level clipping, rewarding, and optimization.

For efficient model inference and response generation during this process, we rely on the vLLM library.

Table of Contents

Prerequisites

Before starting, ensure you have:

  • Access to a Google Cloud Project with TPU quotas.
  • A Hugging Face account with an access token for downloading models.
  • Permissions for Google Artifact Registry (Artifact Registry Writer role).
  • XPK installed (follow official documentation).
  • A Pathways-ready GKE cluster (see create GKE cluster).

Setup Environment Variables

Set up the following environment variables. Replace placeholders with your actual values.

# -- Model configuration --
export HF_MODEL=<Hugging Face Model> # e.g. 'llama3.1-70b-Instruct'
export MODEL=<MaxText Model> # e.g. 'llama3.1-70b'
export TOKENIZER=<Tokenizer> # e.g. 'meta-llama/Llama-3.1-70B-Instruct'
export HF_TOKEN=<Hugging Face access token>

# -- MaxText configuration --
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
export WORKLOAD=<Name for this run> # e.g., llama-3-70b-grpo
export MAXTEXT_CKPT_PATH=${BASE_OUTPUT_DIRECTORY?}/${WORKLOAD?}/0/items

# -- Workload configuration --
export TPU_TYPE=<TPU Type> # e.g., 'v5p-128'
export TPU_CLUSTER=<cluster name>
export PROJECT_ID=<GCP project ID>
export CLOUD_IMAGE_NAME=<your artifact registry image> # Name for the Docker image to be built

Get Your Model Checkpoint

Option 1: Using an existing MaxText checkpoint

If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.

export MAXTEXT_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items

Option 2: Converting from a Hugging Face checkpoint

Refer the steps in Hugging Face to MaxText to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on.

export MAXTEXT_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items

Build and upload MaxText Docker image with post-training dependencies

Before building the Docker image, authenticate to Google Artifact Registry for permission to push your images and other access.

# Authenticate your user account for gcloud CLI access
gcloud auth login

# Configure application default credentials for Docker and other tools
gcloud auth application-default login

# Configure Docker credentials and test your access
gcloud auth configure-docker
docker run hello-world

Option 1: From PyPI releases (Recommended)

Get the latest stable release of MaxText from PyPI. This will automatically pull compatible versions of post-training dependencies, such as Tunix, vLLM, and tpu-inference.

git clone https://github.com/AI-Hypercomputer/maxtext.git
cd maxtext

# checkout the latest stable release here: https://pypi.org/project/maxtext/
export MAXTEXT_VERSION=0.2.0
git checkout maxtext-v${MAXTEXT_VERSION?}

Run the following script to create a Docker image with stable releases of MaxText, and its post-training dependencies. The build process takes approximately 10-15 minutes.

bash src/dependencies/scripts/docker_build_dependency_image.sh WORKFLOW=post-training

For experimental features (such as improved pathwaysutils resharding API), use:

bash src/dependencies/scripts/docker_build_dependency_image.sh WORKFLOW=post-training-experimental

Option 2: From Github

For using a version newer than the latest PyPI release, you could also build the Docker image with the latest vetted versions of post-training dependencies and MaxText in the following way:

git clone https://github.com/AI-Hypercomputer/maxtext.git
cd maxtext

bash src/dependencies/scripts/docker_build_dependency_image.sh WORKFLOW=post-training

Upload the Docker Image

Note: You will need the Artifact Registry Writer role to push Docker images to your project's Artifact Registry. Contact your project administrator if you don't have this permission.

bash src/dependencies/scripts/docker_upload_runner.sh CLOUD_IMAGE_NAME=${CLOUD_IMAGE_NAME?}

Submit your RL workload via Pathways

See the Troubleshooting section for concise instructions on how to retry or resume a failed workload.

Ensure you have a Pathways-ready GKE cluster (as mentioned in Prerequisites) and submit the train_rl.py script via XPK.

Note: XPK v0.14.0+ automatically discovers your cluster's location from GCP. You don't need to specify --zone in the commands below. If using an older XPK version, add --zone=<zone> to the workload commands.

Submit GRPO workload

xpk workload create-pathways --workload ${WORKLOAD?} \
--docker-image gcr.io/${PROJECT_ID?}/${CLOUD_IMAGE_NAME?} --cluster ${TPU_CLUSTER?} \
--tpu-type=${TPU_TYPE?} --num-slices=1 \
--project=${PROJECT_ID?} --priority=high \
--command "HF_TOKEN=${HF_TOKEN?} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
  model_name=${MODEL?} \
  tokenizer_path=${TOKENIZER?} \
  load_parameters_path=${MAXTEXT_CKPT_PATH?} \
  run_name=${WORKLOAD?} \
  base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
  hf_access_token=${HF_TOKEN?}"

Submit GSPO workload

xpk workload create-pathways --workload ${WORKLOAD?} \
--docker-image gcr.io/${PROJECT_ID?}/${CLOUD_IMAGE_NAME?} --cluster ${TPU_CLUSTER?} \
--tpu-type=${TPU_TYPE?} --num-slices=1 \
--project=${PROJECT_ID?} --priority=high \
--command "HF_TOKEN=${HF_TOKEN?} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
  model_name=${MODEL?} \
  tokenizer_path=${TOKENIZER?} \
  load_parameters_path=${MAXTEXT_CKPT_PATH?} \
  run_name=${WORKLOAD?} \
  base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
  hf_access_token=${HF_TOKEN?} \
  loss_algo=gspo-token"

Managing Workloads

  • Monitor workload status: Check Pathways job status: kubectl get pathwaysjob. Check pod status: kubectl get pods.
  • Delete a workload: To remove a failed or unwanted Pathways job, use XPK:
    xpk workload delete \
        --workload ${WORKLOAD?} \
        --cluster ${TPU_CLUSTER?} \
        --project ${PROJECT_ID?}
    In case the job still lingers on, you can use kubectl get pods to obtain the name of the pod and then run: kubectl delete pod <pod-name>.

Troubleshooting

  • Authentication Issues: Ensure your HF_TOKEN environment variable is set correctly and has access to the required models.
  • Resource Quotas: Verify you have sufficient TPU quotas in your GCP project.
  • Docker Build Failures: Check that all dependencies are correctly installed and authentication is configured.
  • Workload Failures: Review the logs for specific error messages and ensure all environment variables are properly set.
  • Workload retry / resume:
    • Retry (fresh run): Use a unique workload name to avoid overwriting outputs: export WORKLOAD=${WORKLOAD}-retry1 export MAXTEXT_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${WORKLOAD}/0/items. Then submit the XPK workload. If "workload already exists" error occurs, pick a new name or list jobs: kubectl get pathwaysjob.
    • Resume from checkpoint: Keep the same WORKLOAD and set the checkpoint path: export load_parameters_path=${MAXTEXT_CKPT_PATH}/checkpoint-0000. Then submit the workload again.
    • Tip: Verify the checkpoint exists in GCS with read access before resuming.

For more detailed troubleshooting, refer to the MaxText documentation and XPK documentation.