Supervised fine-tuning (SFT) is a process where a pre-trained large language model is fine-tuned on a labeled dataset to adapt the model to perform better on specific tasks.
This tutorial demonstrates step-by-step instructions for setting up the multi-host TPU environment and then training the model on the Hugging Face dataset using SFT. In this tutorial we use a multi-host TPU such as v6e-256.
We use Tunix, a JAX-based library designed for post-training tasks, to perform SFT.
Let's get started!
This section guides you through cloning the MaxText repository, building MaxText Docker image with dependencies, and uploading the docker image to your project's Artifact Registry.
git clone https://github.com/google/maxtext.git
cd maxtextBefore building the Docker image, authenticate to Google Artifact Registry for permission to push your images and other access.
# Authenticate your user account for gcloud CLI access
gcloud auth login
# Configure application default credentials for Docker and other tools
gcloud auth application-default login
# Configure Docker credentials and test your access
gcloud auth configure-docker
docker run hello-worldThen run the following command to create a local Docker image named maxtext_base_image. This build process takes approximately 10 to 15 minutes.
bash src/dependencies/scripts/docker_build_dependency_image.sh WORKFLOW=post-trainingNote: You will need the Artifact Registry Writer role to push Docker images to your project's Artifact Registry and to allow the cluster to pull them during workload execution. If you don't have this permission, contact your project administrator to grant you this role through "Google Cloud Console -> IAM -> Grant access".
export DOCKER_IMAGE_NAME=<Docker Image Name>
bash src/dependencies/scripts/docker_upload_runner.sh CLOUD_IMAGE_NAME=${DOCKER_IMAGE_NAME?}The docker_upload_runner.sh script uploads your Docker image to Artifact Registry.
Install XPK by following the instructions in the official documentation.
Use a pathways ready GKE cluster as described here.
# -- Google Cloud Configuration --
export PROJECT=<Google Cloud Project ID>
export CLUSTER_NAME=<Name of GKE Cluster>
export ZONE=<GKE Cluster Zone>
# -- Workload Configuration --
export WORKLOAD_NAME=<Name of Workload> # e.g., sft-$(date +%s)
export TPU_TYPE=<TPU Type> # e.g., v6e-256
export TPU_SLICE=<number of slices>
export DOCKER_IMAGE="gcr.io/${PROJECT?}/${DOCKER_IMAGE_NAME?}"
# -- MaxText Configuration --
export OUTPUT_PATH=<GCS Path for Output/Logs> # e.g., gs://my-bucket/my-output-directory
export STEPS=<Fine-Tuning Steps> # e.g., 1000
export HF_TOKEN=<Hugging Face Access Token>
# -- Model Configuration --
export MODEL_NAME=<Model Name> # e.g., deepseek3-671b
export TOKENIZER_PATH=<Model Tokenizer> # e.g., deepseek-ai/DeepSeek-V3
# -- Dataset configuration --
export DATASET_NAME=<Hugging Face Dataset Name> # e.g., HuggingFaceH4/ultrachat_200k
export TRAIN_SPLIT=<Data Split for Train> # e.g., train_sft
export TRAIN_DATA_COLUMNS=<Data Columns to Train on> # e.g., ['messages']This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint.
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.
export MODEL_CHECKPOINT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/itemsNote: Make sure that MODEL_CHECKPOINT_PATH has the checkpoints created using the correct storage flags:
- For SFT with McJAX:
checkpoint_storage_use_zarr3=Trueandcheckpoint_storage_use_ocdbt=True. - For SFT with Pathways:
checkpoint_storage_use_zarr3=Falseandcheckpoint_storage_use_ocdbt=False.
Refer the steps in Hugging Face to MaxText to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on.
export MODEL_CHECKPOINT_PATH=<gcs path for MaxText checkpoint> # gs://my-bucket/my-checkpoint-directory/0/itemsThis section provides the command to run SFT on a GKE cluster.
xpk workload create \
--cluster=${CLUSTER_NAME?} \
--project=${PROJECT?} \
--zone=${ZONE?} \
--docker-image=${DOCKER_IMAGE?} \
--workload=${WORKLOAD_NAME?} \
--tpu-type=${TPU_TYPE?} \
--num-slices=${TPU_SLICE?} \
--command "python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} tokenizer_path=${TOKENIZER_PATH?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane hf_path=${DATASET_NAME?} train_split=${TRAIN_SPLIT?} train_data_columns=${TRAIN_DATA_COLUMNS?}"Once the fine-tuning is completed, you can access your model checkpoints at $OUTPUT_PATH/$WORKLOAD_NAME/checkpoints.
xpk workload create-pathways \
--cluster=${CLUSTER_NAME?} \
--project=${PROJECT?} \
--zone=${ZONE?} \
--docker-image=${DOCKER_IMAGE?} \
--workload=${WORKLOAD_NAME?} \
--tpu-type=${TPU_TYPE?} \
--num-slices=${TPU_SLICE?} \
--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} tokenizer_path=${TOKENIZER_PATH?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False enable_single_controller=True"Once the fine-tuning is completed, you can access your model checkpoints at $OUTPUT_PATH/$WORKLOAD_NAME/checkpoints.