Low-Rank Adaptation (LoRA) is a Parameter-Efficient Fine-Tuning (PEFT) technique designed to optimize large language models while minimizing resource consumption.
Unlike traditional full-parameter fine-tuning, LoRA:
- Freezes the pre-trained model weights, preserving the original knowledge.
- Injects trainable rank decomposition matrices into the Transformer layers.
This tutorial provides step-by-step instructions for setting up the multi-host TPU environment and performing LoRA fine-tuning on a Hugging Face dataset using MaxText. In this tutorial we use a multi-host TPU such as v6e-256.
We use Tunix, a JAX-based library, to power these post-training tasks.
Let's get started!
Before starting, ensure you have:
- Access to a Google Cloud Project with TPU quotas.
- A Hugging Face account with an access token for downloading models.
- Permissions for Google Artifact Registry (Artifact Registry Writer role).
- Prerequisites for XPK installed (follow official documentation).
- A Pathways-ready GKE cluster (see create GKE cluster).
- Docker installed and configured for sudoless use. Follow the steps to configure sudoless Docker.
For instructions on building and uploading the MaxText Docker image with post-training dependencies, please refer to the official documentation.
Use a pathways ready GKE cluster as described here.
Set up the following environment variables to configure your training run. Replace placeholders with your actual values.
# -- Model configuration --
# The MaxText model name. See `src/maxtext/configs/types.py` for `ModelName` for a
# full list of supported models.
export MODEL=<MODEL_NAME> # e.g., 'gemma4-26b'
# Your Hugging Face access token. Required to download gated models like Gemma.
# You can generate one at https://huggingface.co/settings/tokens.
export HF_TOKEN=<HF_TOKEN>
# -- MaxText configuration --
# Use a GCS bucket you own to store logs and checkpoints. Ideally in the same
# region as your TPUs to minimize latency and costs.
# You can list your buckets and their locations in the
# [Cloud Console](https://console.cloud.google.com/storage/browser) or via
# `gcloud storage buckets list --format="table(name, location)"`.
export BASE_OUTPUT_DIRECTORY=<GCS_BUCKET> # e.g., gs://my-bucket/maxtext-runs
# An arbitrary string to identify this specific run.
# We recommend to include the model, user, and timestamp.
# Note: Kubernetes requires workload names to be valid DNS labels (lowercase, no underscores or periods).
export RUN_NAME=<RUN_NAME>
# -- Workload configuration --
# Your GCP project ID. Find it on the [Cloud Console Dashboard](https://console.cloud.google.com/home/dashboard).
# If you've already set it in your local config, you can retrieve it via:
# gcloud config get-value project
export PROJECT_ID=<PROJECT_ID>
# The GCP location (listed as "Location" in the UI) and name of your
# TPU-enabled GKE cluster. Both can be found on the
# [Cloud Console](https://console.cloud.google.com/kubernetes/list).
export ZONE=<ZONE> # e.g., 'us-central1'
export GKE_CLUSTER=<CLUSTER_NAME>
# For a full list of MaxText-supported TPU types, see: `src/maxtext/utils/accelerator_to_spec_map.py`. To see the TPU type
# of your cluster:
# 1. Connect to the cluster (required for kubectl commands later):
# gcloud container clusters get-credentials ${GKE_CLUSTER?} --location ${ZONE?} --project ${PROJECT_ID?}
# 2. Find your TPU type (e.g., 'v6e-256') by checking the accelerator labels on your nodes:
# kubectl get nodes -l cloud.google.com/gke-tpu-accelerator -o jsonpath='{.items[*].metadata.labels.cloud\.google\.com/gke-tpu-accelerator}' | tr ' ' '\n' | sort -u
export TPU_TYPE=<TPU_TYPE>
export NUM_SLICES=<NUM_SLICES>
# The Docker image you pushed in the prerequisite step
export CLOUD_IMAGE_NAME=<IMAGE_NAME>
export DOCKER_IMAGE="gcr.io/${PROJECT_ID?}/${CLOUD_IMAGE_NAME?}"
# -- Fine-Tuning configuration --
export STEPS=<STEPS> # e.g., 1000
export PER_DEVICE_BATCH_SIZE=<BATCH_SIZE_PER_DEVICE> # e.g., 1
export LORA_RANK=<LORA_RANK> # e.g., 16
export LORA_ALPHA=<LORA_ALPHA> # e.g., 32.0
export LEARNING_RATE=<LEARNING_RATE> # e.g., 3e-6
export MAX_TARGET_LENGTH=<MAX_TARGET_LENGTH> # e.g., 1024
# -- Dataset configuration --
export DATASET_NAME=<DATASET_NAME> # e.g., openai/gsm8k
export TRAIN_SPLIT=<TRAIN_SPLIT> # e.g., train
export HF_DATA_DIR=<DATASET_PATH> # e.g., main
export TRAIN_DATA_COLUMNS=<DATA_COLUMNS> # e.g., ['question','answer']
export CHAT_TEMPLATE_PATH=<TEMPLATE_PATH> # e.g., maxtext/examples/chat_templates/math_qa.json
# -- LoRA Conversion configuration (Optional) --
export HF_LORA_ADAPTER_PATH=<HF_LORA_ADAPTER_PATH> # e.g., 'username/adapter-name'By default, MaxText determines which layers to apply LoRA to based on the model's architecture by reading src/maxtext/configs/post_train/lora_module_path.yml.
If you need to fine-tune specific components (e.g., targeting only Attention layers to optimize memory usage), you can override these defaults through the following hierarchy:
- Command Line Argument: Pass the
lora_module_pathargument directly in your training command. - Task-Specific Config (
sft.yml): Define thelora_module_pathparameter insrc/maxtext/configs/post_train/sft.yml. - Global Defaults: Automatic detection via the model-to-regex mapping defined in
lora_module_path.yml.
This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint.
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.
export MAXTEXT_CKPT_PATH=<CKPT_PATH> # e.g., gs://my-bucket/my-model-checkpoint/0/itemsNote: Make sure that MAXTEXT_CKPT_PATH has the checkpoints created using the correct storage flags:
export USE_PATHWAYS=0 # Set to 1 for Pathways, 0 for McJAX.
checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS))
checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS))Refer to the steps in Hugging Face to MaxText to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on.
export MAXTEXT_CKPT_PATH=<CKPT_PATH> # gs://my-bucket/my-checkpoint-directory/0/itemsThis section provides the command to run LoRA Fine-Tuning on a GKE cluster.
xpk workload create \
--cluster=${GKE_CLUSTER?} \
--project=${PROJECT_ID?} \
--zone=${ZONE?} \
--docker-image=${DOCKER_IMAGE?} \
--workload=${RUN_NAME?} \
--tpu-type=${TPU_TYPE?} \
--num-slices=${NUM_SLICES?} \
--command="python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${RUN_NAME?} base_output_directory=${BASE_OUTPUT_DIRECTORY?} model_name=${MODEL?} load_parameters_path=${MAXTEXT_CKPT_PATH?} hf_access_token=${HF_TOKEN?} hf_path=${DATASET_NAME?} train_split=${TRAIN_SPLIT?} hf_data_dir=${HF_DATA_DIR?} train_data_columns=${TRAIN_DATA_COLUMNS?} steps=${STEPS?} per_device_batch_size=${PER_DEVICE_BATCH_SIZE?} max_target_length=${MAX_TARGET_LENGTH?} learning_rate=${LEARNING_RATE?} chat_template_path=${CHAT_TEMPLATE_PATH?} enable_nnx=True pure_nnx_decoder=True lora.enable_lora=True lora.lora_rank=${LORA_RANK?} lora.lora_alpha=${LORA_ALPHA?}"Once the fine-tuning is completed, you can access your model checkpoints at ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME/checkpoints.
export USE_PATHWAYS=1
xpk workload create-pathways \
--cluster=${GKE_CLUSTER?} \
--project=${PROJECT_ID?} \
--zone=${ZONE?} \
--docker-image=${DOCKER_IMAGE?} \
--workload=${RUN_NAME?} \
--tpu-type=${TPU_TYPE?} \
--num-slices=${NUM_SLICES?} \
--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${RUN_NAME?} base_output_directory=${BASE_OUTPUT_DIRECTORY?} model_name=${MODEL?} load_parameters_path=${MAXTEXT_CKPT_PATH?} hf_access_token=${HF_TOKEN?} hf_path=${DATASET_NAME?} train_split=${TRAIN_SPLIT?} hf_data_dir=${HF_DATA_DIR?} train_data_columns=${TRAIN_DATA_COLUMNS?} steps=${STEPS?} per_device_batch_size=${PER_DEVICE_BATCH_SIZE?} max_target_length=${MAX_TARGET_LENGTH?} learning_rate=${LEARNING_RATE?} chat_template_path=${CHAT_TEMPLATE_PATH?} enable_nnx=True pure_nnx_decoder=True lora.enable_lora=True lora.lora_rank=${LORA_RANK?} lora.lora_alpha=${LORA_ALPHA?} checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) enable_single_controller=True"Once the fine-tuning is completed, you can access your model checkpoints at ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints.
If you want to resume training from a previous run or further fine-tune an existing LoRA adapter, you can specify the LoRA checkpoint path.
If your LoRA adapter is currently in Hugging Face format, you must convert it to MaxText format before it can be loaded. Use the integrated conversion utility:
xpk workload create \
--cluster=${GKE_CLUSTER?} \
--project=${PROJECT_ID?} \
--zone=${ZONE?} \
--docker-image=${DOCKER_IMAGE?} \
--workload=${RUN_NAME?} \
--tpu-type=${TPU_TYPE?} \
--num-slices=${NUM_SLICES?} \
--command="python3 -m maxtext.checkpoint_conversion.to_maxtext model_name=${MODEL?} hf_lora_adapter_path=${HF_LORA_ADAPTER_PATH?} base_output_directory=${BASE_OUTPUT_DIRECTORY?}/converted_adapter hf_access_token=${HF_TOKEN?} hardware=cpu skip_jax_distributed_system=True"Point LORA_RESTORE_PATH to the converted MaxText adapter directory (the directory containing the 0/items or Orbax files).
- load_parameters_path: Points to the frozen base model weights (the original model).
- lora_restore_path: Points to the previous LoRA adapter weights you wish to load.
export LORA_RESTORE_PATH=<LORA_RESTORE_PATH> # e.g., gs://my-bucket/run-1/checkpoints/0/items or /path/to/run-1/checkpoints/0/itemsOnce your environment variables and checkpoints are ready, you can start the LoRA fine-tuning process.
Execute the following command to begin training:
xpk workload create \
--cluster=${GKE_CLUSTER?} \
--project=${PROJECT_ID?} \
--zone=${ZONE?} \
--docker-image=${DOCKER_IMAGE?} \
--workload=${RUN_NAME?} \
--tpu-type=${TPU_TYPE?} \
--num-slices=${NUM_SLICES?} \
--command="python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${RUN_NAME?} base_output_directory=${BASE_OUTPUT_DIRECTORY?} model_name=${MODEL?} load_parameters_path=${MAXTEXT_CKPT_PATH?} hf_access_token=${HF_TOKEN?} hf_path=${DATASET_NAME?} train_split=${TRAIN_SPLIT?} hf_data_dir=${HF_DATA_DIR?} train_data_columns=${TRAIN_DATA_COLUMNS?} steps=${STEPS?} per_device_batch_size=${PER_DEVICE_BATCH_SIZE?} max_target_length=${MAX_TARGET_LENGTH?} lora.lora_restore_path=${LORA_RESTORE_PATH?} learning_rate=${LEARNING_RATE?} chat_template_path=${CHAT_TEMPLATE_PATH?} enable_nnx=True pure_nnx_decoder=True lora.enable_lora=True lora.lora_rank=${LORA_RANK?} lora.lora_alpha=${LORA_ALPHA?}"export USE_PATHWAYS=1
xpk workload create-pathways \
--cluster=${GKE_CLUSTER?} \
--project=${PROJECT_ID?} \
--zone=${ZONE?} \
--docker-image=${DOCKER_IMAGE?} \
--workload=${RUN_NAME?} \
--tpu-type=${TPU_TYPE?} \
--num-slices=${NUM_SLICES?} \
--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${RUN_NAME?} base_output_directory=${BASE_OUTPUT_DIRECTORY?} model_name=${MODEL?} load_parameters_path=${MAXTEXT_CKPT_PATH?} hf_access_token=${HF_TOKEN?} hf_path=${DATASET_NAME?} train_split=${TRAIN_SPLIT?} hf_data_dir=${HF_DATA_DIR?} train_data_columns=${TRAIN_DATA_COLUMNS?} steps=${STEPS?} per_device_batch_size=${PER_DEVICE_BATCH_SIZE?} max_target_length=${MAX_TARGET_LENGTH?} lora.lora_restore_path=${LORA_RESTORE_PATH?} learning_rate=${LEARNING_RATE?} chat_template_path=${CHAT_TEMPLATE_PATH?} enable_nnx=True pure_nnx_decoder=True lora.enable_lora=True lora.lora_rank=${LORA_RANK?} lora.lora_alpha=${LORA_ALPHA?} checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) enable_single_controller=True"Your fine-tuned model checkpoints will be saved here: $BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints.
After completing the fine-tuning process, your LoRA weights are stored in MaxText/Orbax format. To use these weights with the Hugging Face ecosystem (e.g., for inference or sharing), convert them back using the to_huggingface.py script.
xpk workload create \
--cluster=${GKE_CLUSTER?} \
--project=${PROJECT_ID?} \
--zone=${ZONE?} \
--docker-image=${DOCKER_IMAGE?} \
--workload="${RUN_NAME?}-to-hf" \
--tpu-type=${TPU_TYPE?} \
--num-slices=1 \
--command="python3 -m maxtext.checkpoint_conversion.to_huggingface \
model_name=${MODEL?} \
lora.lora_restore_path=${BASE_OUTPUT_DIRECTORY?}/${RUN_NAME?}/checkpoints/<STEPS>/model_params \
base_output_directory=${BASE_OUTPUT_DIRECTORY?}/hf_lora_adapter \
hf_access_token=${HF_TOKEN?}"
lora.lora_restore_path: Point this to the specific checkpoint directory (e.g.,.../checkpoints/1000/items) that you want to export.base_output_directory: The local or GCS directory where the Hugging Faceadapter_model.safetensorsandadapter_config.jsonwill be saved.lora.lora_rank/lora.lora_alpha: Must match the values used during the training phase to ensure theadapter_config.jsonis generated correctly.
When running LoRA fine-tuning in a multi-host environment (e.g., a TPU pod with 64 hosts managing 256 TPUs, such as Pathways or McJAX), special care must be taken when resharding arrays.
In a single-host environment, the host has a global view of all devices, so a standard jax.device_put can easily distribute slices of data to all local TPUs. However, in a multi-host setup:
- Addressability: A host only has a local view of its directly attached devices and cannot push data directly to TPUs managed by other hosts.
- Memory Constraints: If every host tries to load the entire weight matrix into RAM just to extract its local piece, the host CPUs will run out of memory (OOM).
To solve this, MaxText uses jax.make_array_from_callback for a "safe reshard." Instead of pushing data to the devices, this flips the paradigm. It creates a global jax.Array construct where each host locally executes a callback (lambda idx: val[idx]) to load only the specific slice of the data that its attached TPUs need. This completely bypasses cross-host device_put limitations and prevents OOMs since each host only indexes what it requires.