Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/tutorials/posttraining/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ export DATASET_NAME=<DATASET_NAME> # e.g., openai/gsm8k
export TRAIN_SPLIT=<TRAIN_SPLIT> # e.g., train
export HF_DATA_DIR=<DATASET_PATH> # e.g., main
export TRAIN_DATA_COLUMNS=<DATA_COLUMNS> # e.g., ['question','answer']
export CHAT_TEMPLATE_PATH=<TEMPLATE_PATH> # e.g., maxtext/examples/chat_templates/math_qa.json
export CHAT_TEMPLATE_PATH=<TEMPLATE_PATH> # e.g., src/maxtext/examples/chat_templates/math_qa.json (use gemma4_math_qa.json for Gemma 4 models)

# -- LoRA Conversion configuration (Optional) --
export HF_LORA_ADAPTER_PATH=<HF_LORA_ADAPTER_PATH> # e.g., 'username/adapter-name'
Expand Down
305 changes: 305 additions & 0 deletions docs/tutorials/posttraining/lora_on_multi_host.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/maxtext/configs/post_train/lora_module_path.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mistral: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))"
deepseek2: "decoder/(dense_layers|moe_stack)/self_attention/(query|out|wkv_a|wkv_b)|decoder/(dense_layers|moe_stack)/(mlp|shared_experts)/(wi_0|wi_1|wo)"
gemma2: "decoder/layers/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/layers/(mlp_local|mlp_global)/(wi_0|wi_1|wo)"
gemma3: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))"
gemma4: "decoder/(scanned_blocks|layers_remainder)/layers.*/.*(self_attention/(query|key|value|out)|mlp/.*(MoeBlock_0|wi_0|wi_1|wo|shared_experts/(wi_0|wi_1|wo)))"
olmo3: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))"
gpt3: "decoder/layers/(self_attention/(qkv_proj|out)|mlp/(wi|wo))"

Expand Down
3 changes: 3 additions & 0 deletions src/maxtext/configs/post_train/sft.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ lora:
lora_rank: 0
lora_alpha: 0.0
lora_module_path: ""
# For QLoRA, set lora_weight_qtype (e.g., "nf4") and optionally lora_tile_size.
lora_weight_qtype: null
lora_tile_size: null
# Optional path to LoRA weights to load before training. Ignored if the current run is resumed.
lora_restore_path: ""

Expand Down
12 changes: 11 additions & 1 deletion src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,9 +1236,19 @@ class LoRA(BaseModel):
lora_module_path: str = Field(
"",
description=(
"Regex identifying target modules for LoRA, e.g." " '.*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj'."
"Regex identifying target NNX modules for LoRA. "
"Example for standard models: 'decoder/layers/.*(self_attention/(query|out)|mlp/(wi_0|wo))'. "
"Example for MoE: 'decoder/scanned_blocks/layers.*/.*(MoeBlock_0|shared_experts)/(wi_0|wo)'."
),
)
lora_weight_qtype: str | None = Field(
None,
description=("Optional quantization type for QLoRA (e.g., 'nf4'). If set, QLoRA is applied."),
)
lora_tile_size: NonNegativeInt | None = Field(
None,
description=("Tile size for block-wise quantization. Typically 32 or 64."),
)
lora_restore_path: PathStr = Field(
"",
description=("Optional path to LoRA weights to load before training. Ignored if the current run is resumed."),
Expand Down
Loading
Loading