Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 62 additions & 16 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def auto_quantize(
auto_quantize_method="gradient",
auto_quantize_score_size=128,
auto_quantize_checkpoint=None,
full_model: torch.nn.Module | None = None,
):
"""Auto search quantization of multiple formats."""

Expand Down Expand Up @@ -338,23 +339,67 @@ def auto_quantize(
for qformat in qformat_list
), "One or more quantization formats provided are not supported for unified checkpoint export"

def loss_func(output, data):
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
# which contains the loss attribute.
return output.loss

if auto_quantize_method == "gradient":
# For gradient-based method, return full output with loss
def forward_step(model, batch):
return model(**batch)
elif auto_quantize_method == "kl_div":
# For KL divergence method, return only logits
def forward_step(model, batch):
return model(**batch).logits
# For VLMs like Gemma4, the extracted language_model is a base text model without
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this VLM specific or Gemma4 specific?

# lm_head, so it cannot produce logits or loss directly. In that case, use the
# full_model's lm_head to compute logits/loss from the language model's hidden states.
is_base_model = (
full_model is not None
and language_model is not full_model
and not hasattr(language_model, "lm_head")
and hasattr(full_model, "lm_head")
)

if is_base_model:
assert full_model is not None
lm_head = full_model.lm_head

def loss_func(output, data):
logits = lm_head(output.last_hidden_state)
labels = data["labels"]
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
return torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)

if auto_quantize_method == "gradient":

def forward_step(model, batch):
return model(**batch)

elif auto_quantize_method == "kl_div":

def forward_step(model, batch):
hidden_states = model(**batch).last_hidden_state
return lm_head(hidden_states)

else:
raise ValueError(
f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'"
)
Comment on lines +352 to +379
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🌐 Web query:

In Hugging Face Transformers, do base text model classes such as Gemma4TextModelaccept alabelsargument inforward, or is labelsonly supported on*ForCausalLMwrappers with anlm_head?

💡 Result:

labels is not generally accepted by the base (backbone) text model classes like Gemma4TextModel. For Gemma 4 specifically, Gemma4TextModel.forward(...) has no labels parameter and returns hidden states / caches (no LM loss computation). [1]

labels is supported on the task/head wrappers like Gemma4ForCausalLM (and multimodal Gemma4ForConditionalGeneration), whose forward(...) does include labels and returns a loss when labels is provided—because these classes include the language-modeling head (lm_head). [1]

This matches the general Transformers pattern: *Model backbones typically don’t take labels; *For... head models do. [2]

Sources: [1], [2]


🏁 Script executed:

# First, let's see the full context of how is_base_model is set
cd examples/llm_ptq && head -400 hf_ptq.py | tail -100

Repository: NVIDIA/Model-Optimizer

Length of output: 3532


🏁 Script executed:

# Check the make_calib_dataloader function to see if it includes labels for gradient mode
cd examples/llm_ptq && sed -n '282,293p' hf_ptq.py

Repository: NVIDIA/Model-Optimizer

Length of output: 544


🏁 Script executed:

# Look for where is_base_model is assigned in the file
rg "is_base_model\s*=" examples/llm_ptq/hf_ptq.py -B 5 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 523


Strip labels from batch before passing to the extracted base model.

When using the base-model path with gradient auto-quantize, the dataloader includes labels in batches (set by make_calib_dataloader() when auto_quantize_method == "gradient"). However, base text models like Gemma4TextModel do not accept a labels parameter in forward()—only the *ForCausalLM wrappers do. This will cause the gradient auto-quantize path to fail with a TypeError before loss_func() can use the labels.

The proposed fix is correct: define a helper to strip labels and other non-input keys, then use it in both forward_step implementations.

Suggested fix
     if is_base_model:
         assert full_model is not None
         lm_head = full_model.lm_head
+
+        def _model_inputs(batch):
+            return {k: v for k, v in batch.items() if k != "labels"}

         def loss_func(output, data):
             logits = lm_head(output.last_hidden_state)
             labels = data["labels"]
             shift_logits = logits[..., :-1, :].contiguous()
             shift_labels = labels[..., 1:].contiguous()
             return torch.nn.functional.cross_entropy(
                 shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
             )

         if auto_quantize_method == "gradient":

             def forward_step(model, batch):
-                return model(**batch)
+                return model(**_model_inputs(batch))

         elif auto_quantize_method == "kl_div":

             def forward_step(model, batch):
-                hidden_states = model(**batch).last_hidden_state
+                hidden_states = model(**_model_inputs(batch)).last_hidden_state
                 return lm_head(hidden_states)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_ptq/hf_ptq.py` around lines 352 - 379, The gradient/kl_div
base-model path passes the raw batch (which includes "labels") into the
extracted base model, causing a TypeError because base models like
Gemma4TextModel don't accept labels; add a small helper (e.g., sanitize_batch or
strip_non_inputs) that removes "labels" and any non-forward kwargs from the
batch, then call that helper inside both forward_step implementations referenced
in the is_base_model branch (where full_model, lm_head, loss_func, forward_step
and auto_quantize_method are defined) so the model receives only valid forward
inputs while loss_func still reads labels from the original batch.

else:
raise ValueError(
f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'"
)

def loss_func(output, data):
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
# which contains the loss attribute.
return output.loss

if auto_quantize_method == "gradient":
# For gradient-based method, return full output with loss

def forward_step(model, batch):
return model(**batch)

elif auto_quantize_method == "kl_div":
# For KL divergence method, return only logits

def forward_step(model, batch):
return model(**batch).logits

else:
raise ValueError(
f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'"
)

language_model, _ = mtq.auto_quantize(
language_model,
Expand Down Expand Up @@ -1048,6 +1093,7 @@ def quantize_main(
args,
language_model,
calib_dataloader,
full_model=full_model,
)

else:
Expand Down
12 changes: 11 additions & 1 deletion modelopt/torch/export/layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,14 @@ def is_moe(module: nn.Module) -> bool:
if name.endswith("sparsemoeblock") or "moelayer" in name:
return True
# Explicit matches for non-standard naming
return any(key in name for key in ["arcticmoe", "deepseekmoe", "dbrxffn"])
if any(key in name for key in ["arcticmoe", "deepseekmoe", "dbrxffn"]):
return True
# Structural detection: modules with router + experts (e.g. Gemma4TextDecoderLayer)
return (
hasattr(module, "router")
and hasattr(module, "experts")
and isinstance(module.experts, nn.Module)
)


def is_quantlinear(module: nn.Module) -> bool:
Expand Down Expand Up @@ -983,6 +990,9 @@ def module_match_name_list(module, name_list):
elif module_match_name_list(module, ["GptOssMoE"]):
# GPT-OSS MoE modules use gate_up_proj and down_proj
return ["gate_up_proj", "down_proj"]
elif module_match_name_list(module, ["Gemma4TextDecoderLayer"]):
# Gemma4 MoE experts are unfused into per-expert nn.Linear layers
return ["gate_proj", "down_proj", "up_proj"]
else:
# assuming w1, w2, w3 by default
return ["w1", "w2", "w3"]
Expand Down
6 changes: 4 additions & 2 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,8 +791,10 @@ def _nvfp4_selective_quant_cfg(
NVFP4_MLP_WEIGHT_ONLY_CFG = _nvfp4_selective_quant_cfg(
["*mlp*", "*block_sparse_moe*"], quantizer=_nvfp4_cfg_bs32, weight_only=True
)
NVFP4_EXPERTS_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp.experts*", "*block_sparse_moe*"])
NVFP4_MLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp*", "*block_sparse_moe*"])
NVFP4_EXPERTS_ONLY_CFG = _nvfp4_selective_quant_cfg(
["*mlp.experts*", "*block_sparse_moe*", "*.experts.*"]
)
NVFP4_MLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp*", "*block_sparse_moe*", "*.experts.*"])
NVFP4_OMLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*o_proj*", "*mlp*", "*block_sparse_moe*"])

# DO NOT ADD NEW CONFIGS HERE. If you want to add a new general recipe, add it to
Expand Down
13 changes: 13 additions & 0 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,19 @@ def unpack_weight(self):
except ImportError:
pass

try:
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextExperts

# Gemma4TextExperts has the same fused 3D tensor layout as Qwen3_5MoeExperts
# (gate_up_proj, down_proj, hidden_dim, intermediate_dim, num_experts, act_fn)
# so we reuse _QuantQwen35MoeExperts which unfuses into per-expert nn.Linear layers.
if Gemma4TextExperts not in QuantModuleRegistry:
QuantModuleRegistry.register({Gemma4TextExperts: "hf.Gemma4TextExperts"})(
_QuantQwen35MoeExperts
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rename this to something more generic?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have _QuantFusedExperts instead of _QuantQwen35MoeExperts after this PR: #1187

)
except ImportError:
pass


class _QuantGptOssExperts(_QuantFunctionalMixin):
"""Quantized wrapper for `transformers.GptOssExperts`.
Expand Down
Loading