Skip to content

Commit ca8af65

Browse files
committed
feat: integrate LoRA support in training pipeline
1 parent 1adb45c commit ca8af65

25 files changed

Lines changed: 1731 additions & 608 deletions

File tree

docs/_static/js/editable_commands.js

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,23 @@ document.addEventListener('DOMContentLoaded', () => {
1616
"<DATASET_PATH>",
1717
"<GCS_BUCKET>",
1818
"<HF_CKPT_PATH>",
19+
"<HF_LORA_ADAPTER_PATH>",
1920
"<HF_MODEL>",
2021
"<HF_TOKEN>",
2122
"<IMAGE_NAME>",
2223
"<LAZY_LOAD>",
24+
"<LEARNING_RATE>",
25+
"<LORA_ALPHA>",
26+
"<LORA_RANK>",
27+
"<LORA_RESTORE_PATH>",
28+
"<MAX_TARGET_LENGTH>",
2329
"<MODEL_NAME>",
2430
"<NUM_SLICES>",
2531
"<POD_NAME>",
2632
"<PROJECT_ID>",
2733
"<RUN_NAME>",
2834
"<STEPS>",
35+
"<TEMPLATE_PATH>",
2936
"<TPU_TYPE>",
3037
"<TRAIN_SPLIT>",
3138
"<VENV_NAME>",

docs/tutorials/post_training_index.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ MaxText was co-designed with key Google led innovations to provide a unified pos
2626
- **SFT (Supervised Fine-Tuning)**
2727
- [SFT on Single-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft.html)
2828
- [SFT on Multi-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft_on_multi_host.html)
29+
- **LoRA (Low-Rank Adaptation)**
30+
- [LoRA on Single-Host TPUs](posttraining/lora.md)
2931
- **Multimodal SFT**
3032
- [Multimodal Support](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/multimodal.html)
3133
- **Reinforcement Learning (RL)**
@@ -68,6 +70,7 @@ posttraining/sft_on_multi_host.md
6870
posttraining/rl.md
6971
posttraining/rl_on_multi_host.md
7072
posttraining/knowledge_distillation.md
73+
posttraining/lora.md
7174
posttraining/multimodal.md
7275
posttraining/full_finetuning.md
7376
posttraining/gepa_optimization.md
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
<!--
2+
Copyright 2023–2026 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
-->
16+
17+
# LoRA Fine-tuning on single-host TPUs
18+
19+
**Low-Rank Adaptation (LoRA)** is a Parameter-Efficient Fine-Tuning (PEFT) technique designed to optimize large language models while minimizing resource consumption.
20+
21+
Unlike traditional full-parameter fine-tuning, LoRA:
22+
23+
- **Freezes the pre-trained model weights**, preserving the original knowledge.
24+
- **Injects trainable rank decomposition matrices** into the Transformer layers.
25+
26+
This approach **greatly reduces the number of trainable parameters** required for downstream tasks, making the process faster and more memory-efficient.
27+
28+
This tutorial provides step-by-step instructions for setting up the environment and performing LoRA fine-tuning on a Hugging Face dataset using MaxText.
29+
30+
We use [Tunix](https://github.com/google/tunix), a JAX-based library, to power these post-training tasks.
31+
32+
In this tutorial we use a single host TPU VM such as `v6e-8/v5p-8`. Let's get started!
33+
34+
## Setup environment variables
35+
36+
Login to Hugging Face. Provide your access token when prompted:
37+
38+
```bash
39+
hf auth login
40+
```
41+
42+
Set the following environment variables before running LoRA Fine-tuning.
43+
44+
```sh
45+
# -- Model configuration --
46+
export MODEL_NAME=<MODEL_NAME> # e.g., 'gemma3-4b'
47+
48+
# -- MaxText configuration --
49+
export BASE_OUTPUT_DIRECTORY=<GCS_BUCKET> # e.g., gs://my-bucket/my-output-directory or /path/to/my-output-directory
50+
export RUN_NAME=<RUN_NAME> # e.g., $(date +%Y-%m-%d-%H-%M-%S)
51+
export STEPS=<STEPS> # e.g., 1000
52+
export PER_DEVICE_BATCH_SIZE=<BATCH_SIZE_PER_DEVICE> # e.g., 1
53+
export LORA_RANK=<LORA_RANK> # e.g., 16
54+
export LORA_ALPHA=<LORA_ALPHA> # e.g., 32.0
55+
export LEARNING_RATE=<LEARNING_RATE> # e.g., 3e-6
56+
export MAX_TARGET_LENGTH=<MAX_TARGET_LENGTH> # e.g., 1024
57+
58+
# -- Dataset configuration --
59+
export DATASET_NAME=<DATASET_NAME> # e.g., openai/gsm8k
60+
export TRAIN_SPLIT=<TRAIN_SPLIT> # e.g., train
61+
export HF_DATA_DIR=<DATASET_PATH> # e.g., main
62+
export TRAIN_DATA_COLUMNS=<DATA_COLUMNS> # e.g., ['question','answer']
63+
export CHAT_TEMPLATE_PATH=<TEMPLATE_PATH> # e.g., maxtext/examples/chat_templates/math_qa.json
64+
65+
# -- LoRA Conversion configuration (Optional) --
66+
export HF_LORA_ADAPTER_PATH=<HF_LORA_ADAPTER_PATH> # e.g., 'username/adapter-name'
67+
```
68+
69+
## Customizing Trainable Layers (Optional)
70+
71+
By default, MaxText determines which layers to apply LoRA to based on the model's architecture by reading `src/maxtext/configs/post_train/lora_module_path.yml`.
72+
73+
If you need to fine-tune specific components (e.g., targeting only Attention layers to optimize memory usage), you can override these defaults through the following hierarchy:
74+
75+
### Configuration Hierarchy
76+
77+
1. **Command Line Argument**: Pass the `lora_module_path` argument directly in your training command. This is the most flexible way for experimental iterations.
78+
2. **Task-Specific Config (`sft.yml`)**: Define the `lora_module_path` parameter in `src/maxtext/configs/post_train/sft.yml` to set a persistent configuration for your SFT runs.
79+
3. **Global Defaults**: Automatic detection via the model-to-regex mapping defined in `lora_module_path.yml`.
80+
81+
## Get your model checkpoint
82+
83+
This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint.
84+
85+
### Option 1: Using an existing MaxText checkpoint
86+
87+
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.
88+
89+
```sh
90+
export MAXTEXT_CKPT_PATH=<CKPT_PATH> # e.g., gs://my-bucket/my-model-checkpoint/0/items or /path/to/my-model-checkpoint/0/items
91+
```
92+
93+
### Option 2: Converting a Hugging Face checkpoint
94+
95+
Refer to the steps in [Hugging Face to MaxText](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/checkpointing_solutions/convert_checkpoint.html#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have the correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on.
96+
97+
```sh
98+
export MAXTEXT_CKPT_PATH=<CKPT_PATH> # e.g., gs://my-bucket/my-model-checkpoint/0/items or /path/to/my-model-checkpoint/0/items
99+
```
100+
101+
## Run a Fresh LoRA Fine-Tuning on Hugging Face Dataset
102+
103+
Once your environment variables and checkpoints are ready, you can start the LoRA fine-tuning process.
104+
105+
Execute the following command to begin training:
106+
107+
```sh
108+
python3 -m maxtext.trainers.post_train.sft.train_sft \
109+
run_name="${RUN_NAME?}" \
110+
base_output_directory="${BASE_OUTPUT_DIRECTORY?}" \
111+
model_name="${MODEL_NAME?}" \
112+
load_parameters_path="${MAXTEXT_CKPT_PATH?}" \
113+
hf_path="${DATASET_NAME?}" \
114+
train_split="${TRAIN_SPLIT?}" \
115+
hf_data_dir="${HF_DATA_DIR?}" \
116+
train_data_columns="${TRAIN_DATA_COLUMNS?}" \
117+
steps="${STEPS?}" \
118+
per_device_batch_size="${PER_DEVICE_BATCH_SIZE?}" \
119+
max_target_length="${MAX_TARGET_LENGTH?}" \
120+
learning_rate="${LEARNING_RATE?}" \
121+
chat_template_path="${CHAT_TEMPLATE_PATH?}" \
122+
enable_nnx=True \
123+
pure_nnx_decoder=True \
124+
lora.enable_lora=True \
125+
lora.lora_rank="${LORA_RANK?}" \
126+
lora.lora_alpha="${LORA_ALPHA?}"
127+
```
128+
129+
Your fine-tuned model checkpoints will be saved here: `$BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints`.
130+
131+
## (Optional) Resume from a previous LoRA checkpoint
132+
133+
If you want to resume training from a previous run or further fine-tune an existing LoRA adapter, you can specify the LoRA checkpoint path.
134+
135+
### Step 1: Convert HF LoRA adapter to MaxText format
136+
137+
If your LoRA adapter is currently in Hugging Face format, you must convert it to MaxText format before it can be loaded. Use the integrated conversion utility:
138+
139+
```sh
140+
python3 -m maxtext.checkpoint_conversion.to_maxtext \
141+
model_name="${MODEL_NAME?}" \
142+
hf_lora_adapter_path="${HF_LORA_ADAPTER_PATH?}" \
143+
base_output_directory="${BASE_OUTPUT_DIRECTORY?}/converted_adapter" \
144+
hardware=cpu skip_jax_distributed_system=True
145+
```
146+
147+
### Step 2: Set the restore path
148+
149+
Point `LORA_RESTORE_PATH` to the converted MaxText adapter directory (the directory containing the `0/items` or Orbax files).
150+
151+
- **load_parameters_path**: Points to the frozen base model weights (the original model).
152+
- **lora_restore_path**: Points to the previous LoRA adapter weights you wish to load.
153+
154+
```sh
155+
export LORA_RESTORE_PATH=<LORA_RESTORE_PATH> # e.g., gs://my-bucket/run-1/checkpoints/0/items or /path/to/run-1/checkpoints/0/items
156+
```
157+
158+
### Step 3: Run LoRA Fine-Tuning with the Restore Path
159+
160+
Once your environment variables and checkpoints are ready, you can start the LoRA fine-tuning process.
161+
162+
Execute the following command to begin training:
163+
164+
```sh
165+
python3 -m maxtext.trainers.post_train.sft.train_sft \
166+
run_name="${RUN_NAME?}" \
167+
base_output_directory="${BASE_OUTPUT_DIRECTORY?}" \
168+
model_name="${MODEL_NAME?}" \
169+
load_parameters_path="${MAXTEXT_CKPT_PATH?}" \
170+
lora.lora_restore_path="${LORA_RESTORE_PATH?}" \
171+
hf_path="${DATASET_NAME?}" \
172+
train_split="${TRAIN_SPLIT?}" \
173+
hf_data_dir="${HF_DATA_DIR?}" \
174+
train_data_columns="${TRAIN_DATA_COLUMNS?}" \
175+
steps="${STEPS?}" \
176+
per_device_batch_size="${PER_DEVICE_BATCH_SIZE?}" \
177+
max_target_length="${MAX_TARGET_LENGTH?}" \
178+
learning_rate="${LEARNING_RATE?}" \
179+
chat_template_path="${CHAT_TEMPLATE_PATH?}" \
180+
enable_nnx=True \
181+
pure_nnx_decoder=True \
182+
lora.enable_lora=True \
183+
lora.lora_rank="${LORA_RANK?}" \
184+
lora.lora_alpha="${LORA_ALPHA?}"
185+
```
186+
187+
Your fine-tuned model checkpoints will be saved here: `$BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints`.
188+
189+
## (Optional) Convert Fine-tuned LoRA to Hugging Face Format
190+
191+
After completing the fine-tuning process, your LoRA weights are stored in MaxText/Orbax format. To use these weights with the Hugging Face ecosystem (e.g., for inference or sharing), convert them back using the `to_huggingface.py` script.
192+
193+
```sh
194+
python3 -m maxtext.checkpoint_conversion.to_huggingface \
195+
model_name="${MODEL_NAME?}" \
196+
lora.lora_restore_path="${BASE_OUTPUT_DIRECTORY?}/${RUN_NAME?}/checkpoints/<STEPS>/model_params" \
197+
base_output_directory="${BASE_OUTPUT_DIRECTORY?}/hf_lora_adapter"
198+
```
199+
200+
- `lora.lora_restore_path`: Point this to the specific checkpoint directory (e.g., `.../checkpoints/1000/items`) that you want to export.
201+
- `base_output_directory`: The local or GCS directory where the Hugging Face `adapter_model.safetensors` and `adapter_config.json` will be saved.
202+
- `lora.lora_rank` / `lora.lora_alpha`: Must match the values used during the training phase to ensure the `adapter_config.json` is generated correctly.

pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ addopts =
1515
--ignore=tests/unit/gemma3_layers_test.py
1616
--ignore=tests/unit/gpt_vs_reference_test.py
1717
--ignore=tests/unit/llama4_layers_test.py
18+
--ignore=tests/unit/hf_checkpoint_conversion_test.py
1819
--ignore=tests/unit/yarn_vs_reference_test.py
1920
--ignore=tests/unit/moba_vs_reference_test.py
2021
--ignore=tests/unit/offline_engine_test.py

src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ openai
2121
openai-harmony
2222
papermill
2323
partial-json-parser
24+
peft
2425
perfetto
2526
prometheus-fastapi-instrumentator
2627
py-cpuinfo

src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# See https://maxtext.readthedocs.io/en/latest/development/update_dependencies.html for details.
33

44
absl-py>=2.4.0
5+
accelerate>=1.13.0
56
aiofiles>=25.1.0
67
aiohappyeyeballs>=2.6.1
78
aiohttp>=3.13.5
@@ -11,6 +12,7 @@ annotated-types>=0.7.0
1112
anthropic>=0.97.0
1213
antlr4-python3-runtime>=4.9.3
1314
anyio>=4.13.0
15+
apache-tvm-ffi>=0.1.10
1416
appnope>=0.1.4 ; sys_platform == 'darwin'
1517
aqtp>=0.9.0
1618
array-record>=0.8.3
@@ -21,11 +23,11 @@ astunparse>=1.6.3
2123
attrs>=25.4.0
2224
auditwheel>=6.6.0
2325
black>=25.12.0
24-
boto3>=1.42.97
25-
botocore>=1.42.97
26+
boto3>=1.43.0
27+
botocore>=1.43.0
2628
build>=1.4.0
2729
cachetools>=7.0.6
28-
cbor2>=5.9.0
30+
cbor2>=6.0.1
2931
certifi>=2026.2.25
3032
cffi>=2.0.0 ; implementation_name == 'pypy' or platform_python_implementation != 'PyPy'
3133
cfgv>=3.5.0
@@ -79,7 +81,7 @@ google-api-python-client>=2.194.0
7981
google-auth>=2.49.2
8082
google-auth-httplib2>=0.3.1
8183
google-auth-oauthlib>=1.3.1
82-
google-cloud-aiplatform>=1.148.1
84+
google-cloud-aiplatform>=1.149.0
8385
google-cloud-appengine-logging>=1.9.0
8486
google-cloud-audit-log>=0.5.0
8587
google-cloud-bigquery>=3.41.0
@@ -91,7 +93,7 @@ google-cloud-resource-manager>=1.17.0
9193
google-cloud-storage>=3.10.1
9294
google-cloud-storage-control>=1.11.0
9395
google-crc32c>=1.8.0
94-
google-genai>=1.73.1
96+
google-genai>=1.74.0
9597
google-metrax>=0.2.3
9698
google-pasta>=0.2.0
9799
google-resumable-media>=2.8.2
@@ -110,7 +112,7 @@ hf-xet>=1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or
110112
httpcore>=1.0.9
111113
httplib2>=0.31.2
112114
httpx>=0.28.1
113-
huggingface-hub>=1.12.0
115+
huggingface-hub>=1.12.2
114116
humanize>=4.15.0
115117
hypothesis>=6.142.1
116118
identify>=2.6.19
@@ -139,15 +141,15 @@ jsonschema-specifications>=2025.9.1
139141
jupyter-client>=8.8.0
140142
jupyter-core>=5.9.1
141143
jupyterlab-widgets>=3.0.16
142-
kagglehub>=1.0.0
143-
kagglesdk>=0.1.21
144+
kagglehub>=1.0.1
145+
kagglesdk>=0.1.22
144146
keras>=3.13.2
145147
kiwisolver>=1.5.0
146148
latex2sympy2-extended>=1.11.0
147149
libclang>=18.1.1
148150
libcst>=1.8.6
149151
libtpu>=0.0.39
150-
llguidance>=1.7.3
152+
llguidance>=1.7.5
151153
llvmlite>=0.47.0
152154
loguru>=0.7.3
153155
lxml>=6.1.0
@@ -160,7 +162,7 @@ matplotlib>=3.10.8
160162
matplotlib-inline>=0.2.1
161163
mccabe>=0.7.0
162164
mdurl>=0.1.2
163-
mistral-common>=1.11.0
165+
mistral-common>=1.11.1
164166
ml-collections>=1.1.0
165167
ml-dtypes>=0.5.4
166168
ml-goodput-measurement>=0.0.16
@@ -199,7 +201,7 @@ nvidia-nvshmem-cu12>=3.4.5 ; platform_machine == 'x86_64' and sys_platform == 'l
199201
nvidia-nvtx-cu12>=12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
200202
oauthlib>=3.3.1
201203
omegaconf>=2.3.0
202-
openai>=2.32.0
204+
openai>=2.33.0
203205
openai-harmony>=0.0.8
204206
opentelemetry-api>=1.41.1
205207
opt-einsum>=3.4.0
@@ -216,6 +218,7 @@ parso>=0.8.6
216218
partial-json-parser>=0.2.1.1.post7
217219
pathspec>=1.1.1
218220
pathwaysutils>=0.1.8
221+
peft>=0.19.1
219222
perfetto>=0.16.0
220223
pexpect>=4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
221224
pillow>=12.1.1
@@ -273,7 +276,7 @@ rpds-py>=0.30.0
273276
runai-model-streamer>=0.15.8
274277
runai-model-streamer-gcs>=0.15.8
275278
runai-model-streamer-s3>=0.15.8
276-
s3transfer>=0.16.1
279+
s3transfer>=0.17.0
277280
safetensors>=0.7.0
278281
scipy>=1.17.1
279282
scipy-stubs>=1.17.1.2
@@ -314,7 +317,7 @@ tornado>=6.5.5
314317
tpu-info>=0.11.0
315318
tqdm>=4.67.3
316319
traitlets>=5.14.3
317-
transformers>=5.6.2
320+
transformers>=5.7.0
318321
treescope>=0.1.10
319322
triton>=3.6.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'
320323
typeguard>=2.13.3
@@ -336,7 +339,7 @@ wheel>=0.46.3
336339
widgetsnbextension>=4.0.15
337340
win32-setctime>=1.2.0 ; sys_platform == 'win32'
338341
wrapt>=2.1.2
339-
xgrammar>=0.1.33
342+
xgrammar>=0.1.34
340343
xprof>=2.22.2
341344
xxhash>=3.7.0
342345
yapf>=0.43.0

0 commit comments

Comments
 (0)