Skip to content

Commit f5736a1

Browse files
committed
feat: support LoRA training and conversion using NNX and Qwix
1 parent 5f72ab6 commit f5736a1

12 files changed

Lines changed: 1305 additions & 55 deletions

File tree

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
<!--
2+
Copyright 2023–2026 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
-->
16+
17+
# LoRA Fine-tuning on single-host TPUs
18+
19+
**Low-Rank Adaptation (LoRA)** is a Parameter-Efficient Fine-Tuning (PEFT) technique designed to optimize large language models while minimizing resource consumption.
20+
21+
Unlike traditional full-parameter fine-tuning, LoRA:
22+
23+
- **Freezes the pre-trained model weights**, preserving the original knowledge.
24+
- **Injects trainable rank decomposition matrices** into the Transformer layers.
25+
26+
This approach **greatly reduces the number of trainable parameters** required for downstream tasks, making the process faster and more memory-efficient.
27+
28+
This tutorial provides step-by-step instructions for setting up the environment and performing LoRA fine-tuning on a Hugging Face dataset using MaxText.
29+
30+
We use [Tunix](https://github.com/google/tunix), a JAX-based library, to power these post-training tasks.
31+
32+
In this tutorial we use a single host TPU VM such as `v6e-8/v5p-8`. Let's get started!
33+
34+
**Note:** Since **qwix** support has been recently integrated into the **main branch**, you must **clone** the latest source code and install it in **editable mode** to ensure all dependencies are correctly linked.
35+
36+
```sh
37+
# Install Qwix from source
38+
git clone https://github.com/google/qwix.git
39+
cd qwix
40+
uv pip install -e .
41+
```
42+
43+
## Setup environment variables
44+
45+
Set the following environment variables before running LoRA Fine-tuning.
46+
47+
```sh
48+
# -- Model configuration --
49+
export PRE_TRAINED_MODEL=<MaxText Model> # e.g., 'gemma3-4b'
50+
51+
# -- MaxText configuration --
52+
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
53+
export RUN_NAME=<name for this run> # e.g., $(date +%Y-%m-%d-%H-%M-%S)
54+
export STEPS=<number of fine-tuning steps to run> # e.g., 1000
55+
export PER_DEVICE_BATCH_SIZE=<batch size per device> # e.g., 1
56+
export HF_TOKEN=<Hugging Face Access Token>
57+
export LORA_RANK=<dimension of the low-rank update matrices> # e.g., 16
58+
export LORA_ALPHA=<scaling factor for LoRA weights> # e.g., 32.0
59+
export LEARNING_RATE=<step size for the optimizer> # e.g., 3e-6
60+
export MAX_TARGET_LENGTH=<maximum sequence length for input and output tokens> # e.g., 1024
61+
export WEIGHT_DTYPE=<data type for storing model weights> # e.g., bfloat16
62+
export DTYPE=<data type for numerical computations> # e.g., bfloat16
63+
64+
# -- Dataset configuration --
65+
export DATASET_NAME=<Hugging Face dataset name> # e.g., openai/gsm8k
66+
export TRAIN_SPLIT=<data split for train> # e.g., train
67+
export HF_DATA_DIR=<The directory or sub-config within the Hugging Face dataset> # e.g., main
68+
export TRAIN_DATA_COLUMNS=<data columns to train on> # e.g., ['question','answer']
69+
70+
# -- LoRA Conversion configuration (Optional) --
71+
export HF_LORA_ADAPTER_PATH=<hf_repo_id_or_local_path> # e.g., 'username/adapter-name'
72+
```
73+
74+
## Customizing Trainable Layers (Optional)
75+
76+
By default, MaxText determines which layers to apply LoRA to based on the model's architecture by reading `src/maxtext/configs/post_train/lora_module_path.yml`.
77+
78+
If you need to fine-tune specific components (e.g., targeting only Attention layers to optimize memory usage), you can override these defaults through the following hierarchy:
79+
80+
### Configuration Hierarchy
81+
82+
1. **Command Line Argument**: Pass the `lora_module_path` argument directly in your training command. This is the most flexible way for experimental iterations.
83+
2. **Task-Specific Config (`sft.yml`)**: Define the `lora_module_path` parameter in `src/maxtext/configs/post_train/sft.yml` to set a persistent configuration for your SFT runs.
84+
3. **Global Defaults**: Automatic detection via the model-to-regex mapping defined in `lora_module_path.yml`.
85+
86+
## Get your model checkpoint
87+
88+
This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint.
89+
90+
### Option 1: Using an existing MaxText checkpoint
91+
92+
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.
93+
94+
```sh
95+
export PRE_TRAINED_MODEL_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
96+
```
97+
98+
### Option 2: Converting a Hugging Face checkpoint
99+
100+
Refer to the steps in [Hugging Face to MaxText](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/checkpointing_solutions/convert_checkpoint.html#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on.
101+
102+
```sh
103+
export PRE_TRAINED_MODEL_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
104+
```
105+
106+
## Run a Fresh LoRA Fine-Tuning on Hugging Face Dataset
107+
108+
Once your environment variables and checkpoints are ready, you can start the LoRA fine-tuning process.
109+
110+
Execute the following command to begin training:
111+
112+
```sh
113+
python3 -m maxtext.trainers.post_train.sft.train_sft \
114+
run_name="${RUN_NAME?}" \
115+
base_output_directory="${BASE_OUTPUT_DIRECTORY?}" \
116+
model_name="${PRE_TRAINED_MODEL?}" \
117+
load_parameters_path="${PRE_TRAINED_MODEL_CKPT_PATH?}" \
118+
hf_access_token="${HF_TOKEN?}" \
119+
hf_path="${DATASET_NAME?}" \
120+
train_split="${TRAIN_SPLIT?}" \
121+
hf_data_dir="${HF_DATA_DIR?}" \
122+
train_data_columns="${TRAIN_DATA_COLUMNS?}" \
123+
steps="${STEPS?}" \
124+
per_device_batch_size="${PER_DEVICE_BATCH_SIZE?}" \
125+
max_target_length="${MAX_TARGET_LENGTH?}" \
126+
learning_rate="${LEARNING_RATE?}" \
127+
weight_dtype="${WEIGHT_DTYPE?}" \
128+
dtype="${DTYPE?}" \
129+
enable_nnx=True \
130+
pure_nnx_decoder=True \
131+
enable_lora=True \
132+
lora_rank="${LORA_RANK?}" \
133+
lora_alpha="${LORA_ALPHA?}" \
134+
scan_layers=True
135+
```
136+
137+
Your fine-tuned model checkpoints will be saved here: `$BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints`.
138+
139+
## (Optional) Resume from a previous LoRA checkpoint
140+
141+
If you want to resume training from a previous run or further fine-tune an existing LoRA adapter, you can specify the LoRA checkpoint path.
142+
143+
### Step 1: Convert HF LoRA adapter to MaxText format
144+
145+
If your LoRA adapter is currently in Hugging Face format, you must convert it to MaxText format before it can be loaded. Use the provided conversion script:
146+
147+
```sh
148+
python3 -m maxtext.checkpoint_conversion.hf_lora_to_maxtext \
149+
model_name="${PRE_TRAINED_MODEL?}" \
150+
hf_lora_adapter_path="${HF_LORA_ADAPTER_PATH?}" \
151+
base_output_directory="${BASE_OUTPUT_DIRECTORY?}" \
152+
scan_layers=True
153+
```
154+
155+
### Step 2: Set the restore path
156+
157+
Point `LORA_RESTORE_PATH` to the converted MaxText adapter directory (the directory containing the `0/items` or Orbax files).
158+
159+
- **load_parameters_path**: Points to the frozen base model weights (the original model).
160+
- **lora_restore_path**: Points to the previous LoRA adapter weights you wish to load.
161+
162+
```sh
163+
export LORA_RESTORE_PATH=<gcs_path_to_converted_adapter_items> # e.g., gs://my-bucket/run-1/checkpoints/0/items
164+
```
165+
166+
### Step 3: Run LoRA Fine-Tuning with the Restore Path
167+
168+
Once your environment variables and checkpoints are ready, you can start the LoRA fine-tuning process.
169+
170+
Execute the following command to begin training:
171+
172+
```sh
173+
python3 -m maxtext.trainers.post_train.sft.train_sft \
174+
run_name="${RUN_NAME?}" \
175+
base_output_directory="${BASE_OUTPUT_DIRECTORY?}" \
176+
model_name="${PRE_TRAINED_MODEL?}" \
177+
load_parameters_path="${PRE_TRAINED_MODEL_CKPT_PATH?}" \
178+
lora_restore_path="${LORA_RESTORE_PATH}" \
179+
hf_access_token="${HF_TOKEN?}" \
180+
hf_path="${DATASET_NAME?}" \
181+
train_split="${TRAIN_SPLIT?}" \
182+
hf_data_dir="${HF_DATA_DIR?}" \
183+
train_data_columns="${TRAIN_DATA_COLUMNS?}" \
184+
steps="${STEPS?}" \
185+
per_device_batch_size="${PER_DEVICE_BATCH_SIZE?}" \
186+
max_target_length="${MAX_TARGET_LENGTH?}" \
187+
learning_rate="${LEARNING_RATE?}" \
188+
weight_dtype="${WEIGHT_DTYPE?}" \
189+
dtype="${DTYPE?}" \
190+
enable_nnx=True \
191+
pure_nnx_decoder=True \
192+
enable_lora=True \
193+
lora_rank="${LORA_RANK?}" \
194+
lora_alpha="${LORA_ALPHA?}" \
195+
scan_layers=True
196+
```
197+
198+
Your fine-tuned model checkpoints will be saved here: `$BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints`.
199+
200+
## (Optional) Convert Fine-tuned LoRA to Hugging Face Format
201+
202+
After completing the fine-tuning process, your LoRA weights are stored in MaxText/Orbax format. To use these weights with the Hugging Face ecosystem (e.g., for inference or sharing), convert them back using the `maxtext_lora_to_hf.py` script.
203+
204+
```sh
205+
python3 -m maxtext.checkpoint_conversion.maxtext_to_hf_lora \
206+
model_name="${PRE_TRAINED_MODEL?}" \
207+
load_parameters_path="${BASE_OUTPUT_DIRECTORY?}/${RUN_NAME?}/checkpoints/<step_number>/model_params" \
208+
base_output_directory="${BASE_OUTPUT_DIRECTORY?}/hf_lora_adapter" \
209+
lora_rank="${LORA_RANK?}" \
210+
lora_alpha="${LORA_ALPHA?}"
211+
```
212+
213+
- `load_parameters_path`: Point this to the specific checkpoint directory (e.g., `.../checkpoints/1000/items`) that you want to export.
214+
- `base_output_directory`: The local or GCS directory where the Hugging Face `adapter_model.safetensors` and `adapter_config.json` will be saved.
215+
- `lora_rank` / `lora_alpha`: Must match the values used during the training phase to ensure the `adapter_config.json` is generated correctly.

0 commit comments

Comments
 (0)