Skip to content

Extend LoRA for Gemma4#3969

Open
RexBearIU wants to merge 1 commit into
mainfrom
jackyf/gemma4-lora
Open

Extend LoRA for Gemma4#3969
RexBearIU wants to merge 1 commit into
mainfrom
jackyf/gemma4-lora

Conversation

@RexBearIU
Copy link
Copy Markdown
Collaborator

@RexBearIU RexBearIU commented May 22, 2026

Description

This PR extends the recent LoRA support to accurately target and process Gemma 4 architectures (including MoE).

Gemma 4 introduces complex nested structures (like scanned_blocks and layers_remainder) and unique chat template behaviors (such as the <|channel>thought block) that are incompatible with standard LoRA targeting and data
processing. Furthermore, MoE models require dynamic metadata synchronization during forward passes which is broken by aggressive NNX graph caching.

This PR addresses these challenges by:

  • Adding accurate regex mapping for Gemma 4 standard and MoE LoRA targets in lora_module_path.yml.
  • Dynamically disabling NNX graph caching in train_sft.py specifically for MoE models (where experts > 1) to allow necessary metadata synchronization.

Tests

  • Added unit tests for the Gemma 4 tokenizer bypass in tests/post_training/unit/sft_data_processing_test.py (test_tokenizer_gemma4_thought_channel_bypass).
  • Verified caching behavior changes by running Gemma-4 MoE LoRA tuning on TPU.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 22, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@RexBearIU RexBearIU force-pushed the jackyf/gemma4-lora branch from 2bc8632 to ab61640 Compare May 22, 2026 07:38
def test_tokenizer_wo_generation_prompt(self):
verify_chat_template_generation_prompt_logic(self.llama2_tokenizer)

def test_tokenizer_gemma4_thought_channel_bypass(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test expects to not fail with TemplateError or ValueError. Can you add an assertion for this so that it is readable what this test actually verifies?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated verify_chat_template_generation_prompt_logic to return True on success, and wrapped all three test cases in explicit self.assertTrue() assertions.

This cleanly verifies that validation succeeds and keeps all tests uniform. All tests pass successfully!"

@RexBearIU RexBearIU force-pushed the jackyf/gemma4-lora branch 2 times, most recently from 61626bd to ef50ff7 Compare May 28, 2026 08:41
actual_prefix_in_full_turn = full_turn_ids[len(prompt_wo_gen_ids) : len(prompt_wo_gen_ids) + len(assistant_prefix)]

if actual_prefix_in_full_turn != assistant_prefix:
# Allow the generation prompt to include a thought channel block (e.g., for Gemma 4).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic looks like a hacky approach to support Gemma4. I am working on a generalized logic to support any model that requires specific prefix shifting. I will send out the PR soon.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, it was definitely a workaround. Thanks for taking the lead on a generalized solution! I'll track #4010 and we can use that approach instead.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've rebased this PR on top of #4010

@RexBearIU RexBearIU force-pushed the jackyf/gemma4-lora branch from ef50ff7 to 5fd616b Compare May 29, 2026 07:48
@RexBearIU RexBearIU force-pushed the jackyf/gemma4-lora branch from 5fd616b to 6a64bd0 Compare June 1, 2026 06:59
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Jun 2, 2026

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request successfully extends LoRA support for Gemma 4 architectures and addresses a critical issue with NNX graph caching in MoE models. However, there is a significant discrepancy between the PR description and the actual changes, as several mentioned files and unit tests are missing from the diff.

🔍 General Feedback

  • Missing Implementation: The PR description mentions a "thought channel bypass" in input_pipeline_utils.py and new unit tests in tests/post_training/unit/sft_data_processing_test.py, but these files are not included in the PR. Please ensure all intended changes are staged and pushed.
  • Consistency across Trainers: The dynamic disabling of NNX graph caching is a great addition for MoE stability; consider applying this same logic to DPO, RL, and Distillation trainers to ensure consistent behavior across the post-training suite.
  • LoRA Targeting: The regex for Gemma 4 LoRA targeting is comprehensive but should be monitored to ensure it doesn't become overly broad as the architecture evolves.

Comment thread src/maxtext/trainers/post_train/sft/train_sft.py Outdated
Comment thread src/maxtext/configs/post_train/lora_module_path.yml
deepseek2: "decoder/(dense_layers|moe_stack)/self_attention/(query|out|wkv_a|wkv_b)|decoder/(dense_layers|moe_stack)/(mlp|shared_experts)/(wi_0|wi_1|wo)"
gemma2: "decoder/layers/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/layers/(mlp_local|mlp_global)/(wi_0|wi_1|wo)"
gemma3: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))"
gemma4: "decoder/(scanned_blocks|layers_remainder)/layers.*/.*(self_attention/(query|key|value|out)|mlp/.*(MoeBlock_0|wi_0|wi_1|wo|shared_experts/(wi_0|wi_1|wo)))"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pattern also adapts the MoE router: probably by accident. The MoeBlock_0 term matches everything inside the routed-MoE block, including the router/gate that decides which expert each token goes to. You normally don't want LoRA on the router. The actual expert weights (wi_0, wi_1, wo) are already matched by the other terms in the pattern, so you can just delete MoeBlock_0|. Experts and shared experts still get LoRA; the router no longer does. (Checked against the parameter tree.)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the excellent catch! Yes, including MoeBlock_0 was a mistake that dragged in the routing gate.

I have removed MoeBlock_0| from the gemma4 pattern in lora_module_path.yml. This successfully excludes the router/gate from LoRA adaptation while still perfectly capturing the expert MLP weights (wi_0, wi_1, wo)
and the shared experts.

with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
# Disable NNX graph caching for MoE models (where experts > 1) to allow
# necessary dynamic metadata synchronization during forward passes (e.g., in jax.lax.scan).
enable_nnx_cache = getattr(mt_config, "num_experts", 1) <= 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For MaxText flags just directly use dot notation:
mt_config.num_experts

Also why do you need to disable it for MoE, what happens if you don't?

Copy link
Copy Markdown
Collaborator Author

@RexBearIU RexBearIU Jun 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Simplified Flag Access: Updated to use direct dot notation (mt_config.num_experts) as suggested.

    1. Why we disable caching for MoE:
      If we run training with cache_nnx_graph=True (which uses nnx.cached_partial under the hood to cache graph topology), MoE training fails during step execution with the following traceback:
    File "/home/jackyf_google_com/5-4-qlora/lib/python3.12/site-packages/flax/nnx/graph.py", line 1716, in unflatten
      raise ValueError(
    ValueError: The graph structure of a node added to cached_partial was mutated inside the transformation, this is not allowed.
    Node: Transformer( ... )

nnx.cached_partial expects a static topological structure ( graphdef ) across JIT boundaries. However, Mixture-of-Experts (MoE) models perform dynamic token routing and scan-indexing during the forward pass. These operations introduce dynamic tracer metadata that mutates the internal structure of the Transformer node during transformation. This violates the static structural invariants of cached_partial , triggering the traceback. Disabling the graph cache allows NNX to split/merge definitions dynamically at each step, preventing this error.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any tests for Gemma4 Lora?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We ran the end-to-end LoRA training loop for Gemma 4 successfully without any issues.4 successfully without any issues. log

@RexBearIU RexBearIU force-pushed the jackyf/gemma4-lora branch 2 times, most recently from 4c2d52f to 491b37c Compare June 3, 2026 15:43
@RexBearIU RexBearIU force-pushed the jackyf/gemma4-lora branch from 491b37c to b52137c Compare June 3, 2026 16:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants