Skip to content

Commit dedd0a0

Browse files
Edwardf0t1danielkorzekwa
authored andcommitted
Fix a nvfp4 weight amax attribute issue during export (#785)
## What does this PR do? **Type of change:** Bugfix <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** Fix a nvfp4 weight amax attribute issue during export, especially when calibration size is small. Context: sgl-project/sglang#14677 (comment) ## Usage <!-- You can potentially add a usage example below. --> ```python python3 hf_ptq.py --pyt_ckpt_path /home/scratch.jingyux_coreai/kimi-k2/models/Kimi-K2-Thinking-BF16 --qformat nvfp4_mlp_only --export_path /home/omniml_data_3/zhiyuc/checkpoints/Kimi-K2-Thinking-NVFP4 --kv_cache_qformat none --calib_size 20 --trust_remote_code --dataset cnn_dailymail ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Bug Fixes * Improved weight quantizer calibration to ensure quantizers are properly initialized with calibration statistics before computing scaling factors. * Enhanced reliability and consistency of quantized model exports. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
1 parent 6d35ffe commit dedd0a0

1 file changed

Lines changed: 44 additions & 2 deletions

File tree

modelopt/torch/export/quant_utils.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,31 @@ def get_scaling_factor(quantizer: TensorQuantizer) -> torch.Tensor:
236236
return scaling_factor
237237

238238

239+
def _ensure_weight_quantizer_calibrated(
240+
weight_quantizer: TensorQuantizer, weight: torch.Tensor, module_name: str = ""
241+
) -> None:
242+
"""Calibrate weight quantizer if amax is not set.
243+
244+
This is a lazy calibration pattern used during export when weight quantizers
245+
may not have been calibrated during the main calibration phase.
246+
247+
Args:
248+
weight_quantizer: The weight quantizer to calibrate
249+
weight: The weight tensor to use for calibration
250+
module_name: Optional module name for better warning messages
251+
"""
252+
if not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None:
253+
warn(
254+
f"Weight quantizer{f' for {module_name}' if module_name else ''} was not calibrated. "
255+
f"Computing amax from weights. This may occur if: "
256+
f"some experts were not activated during calibration (expected for MoE models), try increasing --calib_size"
257+
)
258+
weight_quantizer.reset_amax()
259+
enable_stats_collection(weight_quantizer)
260+
weight_quantizer(weight)
261+
finish_stats_collection(weight_quantizer)
262+
263+
239264
def get_activation_scaling_factor(
240265
module: nn.Module, input_quantizer_name: str = "input_quantizer"
241266
) -> torch.Tensor:
@@ -279,6 +304,10 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
279304
QUANTIZATION_NVFP4_SVDQUANT,
280305
QUANTIZATION_W4A8_NVFP4_FP8,
281306
]:
307+
# Calibrate weight quantizer if amax is not set
308+
module_name = f"{type(module).__name__}.{weight_name}"
309+
_ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name)
310+
282311
if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
283312
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
284313
# This is because the kernel dequantizes weight to fp8, which is in range 448.
@@ -307,13 +336,26 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
307336
if weight_quantizer is None:
308337
return None
309338

310-
if get_quantization_format(module) in [
339+
quantization_format = get_quantization_format(module)
340+
341+
# Calibrate weight quantizer if amax is not set for all NVFP4 variants
342+
if quantization_format in [
343+
QUANTIZATION_NVFP4,
344+
QUANTIZATION_NVFP4_AWQ,
345+
QUANTIZATION_NVFP4_SVDQUANT,
346+
QUANTIZATION_W4A8_NVFP4_FP8,
347+
]:
348+
weight = getattr(module, weight_name)
349+
module_name = f"{type(module).__name__}.{weight_name}"
350+
_ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name)
351+
352+
if quantization_format in [
311353
QUANTIZATION_NVFP4,
312354
QUANTIZATION_NVFP4_AWQ,
313355
QUANTIZATION_NVFP4_SVDQUANT,
314356
]:
315357
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
316-
elif get_quantization_format(module) == QUANTIZATION_W4A8_NVFP4_FP8:
358+
elif quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
317359
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
318360
# This is because the kernel dequantizes weight to fp8, which is in range 448.
319361
return weight_quantizer._amax.float() / 448.0

0 commit comments

Comments
 (0)