Skip to content

Commit 2bc8632

Browse files
committed
feat: Gemma4 LoRA Extension
1 parent 76500f3 commit 2bc8632

4 files changed

Lines changed: 26 additions & 0 deletions

File tree

src/maxtext/configs/post_train/lora_module_path.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ mistral: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))"
2121
deepseek2: "decoder/(dense_layers|moe_stack)/self_attention/(query|out|wkv_a|wkv_b)|decoder/(dense_layers|moe_stack)/(mlp|shared_experts)/(wi_0|wi_1|wo)"
2222
gemma2: "decoder/layers/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/layers/(mlp_local|mlp_global)/(wi_0|wi_1|wo)"
2323
gemma3: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))"
24+
gemma4: "decoder/(scanned_blocks|layers_remainder)/layers.*/.*(self_attention/(query|key|value|out)|mlp/.*(MoeBlock_0|wi_0|wi_1|wo|shared_experts/(wi_0|wi_1|wo)))"
2425
olmo3: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))"
2526
gpt3: "decoder/layers/(self_attention/(qkv_proj|out)|mlp/(wi|wo))"
2627

src/maxtext/input_pipeline/input_pipeline_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,16 @@ def verify_chat_template_generation_prompt_logic(tokenizer_model):
267267
actual_prefix_in_full_turn = full_turn_ids[len(prompt_wo_gen_ids) : len(prompt_wo_gen_ids) + len(assistant_prefix)]
268268

269269
if actual_prefix_in_full_turn != assistant_prefix:
270+
# Allow the generation prompt to include a thought channel block (e.g., for Gemma 4).
271+
thought_channel = "<|channel>thought\n<channel|>"
272+
thought_ids = extract_token_ids(tokenizer_model.encode(thought_channel, add_special_tokens=False))
273+
if len(assistant_prefix) >= len(thought_ids) and assistant_prefix[-len(thought_ids) :] == thought_ids:
274+
true_prefix_ids = assistant_prefix[: -len(thought_ids)]
275+
actual_prefix = full_turn_ids[len(prompt_wo_gen_ids) : len(prompt_wo_gen_ids) + len(true_prefix_ids)]
276+
if actual_prefix == true_prefix_ids:
277+
max_logging.info("Chat template generation prompt mismatch resolved via thought channel bypass.")
278+
return
279+
270280
expected_str = tokenizer_model.decode(assistant_prefix)
271281
actual_str = tokenizer_model.decode(actual_prefix_in_full_turn)
272282
raise ValueError(
@@ -298,6 +308,12 @@ def _get_completion_in_chat_template(tokenizer_model, round_msgs):
298308
prompt_completion_ids = extract_token_ids(prompt_completion_tokens)
299309
prompt_ids = extract_token_ids(prompt_tokens)
300310

311+
# Bypass for Gemma 4's thought channel block which is included in generation prompt but not in normal assistant turns
312+
thought_channel = "<|channel>thought\n<channel|>"
313+
thought_ids = extract_token_ids(tokenizer_model.encode(thought_channel, add_special_tokens=False))
314+
if len(prompt_ids) >= len(thought_ids) and prompt_ids[-len(thought_ids) :] == thought_ids:
315+
prompt_ids = prompt_ids[: -len(thought_ids)]
316+
301317
completion_tokens = prompt_completion_ids[len(prompt_ids) :]
302318
completion_in_chat_template = tokenizer_model.decode(completion_tokens, skip_special_tokens=False)
303319
return completion_in_chat_template

src/maxtext/trainers/post_train/sft/train_sft.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,14 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
264264
def train_model(mt_config, trainer, mesh):
265265
"""Runs the SFT training loop in Tunix."""
266266
with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
267+
# Disable NNX graph caching for MoE models (where experts > 1) to allow
268+
# necessary dynamic metadata synchronization during forward passes (e.g., in jax.lax.scan).
269+
enable_nnx_cache = getattr(mt_config, "num_experts", 1) <= 1
270+
267271
trainer.train(
268272
trainer.data_hooks.train_data_iterator,
269273
trainer.data_hooks.eval_data_iterator,
274+
cache_nnx_graph=enable_nnx_cache,
270275
)
271276
return trainer
272277

tests/post_training/unit/sft_data_processing_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,13 +512,17 @@ def setUp(self):
512512
super().setUp()
513513
self.qwen3_tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")
514514
self.llama2_tokenizer = transformers.AutoTokenizer.from_pretrained(self.LLAMA_TOKENIZER_PATH)
515+
self.gemma4_tokenizer = transformers.AutoTokenizer.from_pretrained("google/gemma-4-26b-a4b-it")
515516

516517
def test_tokenizer_w_generation_prompt(self):
517518
verify_chat_template_generation_prompt_logic(self.qwen3_tokenizer)
518519

519520
def test_tokenizer_wo_generation_prompt(self):
520521
verify_chat_template_generation_prompt_logic(self.llama2_tokenizer)
521522

523+
def test_tokenizer_gemma4_thought_channel_bypass(self):
524+
verify_chat_template_generation_prompt_logic(self.gemma4_tokenizer)
525+
522526
def test_failure_path_with_modified_template(self):
523527
"""Verifies the function correctly raises a ValueError on a bad template."""
524528
# Replace the role within the existing add_generation_prompt block with a deliberately faulty one.

0 commit comments

Comments
 (0)