-
Notifications
You must be signed in to change notification settings - Fork 360
Add Gemma4 MoE quantization support #1219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4001a91
ecdaac7
cf17833
c79ebc0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -301,6 +301,7 @@ def auto_quantize( | |
| auto_quantize_method="gradient", | ||
| auto_quantize_score_size=128, | ||
| auto_quantize_checkpoint=None, | ||
| full_model: torch.nn.Module | None = None, | ||
| ): | ||
| """Auto search quantization of multiple formats.""" | ||
|
|
||
|
|
@@ -338,23 +339,67 @@ def auto_quantize( | |
| for qformat in qformat_list | ||
| ), "One or more quantization formats provided are not supported for unified checkpoint export" | ||
|
|
||
| def loss_func(output, data): | ||
| # For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast` | ||
| # which contains the loss attribute. | ||
| return output.loss | ||
|
|
||
| if auto_quantize_method == "gradient": | ||
| # For gradient-based method, return full output with loss | ||
| def forward_step(model, batch): | ||
| return model(**batch) | ||
| elif auto_quantize_method == "kl_div": | ||
| # For KL divergence method, return only logits | ||
| def forward_step(model, batch): | ||
| return model(**batch).logits | ||
| # For VLMs like Gemma4, the extracted language_model is a base text model without | ||
| # lm_head, so it cannot produce logits or loss directly. In that case, use the | ||
| # full_model's lm_head to compute logits/loss from the language model's hidden states. | ||
| is_base_model = ( | ||
| full_model is not None | ||
| and language_model is not full_model | ||
| and not hasattr(language_model, "lm_head") | ||
| and hasattr(full_model, "lm_head") | ||
| ) | ||
|
|
||
| if is_base_model: | ||
| assert full_model is not None | ||
| lm_head = full_model.lm_head | ||
|
|
||
| def loss_func(output, data): | ||
| logits = lm_head(output.last_hidden_state) | ||
| labels = data["labels"] | ||
| shift_logits = logits[..., :-1, :].contiguous() | ||
| shift_labels = labels[..., 1:].contiguous() | ||
| return torch.nn.functional.cross_entropy( | ||
| shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) | ||
| ) | ||
|
|
||
| if auto_quantize_method == "gradient": | ||
|
|
||
| def forward_step(model, batch): | ||
| return model(**batch) | ||
|
|
||
| elif auto_quantize_method == "kl_div": | ||
|
|
||
| def forward_step(model, batch): | ||
| hidden_states = model(**batch).last_hidden_state | ||
| return lm_head(hidden_states) | ||
|
|
||
| else: | ||
| raise ValueError( | ||
| f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'" | ||
| ) | ||
|
Comment on lines
+352
to
+379
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🌐 Web query:
💡 Result:
This matches the general Transformers pattern: Sources: [1], [2] 🏁 Script executed: # First, let's see the full context of how is_base_model is set
cd examples/llm_ptq && head -400 hf_ptq.py | tail -100Repository: NVIDIA/Model-Optimizer Length of output: 3532 🏁 Script executed: # Check the make_calib_dataloader function to see if it includes labels for gradient mode
cd examples/llm_ptq && sed -n '282,293p' hf_ptq.pyRepository: NVIDIA/Model-Optimizer Length of output: 544 🏁 Script executed: # Look for where is_base_model is assigned in the file
rg "is_base_model\s*=" examples/llm_ptq/hf_ptq.py -B 5 -A 2Repository: NVIDIA/Model-Optimizer Length of output: 523 Strip When using the base-model path with gradient auto-quantize, the dataloader includes The proposed fix is correct: define a helper to strip Suggested fix if is_base_model:
assert full_model is not None
lm_head = full_model.lm_head
+
+ def _model_inputs(batch):
+ return {k: v for k, v in batch.items() if k != "labels"}
def loss_func(output, data):
logits = lm_head(output.last_hidden_state)
labels = data["labels"]
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
return torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
if auto_quantize_method == "gradient":
def forward_step(model, batch):
- return model(**batch)
+ return model(**_model_inputs(batch))
elif auto_quantize_method == "kl_div":
def forward_step(model, batch):
- hidden_states = model(**batch).last_hidden_state
+ hidden_states = model(**_model_inputs(batch)).last_hidden_state
return lm_head(hidden_states)🤖 Prompt for AI Agents |
||
| else: | ||
| raise ValueError( | ||
| f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'" | ||
| ) | ||
|
|
||
| def loss_func(output, data): | ||
| # For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast` | ||
| # which contains the loss attribute. | ||
| return output.loss | ||
|
|
||
| if auto_quantize_method == "gradient": | ||
| # For gradient-based method, return full output with loss | ||
|
|
||
| def forward_step(model, batch): | ||
| return model(**batch) | ||
|
|
||
| elif auto_quantize_method == "kl_div": | ||
| # For KL divergence method, return only logits | ||
|
|
||
| def forward_step(model, batch): | ||
| return model(**batch).logits | ||
|
|
||
| else: | ||
| raise ValueError( | ||
| f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'" | ||
| ) | ||
|
|
||
| language_model, _ = mtq.auto_quantize( | ||
| language_model, | ||
|
|
@@ -1048,6 +1093,7 @@ def quantize_main( | |
| args, | ||
| language_model, | ||
| calib_dataloader, | ||
| full_model=full_model, | ||
| ) | ||
|
|
||
| else: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1202,6 +1202,19 @@ def unpack_weight(self): | |
| except ImportError: | ||
| pass | ||
|
|
||
| try: | ||
| from transformers.models.gemma4.modeling_gemma4 import Gemma4TextExperts | ||
|
|
||
| # Gemma4TextExperts has the same fused 3D tensor layout as Qwen3_5MoeExperts | ||
| # (gate_up_proj, down_proj, hidden_dim, intermediate_dim, num_experts, act_fn) | ||
| # so we reuse _QuantQwen35MoeExperts which unfuses into per-expert nn.Linear layers. | ||
| if Gemma4TextExperts not in QuantModuleRegistry: | ||
| QuantModuleRegistry.register({Gemma4TextExperts: "hf.Gemma4TextExperts"})( | ||
| _QuantQwen35MoeExperts | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we rename this to something more generic?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have |
||
| ) | ||
| except ImportError: | ||
| pass | ||
|
|
||
|
|
||
| class _QuantGptOssExperts(_QuantFunctionalMixin): | ||
| """Quantized wrapper for `transformers.GptOssExperts`. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this VLM specific or Gemma4 specific?