Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,11 @@ def pre_quantize(
preview_input_ids = next(iter(calib_dataloader))[
"input_features" if model_type == "whisper" else "input_ids"
][0:1]
# Strip leading padding tokens so the preview input shows real content
Comment thread
Fridah-nv marked this conversation as resolved.
Outdated
if model_type != "whisper" and tokenizer is not None and tokenizer.pad_token_id is not None:
first_non_pad = (preview_input_ids[0] != tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
if first_non_pad.numel() > 0:
preview_input_ids = preview_input_ids[:, first_non_pad[0] :]
Comment thread
Fridah-nv marked this conversation as resolved.
Outdated

# Generate preview before quantization
if args.skip_generate:
Expand Down Expand Up @@ -928,7 +933,7 @@ def input_decode(input_ids):
if processor is not None and isinstance(processor, WhisperProcessor):
return first_text_speech_dataset
elif tokenizer is not None:
return tokenizer.batch_decode(input_ids)
return tokenizer.batch_decode(input_ids, skip_special_tokens=True)
Comment thread
Fridah-nv marked this conversation as resolved.
Outdated
else:
raise ValueError("The processor or tokenizer must be set")

Expand Down
2 changes: 1 addition & 1 deletion modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ def finish_stats_collection(model: nn.Module, method: str | None = None, **kwarg

cal = getattr(module, "_calibrator", None)
if cal and not getattr(module, "_dynamic", False):
if method in {"entropy"}:
if method == "entropy":
Comment thread
realAsma marked this conversation as resolved.
if cal.compute_amax(method) is not None:
module.load_calib_amax("entropy", **kwargs)
elif cal.compute_amax(**kwargs) is not None:
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/quantization/model_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ def print_quant_summary(model: nn.Module, output_dir: str | None = None):
lines.append(f"{len(lines)} TensorQuantizers found in model")

if output_dir:
os.makedirs(output_dir, exist_ok=True)
path = os.path.join(output_dir, ".quant_summary.txt")
with open(path, "w", encoding="utf-8") as f:
f.write("\n".join(lines) + "\n")
Expand Down
4 changes: 2 additions & 2 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,7 +1122,7 @@ def forward(self, inputs):

return outputs

def _short_amax(self, fmt=".4f"):
def _short_amax(self, fmt=".2e"):
"""Short description of amax.

Returns:
Expand All @@ -1140,7 +1140,7 @@ def _short_amax(self, fmt=".4f"):
return "meta"
return self._short_tensor(self._amax, fmt)

def _short_tensor(self, tensor: torch.Tensor, fmt=".4f"):
def _short_tensor(self, tensor: torch.Tensor, fmt=".2e"):
"""Short description of tensor."""
if tensor.numel() == 1:
return f"{tensor.item():{fmt}}"
Expand Down
24 changes: 13 additions & 11 deletions modelopt/torch/quantization/qtensor/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer):
)
return weight_quantizer._amax.float() / (6.0 * 448.0)

@classmethod
def _cast_per_block_scale_to_fp8(cls, per_block_scale: torch.Tensor) -> torch.Tensor:
Comment thread
Fridah-nv marked this conversation as resolved.
Outdated
"""Clamp to FP8 E4M3FN representable range, then cast.

FP8 E4M3FN has no Inf and a smallest positive subnormal of ``2**-9`` (~0.00195).
Values below the min silently underflow to 0 (zero outputs at inference); values
above 448 cast to NaN.
Comment thread
Fridah-nv marked this conversation as resolved.
Outdated
"""
return per_block_scale.clamp(min=2**-9, max=448.0).to(torch.float8_e4m3fn)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] The min-clamp at 2**-9 (smallest FP8 E4M3FN subnormal) is a reasonable lower bound, but it pushes legitimately-small block scales into the subnormal range, where FP8 E4M3FN has only 3 mantissa bits of resolution rather than the implicit-1 + 3-bit precision of normals. For a block whose true scale lies just below 2**-6 (the smallest normal), clamping is harmless — the cast already lands in subnormal territory. But for blocks whose true scale is much smaller, clamping to 2**-9 substantially over-states the per-block scale and inflates that block's quantization error. That's still strictly better than the underflow-to-0 it replaces (which silently zeros all 16 weights in the block at inference), so the change is right; just worth a one-line note in the docstring that "very small per-block scales are saturated up to the subnormal floor, which trades some block-level accuracy for the guarantee that no block silently outputs 0."

Also consider asserting input >= 0 in the helper — clamp(min=2**-9) would silently flip a negative value (which shouldn't ever exist for a per-block scale) into a positive subnormal and hide a real bug elsewhere.


@classmethod
def get_weights_scaling_factor_from_quantizer(
cls,
Expand Down Expand Up @@ -122,16 +132,9 @@ def get_weights_scaling_factor_from_quantizer(
expected_shape = (*weight.shape[:-1], num_blocks_per_row)
per_block_scale = per_block_scale.view(expected_shape)

# Quantize scales to FP8. Saturate to the fp8_e4m3fn max (448) before the
# cast: when the [==0]=1.0 safety net above fires (per_block_amax was zero
# for an all-zero weight block) and global_amax is small, the pre-cast value
# explodes to ``1.0 * 448 / (global_amax/6)``. fp8_e4m3fn has no Inf, so any
# value >= 480 casts to NaN — clamp first to keep the stored byte finite.
if not keep_high_precision:
per_block_scale = (
(per_block_scale * 448.0 / per_block_scale_max)
.clamp_(max=448.0)
.to(torch.float8_e4m3fn)
per_block_scale = cls._cast_per_block_scale_to_fp8(
per_block_scale * 448.0 / per_block_scale_max
Comment thread
Fridah-nv marked this conversation as resolved.
Outdated
)
Comment thread
Fridah-nv marked this conversation as resolved.
Outdated
return per_block_scale, weights_scaling_factor_2
else:
Expand Down Expand Up @@ -171,9 +174,8 @@ def get_weights_scaling_factor(
)
# Set all zero values in scale to 1.0
per_block_scale[per_block_scale == 0] = 1.0
# Convert to torch.float8_e4m3fn
if not keep_high_precision:
per_block_scale = per_block_scale.to(torch.float8_e4m3fn)
per_block_scale = cls._cast_per_block_scale_to_fp8(per_block_scale)
return per_block_scale, weights_scaling_factor_2

@classmethod
Expand Down
44 changes: 44 additions & 0 deletions modelopt_recipes/general/ptq/nvfp4_experts_only_mse.yaml
Comment thread
Fridah-nv marked this conversation as resolved.
Outdated
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

imports:
base_disable_all: configs/ptq/units/base_disable_all
default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers
nvfp4: configs/numerics/nvfp4
nvfp4_static: configs/numerics/nvfp4_static

metadata:
Comment thread
Fridah-nv marked this conversation as resolved.
Outdated
recipe_type: ptq
description: NVFP4 static weight (MSE FP8-scale sweep) and dynamic activation for expert layers only (W4A4), no KV-cache quantization.
quantize:
algorithm:
method: mse
fp8_scale_sweep: true
layerwise: false
quant_cfg:
- $import: base_disable_all
- quantizer_name: '*mlp.experts*weight_quantizer'
cfg:
$import: nvfp4_static
- quantizer_name: '*mlp.experts*input_quantizer'
cfg:
$import: nvfp4
- quantizer_name: '*block_sparse_moe*weight_quantizer'
cfg:
$import: nvfp4_static
- quantizer_name: '*block_sparse_moe*input_quantizer'
cfg:
$import: nvfp4
- $import: default_disabled_quantizers
102 changes: 85 additions & 17 deletions tests/unit/torch/quantization/plugins/test_fused_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,27 +256,51 @@ def test_expert_index_recovery(self):
# Tests for export
# ---------------------------------------------------------------------------
class TestExportFusedExperts:
@staticmethod
def _cleanup_registry(mod_type):
if QuantModuleRegistry.get(mod_type) is not None:
QuantModuleRegistry.unregister(mod_type)

def test_export_creates_per_expert_submodules(self):
"""_export_fused_experts should create per-expert submodules with standard naming."""
import modelopt.torch.quantization as mtq
from modelopt.torch.export.moe_utils import _export_fused_experts

experts = _SyntheticFusedExperts()
expert_type = type(experts)
model = _TinyMoEModel()
expert_type = type(model.moe.experts)
self._cleanup_registry(expert_type)

# Manually register and convert
if QuantModuleRegistry.get(expert_type) is None:
QuantModuleRegistry.register({expert_type: "test.SyntheticFusedExperts"})(
_QuantFusedExperts
)
converted = QuantModuleRegistry.convert(experts)
quant_cfg = {
"quant_cfg": [
{"quantizer_name": "*", "enable": False},
{
"quantizer_name": "*gate_up_proj_input_quantizer",
"cfg": {"num_bits": 8, "axis": None},
},
{
"quantizer_name": "*down_proj_input_quantizer",
"cfg": {"num_bits": 8, "axis": None},
},
{
"quantizer_name": "*gate_up_proj_weight_quantizer",
"cfg": {"num_bits": 8, "axis": 0},
},
{
"quantizer_name": "*down_proj_weight_quantizer",
"cfg": {"num_bits": 8, "axis": 0},
},
],
"algorithm": "max",
}

# Run a forward pass to calibrate (set amaxes)
seq_len = 16
hidden_states = torch.randn(seq_len, HIDDEN_DIM)
top_k_index = torch.randint(0, NUM_EXPERTS, (seq_len, TOP_K))
top_k_weights = torch.softmax(torch.randn(seq_len, TOP_K), dim=-1)
with torch.no_grad():
converted(hidden_states, top_k_index, top_k_weights)
def forward_loop(m):
torch.manual_seed(0)
for _ in range(2):
x = torch.randn(1, 4, HIDDEN_DIM)
m(x)

mtq.quantize(model, quant_cfg, forward_loop=forward_loop)
converted = model.moe.experts

_export_fused_experts(converted, torch.float16)

Expand All @@ -297,8 +321,7 @@ def test_export_creates_per_expert_submodules(self):
assert not hasattr(converted, "down_proj")
assert not hasattr(converted, "gate_up_proj_weight_quantizers")

if QuantModuleRegistry.get(expert_type) is not None:
QuantModuleRegistry.unregister(expert_type)
self._cleanup_registry(expert_type)

def test_uncalibrated_expert_gate_up_share_amax(self, monkeypatch):
"""gate_proj and up_proj must share weight_scale_2 even when an expert
Expand Down Expand Up @@ -899,3 +922,48 @@ def test_unrelated_dotted_number_unchanged(self):
_normalize_fused_experts_quantizer_name("moe.layers.3.gate.weight")
== "moe.layers.3.gate.weight"
)


# Verifies that MSE calibration discovers and calibrates every per-expert weight quantizer
# inside a fused-expert ModuleList (both gate_up_proj and down_proj, for all experts).
class TestFusedExpertsMSECalibration:
Comment thread
Fridah-nv marked this conversation as resolved.
Outdated
@staticmethod
def _cleanup_registry(mod_type):
if QuantModuleRegistry.get(mod_type) is not None:
QuantModuleRegistry.unregister(mod_type)

def test_mse_calibration_populates_all_expert_quantizers(self):
import modelopt.torch.quantization as mtq

model = _TinyMoEModel()
expert_type = type(model.moe.experts)
self._cleanup_registry(expert_type)

mtq.quantize(
model,
{
"quant_cfg": [
{"quantizer_name": "*", "enable": False},
{
"quantizer_name": "*gate_up_proj_weight_quantizer",
"cfg": {"num_bits": 8, "axis": None},
},
{
"quantizer_name": "*down_proj_weight_quantizer",
"cfg": {"num_bits": 8, "axis": None},
},
],
"algorithm": "mse",
},
forward_loop=lambda m: [m(torch.randn(1, 4, HIDDEN_DIM)) for _ in range(2)],
)

experts = model.moe.experts
for idx in range(NUM_EXPERTS):
assert experts.gate_up_proj_weight_quantizers[idx].amax is not None, (
f"gate_up_proj_weight_quantizers[{idx}] not calibrated — Bug 1 regression"
)
assert experts.down_proj_weight_quantizers[idx].amax is not None, (
f"down_proj_weight_quantizers[{idx}] not calibrated"
)
Comment thread
Fridah-nv marked this conversation as resolved.
Outdated
self._cleanup_registry(expert_type)
111 changes: 111 additions & 0 deletions tests/unit/torch/quantization/test_nvfp4_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for NVFP4QTensor per-block FP8 scale clamping (underflow + overflow)."""

from types import SimpleNamespace

import torch

from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor

_FP8_E4M3FN_MIN = 2**-9 # 0.001953125 — smallest positive FP8 E4M3FN subnormal
_FP8_E4M3FN_MAX = 448.0


class TestNVFP4ScaleClamping:
"""Per-block weight scales outside the FP8 E4M3FN range must be clamped, not turned into 0/NaN."""

def test_no_zero_scales_for_tiny_weights(self):
"""Tiny per-block amax (<<FP8 min) must not underflow to zero after FP8 cast."""
block_size = 16
tiny_weight = torch.full((4, block_size), 1e-10)
# wsf2=1.0 → per_block_scale = amax/(6*wsf2) ≈ 1.7e-11 << 2^-9, exercises FP8-min clamp
wsf2 = torch.tensor(1.0)

per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor(tiny_weight, block_size, wsf2)
per_block_scale_f32 = per_block_scale.float()

assert (per_block_scale_f32 > 0).all(), (
f"Zero per-block scales found after FP8 cast: {per_block_scale_f32.tolist()}. "
Comment thread
Fridah-nv marked this conversation as resolved.
"FP8 scale underflow clamping likely regressed."
)
assert (per_block_scale_f32 >= _FP8_E4M3FN_MIN).all(), (
"Per-block scales below FP8 minimum subnormal found after cast."
)

def test_normal_weights_unaffected_by_clamp(self):
"""Weights with typical magnitudes must not be affected by the underflow clamp."""
block_size = 16
torch.manual_seed(42)
normal_weight = torch.randn(8, block_size)

per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor(normal_weight, block_size)
assert (per_block_scale.float() > 0).all(), "Normal weights produced zero scales."

def test_mixed_weight_no_zeros(self):
"""Mixed-magnitude tensor (normal + tiny blocks) must have no zero scales."""
block_size = 16
weight = torch.cat(
[
torch.randn(4, block_size),
torch.full((4, block_size), 1e-12),
],
dim=0,
)

per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor(weight, block_size)
assert (per_block_scale.float() > 0).all(), (
"Zero scales in mixed-magnitude tensor after FP8 cast."
)

def test_helper_clamps_overflow_to_max(self):
"""Values above 448 must saturate to 448, not cast to NaN (fp8_e4m3fn has no Inf)."""
oversized = torch.tensor([100.0, 448.0, 1e3, 1e6])
out = NVFP4QTensor._cast_per_block_scale_to_fp8(oversized).float()
assert torch.isfinite(out).all(), f"FP8 cast produced non-finite values: {out.tolist()}"
assert (out <= _FP8_E4M3FN_MAX).all(), f"FP8 cast values exceed 448: {out.tolist()}"

def test_helper_clamps_underflow_to_min(self):
"""Values below the FP8 subnormal must clamp up, not collapse to 0."""
tiny = torch.tensor([0.0, 1e-12, 1e-6, _FP8_E4M3FN_MIN / 2])
out = NVFP4QTensor._cast_per_block_scale_to_fp8(tiny).float()
assert (out > 0).all(), f"FP8 cast produced zero scales: {out.tolist()}"

def test_static_path_no_nan_when_block_amax_zero(self):
"""Static path: when a block's amax is 0 (all-zero weights), the `[==0]=1.0` safety net
Comment thread
Fridah-nv marked this conversation as resolved.
Outdated
and a small global_amax push the pre-cast value above 448. Without the max clamp,
fp8_e4m3fn would cast it to NaN — regression for the export-time NaN reported on this PR.
"""
block_size = 16
# global_amax small enough that 1.0 * 448 / (global_amax/6) >> 448.
global_amax = torch.tensor(0.01)
# One block with amax=0 (triggers safety net), three normal blocks.
per_block_amax = torch.tensor([[0.0, 0.005, 0.008, 0.01]])
weight = torch.randn(1, 4 * block_size)
q = SimpleNamespace(
global_amax=global_amax,
_amax=per_block_amax,
block_sizes={-1: block_size},
)

per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor_from_quantizer(q, weight)
per_block_scale_f32 = per_block_scale.float()
assert torch.isfinite(per_block_scale_f32).all(), (
f"NaN/Inf in exported static per-block scale: {per_block_scale_f32.tolist()}"
)
assert (per_block_scale_f32 <= _FP8_E4M3FN_MAX).all(), (
f"Static per-block scale exceeds FP8 max 448: {per_block_scale_f32.tolist()}"
)
Loading