Skip to content

Commit 491b37c

Browse files
committed
feat: Gemma4 LoRA Extension
1 parent 493fba6 commit 491b37c

3 files changed

Lines changed: 11 additions & 2 deletions

File tree

src/maxtext/configs/post_train/lora_module_path.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ mistral: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))"
2121
deepseek2: "decoder/(dense_layers|moe_stack)/self_attention/(query|out|wkv_a|wkv_b)|decoder/(dense_layers|moe_stack)/(mlp|shared_experts)/(wi_0|wi_1|wo)"
2222
gemma2: "decoder/layers/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/layers/(mlp_local|mlp_global)/(wi_0|wi_1|wo)"
2323
gemma3: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))"
24+
gemma4: "decoder/(scanned_blocks|layers_remainder)/layers.*/.*(self_attention/(query|key|value|out)|mlp/.*(wi_0|wi_1|wo|shared_experts/(wi_0|wi_1|wo)))"
2425
olmo3: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))"
2526
gpt3: "decoder/layers/(self_attention/(qkv_proj|out)|mlp/(wi|wo))"
2627

src/maxtext/trainers/post_train/sft/train_sft.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,14 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
264264
def train_model(mt_config, trainer, mesh):
265265
"""Runs the SFT training loop in Tunix."""
266266
with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
267+
# Disable NNX graph caching for MoE models (where experts > 1) to allow
268+
# necessary dynamic metadata synchronization during forward passes (e.g., in jax.lax.scan).
269+
enable_nnx_cache = mt_config.num_experts <= 1
270+
267271
trainer.train(
268272
trainer.data_hooks.train_data_iterator,
269273
trainer.data_hooks.eval_data_iterator,
274+
cache_nnx_graph=enable_nnx_cache,
270275
)
271276
return trainer
272277

src/maxtext/utils/lora_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import jax
2828
import jax.numpy as jnp
2929
from orbax import checkpoint as ocp
30-
import qwix
3130

3231
from maxtext.common import checkpointing
3332
from maxtext.configs import pyconfig
@@ -408,8 +407,10 @@ def _get_lora_module_path(mt_config: pyconfig.HyperParameters) -> str:
408407
return final_path
409408

410409

411-
def _build_lora_provider(mt_config: pyconfig.HyperParameters) -> qwix.LoraProvider:
410+
def _build_lora_provider(mt_config: pyconfig.HyperParameters) -> "qwix.LoraProvider":
412411
"""Builds a Qwix LoRA provider from MaxText LoRA settings."""
412+
import qwix # pylint: disable=import-outside-toplevel
413+
413414
lora_module_path = _get_lora_module_path(mt_config)
414415
lora_kwargs = {
415416
"module_path": lora_module_path,
@@ -495,6 +496,8 @@ def apply_lora_to_model(
495496
model_rngs = getattr(model.decoder, "rngs", None)
496497
decoder_input_tokens, decoder_positions = _prepare_dummy_inputs()
497498

499+
import qwix # pylint: disable=import-outside-toplevel
500+
498501
lora_model = qwix.apply_lora_to_model(
499502
model,
500503
lora_provider,

0 commit comments

Comments
 (0)