Skip to content

Commit ef326c8

Browse files
yueshen2016claude
andauthored
Add Gemma4 MoE quantization support (NVIDIA#1219)
## Summary - Register `Gemma4TextExperts` with `_QuantQwen35MoeExperts` plugin to unfuse fused 3D expert tensors into per-expert `nn.Linear` layers for quantization - Add structural `is_moe()` detection for modules with `router` + `experts` attributes (Gemma4 has no dedicated `SparseMoeBlock` class — the decoder layer directly owns `router` and `experts`) - Add `Gemma4TextDecoderLayer` to `get_expert_linear_names()` returning `["gate_proj", "down_proj", "up_proj"]` - Add `"*.experts.*"` pattern to `NVFP4_MLP_ONLY_CFG` and `NVFP4_EXPERTS_ONLY_CFG` to match Gemma4's expert path (`model.layers.X.experts.*`, not nested under `mlp`) **Context:** Gemma4 MoE models (e.g. `google/gemma-4-26B-A4B-it`) store expert weights as fused 3D `nn.Parameter` tensors (`gate_up_proj`, `down_proj`) instead of `nn.ModuleList` of `nn.Linear`. Since ModelOpt's quantizer only discovers `nn.Linear` modules, it silently skips the expert weights — the bulk of the model remains unquantized. **Companion vLLM PR:** vllm-project/vllm#39406 (robust quantized MoE weight loading for Gemma4) ## Test plan - [x] `hf_ptq.py --pyt_ckpt_path google/gemma-4-26B-A4B-it --qformat nvfp4_mlp_only` — 35k+ quantizers inserted, 17GB output (vs 49GB BF16) - [x] `vllm serve <path> --quantization modelopt` — loads and serves successfully - [x] Text generation: correct ("The capital of France is **Paris**.") - [x] Vision: correct (describes image content accurately) 🤖 Generated with [Claude Code](https://claude.com/claude-code) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Support quantizing models with separate base/full components (handles heads present only on the full model) * Enhanced Mixture-of-Experts detection and explicit support for Gemma4 expert layer layouts * Extended NVFP4 selective quantization presets and recipes to include expert-layer patterns and enable FP8 for expert modules * **Bug Fixes** * Improved loss/logit handling and clearer errors for unsupported quantization methods <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: James Shen <yueshen@nvidia.com> Signed-off-by: Yue Shen <yueshen@nvidia.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d1ed76d commit ef326c8

5 files changed

Lines changed: 191 additions & 11 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def auto_quantize(
294294
auto_quantize_method="gradient",
295295
auto_quantize_score_size=128,
296296
auto_quantize_checkpoint=None,
297+
full_model: torch.nn.Module | None = None,
297298
):
298299
"""Auto search quantization of multiple formats."""
299300

@@ -332,19 +333,49 @@ def auto_quantize(
332333
for qformat in qformat_list
333334
), "One or more quantization formats provided are not supported for unified checkpoint export"
334335

335-
def loss_func(output, data):
336-
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
337-
# which contains the loss attribute.
338-
return output.loss
336+
# When language_model is a base text model without lm_head (e.g. Gemma4TextModel),
337+
# use full_model's lm_head to compute logits/loss from hidden states.
338+
is_base_model = (
339+
full_model is not None
340+
and language_model is not full_model
341+
and not hasattr(language_model, "lm_head")
342+
and hasattr(full_model, "lm_head")
343+
)
344+
345+
if is_base_model:
346+
assert full_model is not None
347+
lm_head = full_model.lm_head
348+
349+
def loss_func(output, data):
350+
logits = lm_head(output.last_hidden_state)
351+
labels = data["labels"]
352+
shift_logits = logits[..., :-1, :].contiguous()
353+
shift_labels = labels[..., 1:].contiguous()
354+
return torch.nn.functional.cross_entropy(
355+
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
356+
)
357+
358+
else:
359+
360+
def loss_func(output, data):
361+
return output.loss
339362

340363
if auto_quantize_method == "gradient":
341-
# For gradient-based method, return full output with loss
364+
342365
def forward_step(model, batch):
343-
return model(**batch)
366+
inputs = {k: v for k, v in batch.items() if k != "labels"} if is_base_model else batch
367+
return model(**inputs)
368+
344369
elif auto_quantize_method == "kl_div":
345-
# For KL divergence method, return only logits
370+
346371
def forward_step(model, batch):
347-
return model(**batch).logits
372+
inputs = {k: v for k, v in batch.items() if k != "labels"} if is_base_model else batch
373+
output = model(**inputs)
374+
if is_base_model:
375+
assert full_model is not None
376+
return full_model.lm_head(output.last_hidden_state)
377+
return output.logits
378+
348379
else:
349380
raise ValueError(
350381
f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'"
@@ -1024,6 +1055,10 @@ def quantize_main(
10241055
args,
10251056
language_model,
10261057
calib_dataloader,
1058+
auto_quantize_method=args.auto_quantize_method,
1059+
auto_quantize_score_size=args.auto_quantize_score_size,
1060+
auto_quantize_checkpoint=args.auto_quantize_checkpoint,
1061+
full_model=full_model,
10271062
)
10281063

10291064
else:

modelopt/torch/export/layer_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ def get_experts_list(
108108
linear_names = ["gate_proj", "down_proj", "up_proj"]
109109
elif "nemotronhforcausallm" in model_type:
110110
linear_names = ["up_proj", "down_proj"]
111+
elif "gemma4" in model_type:
112+
linear_names = ["gate_proj", "down_proj", "up_proj"]
111113
else:
112114
raise NotImplementedError(f" {model_type} not supported")
113115

@@ -315,7 +317,14 @@ def is_moe(module: nn.Module) -> bool:
315317
if name.endswith("sparsemoeblock") or "moelayer" in name:
316318
return True
317319
# Explicit matches for non-standard naming
318-
return any(key in name for key in ["arcticmoe", "deepseekmoe", "dbrxffn", "nemotronhmoe"])
320+
if any(key in name for key in ["arcticmoe", "deepseekmoe", "dbrxffn", "nemotronhmoe"]):
321+
return True
322+
# Structural detection: modules with router + experts (e.g. Gemma4TextDecoderLayer)
323+
return (
324+
hasattr(module, "router")
325+
and hasattr(module, "experts")
326+
and isinstance(module.experts, nn.Module)
327+
)
319328

320329

321330
def is_quantlinear(module: nn.Module) -> bool:
@@ -1007,6 +1016,9 @@ def module_match_name_list(module, name_list):
10071016
elif module_match_name_list(module, ["NemotronHMOE"]):
10081017
# NemotronHMOE experts (NemotronHMLP) use up_proj and down_proj only (no gate).
10091018
return ["up_proj", "down_proj"]
1019+
elif module_match_name_list(module, ["Gemma4TextDecoderLayer"]):
1020+
# Gemma4 MoE experts are unfused into per-expert nn.Linear layers
1021+
return ["gate_proj", "down_proj", "up_proj"]
10101022
else:
10111023
# assuming w1, w2, w3 by default
10121024
return ["w1", "w2", "w3"]

modelopt/torch/quantization/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -799,8 +799,10 @@ def _nvfp4_selective_quant_cfg(
799799
NVFP4_MLP_WEIGHT_ONLY_CFG = _nvfp4_selective_quant_cfg(
800800
["*mlp*", "*block_sparse_moe*"], quantizer=_nvfp4_cfg_bs32, weight_only=True
801801
)
802-
NVFP4_EXPERTS_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp.experts*", "*block_sparse_moe*"])
803-
NVFP4_MLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp*", "*block_sparse_moe*"])
802+
NVFP4_EXPERTS_ONLY_CFG = _nvfp4_selective_quant_cfg(
803+
["*mlp.experts*", "*block_sparse_moe*", "*.experts.*"]
804+
)
805+
NVFP4_MLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp*", "*block_sparse_moe*", "*.experts.*"])
804806
NVFP4_OMLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*o_proj*", "*mlp*", "*block_sparse_moe*"])
805807

806808
# DO NOT ADD NEW CONFIGS HERE. If you want to add a new general recipe, add it to

modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,22 @@ quantize:
5353
type: dynamic
5454
scale_bits: e4m3
5555
num_bits: e2m1
56+
- quantizer_name: '*.experts.*weight_quantizer'
57+
enable: true
58+
cfg:
59+
block_sizes:
60+
-1: 16
61+
type: dynamic
62+
scale_bits: e4m3
63+
num_bits: e2m1
64+
- quantizer_name: '*.experts.*input_quantizer'
65+
enable: true
66+
cfg:
67+
block_sizes:
68+
-1: 16
69+
type: dynamic
70+
scale_bits: e4m3
71+
num_bits: e2m1
5672
- quantizer_name: '*[kv]_bmm_quantizer'
5773
enable: true
5874
cfg:
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Unit tests for modelopt.torch.export.layer_utils — MoE detection and expert naming."""
17+
18+
import pytest
19+
import torch.nn as nn
20+
21+
from modelopt.torch.export.layer_utils import get_expert_linear_names, is_moe
22+
23+
# ---------------------------------------------------------------------------
24+
# is_moe tests
25+
# ---------------------------------------------------------------------------
26+
27+
28+
class _FakeSparseMoeBlock(nn.Module):
29+
"""Name ends with 'sparsemoeblock' — detected by naming convention."""
30+
31+
32+
class _FakeMoeLayer(nn.Module):
33+
"""Name contains 'moelayer' — detected by naming convention."""
34+
35+
36+
class _FakeArcticMoe(nn.Module):
37+
"""Name contains 'arcticmoe' — detected by explicit match."""
38+
39+
40+
class _StructuralMoeModule(nn.Module):
41+
"""Has router + experts attributes — detected by structural check."""
42+
43+
def __init__(self):
44+
super().__init__()
45+
self.router = nn.Linear(8, 4)
46+
self.experts = nn.ModuleList([nn.Linear(8, 8) for _ in range(4)])
47+
48+
49+
class _NotMoeModule(nn.Module):
50+
"""Plain module — should NOT be classified as MoE."""
51+
52+
def __init__(self):
53+
super().__init__()
54+
self.fc = nn.Linear(8, 8)
55+
56+
57+
class _PartialStructuralModule(nn.Module):
58+
"""Has router but no experts — should NOT be classified as MoE."""
59+
60+
def __init__(self):
61+
super().__init__()
62+
self.router = nn.Linear(8, 4)
63+
64+
65+
@pytest.mark.parametrize(
66+
"module_cls",
67+
[_FakeSparseMoeBlock, _FakeMoeLayer, _FakeArcticMoe],
68+
)
69+
def test_is_moe_name_based(module_cls):
70+
assert is_moe(module_cls())
71+
72+
73+
def test_is_moe_structural():
74+
assert is_moe(_StructuralMoeModule())
75+
76+
77+
def test_is_moe_negative():
78+
assert not is_moe(_NotMoeModule())
79+
80+
81+
def test_is_moe_partial_structural():
82+
assert not is_moe(_PartialStructuralModule())
83+
84+
85+
# ---------------------------------------------------------------------------
86+
# get_expert_linear_names tests
87+
# ---------------------------------------------------------------------------
88+
89+
90+
class _FakeGemma4TextDecoderLayer(nn.Module):
91+
pass
92+
93+
94+
class _FakeMixtralSparseMoeBlock(nn.Module):
95+
pass
96+
97+
98+
class _FakeNemotronHMOE(nn.Module):
99+
pass
100+
101+
102+
def test_get_expert_linear_names_gemma4():
103+
assert get_expert_linear_names(_FakeGemma4TextDecoderLayer()) == [
104+
"gate_proj",
105+
"down_proj",
106+
"up_proj",
107+
]
108+
109+
110+
def test_get_expert_linear_names_mixtral():
111+
assert get_expert_linear_names(_FakeMixtralSparseMoeBlock()) == ["w1", "w2", "w3"]
112+
113+
114+
def test_get_expert_linear_names_nemotron():
115+
assert get_expert_linear_names(_FakeNemotronHMOE()) == ["up_proj", "down_proj"]

0 commit comments

Comments
 (0)