diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index f19e82c5d4..984243a320 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -301,6 +301,7 @@ def auto_quantize( auto_quantize_method="gradient", auto_quantize_score_size=128, auto_quantize_checkpoint=None, + full_model: torch.nn.Module | None = None, ): """Auto search quantization of multiple formats.""" @@ -338,23 +339,67 @@ def auto_quantize( for qformat in qformat_list ), "One or more quantization formats provided are not supported for unified checkpoint export" - def loss_func(output, data): - # For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast` - # which contains the loss attribute. - return output.loss - - if auto_quantize_method == "gradient": - # For gradient-based method, return full output with loss - def forward_step(model, batch): - return model(**batch) - elif auto_quantize_method == "kl_div": - # For KL divergence method, return only logits - def forward_step(model, batch): - return model(**batch).logits + # For VLMs like Gemma4, the extracted language_model is a base text model without + # lm_head, so it cannot produce logits or loss directly. In that case, use the + # full_model's lm_head to compute logits/loss from the language model's hidden states. + is_base_model = ( + full_model is not None + and language_model is not full_model + and not hasattr(language_model, "lm_head") + and hasattr(full_model, "lm_head") + ) + + if is_base_model: + assert full_model is not None + lm_head = full_model.lm_head + + def loss_func(output, data): + logits = lm_head(output.last_hidden_state) + labels = data["labels"] + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + return torch.nn.functional.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + + if auto_quantize_method == "gradient": + + def forward_step(model, batch): + return model(**batch) + + elif auto_quantize_method == "kl_div": + + def forward_step(model, batch): + hidden_states = model(**batch).last_hidden_state + return lm_head(hidden_states) + + else: + raise ValueError( + f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'" + ) else: - raise ValueError( - f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'" - ) + + def loss_func(output, data): + # For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast` + # which contains the loss attribute. + return output.loss + + if auto_quantize_method == "gradient": + # For gradient-based method, return full output with loss + + def forward_step(model, batch): + return model(**batch) + + elif auto_quantize_method == "kl_div": + # For KL divergence method, return only logits + + def forward_step(model, batch): + return model(**batch).logits + + else: + raise ValueError( + f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'" + ) language_model, _ = mtq.auto_quantize( language_model, @@ -1048,6 +1093,7 @@ def quantize_main( args, language_model, calib_dataloader, + full_model=full_model, ) else: diff --git a/modelopt/torch/export/layer_utils.py b/modelopt/torch/export/layer_utils.py index 96ecf91e5b..6862f054eb 100755 --- a/modelopt/torch/export/layer_utils.py +++ b/modelopt/torch/export/layer_utils.py @@ -305,7 +305,14 @@ def is_moe(module: nn.Module) -> bool: if name.endswith("sparsemoeblock") or "moelayer" in name: return True # Explicit matches for non-standard naming - return any(key in name for key in ["arcticmoe", "deepseekmoe", "dbrxffn"]) + if any(key in name for key in ["arcticmoe", "deepseekmoe", "dbrxffn"]): + return True + # Structural detection: modules with router + experts (e.g. Gemma4TextDecoderLayer) + return ( + hasattr(module, "router") + and hasattr(module, "experts") + and isinstance(module.experts, nn.Module) + ) def is_quantlinear(module: nn.Module) -> bool: @@ -983,6 +990,9 @@ def module_match_name_list(module, name_list): elif module_match_name_list(module, ["GptOssMoE"]): # GPT-OSS MoE modules use gate_up_proj and down_proj return ["gate_up_proj", "down_proj"] + elif module_match_name_list(module, ["Gemma4TextDecoderLayer"]): + # Gemma4 MoE experts are unfused into per-expert nn.Linear layers + return ["gate_proj", "down_proj", "up_proj"] else: # assuming w1, w2, w3 by default return ["w1", "w2", "w3"] diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index f0fd61798b..52edfb8f46 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -791,8 +791,10 @@ def _nvfp4_selective_quant_cfg( NVFP4_MLP_WEIGHT_ONLY_CFG = _nvfp4_selective_quant_cfg( ["*mlp*", "*block_sparse_moe*"], quantizer=_nvfp4_cfg_bs32, weight_only=True ) -NVFP4_EXPERTS_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp.experts*", "*block_sparse_moe*"]) -NVFP4_MLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp*", "*block_sparse_moe*"]) +NVFP4_EXPERTS_ONLY_CFG = _nvfp4_selective_quant_cfg( + ["*mlp.experts*", "*block_sparse_moe*", "*.experts.*"] +) +NVFP4_MLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp*", "*block_sparse_moe*", "*.experts.*"]) NVFP4_OMLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*o_proj*", "*mlp*", "*block_sparse_moe*"]) # DO NOT ADD NEW CONFIGS HERE. If you want to add a new general recipe, add it to diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index e5630d9340..c3f763e116 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -1202,6 +1202,19 @@ def unpack_weight(self): except ImportError: pass +try: + from transformers.models.gemma4.modeling_gemma4 import Gemma4TextExperts + + # Gemma4TextExperts has the same fused 3D tensor layout as Qwen3_5MoeExperts + # (gate_up_proj, down_proj, hidden_dim, intermediate_dim, num_experts, act_fn) + # so we reuse _QuantQwen35MoeExperts which unfuses into per-expert nn.Linear layers. + if Gemma4TextExperts not in QuantModuleRegistry: + QuantModuleRegistry.register({Gemma4TextExperts: "hf.Gemma4TextExperts"})( + _QuantQwen35MoeExperts + ) +except ImportError: + pass + class _QuantGptOssExperts(_QuantFunctionalMixin): """Quantized wrapper for `transformers.GptOssExperts`.