Skip to content

Commit c32eb92

Browse files
Merge pull request #2890 from AI-Hypercomputer:jackyf/docs/rl_multi
PiperOrigin-RevId: 853581001
2 parents 2af3ead + 7a2b4b6 commit c32eb92

1 file changed

Lines changed: 109 additions & 39 deletions

File tree

docs/tutorials/posttraining/rl_on_multi_host.md

Lines changed: 109 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
limitations under the License.
1515
-->
1616

17-
# Reinforcement Learning on multi-host TPUs
17+
# Reinforcement Learning on Multi-Host TPUs
1818

19-
This tutorial demonstrates step-by-step instructions for setting up the environment and then training the Llama3.1 70B-IT model on the GSM8K math reasoning dataset using [Pathways for orchestration](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro) on multi-host TPU-VMs such as `v5p-128`.
19+
This tutorial provides step-by-step instructions for setting up the environment and training the Llama3.1 70B-IT model on the GSM8K math reasoning dataset using [Pathways for orchestration](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro) on multi-host TPU-VMs, such as `v5p-128`.
2020

2121
We utilize two RL algorithms, implemented via the Tunix library, to enhance the model's reasoning capabilities:
2222

@@ -26,16 +26,28 @@ We utilize two RL algorithms, implemented via the Tunix library, to enhance the
2626

2727
For efficient model inference and response generation during this process, we rely on the vLLM library.
2828

29-
Let's get started!
29+
## Table of Contents
3030

31-
## Create virtual environment and Install MaxText dependencies
32-
Follow instructions in [Install MaxText](../../install_maxtext.md), but
33-
recommend creating the virtual environment outside the `maxtext` directory.
31+
- [Prerequisites](#prerequisites)
32+
- [Setup Environment Variables](#setup-environment-variables)
33+
- [Get Your Model Checkpoint](#get-your-model-checkpoint)
34+
- [Build and Upload MaxText Docker Image](#build-and-upload-maxtext-docker-image-with-post-training-dependencies)
35+
- [Submit your RL workload via Pathways](#submit-your-rl-workload-via-pathways)
36+
- [Managing Workloads](#managing-workloads)
37+
- [Troubleshooting](#troubleshooting)
3438

39+
## Prerequisites
3540

36-
## Setup environment variables
41+
Before starting, ensure you have:
42+
- Access to a Google Cloud Project with TPU quotas.
43+
- A Hugging Face account with an access token for downloading models.
44+
- Permissions for Google Artifact Registry (Artifact Registry Writer role).
45+
- XPK installed (follow [official documentation](https://github.com/AI-Hypercomputer/xpk/blob/main/docs/installation.md)).
46+
- A Pathways-ready GKE cluster (see [create GKE cluster](https://docs.cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/create-gke-cluster)).
3747

38-
Setup following environment variables:
48+
## Setup Environment Variables
49+
50+
Set up the following environment variables. Replace placeholders with your actual values.
3951

4052
```bash
4153
# -- Model configuration --
@@ -46,18 +58,17 @@ export HF_TOKEN=<Hugging Face access token>
4658

4759
# -- MaxText configuration --
4860
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
49-
export RUN_NAME=<Name for this run> # e.g., llama-3-70b-grpo
50-
export MAXTEXT_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/0/items
61+
export WORKLOAD=<Name for this run> # e.g., llama-3-70b-grpo
62+
export MAXTEXT_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${WORKLOAD}/0/items
5163

5264
# -- Workload configuration --
53-
export WORKLOAD=${RUN_NAME}
5465
export TPU_TYPE=<TPU Type> # e.g., 'v5p-128'
5566
export TPU_CLUSTER=<cluster name>
5667
export PROJECT_ID=<GCP project ID>
57-
export ZONE=<zone name>
68+
export CLOUD_IMAGE_NAME=<your artifact registry image> # Name for the Docker image to be built
5869
```
5970

60-
## Get your model checkpoint
71+
## Get Your Model Checkpoint
6172

6273
### Option 1: Using an existing MaxText checkpoint
6374

@@ -71,17 +82,17 @@ export MAXTEXT_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucke
7182

7283
You can convert a Hugging Face checkpoint to MaxText format using the `src/MaxText/utils/ckpt_conversion/to_maxtext.py` script. This is useful if you have a pre-trained model from Hugging Face that you want to use with MaxText.
7384

74-
First, ensure you have the necessary dependencies installed. Then, run the conversion script on a CPU machine. For large models, it is recommended to use the `--lazy_load_tensors` flag to reduce memory usage during conversion. \
75-
For example, converting a Llama3.1-70B model scanned checkpoint using `--lazy_load_tensors=true` will use around 200GB of RAM and completes in ~10 mins. This command will download the Hugging Face model and convert it to the MaxText format, saving it to the specified GCS bucket.
85+
First, ensure you have the necessary dependencies installed (PyTorch for the conversion script). Then, run the conversion script on a CPU machine. For large models, use the `--lazy_load_tensors` flag to reduce memory usage during conversion.
86+
87+
For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes. This command will download the Hugging Face model and convert it to the MaxText format, saving it to the specified GCS bucket.
7688

7789
```bash
7890
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
7991

80-
# using --lazy_load_tensors=true here will reduce the memory usage. eg, Llama3.1-70B conversion takes around 86GB of RAM
8192
python3 -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \
8293
model_name=${HF_MODEL} \
8394
hf_access_token=${HF_TOKEN} \
84-
base_output_directory=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME} \
95+
base_output_directory=${BASE_OUTPUT_DIRECTORY}/${WORKLOAD} \
8596
scan_layers=true checkpoint_storage_use_ocdbt=false checkpoint_storage_use_zarr3=false \
8697
skip_jax_distributed_system=true --lazy_load_tensors=true
8798
```
@@ -98,64 +109,123 @@ gcloud auth configure-docker
98109
docker run hello-world
99110
```
100111

101-
You can install the required dependencies using either of the following two options:
102-
103112
### Option 1: Install stable releases of post-training dependencies
104113
> **Caution:** RL in MaxText is currently broken with stable releases of post-training dependencies. We are working on fixing this and recommend following [Option 2: Install from Git repositories of post-training dependencies](#option-2-install-from-git-repositories-of-post-training-dependencies) in the meantime.
105-
106-
Run the following bash script to create a docker image with MaxText dependencies, plus all the post-training dependencies installed. For the post-training dependencies, primarily, it installs `Tunix`, and `vllm-tpu` which is [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) and thereby providing TPU inference for vLLM, with unified JAX and PyTorch support. This build process takes approximately 10 to 15 minutes.
107114
108-
```
115+
Run the following script to create a Docker image with stable releases of MaxText, [Tunix](https://github.com/google/tunix), [vLLM](https://github.com/vllm-project/vllm), and [tpu-inference](https://github.com/vllm-project/tpu-inference) dependencies. This installs `vllm-tpu` which provides TPU inference for vLLM with unified JAX and PyTorch support. The build process takes approximately 10-15 minutes.
116+
117+
```bash
109118
bash dependencies/scripts/docker_build_dependency_image.sh WORKFLOW=post-training
110119
```
111120

112-
You can also use `bash dependencies/scripts/docker_build_dependency_image.sh WORKFLOW=post-training-experimental` to try out new features via experimental dependencies such as improved pathwaysutils resharding API.
121+
For experimental features (such as improved pathwaysutils resharding API), use:
113122

114-
### Option 2: Install from Git repositories of post-training dependencies
115-
You can also locally git clone [tunix](https://github.com/google/tunix), [tpu-inference](https://github.com/vllm-project/tpu-inference), [vllm](https://github.com/vllm-project/vllm) and then use the following command to build a docker image using them:
123+
```bash
124+
bash dependencies/scripts/docker_build_dependency_image.sh WORKFLOW=post-training-experimental
116125
```
126+
127+
### Option 2: Install from Git repositories of post-training dependencies
128+
129+
You can also locally clone the [tunix](https://github.com/google/tunix), [tpu-inference](https://github.com/vllm-project/tpu-inference), and [vllm](https://github.com/vllm-project/vllm.git) repositories and then build the docker image with these local sources.
130+
131+
**Note:** Clone these repositories as siblings of the `maxtext` directory (e.g., in the same parent directory). After cloning, run the build from inside the `maxtext` repository so it picks up the local sources:
132+
133+
```bash
117134
bash dependencies/scripts/docker_build_dependency_image.sh WORKFLOW=post-training POST_TRAINING_SOURCE=local
118135
```
119136

120-
### Upload the dependency docker image along with MaxText code
121-
> **Note:** You will need the [**Artifact Registry Writer**](https://docs.cloud.google.com/artifact-registry/docs/access-control#permissions) role to push Docker images to your project's Artifact Registry and to allow the cluster to pull them during workload execution. If you don't have this permission, contact your project administrator to grant you this role through "Google Cloud Console -> IAM -> Grant access".
122-
```
137+
### Upload the Docker Image
138+
139+
> **Note:** You will need the [**Artifact Registry Writer**](https://docs.cloud.google.com/artifact-registry/docs/access-control#permissions) role to push Docker images to your project's Artifact Registry. Contact your project administrator if you don't have this permission.
140+
141+
```bash
123142
bash dependencies/scripts/docker_upload_runner.sh CLOUD_IMAGE_NAME=${CLOUD_IMAGE_NAME}
124143
```
125144

126145
## Submit your RL workload via Pathways
127146

128-
Please create a pathways ready GKE cluster as described [here](https://docs.cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/create-gke-cluster), and you can submit the `train_rl.py` script via XPK. You can install XPK by following the instructions in the [official documentation](https://github.com/AI-Hypercomputer/xpk/blob/main/docs/installation.md).
147+
See the **Troubleshooting** section for concise instructions on how to retry or resume a failed workload.
148+
149+
Ensure you have a Pathways-ready GKE cluster (as mentioned in Prerequisites) and submit the `train_rl.py` script via XPK.
150+
151+
> **Note:** XPK v0.14.0+ automatically discovers your cluster's location from GCP. You don't need to specify `--zone` in the commands below. If using an older XPK version, add `--zone=<zone>` to the workload commands.
129152
130153
### Submit GRPO workload
131154
```
132155
xpk workload create-pathways --workload $WORKLOAD \
133-
--docker-image <path/to/gcr.io> --cluster $TPU_CLUSTER \
134-
--tpu-type=$TPU_TYPE --num-slices=1 --zone=$ZONE \
156+
--docker-image gcr.io/$PROJECT_ID/$CLOUD_IMAGE_NAME --cluster $TPU_CLUSTER \
157+
--tpu-type=$TPU_TYPE --num-slices=1 \
135158
--project=$PROJECT_ID --priority=high \
136-
--command "TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \
159+
--command "HF_TOKEN=${HF_TOKEN} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \
137160
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
138161
model_name=${MODEL} \
139162
tokenizer_path=${TOKENIZER} \
140163
load_parameters_path=${MAXTEXT_CKPT_PATH} \
141-
run_name=${RUN_NAME} \
164+
run_name=${WORKLOAD} \
142165
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
143-
hf_access_token=$HF_TOKEN"
166+
hf_access_token=${HF_TOKEN}"
144167
```
145168

146169
### Submit GSPO workload
147170
```
148171
xpk workload create-pathways --workload $WORKLOAD \
149-
--docker-image <path/to/gcr.io> --cluster $TPU_CLUSTER \
150-
--tpu-type=$TPU_TYPE --num-slices=1 --zone=$ZONE \
172+
--docker-image gcr.io/$PROJECT_ID/$CLOUD_IMAGE_NAME --cluster $TPU_CLUSTER \
173+
--tpu-type=$TPU_TYPE --num-slices=1 \
151174
--project=$PROJECT_ID --priority=high \
152-
--command "TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \
175+
--command "HF_TOKEN=${HF_TOKEN} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \
153176
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
154177
model_name=${MODEL} \
155178
tokenizer_path=${TOKENIZER} \
156179
load_parameters_path=${MAXTEXT_CKPT_PATH} \
157-
run_name=${RUN_NAME} \
180+
run_name=${WORKLOAD} \
158181
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
159-
hf_access_token=$HF_TOKEN \
182+
hf_access_token=${HF_TOKEN} \
160183
loss_algo=gspo-token"
161184
```
185+
186+
## Managing Workloads
187+
188+
- **Monitor workload status**: Check Pathways job status:
189+
```bash
190+
kubectl get pathwaysjob
191+
```
192+
Check pod status:
193+
```bash
194+
kubectl get pods
195+
```
196+
- **Delete a workload**: To remove a failed or unwanted Pathways job, use XPK:
197+
```bash
198+
xpk workload delete \
199+
--workload $WORKLOAD \
200+
--cluster $TPU_CLUSTER \
201+
--project $PROJECT_ID
202+
```
203+
In case the job still lingers on, you can use `kubectl get pods` to obtain the name of the pod and then run:
204+
```bash
205+
kubectl delete pod <pod-name>
206+
```
207+
208+
## Troubleshooting
209+
210+
- **Authentication Issues**: Ensure your `HF_TOKEN` environment variable is set correctly and has access to the required models.
211+
- **Resource Quotas**: Verify you have sufficient TPU quotas in your GCP project.
212+
- **Docker Build Failures**: Check that all dependencies are correctly installed and authentication is configured.
213+
- **Workload Failures**: Review the logs for specific error messages and ensure all environment variables are properly set.
214+
- **Workload retry / resume**:
215+
- **Retry (fresh run)**: Use a unique workload name to avoid overwriting outputs:
216+
```bash
217+
export WORKLOAD=${WORKLOAD}-retry1
218+
export MAXTEXT_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${WORKLOAD}/0/items
219+
```
220+
Then submit the XPK workload. If "workload already exists" error occurs, pick a new name or list jobs:
221+
```bash
222+
kubectl get pathwaysjob
223+
```
224+
- **Resume from checkpoint**: Keep the same `WORKLOAD` and set the checkpoint path:
225+
```bash
226+
export load_parameters_path=${MAXTEXT_CKPT_PATH}/checkpoint-0000
227+
```
228+
Then submit the workload again.
229+
- **Tip**: Verify the checkpoint exists in GCS with read access before resuming.
230+
231+
For more detailed troubleshooting, refer to the [MaxText documentation](https://maxtext.readthedocs.io) and [XPK documentation](https://github.com/AI-Hypercomputer/xpk).

0 commit comments

Comments
 (0)