Skip to content

Commit 6a64bd0

Browse files
committed
feat: Gemma4 LoRA Extension
1 parent 493fba6 commit 6a64bd0

2 files changed

Lines changed: 6 additions & 0 deletions

File tree

src/maxtext/configs/post_train/lora_module_path.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ mistral: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))"
2121
deepseek2: "decoder/(dense_layers|moe_stack)/self_attention/(query|out|wkv_a|wkv_b)|decoder/(dense_layers|moe_stack)/(mlp|shared_experts)/(wi_0|wi_1|wo)"
2222
gemma2: "decoder/layers/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/layers/(mlp_local|mlp_global)/(wi_0|wi_1|wo)"
2323
gemma3: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))"
24+
gemma4: "decoder/(scanned_blocks|layers_remainder)/layers.*/.*(self_attention/(query|key|value|out)|mlp/.*(MoeBlock_0|wi_0|wi_1|wo|shared_experts/(wi_0|wi_1|wo)))"
2425
olmo3: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))"
2526
gpt3: "decoder/layers/(self_attention/(qkv_proj|out)|mlp/(wi|wo))"
2627

src/maxtext/trainers/post_train/sft/train_sft.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,14 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
264264
def train_model(mt_config, trainer, mesh):
265265
"""Runs the SFT training loop in Tunix."""
266266
with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
267+
# Disable NNX graph caching for MoE models (where experts > 1) to allow
268+
# necessary dynamic metadata synchronization during forward passes (e.g., in jax.lax.scan).
269+
enable_nnx_cache = getattr(mt_config, "num_experts", 1) <= 1
270+
267271
trainer.train(
268272
trainer.data_hooks.train_data_iterator,
269273
trainer.data_hooks.eval_data_iterator,
274+
cache_nnx_graph=enable_nnx_cache,
270275
)
271276
return trainer
272277

0 commit comments

Comments
 (0)