Extend LoRA for Gemma4#3969
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
2bc8632 to
ab61640
Compare
| def test_tokenizer_wo_generation_prompt(self): | ||
| verify_chat_template_generation_prompt_logic(self.llama2_tokenizer) | ||
|
|
||
| def test_tokenizer_gemma4_thought_channel_bypass(self): |
There was a problem hiding this comment.
This test expects to not fail with TemplateError or ValueError. Can you add an assertion for this so that it is readable what this test actually verifies?
There was a problem hiding this comment.
I updated verify_chat_template_generation_prompt_logic to return True on success, and wrapped all three test cases in explicit self.assertTrue() assertions.
This cleanly verifies that validation succeeds and keeps all tests uniform. All tests pass successfully!"
61626bd to
ef50ff7
Compare
| actual_prefix_in_full_turn = full_turn_ids[len(prompt_wo_gen_ids) : len(prompt_wo_gen_ids) + len(assistant_prefix)] | ||
|
|
||
| if actual_prefix_in_full_turn != assistant_prefix: | ||
| # Allow the generation prompt to include a thought channel block (e.g., for Gemma 4). |
There was a problem hiding this comment.
This logic looks like a hacky approach to support Gemma4. I am working on a generalized logic to support any model that requires specific prefix shifting. I will send out the PR soon.
There was a problem hiding this comment.
Agreed, it was definitely a workaround. Thanks for taking the lead on a generalized solution! I'll track #4010 and we can use that approach instead.
ef50ff7 to
5fd616b
Compare
5fd616b to
6a64bd0
Compare
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request successfully extends LoRA support for Gemma 4 architectures and addresses a critical issue with NNX graph caching in MoE models. However, there is a significant discrepancy between the PR description and the actual changes, as several mentioned files and unit tests are missing from the diff.
🔍 General Feedback
- Missing Implementation: The PR description mentions a "thought channel bypass" in
input_pipeline_utils.pyand new unit tests intests/post_training/unit/sft_data_processing_test.py, but these files are not included in the PR. Please ensure all intended changes are staged and pushed. - Consistency across Trainers: The dynamic disabling of NNX graph caching is a great addition for MoE stability; consider applying this same logic to DPO, RL, and Distillation trainers to ensure consistent behavior across the post-training suite.
- LoRA Targeting: The regex for Gemma 4 LoRA targeting is comprehensive but should be monitored to ensure it doesn't become overly broad as the architecture evolves.
| deepseek2: "decoder/(dense_layers|moe_stack)/self_attention/(query|out|wkv_a|wkv_b)|decoder/(dense_layers|moe_stack)/(mlp|shared_experts)/(wi_0|wi_1|wo)" | ||
| gemma2: "decoder/layers/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/layers/(mlp_local|mlp_global)/(wi_0|wi_1|wo)" | ||
| gemma3: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))" | ||
| gemma4: "decoder/(scanned_blocks|layers_remainder)/layers.*/.*(self_attention/(query|key|value|out)|mlp/.*(MoeBlock_0|wi_0|wi_1|wo|shared_experts/(wi_0|wi_1|wo)))" |
There was a problem hiding this comment.
The pattern also adapts the MoE router: probably by accident. The MoeBlock_0 term matches everything inside the routed-MoE block, including the router/gate that decides which expert each token goes to. You normally don't want LoRA on the router. The actual expert weights (wi_0, wi_1, wo) are already matched by the other terms in the pattern, so you can just delete MoeBlock_0|. Experts and shared experts still get LoRA; the router no longer does. (Checked against the parameter tree.)
There was a problem hiding this comment.
Thank you for the excellent catch! Yes, including MoeBlock_0 was a mistake that dragged in the routing gate.
I have removed MoeBlock_0| from the gemma4 pattern in lora_module_path.yml. This successfully excludes the router/gate from LoRA adaptation while still perfectly capturing the expert MLP weights (wi_0, wi_1, wo)
and the shared experts.
| with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules): | ||
| # Disable NNX graph caching for MoE models (where experts > 1) to allow | ||
| # necessary dynamic metadata synchronization during forward passes (e.g., in jax.lax.scan). | ||
| enable_nnx_cache = getattr(mt_config, "num_experts", 1) <= 1 |
There was a problem hiding this comment.
For MaxText flags just directly use dot notation:
mt_config.num_experts
Also why do you need to disable it for MoE, what happens if you don't?
There was a problem hiding this comment.
-
Simplified Flag Access: Updated to use direct dot notation (
mt_config.num_experts) as suggested.- Why we disable caching for MoE:
If we run training withcache_nnx_graph=True(which usesnnx.cached_partialunder the hood to cache graph topology), MoE training fails during step execution with the following traceback:
File "/home/jackyf_google_com/5-4-qlora/lib/python3.12/site-packages/flax/nnx/graph.py", line 1716, in unflatten raise ValueError( ValueError: The graph structure of a node added to cached_partial was mutated inside the transformation, this is not allowed. Node: Transformer( ... )
- Why we disable caching for MoE:
nnx.cached_partial expects a static topological structure ( graphdef ) across JIT boundaries. However, Mixture-of-Experts (MoE) models perform dynamic token routing and scan-indexing during the forward pass. These operations introduce dynamic tracer metadata that mutates the internal structure of the Transformer node during transformation. This violates the static structural invariants of cached_partial , triggering the traceback. Disabling the graph cache allows NNX to split/merge definitions dynamically at each step, preventing this error.
There was a problem hiding this comment.
Are there any tests for Gemma4 Lora?
There was a problem hiding this comment.
We ran the end-to-end LoRA training loop for Gemma 4 successfully without any issues.4 successfully without any issues. log
4c2d52f to
491b37c
Compare
491b37c to
b52137c
Compare
Description
This PR extends the recent LoRA support to accurately target and process Gemma 4 architectures (including MoE).
Gemma 4 introduces complex nested structures (like scanned_blocks and layers_remainder) and unique chat template behaviors (such as the <|channel>thought block) that are incompatible with standard LoRA targeting and data
processing. Furthermore, MoE models require dynamic metadata synchronization during forward passes which is broken by aggressive NNX graph caching.
This PR addresses these challenges by:
Tests
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.