diff --git a/src/maxtext/configs/post_train/lora_module_path.yml b/src/maxtext/configs/post_train/lora_module_path.yml index 11f81d52c5..30006faf79 100644 --- a/src/maxtext/configs/post_train/lora_module_path.yml +++ b/src/maxtext/configs/post_train/lora_module_path.yml @@ -21,6 +21,7 @@ mistral: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))" deepseek2: "decoder/(dense_layers|moe_stack)/self_attention/(query|out|wkv_a|wkv_b)|decoder/(dense_layers|moe_stack)/(mlp|shared_experts)/(wi_0|wi_1|wo)" gemma2: "decoder/layers/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/layers/(mlp_local|mlp_global)/(wi_0|wi_1|wo)" gemma3: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))" +gemma4: "decoder/(scanned_blocks|layers_remainder)/layers.*/.*(self_attention/(query|key|value|out)|mlp/.*(wi_0|wi_1|wo|shared_experts/(wi_0|wi_1|wo)))" olmo3: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))" gpt3: "decoder/layers/(self_attention/(qkv_proj|out)|mlp/(wi|wo))" diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index 75c7989d9f..65682f130e 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -264,9 +264,14 @@ def setup_trainer_state(mt_config, goodput_recorder=None): def train_model(mt_config, trainer, mesh): """Runs the SFT training loop in Tunix.""" with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules): + # Disable NNX graph caching for MoE models (where experts > 1) to allow + # necessary dynamic metadata synchronization during forward passes (e.g., in jax.lax.scan). + enable_nnx_cache = mt_config.num_experts <= 1 + trainer.train( trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator, + cache_nnx_graph=enable_nnx_cache, ) return trainer