Skip to content

Commit d39cf45

Browse files
meenchenkevalmorabia97
authored andcommitted
[NVBug 5702186] Fix awq model export for Gemma3 (#793)
## What does this PR do? **Type of change:** Bug fix <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** norms laers in Gemma that use (1 + weight) in forward, we will fold pre_quant_scale into the effective weight. That is to find folded w' subject to: `1 + w' = (1 + w) * s` => `w' = (1 + w) * s -1` ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ./scripts/huggingface_example.sh --model google/gemma-3-1b-it --quant int4_awq ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Improvements** * Enhanced quantization utilities to better handle various LayerNorm variants and normalization patterns, including support for weight-offset variants and zero-centered gamma configurations. * Optimized pre-quantization layer normalization fusion to apply conditional weight scaling strategies based on normalization type. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
1 parent 304e81f commit d39cf45

1 file changed

Lines changed: 19 additions & 5 deletions

File tree

modelopt/torch/export/quant_utils.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,16 @@ def fuse_prequant_to_linear(model: torch.nn.Module, fuse_grouped_heads=False):
11011101
setattr(linear_pqs_from, "fused_with_prequant", True)
11021102

11031103

1104+
def _layernorm_uses_weight_plus_one(module: torch.nn.Module) -> bool:
1105+
if any(
1106+
name in type(module).__name__
1107+
for name in ["LayerNorm1P", "GemmaRMSNorm", "Gemma2RMSNorm", "Gemma3RMSNorm"]
1108+
):
1109+
return True
1110+
1111+
return bool(hasattr(module, "zero_centered_gamma") and module.zero_centered_gamma)
1112+
1113+
11041114
def fuse_prequant_layernorm(
11051115
layernorm_module: torch.nn.Module,
11061116
modules: list[torch.Tensor],
@@ -1116,13 +1126,17 @@ def fuse_prequant_layernorm(
11161126
fused_bias = bias * avg_pre_quant_scale
11171127
layernorm_output_scaled = (normalization(input) * fused_weight) + fused_bias
11181128
"""
1119-
layernorm_module.weight = torch.nn.Parameter(
1120-
layernorm_module.weight * getattr(modules[0].input_quantizer, "_pre_quant_scale")
1129+
pre_quant_scale = getattr(modules[0].input_quantizer, "_pre_quant_scale").to(
1130+
layernorm_module.weight.device
11211131
)
1132+
if _layernorm_uses_weight_plus_one(layernorm_module):
1133+
# For norms that use (1 + weight) in forward, fold pre_quant_scale into the effective weight.
1134+
fused_weight = (layernorm_module.weight + 1.0) * pre_quant_scale - 1.0
1135+
else:
1136+
fused_weight = layernorm_module.weight * pre_quant_scale
1137+
layernorm_module.weight = torch.nn.Parameter(fused_weight.to(layernorm_module.weight.dtype))
11221138
if hasattr(layernorm_module, "bias") and layernorm_module.bias is not None:
1123-
layernorm_module.bias = torch.nn.Parameter(
1124-
layernorm_module.bias * getattr(modules[0].input_quantizer, "_pre_quant_scale")
1125-
)
1139+
layernorm_module.bias = torch.nn.Parameter(layernorm_module.bias * pre_quant_scale)
11261140
# Pre_quant_scales of modules must not be exported, since they have been fused with layernorm
11271141
for module in modules:
11281142
delattr(module.input_quantizer, "_pre_quant_scale")

0 commit comments

Comments
 (0)