Skip to content

Commit f61e982

Browse files
committed
add unit tests, update _ensure_weight_quantizer_calibrated to handle NVFP4StaticQuantizer, update changelog
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
1 parent e407987 commit f61e982

8 files changed

Lines changed: 163 additions & 5 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ NVIDIA Model Optimizer Changelog (Linux)
1919
- Add support for context parallelism in Eagle speculative decoding for huggingface and megatron core models.
2020
- Add PTQ support for GLM-4.7, including loading MTP layer weights from a separate ``mtp.safetensors`` file and export as-is.
2121
- Add support for image-text data calibration in PTQ for Nemotron VL models.
22+
- Add support for advanced weight scale search for NVFP4 quantization and its export path.
2223

2324
0.41 (2026-01-19)
2425
^^^^^^^^^^^^^^^^^

examples/llm_ptq/hf_ptq.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,7 @@ def quantize_main(
868868
"fp8",
869869
"nvfp4",
870870
"nvfp4_awq",
871+
"nvfp4_mse",
871872
"w4a8_awq",
872873
"fp8_pb_wo",
873874
"w4a8_mxfp4_fp8",

modelopt/torch/export/quant_utils.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from modelopt.torch.quantization.utils import (
4242
QuantizerAttrNames,
4343
quantizer_attr_names,
44+
reduce_block_amax,
4445
weight_attr_names,
4546
)
4647
from modelopt.torch.utils import clear_cuda_cache
@@ -238,6 +239,36 @@ def get_scaling_factor(quantizer: TensorQuantizer) -> torch.Tensor:
238239
return scaling_factor
239240

240241

242+
def _get_nvfp4_block_size(
243+
weight_quantizer: NVFP4StaticQuantizer, weight: torch.Tensor, module_name: str = ""
244+
) -> int:
245+
"""Return block size for NVFP4 from quantizer's block_sizes; raise if missing."""
246+
prefix = f"NVFP4StaticQuantizer{f' for {module_name}' if module_name else ''}"
247+
block_sizes = weight_quantizer.block_sizes
248+
if block_sizes is None:
249+
raise ValueError(f"{prefix} has no block_sizes; cannot compute per-block amax from weight.")
250+
block_size = block_sizes.get(-1) or block_sizes.get(weight.dim() - 1)
251+
if block_size is None:
252+
raise ValueError(
253+
f"{prefix} block_sizes has no -1 or last-dim key; cannot compute per-block amax."
254+
)
255+
return block_size
256+
257+
258+
def _set_amax_from_tensor(weight_quantizer: TensorQuantizer, tensor: torch.Tensor) -> None:
259+
"""Set quantizer _amax buffer from tensor; copy in-place if same shape, else replace buffer."""
260+
if (
261+
hasattr(weight_quantizer, "_amax")
262+
and weight_quantizer._amax is not None
263+
and weight_quantizer._amax.shape == tensor.shape
264+
):
265+
weight_quantizer._amax.data.copy_(tensor.to(weight_quantizer._amax.device))
266+
else:
267+
if hasattr(weight_quantizer, "_amax"):
268+
delattr(weight_quantizer, "_amax")
269+
weight_quantizer.register_buffer("_amax", tensor.clone().detach())
270+
271+
241272
def _ensure_weight_quantizer_calibrated(
242273
weight_quantizer: TensorQuantizer, weight: torch.Tensor, module_name: str = ""
243274
) -> None:
@@ -246,11 +277,34 @@ def _ensure_weight_quantizer_calibrated(
246277
This is a lazy calibration pattern used during export when weight quantizers
247278
may not have been calibrated during the main calibration phase.
248279
280+
For NVFP4StaticQuantizer, _amax is per-block amax and _global_amax is the max over
281+
blocks; both are computed from the weight when missing.
282+
249283
Args:
250284
weight_quantizer: The weight quantizer to calibrate
251285
weight: The weight tensor to use for calibration
252286
module_name: Optional module name for better warning messages
253287
"""
288+
if isinstance(weight_quantizer, NVFP4StaticQuantizer):
289+
need_per_block = not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None
290+
need_global = (
291+
not hasattr(weight_quantizer, "_global_amax") or weight_quantizer.global_amax is None
292+
)
293+
if not (need_per_block or need_global):
294+
return
295+
block_size = _get_nvfp4_block_size(weight_quantizer, weight, module_name)
296+
warn(
297+
f"NVFP4StaticQuantizer{f' for {module_name}' if module_name else ''} was not fully calibrated. "
298+
f"Computing per-block amax and global_amax from weights. This may occur if: "
299+
f"some experts were not activated during calibration (expected for MoE models), try increasing --calib_size"
300+
)
301+
per_block_amax = reduce_block_amax(weight, block_sizes={-1: block_size})
302+
if need_per_block:
303+
_set_amax_from_tensor(weight_quantizer, per_block_amax.to(weight.device))
304+
if need_global:
305+
weight_quantizer.global_amax = per_block_amax.max()
306+
return
307+
254308
if not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None:
255309
warn(
256310
f"Weight quantizer{f' for {module_name}' if module_name else ''} was not calibrated. "

tests/_test_utils/torch/export/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,18 @@ def forward(self, x):
119119
"axis": None,
120120
"enable": True,
121121
},
122+
"*.2.weight_quantizer": {
123+
"num_bits": (2, 1),
124+
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
125+
"axis": None,
126+
"enable": True,
127+
},
128+
"*.2.input_quantizer": {
129+
"num_bits": (2, 1),
130+
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
131+
"axis": None,
132+
"enable": True,
133+
},
122134
"default": {"enable": False},
123135
},
124136
"algorithm": "max",

tests/gpu/torch/export/test_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ def test_get_scaling_factor(
415415
(
416416
partial_nvfp4_config,
417417
{
418-
"exclude_modules": ["linears.0", "linears.2"],
418+
"exclude_modules": ["linears.0"],
419419
"group_size": 16,
420420
"kv_cache_quant_algo": None,
421421
"quant_algo": "NVFP4",

tests/gpu/torch/export/test_export_weight_gpu.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,17 @@
1717

1818
import torch
1919
import torch.nn as nn
20-
from _test_utils.torch.export.utils import ToyModel, partial_w4a8_config
20+
from _test_utils.torch.export.utils import ToyModel, partial_nvfp4_config, partial_w4a8_config
2121
from torch.nn import functional as F
2222
from torch.nn import init
2323

2424
import modelopt.torch.quantization as mtq
2525
from modelopt.torch.export.unified_export_hf import _export_quantized_weight
26+
from modelopt.torch.quantization.nn import NVFP4StaticQuantizer
2627
from modelopt.torch.quantization.nn.modules.quant_module import QuantModule, QuantModuleRegistry
2728
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer
2829
from modelopt.torch.quantization.tensor_quant import QUANT_DESC_8BIT_PER_TENSOR
29-
from modelopt.torch.quantization.utils import quantizer_attr_names
30+
from modelopt.torch.quantization.utils import quantizer_attr_names, reduce_block_amax
3031

3132

3233
class ToyLinear(nn.Module):
@@ -121,3 +122,61 @@ def test_export_per_block_quantized_weight():
121122
assert hasattr(model.linears[2], quantizer_attrs.output_quantizer)
122123
assert not getattr(model.linears[2], quantizer_attrs.output_quantizer).is_enabled
123124
assert not hasattr(model.linears[2], quantizer_attrs.output_scale)
125+
126+
127+
def test_export_nvfp4_static_weight_dynamic_vs_static_match():
128+
"""Dynamic vs static NVFP4 export: same weight and scales after export even when amaxs are
129+
cleared on one layer (lazy calibration via _ensure_weight_quantizer_calibrated fills them from weights).
130+
"""
131+
device = "cuda"
132+
dims = [32, 32, 32, 32]
133+
block_size = 16
134+
calib_input = torch.randn(1, 4, 32, device=device)
135+
nvfp4_layer_indices = [1, 2] # layers with NVFP4 enabled in partial_nvfp4_config
136+
137+
torch.manual_seed(42)
138+
model_dynamic = ToyModel(dims=dims).to(device)
139+
mtq.quantize(model_dynamic, partial_nvfp4_config, lambda x: x(calib_input))
140+
141+
torch.manual_seed(42)
142+
model_static = ToyModel(dims=dims).to(device)
143+
mtq.quantize(model_static, partial_nvfp4_config, lambda x: x(calib_input))
144+
145+
# Convert NVFP4 layers to NVFP4StaticQuantizer with per-block and global amax
146+
for idx in nvfp4_layer_indices:
147+
layer = model_static.linears[idx]
148+
weight = layer.weight.data
149+
per_block_amax = reduce_block_amax(weight, block_sizes={-1: block_size})
150+
tq = layer.weight_quantizer
151+
if hasattr(tq, "_amax"):
152+
delattr(tq, "_amax")
153+
tq.register_buffer("_amax", per_block_amax.to(weight.device).clone().detach())
154+
NVFP4StaticQuantizer.from_tensor_quantizer(tq, global_amax=per_block_amax.max())
155+
156+
# Clear amaxs on layer 1 to exercise lazy calibration during export
157+
for linear, is_static in [(model_dynamic.linears[1], False), (model_static.linears[1], True)]:
158+
wq = linear.weight_quantizer
159+
if hasattr(wq, "_amax"):
160+
delattr(wq, "_amax")
161+
if is_static and hasattr(wq, "_global_amax"):
162+
delattr(wq, "_global_amax")
163+
164+
quantizer_attrs = quantizer_attr_names("weight")
165+
for idx in nvfp4_layer_indices:
166+
_export_quantized_weight(model_dynamic.linears[idx], torch.float32, "weight")
167+
_export_quantized_weight(model_static.linears[idx], torch.float32, "weight")
168+
169+
for idx in nvfp4_layer_indices:
170+
dyn_linear = model_dynamic.linears[idx]
171+
sta_linear = model_static.linears[idx]
172+
assert torch.equal(dyn_linear.weight, sta_linear.weight), (
173+
f"Layer {idx}: exported NVFP4 weight should match (dynamic vs static)"
174+
)
175+
assert torch.allclose(
176+
getattr(dyn_linear, quantizer_attrs.weight_scale).float(),
177+
getattr(sta_linear, quantizer_attrs.weight_scale).float(),
178+
), f"Layer {idx}: weight_scale should match"
179+
assert torch.allclose(
180+
getattr(dyn_linear, quantizer_attrs.weight_scale_2).float(),
181+
getattr(sta_linear, quantizer_attrs.weight_scale_2).float(),
182+
), f"Layer {idx}: weight_scale_2 should match"

tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
[
3636
("fp8", "tiny_llama-fp8", True, False, True, True, False),
3737
("nvfp4", "tiny_llama-nvfp4", True, False, True, True, False),
38+
("nvfp4_mse", "tiny_llama-nvfp4-mse", True, False, True, True, False),
3839
("nvfp4_awq", "tiny_llama-nvfp4-awq", True, False, True, True, False),
3940
("int4_awq", "tiny_llama-int4-awq", True, False, True, True, False),
4041
("w4a8_awq", "tiny_llama-w4a8-awq", True, False, True, True, False),

tests/unit/torch/export/test_get_quantization.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,22 @@
1515

1616
import pytest
1717
import torch
18-
from _test_utils.torch.export.utils import ToyModel, partial_fp8_config, partial_w4a8_config
18+
from _test_utils.torch.export.utils import (
19+
ToyModel,
20+
partial_fp8_config,
21+
partial_nvfp4_config,
22+
partial_w4a8_config,
23+
)
1924

2025
import modelopt.torch.quantization as mtq
2126
from modelopt.torch.export.layer_utils import get_quantization_format
22-
from modelopt.torch.export.model_config import QUANTIZATION_FP8, QUANTIZATION_W4A8_AWQ
27+
from modelopt.torch.export.model_config import (
28+
QUANTIZATION_FP8,
29+
QUANTIZATION_NVFP4,
30+
QUANTIZATION_W4A8_AWQ,
31+
)
32+
from modelopt.torch.export.quant_utils import get_quant_config
33+
from modelopt.torch.quantization.nn import NVFP4StaticQuantizer
2334

2435

2536
@pytest.mark.parametrize(
@@ -30,3 +41,22 @@ def test_get_quantization_format(config, expected):
3041
model = ToyModel()
3142
mtq.quantize(model, config, lambda x: x(torch.randn(1, 4, 10)))
3243
assert get_quantization_format(model) == expected
44+
45+
46+
def test_nvfp4_static_quantizer_export():
47+
"""NVFP4StaticQuantizer: get_quantization_format returns NVFP4 and get_quant_config returns export config."""
48+
model = ToyModel()
49+
mtq.quantize(model, partial_nvfp4_config, lambda x: x(torch.randn(1, 4, 10)))
50+
51+
# Convert all weight quantizers to NVFP4StaticQuantizer
52+
for module in model.modules():
53+
tq = getattr(module, "weight_quantizer", None)
54+
if tq is not None and hasattr(tq, "_amax") and not isinstance(tq, NVFP4StaticQuantizer):
55+
global_amax = tq._amax.max() if tq._amax.dim() > 0 else tq._amax
56+
NVFP4StaticQuantizer.from_tensor_quantizer(tq, global_amax=global_amax)
57+
58+
assert get_quantization_format(model) == QUANTIZATION_NVFP4
59+
60+
quant_config = get_quant_config(model)
61+
assert quant_config["quantization"]["quant_algo"] == "NVFP4"
62+
assert quant_config["quantization"]["group_size"] == 16

0 commit comments

Comments
 (0)