|
| 1 | +<!-- |
| 2 | + Copyright 2023–2026 Google LLC |
| 3 | +
|
| 4 | + Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + you may not use this file except in compliance with the License. |
| 6 | + You may obtain a copy of the License at |
| 7 | +
|
| 8 | + https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | + Unless required by applicable law or agreed to in writing, software |
| 11 | + distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + See the License for the specific language governing permissions and |
| 14 | + limitations under the License. |
| 15 | + --> |
| 16 | + |
| 17 | +# LoRA Fine-tuning on single-host TPUs |
| 18 | + |
| 19 | +**Low-Rank Adaptation (LoRA)** is a Parameter-Efficient Fine-Tuning (PEFT) technique designed to optimize large language models while minimizing resource consumption. |
| 20 | + |
| 21 | +Unlike traditional full-parameter fine-tuning, LoRA: |
| 22 | + |
| 23 | +- **Freezes the pre-trained model weights**, preserving the original knowledge. |
| 24 | +- **Injects trainable rank decomposition matrices** into the Transformer layers. |
| 25 | + |
| 26 | +This approach **greatly reduces the number of trainable parameters** required for downstream tasks, making the process faster and more memory-efficient. |
| 27 | + |
| 28 | +This tutorial provides step-by-step instructions for setting up the environment and performing LoRA fine-tuning on a Hugging Face dataset using MaxText. |
| 29 | + |
| 30 | +We use [Tunix](https://github.com/google/tunix), a JAX-based library, to power these post-training tasks. |
| 31 | + |
| 32 | +In this tutorial we use a single host TPU VM such as `v6e-8/v5p-8`. Let's get started! |
| 33 | + |
| 34 | +**Note:** Since **qwix** support has been recently integrated into the **main branch**, you must **clone** the latest source code and install it in **editable mode** to ensure all dependencies are correctly linked. |
| 35 | + |
| 36 | +```sh |
| 37 | +# Install Qwix from source |
| 38 | +git clone https://github.com/google/qwix.git |
| 39 | +cd qwix |
| 40 | +uv pip install -e . |
| 41 | +``` |
| 42 | + |
| 43 | +## Setup environment variables |
| 44 | + |
| 45 | +Set the following environment variables before running LoRA Fine-tuning. |
| 46 | + |
| 47 | +```sh |
| 48 | +# -- Model configuration -- |
| 49 | +export PRE_TRAINED_MODEL=<MAXTEXT_MODEL> # e.g., 'gemma3-4b' |
| 50 | + |
| 51 | +# -- MaxText configuration -- |
| 52 | +export BASE_OUTPUT_DIRECTORY=<BASE_OUTPUT_DIRECTORY> # e.g., gs://my-bucket/my-output-directory or /path/to/my-output-directory |
| 53 | +export RUN_NAME=<RUN_NAME> # e.g., $(date +%Y-%m-%d-%H-%M-%S) |
| 54 | +export STEPS=<STEPS> # e.g., 1000 |
| 55 | +export PER_DEVICE_BATCH_SIZE=<BATCH_SIZE_PER_DEVICE> # e.g., 1 |
| 56 | +export HF_TOKEN=<HF_TOKEN> |
| 57 | +export LORA_RANK=<LORA_RANK> # e.g., 16 |
| 58 | +export LORA_ALPHA=<LORA_ALPHA> # e.g., 32.0 |
| 59 | +export LEARNING_RATE=<LEARNING_RATE> # e.g., 3e-6 |
| 60 | +export MAX_TARGET_LENGTH=<MAX_TARGET_LENGTH> # e.g., 1024 |
| 61 | +export WEIGHT_DTYPE=<WEIGHT_DTYPE> # e.g., bfloat16 |
| 62 | +export DTYPE=<DTYPE> # e.g., bfloat16 |
| 63 | + |
| 64 | +# -- Dataset configuration -- |
| 65 | +export DATASET_NAME=<DATASET_NAME> # e.g., openai/gsm8k |
| 66 | +export TRAIN_SPLIT=<TRAIN_SPLIT> # e.g., train |
| 67 | +export HF_DATA_DIR=<DATASET_PATH> # e.g., main |
| 68 | +export TRAIN_DATA_COLUMNS=<DATA_COLUMNS> # e.g., ['question','answer'] |
| 69 | + |
| 70 | +# -- LoRA Conversion configuration (Optional) -- |
| 71 | +export HF_LORA_ADAPTER_PATH=<HF_LORA_ADAPTER_PATH> # e.g., 'username/adapter-name' |
| 72 | +``` |
| 73 | + |
| 74 | +## Customizing Trainable Layers (Optional) |
| 75 | + |
| 76 | +By default, MaxText determines which layers to apply LoRA to based on the model's architecture by reading `src/maxtext/configs/post_train/lora_module_path.yml`. |
| 77 | + |
| 78 | +If you need to fine-tune specific components (e.g., targeting only Attention layers to optimize memory usage), you can override these defaults through the following hierarchy: |
| 79 | + |
| 80 | +### Configuration Hierarchy |
| 81 | + |
| 82 | +1. **Command Line Argument**: Pass the `lora_module_path` argument directly in your training command. This is the most flexible way for experimental iterations. |
| 83 | +2. **Task-Specific Config (`sft.yml`)**: Define the `lora_module_path` parameter in `src/maxtext/configs/post_train/sft.yml` to set a persistent configuration for your SFT runs. |
| 84 | +3. **Global Defaults**: Automatic detection via the model-to-regex mapping defined in `lora_module_path.yml`. |
| 85 | + |
| 86 | +## Get your model checkpoint |
| 87 | + |
| 88 | +This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint. |
| 89 | + |
| 90 | +### Option 1: Using an existing MaxText checkpoint |
| 91 | + |
| 92 | +If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section. |
| 93 | + |
| 94 | +```sh |
| 95 | +export PRE_TRAINED_MODEL_CKPT_PATH=<PRE_TRAINED_MODEL_CKPT_PATH> # e.g., gs://my-bucket/my-model-checkpoint/0/items or /path/to/my-model-checkpoint/0/items |
| 96 | +``` |
| 97 | + |
| 98 | +### Option 2: Converting a Hugging Face checkpoint |
| 99 | + |
| 100 | +Refer to the steps in [Hugging Face to MaxText](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/checkpointing_solutions/convert_checkpoint.html#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have the correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. |
| 101 | + |
| 102 | +```sh |
| 103 | +export PRE_TRAINED_MODEL_CKPT_PATH=<PRE_TRAINED_MODEL_CKPT_PATH> # e.g., gs://my-bucket/my-model-checkpoint/0/items or /path/to/my-model-checkpoint/0/items |
| 104 | +``` |
| 105 | + |
| 106 | +## Run a Fresh LoRA Fine-Tuning on Hugging Face Dataset |
| 107 | + |
| 108 | +Once your environment variables and checkpoints are ready, you can start the LoRA fine-tuning process. |
| 109 | + |
| 110 | +Execute the following command to begin training: |
| 111 | + |
| 112 | +```sh |
| 113 | +python3 -m maxtext.trainers.post_train.sft.train_sft \ |
| 114 | + run_name="${RUN_NAME?}" \ |
| 115 | + base_output_directory="${BASE_OUTPUT_DIRECTORY?}" \ |
| 116 | + model_name="${PRE_TRAINED_MODEL?}" \ |
| 117 | + load_parameters_path="${PRE_TRAINED_MODEL_CKPT_PATH?}" \ |
| 118 | + hf_access_token="${HF_TOKEN?}" \ |
| 119 | + hf_path="${DATASET_NAME?}" \ |
| 120 | + train_split="${TRAIN_SPLIT?}" \ |
| 121 | + hf_data_dir="${HF_DATA_DIR?}" \ |
| 122 | + train_data_columns="${TRAIN_DATA_COLUMNS?}" \ |
| 123 | + steps="${STEPS?}" \ |
| 124 | + per_device_batch_size="${PER_DEVICE_BATCH_SIZE?}" \ |
| 125 | + max_target_length="${MAX_TARGET_LENGTH?}" \ |
| 126 | + learning_rate="${LEARNING_RATE?}" \ |
| 127 | + weight_dtype="${WEIGHT_DTYPE?}" \ |
| 128 | + dtype="${DTYPE?}" \ |
| 129 | + enable_nnx=True \ |
| 130 | + pure_nnx_decoder=True \ |
| 131 | + enable_lora=True \ |
| 132 | + lora_rank="${LORA_RANK?}" \ |
| 133 | + lora_alpha="${LORA_ALPHA?}" \ |
| 134 | + scan_layers=True |
| 135 | +``` |
| 136 | + |
| 137 | +Your fine-tuned model checkpoints will be saved here: `$BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints`. |
| 138 | + |
| 139 | +## (Optional) Resume from a previous LoRA checkpoint |
| 140 | + |
| 141 | +If you want to resume training from a previous run or further fine-tune an existing LoRA adapter, you can specify the LoRA checkpoint path. |
| 142 | + |
| 143 | +### Step 1: Convert HF LoRA adapter to MaxText format |
| 144 | + |
| 145 | +If your LoRA adapter is currently in Hugging Face format, you must convert it to MaxText format before it can be loaded. Use the integrated conversion utility: |
| 146 | + |
| 147 | +```sh |
| 148 | +python3 -m maxtext.checkpoint_conversion.to_maxtext \ |
| 149 | + maxtext/src/maxtext/configs/base.yml \ |
| 150 | + model_name="${PRE_TRAINED_MODEL?}" \ |
| 151 | + hf_lora_adapter_path="${HF_LORA_ADAPTER_PATH?}" \ |
| 152 | + base_output_directory="${BASE_OUTPUT_DIRECTORY?}/converted_adapter" \ |
| 153 | + hf_access_token="${HF_TOKEN?}" \ |
| 154 | + hardware=cpu skip_jax_distributed_system=True \ |
| 155 | + scan_layers=True |
| 156 | +``` |
| 157 | + |
| 158 | +### Step 2: Set the restore path |
| 159 | + |
| 160 | +Point `LORA_RESTORE_PATH` to the converted MaxText adapter directory (the directory containing the `0/items` or Orbax files). |
| 161 | + |
| 162 | +- **load_parameters_path**: Points to the frozen base model weights (the original model). |
| 163 | +- **lora_restore_path**: Points to the previous LoRA adapter weights you wish to load. |
| 164 | + |
| 165 | +```sh |
| 166 | +export LORA_RESTORE_PATH=<LORA_RESTORE_PATH> # e.g., gs://my-bucket/run-1/checkpoints/0/items or /path/to/run-1/checkpoints/0/items |
| 167 | +``` |
| 168 | + |
| 169 | +### Step 3: Run LoRA Fine-Tuning with the Restore Path |
| 170 | + |
| 171 | +Once your environment variables and checkpoints are ready, you can start the LoRA fine-tuning process. |
| 172 | + |
| 173 | +Execute the following command to begin training: |
| 174 | + |
| 175 | +```sh |
| 176 | +python3 -m maxtext.trainers.post_train.sft.train_sft \ |
| 177 | + run_name="${RUN_NAME?}" \ |
| 178 | + base_output_directory="${BASE_OUTPUT_DIRECTORY?}" \ |
| 179 | + model_name="${PRE_TRAINED_MODEL?}" \ |
| 180 | + load_parameters_path="${PRE_TRAINED_MODEL_CKPT_PATH?}" \ |
| 181 | + lora_restore_path="${LORA_RESTORE_PATH}" \ |
| 182 | + hf_access_token="${HF_TOKEN?}" \ |
| 183 | + hf_path="${DATASET_NAME?}" \ |
| 184 | + train_split="${TRAIN_SPLIT?}" \ |
| 185 | + hf_data_dir="${HF_DATA_DIR?}" \ |
| 186 | + train_data_columns="${TRAIN_DATA_COLUMNS?}" \ |
| 187 | + steps="${STEPS?}" \ |
| 188 | + per_device_batch_size="${PER_DEVICE_BATCH_SIZE?}" \ |
| 189 | + max_target_length="${MAX_TARGET_LENGTH?}" \ |
| 190 | + learning_rate="${LEARNING_RATE?}" \ |
| 191 | + weight_dtype="${WEIGHT_DTYPE?}" \ |
| 192 | + dtype="${DTYPE?}" \ |
| 193 | + enable_nnx=True \ |
| 194 | + pure_nnx_decoder=True \ |
| 195 | + enable_lora=True \ |
| 196 | + lora_rank="${LORA_RANK?}" \ |
| 197 | + lora_alpha="${LORA_ALPHA?}" \ |
| 198 | + scan_layers=True |
| 199 | +``` |
| 200 | + |
| 201 | +Your fine-tuned model checkpoints will be saved here: `$BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints`. |
| 202 | + |
| 203 | +## (Optional) Convert Fine-tuned LoRA to Hugging Face Format |
| 204 | + |
| 205 | +After completing the fine-tuning process, your LoRA weights are stored in MaxText/Orbax format. To use these weights with the Hugging Face ecosystem (e.g., for inference or sharing), convert them back using the `to_huggingface.py` script. |
| 206 | + |
| 207 | +```sh |
| 208 | +python3 -m maxtext.checkpoint_conversion.to_huggingface \ |
| 209 | + maxtext/src/maxtext/configs/base.yml \ |
| 210 | + model_name="${PRE_TRAINED_MODEL?}" \ |
| 211 | + lora.lora_restore_path="${BASE_OUTPUT_DIRECTORY?}/${RUN_NAME?}/checkpoints/<STEPS>/model_params" \ |
| 212 | + base_output_directory="${BASE_OUTPUT_DIRECTORY?}/hf_lora_adapter" \ |
| 213 | + hf_access_token="${HF_TOKEN?}" \ |
| 214 | + lora.lora_rank="${LORA_RANK?}" \ |
| 215 | + lora.lora_alpha="${LORA_ALPHA?}" \ |
| 216 | + scan_layers=True |
| 217 | +``` |
| 218 | + |
| 219 | +- `lora.lora_restore_path`: Point this to the specific checkpoint directory (e.g., `.../checkpoints/1000/items`) that you want to export. |
| 220 | +- `base_output_directory`: The local or GCS directory where the Hugging Face `adapter_model.safetensors` and `adapter_config.json` will be saved. |
| 221 | +- `lora.lora_rank` / `lora.lora_alpha`: Must match the values used during the training phase to ensure the `adapter_config.json` is generated correctly. |
0 commit comments