Skip to content

Commit dfe705a

Browse files
authored
support static NVFP4 HF export (#858)
## What does this PR do? **Type of change:** ? new feature **Overview:** ? Supports export `NVFP4StaticQantizer` in unified huggingface checkpoint, as a deployment path for PTQ algorithms such as MSE ## Usage <!-- You can potentially add a usage example below. --> ```python # checkpoint generation python examples/llm_ptq/hf_ptq.py --pyt_ckpt_path Qwen/Qwen3-8B --qformat nvfp4_mse --export_path test-Qwen3-8B-Instruct-MSE-FP8-sweep-FP4 --kv_cache_qformat none --trust_remote_code ``` ## Testing Tested generated Qwen3 8B checkpoint with trtllm serve and nv_eval example in `Model-Optimizer-Internal/examples/nv_eval`. NV eval results: ``` | Groups |Version|Filter|n-shot| Metric | |Value | |Stderr| |--------|-------|------|------|-----------|---|-----:|---|-----:| |mmlu_str| |none | |exact_match|↑ |0.7186|± |0.0036| ``` ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added support for static NVFP4 quantizers that utilize pre-computed calibration scales. * Introduced new NVFP4 W4A4 quantization configuration with optional FP8 scale sweep. * **Performance Improvements** * Static quantizers now skip unnecessary dynamic scaling factor recalculation. * Unified quantization handling for improved consistency and efficiency. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
1 parent 3fe7e65 commit dfe705a

11 files changed

Lines changed: 325 additions & 27 deletions

File tree

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ NVIDIA Model Optimizer Changelog (Linux)
1515
- Add ``--moe_calib_experts_ratio`` flag in ``hf_ptq.py`` to specify the ratio of experts to calibrate during forward pass to improve expert coverage during calibration. Default to all the experts.
1616
- Add sparse attention optimization for transformer models (``modelopt.torch.sparsity.attention_sparsity``). This reduces computational cost by skipping attention computation. Supports calibration for threshold selection on HuggingFace models. See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.
1717
- Add support for rotating the input before quantization for RHT.
18+
- Add support for advanced weight scale search for NVFP4 quantization and its export path.
1819

1920
0.42 (2026-02-xx)
2021
^^^^^^^^^^^^^^^^^
@@ -36,6 +37,7 @@ NVIDIA Model Optimizer Changelog (Linux)
3637
- Add LTX-2 and Wan2.2 (T2V) support in the diffusers quantization workflow.
3738
- Add PTQ support for GLM-4.7, including loading MTP layer weights from a separate ``mtp.safetensors`` file and export as-is.
3839
- Add support for image-text data calibration in PTQ for Nemotron VL models.
40+
- Add support for advanced weight scale search for NVFP4 quantization and its export path.
3941
- Add PTQ support for Nemotron Parse.
4042
- Add distillation support for LTX-2. See `examples/diffusers/distillation/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/diffusers/distillation>`_ for more details.
4143

examples/llm_ptq/hf_ptq.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
8282
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
8383
"nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG,
84+
"nvfp4_mse": mtq.NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG,
8485
"fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
8586
"fp8_pc_pt": mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG,
8687
"w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG,
@@ -890,6 +891,7 @@ def quantize_main(
890891
"fp8",
891892
"nvfp4",
892893
"nvfp4_awq",
894+
"nvfp4_mse",
893895
"w4a8_awq",
894896
"fp8_pb_wo",
895897
"w4a8_mxfp4_fp8",

modelopt/torch/export/quant_utils.py

Lines changed: 96 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,12 @@
4141
from modelopt.torch.quantization.utils import (
4242
QuantizerAttrNames,
4343
quantizer_attr_names,
44+
reduce_block_amax,
4445
weight_attr_names,
4546
)
4647
from modelopt.torch.utils import clear_cuda_cache
4748

48-
from ..quantization.nn import SequentialQuantizer, TensorQuantizer
49+
from ..quantization.nn import NVFP4StaticQuantizer, SequentialQuantizer, TensorQuantizer
4950
from .model_config import (
5051
KV_CACHE_FP8,
5152
KV_CACHE_INT8,
@@ -238,6 +239,36 @@ def get_scaling_factor(quantizer: TensorQuantizer) -> torch.Tensor:
238239
return scaling_factor
239240

240241

242+
def _get_nvfp4_block_size(
243+
weight_quantizer: NVFP4StaticQuantizer, weight: torch.Tensor, module_name: str = ""
244+
) -> int:
245+
"""Return block size for NVFP4 from quantizer's block_sizes; raise if missing."""
246+
prefix = f"NVFP4StaticQuantizer{f' for {module_name}' if module_name else ''}"
247+
block_sizes = weight_quantizer.block_sizes
248+
if block_sizes is None:
249+
raise ValueError(f"{prefix} has no block_sizes; cannot compute per-block amax from weight.")
250+
block_size = block_sizes.get(-1) or block_sizes.get(weight.dim() - 1)
251+
if block_size is None:
252+
raise ValueError(
253+
f"{prefix} block_sizes has no -1 or last-dim key; cannot compute per-block amax."
254+
)
255+
return block_size
256+
257+
258+
def _set_amax_from_tensor(weight_quantizer: TensorQuantizer, tensor: torch.Tensor) -> None:
259+
"""Set quantizer _amax buffer from tensor; copy in-place if same shape, else replace buffer."""
260+
if (
261+
hasattr(weight_quantizer, "_amax")
262+
and weight_quantizer._amax is not None
263+
and weight_quantizer._amax.shape == tensor.shape
264+
):
265+
weight_quantizer._amax.data.copy_(tensor.to(weight_quantizer._amax.device))
266+
else:
267+
if hasattr(weight_quantizer, "_amax"):
268+
delattr(weight_quantizer, "_amax")
269+
weight_quantizer.register_buffer("_amax", tensor.clone().detach())
270+
271+
241272
def _ensure_weight_quantizer_calibrated(
242273
weight_quantizer: TensorQuantizer, weight: torch.Tensor, module_name: str = ""
243274
) -> None:
@@ -246,11 +277,34 @@ def _ensure_weight_quantizer_calibrated(
246277
This is a lazy calibration pattern used during export when weight quantizers
247278
may not have been calibrated during the main calibration phase.
248279
280+
For NVFP4StaticQuantizer, _amax is per-block amax and _global_amax is the max over
281+
blocks; both are computed from the weight when missing.
282+
249283
Args:
250284
weight_quantizer: The weight quantizer to calibrate
251285
weight: The weight tensor to use for calibration
252286
module_name: Optional module name for better warning messages
253287
"""
288+
if isinstance(weight_quantizer, NVFP4StaticQuantizer):
289+
need_per_block = not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None
290+
need_global = (
291+
not hasattr(weight_quantizer, "_global_amax") or weight_quantizer.global_amax is None
292+
)
293+
if not (need_per_block or need_global):
294+
return
295+
block_size = _get_nvfp4_block_size(weight_quantizer, weight, module_name)
296+
warn(
297+
f"NVFP4StaticQuantizer{f' for {module_name}' if module_name else ''} was not fully calibrated. "
298+
f"Computing per-block amax and global_amax from weights. This may occur if: "
299+
f"some experts were not activated during calibration (expected for MoE models), try increasing --calib_size"
300+
)
301+
per_block_amax = reduce_block_amax(weight, block_sizes={-1: block_size})
302+
if need_per_block:
303+
_set_amax_from_tensor(weight_quantizer, per_block_amax.to(weight.device))
304+
if need_global:
305+
weight_quantizer.global_amax = per_block_amax.max()
306+
return
307+
254308
if not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None:
255309
warn(
256310
f"Weight quantizer{f' for {module_name}' if module_name else ''} was not calibrated. "
@@ -299,7 +353,7 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
299353
return get_scaling_factor(weight_quantizer[0])
300354

301355
quantization_format = get_quantization_format(module)
302-
# If NVFP4, we need to return quantized per_block scaling factors
356+
303357
if quantization_format in [
304358
QUANTIZATION_NVFP4,
305359
QUANTIZATION_NVFP4_AWQ,
@@ -318,9 +372,10 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
318372
weight_scaling_factor_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(
319373
weight_quantizer
320374
)
321-
return NVFP4QTensor.get_weights_scaling_factor(
375+
# Unified method handles both static and dynamic quantizers
376+
return NVFP4QTensor.get_weights_scaling_factor_from_quantizer(
377+
weight_quantizer,
322378
weight,
323-
weight_quantizer.block_sizes[-1],
324379
weight_scaling_factor_2.to(weight.device),
325380
)[0]
326381

@@ -343,27 +398,24 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
343398

344399
quantization_format = get_quantization_format(module)
345400

346-
# Calibrate weight quantizer if amax is not set for all NVFP4 variants
347401
if quantization_format in [
348402
QUANTIZATION_NVFP4,
349403
QUANTIZATION_NVFP4_AWQ,
350404
QUANTIZATION_NVFP4_SVDQUANT,
351405
QUANTIZATION_W4A8_NVFP4_FP8,
352406
]:
407+
# Calibrate weight quantizer if amax is not set
353408
weight = getattr(module, weight_name)
354409
module_name = f"{type(module).__name__}.{weight_name}"
355410
_ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name)
356411

357-
if quantization_format in [
358-
QUANTIZATION_NVFP4,
359-
QUANTIZATION_NVFP4_AWQ,
360-
QUANTIZATION_NVFP4_SVDQUANT,
361-
]:
362-
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
363-
elif quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
364-
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
365-
# This is because the kernel dequantizes weight to fp8, which is in range 448.
366-
return weight_quantizer._amax.float() / 448.0
412+
if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
413+
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
414+
# This is because the kernel dequantizes weight to fp8, which is in range 448.
415+
return weight_quantizer._amax.float() / 448.0
416+
else:
417+
# Unified method handles both static and dynamic quantizers
418+
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
367419

368420
# SequentialQuantizer is required
369421
if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled:
@@ -735,7 +787,7 @@ def process_layer_quant_config(layer_config_dict):
735787
layer_config = {"quant_algo": "W8A16"}
736788
elif v == "int8_sq":
737789
layer_config = {"quant_algo": "W8A8_SQ_PER_CHANNEL"}
738-
elif v == "nvfp4":
790+
elif v in ["nvfp4", "nvfp4_static"]:
739791
layer_config = {
740792
"quant_algo": "NVFP4",
741793
"group_size": block_size_value,
@@ -1339,6 +1391,18 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False
13391391
for module in modules:
13401392
module.weight_quantizer[-1].amax = weight_amax
13411393

1394+
# Handle NVFP4StaticQuantizer: unify global_amax for fused layers
1395+
elif isinstance(modules[0].weight_quantizer, NVFP4StaticQuantizer):
1396+
global_amax_list = [
1397+
m.weight_quantizer.global_amax
1398+
for m in modules
1399+
if m.weight_quantizer.global_amax is not None
1400+
]
1401+
if global_amax_list:
1402+
unified_global_amax = torch.max(torch.stack(global_amax_list))
1403+
for module in modules:
1404+
module.weight_quantizer.global_amax = unified_global_amax
1405+
13421406
elif (
13431407
modules[0].weight_quantizer.is_enabled
13441408
and modules[0].weight_quantizer.amax is not None
@@ -1423,6 +1487,22 @@ def get_quant_config(
14231487
if block_size == 0:
14241488
block_size = get_weight_block_size(module)
14251489

1490+
# Static NVFP4 uses pre-computed per-block scales from MSE calibration
1491+
if quantization_format == QUANTIZATION_NVFP4:
1492+
weight_quantizer = getattr(module, "weight_quantizer", None)
1493+
if weight_quantizer is None:
1494+
# Try to get from first weight attribute
1495+
for wn in weight_names:
1496+
weight_quantizer = getattr(
1497+
module, quantizer_attr_names(wn).weight_quantizer, None
1498+
)
1499+
if weight_quantizer is not None:
1500+
break
1501+
if weight_quantizer is not None:
1502+
is_static = isinstance(weight_quantizer, NVFP4StaticQuantizer)
1503+
if is_static:
1504+
quantization_format = "nvfp4_static"
1505+
14261506
# Construct per layer config dictionary
14271507
layer_config_dict[name + ".quantization"] = quantization_format
14281508
layer_config_dict[name + ".awq_block_size"] = block_size

modelopt/torch/export/unified_export_hf.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@
5050
from torch.distributed.fsdp import FSDPModule
5151

5252
from modelopt.torch.quantization import set_quantizer_by_cfg_context
53-
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
53+
from modelopt.torch.quantization.nn import (
54+
NVFP4StaticQuantizer,
55+
SequentialQuantizer,
56+
TensorQuantizer,
57+
)
5458
from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor
5559
from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names
5660

@@ -496,6 +500,11 @@ def _export_quantized_weight(
496500
expert_type in type(sub_module).__name__
497501
for expert_type in ["Llama4TextExperts", "GptOssExperts"]
498502
)
503+
if is_bmm_expert_weight and isinstance(weight_quantizer, NVFP4StaticQuantizer):
504+
raise ValueError(
505+
"NVFP4StaticQuantizer with BMM-style expert weights (e.g. Llama4TextExperts, "
506+
"GptOssExperts) is not yet supported."
507+
)
499508

500509
if quantization_format in [
501510
QUANTIZATION_NVFP4,
@@ -507,6 +516,7 @@ def _export_quantized_weight(
507516
weight, _ = maybe_transpose_expert_weight_dimensions(
508517
weight, is_bmm_expert_weight=is_bmm_expert_weight
509518
)
519+
510520
weight_scale = NVFP4QTensor.get_weights_scaling_factor(
511521
weight,
512522
block_size=block_size,

modelopt/torch/quantization/config.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,28 @@
419419
"algorithm": "max",
420420
}
421421

422+
NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG = {
423+
"quant_cfg": {
424+
"*weight_quantizer": {
425+
"num_bits": (2, 1),
426+
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
427+
"axis": None,
428+
"enable": True,
429+
},
430+
"*input_quantizer": {
431+
"num_bits": (2, 1),
432+
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
433+
"axis": None,
434+
"enable": True,
435+
},
436+
**_default_disabled_quantizer_cfg,
437+
},
438+
"algorithm": {
439+
"method": "mse",
440+
"fp8_scale_sweep": True,
441+
},
442+
}
443+
422444
NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG = {
423445
"quant_cfg": {
424446
"*weight_quantizer": {
@@ -751,6 +773,7 @@
751773
"MAMBA_MOE_NVFP4_AGGRESSIVE_CFG",
752774
"MAMBA_MOE_FP8_CONSERVATIVE_CFG",
753775
"MAMBA_MOE_FP8_AGGRESSIVE_CFG",
776+
"NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG",
754777
}
755778

756779
BiasType = Literal["static", "dynamic"]

0 commit comments

Comments
 (0)