This tutorial demonstrates step-by-step instructions for setting up the
environment and then training the Llama3.1 8B-IT model on the GSM8K math
reasoning dataset using a single host TPU-VM such as v6e-8/v5p-8.
We utilize two RL algorithms, implemented via the Tunix library, to enhance the model's reasoning capabilities:
-
Group Relative Policy Optimization (GRPO): GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group's performance to update the policy.
-
Group Sequence Policy Optimization (GSPO): GSPO is an RL algorithm that improves training efficiency and performance of LLMs by using sequence-level importance ratios and operations. GSPO defines the importance ratio based on sequence likelihood and performs sequence-level clipping, rewarding, and optimization.
For efficient model inference and response generation during this process, we rely on the vLLM library.
Let's get started!
# Create a virtual environment
export VENV_NAME=<your virtual env name> # e.g., maxtext_venv
pip install uv
uv venv --python 3.12 --seed ${VENV_NAME?}
source ${VENV_NAME?}/bin/activateRun the following commands to get all the necessary installations.
uv pip install "maxtext[tpu-post-train]>=0.2.0" --resolution=lowest
install_maxtext_tpu_post_train_extra_depsIt installs MaxText and then for post-training, it installs primarily the following:
a. Tunix as the LLM Post-Training Library, and
b. vllm-tpu which is
vllm and
tpu-inference and thereby
providing TPU inference for vLLM, with unified JAX and PyTorch support.
For using a version newer than the latest PyPI release, you could also install the latest vetted versions of the dependencies from MaxText in the following way:
# 1. Clone the repository
git clone https://github.com/AI-Hypercomputer/maxtext.git
cd maxtext
# 2. Install dependencies in editable mode
uv pip install -e .[tpu-post-train] --resolution=lowest
install_maxtext_tpu_post_train_extra_depsSetup following environment variables before running GRPO/GSPO:
# -- Model configuration --
export HF_MODEL=<Hugging Face Model> # e.g. 'llama3.1-8b-Instruct'
export MODEL=<MaxText Model> # e.g. 'llama3.1-8b'
export TOKENIZER=<Tokenizer> # e.g. 'meta-llama/Llama-3.1-8B-Instruct'
export HF_TOKEN=<Hugging Face access token>
# -- MaxText configuration --
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
export RUN_NAME=<name for this run> # e.g., $(date +%Y-%m-%d-%H-%M-%S)
export CHIPS_PER_VM=<the number of chips per VM> # depends on hardware, for v5p this is 4, for v6e this is 8For the value of CHIPS_PER_VM on different TPU hardware, refer the official document
- TPU v5e (single host, chips_per_vm=8)
- TPU v5p (single host, chips_per_vm=4)
- TPU v6e (single host, chips_per_vm=8)
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.
export MAXTEXT_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/itemsRefer the steps in Hugging Face to MaxText to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on.
export MAXTEXT_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/itemsRun the following command for GRPO:
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
model_name=${MODEL?} \
tokenizer_path=${TOKENIZER?} \
load_parameters_path=${MAXTEXT_CKPT_PATH?} \
run_name=${RUN_NAME?} \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
hf_access_token=${HF_TOKEN?} \
chips_per_vm=${CHIPS_PER_VM?}
The overview of what this run will do is as follows:
- We load a policy model and a reference model. Both are copies of the model
checkpoint you specified (e.g.,
Llama3.1-8b-Instruct). - Evaluate the policy model's performance on GSM8K math reasoning benchmark.
- Train the policy model using GRPO.
- Evaluate the policy model's performance on GSM8K math reasoning benchmark after the post-training with GRPO.
Run the following command for GSPO:
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
model_name=${MODEL?} \
tokenizer_path=${TOKENIZER?} \
load_parameters_path=${MAXTEXT_CKPT_PATH?} \
run_name=${RUN_NAME?} \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
hf_access_token=${HF_TOKEN?} \
loss_algo=gspo-token \
chips_per_vm=${CHIPS_PER_VM?}
The overview of what this run will do is as follows:
- We load a policy model and a reference model. Both are copies of the model
checkpoint you specified (e.g.,
Llama3.1-8b-Instruct). - Evaluate the policy model's performance on GSM8K math reasoning benchmark.
- Train the policy model using GSPO.
- Evaluate the policy model's performance on GSM8K math reasoning benchmark after the post-training with GSPO.