Skip to content

Commit e407987

Browse files
committed
remove special case of NVFP4 static quantizer in general export logic
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
1 parent 9725c34 commit e407987

2 files changed

Lines changed: 19 additions & 25 deletions

File tree

modelopt/torch/export/quant_utils.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -300,18 +300,15 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
300300

301301
quantization_format = get_quantization_format(module)
302302

303-
# Handle NVFP4 variants (static or dynamic)
304-
is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer)
305-
if is_nvfp4_static or quantization_format in [
303+
if quantization_format in [
306304
QUANTIZATION_NVFP4,
307305
QUANTIZATION_NVFP4_AWQ,
308306
QUANTIZATION_NVFP4_SVDQUANT,
309307
QUANTIZATION_W4A8_NVFP4_FP8,
310308
]:
311-
# Calibrate weight quantizer if amax is not set (only needed for dynamic quantizers)
312-
if not is_nvfp4_static:
313-
module_name = f"{type(module).__name__}.{weight_name}"
314-
_ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name)
309+
# Calibrate weight quantizer if amax is not set
310+
module_name = f"{type(module).__name__}.{weight_name}"
311+
_ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name)
315312

316313
if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
317314
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
@@ -347,18 +344,16 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
347344

348345
quantization_format = get_quantization_format(module)
349346

350-
is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer)
351-
if is_nvfp4_static or quantization_format in [
347+
if quantization_format in [
352348
QUANTIZATION_NVFP4,
353349
QUANTIZATION_NVFP4_AWQ,
354350
QUANTIZATION_NVFP4_SVDQUANT,
355351
QUANTIZATION_W4A8_NVFP4_FP8,
356352
]:
357-
# Calibrate weight quantizer if amax is not set (only needed for dynamic quantizers)
358-
if not is_nvfp4_static:
359-
weight = getattr(module, weight_name)
360-
module_name = f"{type(module).__name__}.{weight_name}"
361-
_ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name)
353+
# Calibrate weight quantizer if amax is not set
354+
weight = getattr(module, weight_name)
355+
module_name = f"{type(module).__name__}.{weight_name}"
356+
_ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name)
362357

363358
if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
364359
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.

modelopt/torch/export/unified_export_hf.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,11 @@ def _export_quantized_weight(
495495
expert_type in type(sub_module).__name__
496496
for expert_type in ["Llama4TextExperts", "GptOssExperts"]
497497
)
498+
if is_bmm_expert_weight and isinstance(weight_quantizer, NVFP4StaticQuantizer):
499+
warnings.warn(
500+
"NVFP4StaticQuantizer with BMM-style expert weights (e.g. Llama4TextExperts, "
501+
"GptOssExperts) is not yet supported; export may produce incorrect results."
502+
)
498503

499504
if quantization_format in [
500505
QUANTIZATION_NVFP4,
@@ -507,17 +512,11 @@ def _export_quantized_weight(
507512
weight, is_bmm_expert_weight=is_bmm_expert_weight
508513
)
509514

510-
# Check if this is a static NVFP4 quantizer (has pre-computed scales from MSE calibration)
511-
# For static NVFP4, weight_scale is already computed from static _amax values in get_weight_scaling_factor
512-
is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer)
513-
514-
if not is_nvfp4_static:
515-
# For dynamic NVFP4, compute scales from weights
516-
weight_scale = NVFP4QTensor.get_weights_scaling_factor(
517-
weight,
518-
block_size=block_size,
519-
weights_scaling_factor_2=weight_scale_2,
520-
)[0]
515+
weight_scale = NVFP4QTensor.get_weights_scaling_factor(
516+
weight,
517+
block_size=block_size,
518+
weights_scaling_factor_2=weight_scale_2,
519+
)[0]
521520

522521
quantized_weight = to_quantized_weight(
523522
weight.to(dtype),

0 commit comments

Comments
 (0)