Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ def finish_stats_collection(model: nn.Module, method: str | None = None, **kwarg

cal = getattr(module, "_calibrator", None)
if cal and not getattr(module, "_dynamic", False):
if method in {"entropy"}:
if method == "entropy":
Comment thread
realAsma marked this conversation as resolved.
if cal.compute_amax(method) is not None:
module.load_calib_amax("entropy", **kwargs)
elif cal.compute_amax(**kwargs) is not None:
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/quantization/model_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ def print_quant_summary(model: nn.Module, output_dir: str | None = None):
lines.append(f"{len(lines)} TensorQuantizers found in model")

if output_dir:
os.makedirs(output_dir, exist_ok=True)
path = os.path.join(output_dir, ".quant_summary.txt")
with open(path, "w", encoding="utf-8") as f:
f.write("\n".join(lines) + "\n")
Expand Down
4 changes: 2 additions & 2 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,7 +1122,7 @@ def forward(self, inputs):

return outputs

def _short_amax(self, fmt=".4f"):
def _short_amax(self, fmt=".2e"):
"""Short description of amax.

Returns:
Expand All @@ -1140,7 +1140,7 @@ def _short_amax(self, fmt=".4f"):
return "meta"
return self._short_tensor(self._amax, fmt)

def _short_tensor(self, tensor: torch.Tensor, fmt=".4f"):
def _short_tensor(self, tensor: torch.Tensor, fmt=".2e"):
"""Short description of tensor."""
if tensor.numel() == 1:
return f"{tensor.item():{fmt}}"
Expand Down
30 changes: 18 additions & 12 deletions modelopt/torch/quantization/qtensor/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,22 @@
__all__ = ["NVFP4QTensor"]


def _cast_per_block_scale_to_fp8(
per_block_scale: torch.Tensor,
per_block_scale_max: torch.Tensor | None = None,
) -> torch.Tensor:
"""Clamp to FP8 E4M3FN range [2**-9, 448] and cast — avoids underflow→0 / overflow→NaN.

When ``per_block_scale_max`` is provided, first rescales as
``per_block_scale.float() * 448 / per_block_scale_max`` — the static-export
path needs this because the ``[==0]=1.0`` safety net combined with a small
``global_amax`` can drive the rescaled value above 448 (see PR #1397).
"""
if per_block_scale_max is not None:
per_block_scale = per_block_scale.float() * 448.0 / per_block_scale_max
return per_block_scale.clamp(min=2**-9, max=448.0).to(torch.float8_e4m3fn)


class NVFP4QTensor(BaseQuantizedTensor):
"""Implements the INT4 quantization on tensors for more efficient storage or computation.

Expand Down Expand Up @@ -122,17 +138,8 @@ def get_weights_scaling_factor_from_quantizer(
expected_shape = (*weight.shape[:-1], num_blocks_per_row)
per_block_scale = per_block_scale.view(expected_shape)

# Quantize scales to FP8. Saturate to the fp8_e4m3fn max (448) before the
# cast: when the [==0]=1.0 safety net above fires (per_block_amax was zero
# for an all-zero weight block) and global_amax is small, the pre-cast value
# explodes to ``1.0 * 448 / (global_amax/6)``. fp8_e4m3fn has no Inf, so any
# value >= 480 casts to NaN — clamp first to keep the stored byte finite.
if not keep_high_precision:
per_block_scale = (
(per_block_scale * 448.0 / per_block_scale_max)
.clamp_(max=448.0)
.to(torch.float8_e4m3fn)
)
per_block_scale = _cast_per_block_scale_to_fp8(per_block_scale, per_block_scale_max)
return per_block_scale, weights_scaling_factor_2
else:
# Dynamic path: compute from weight tensor
Expand Down Expand Up @@ -171,9 +178,8 @@ def get_weights_scaling_factor(
)
# Set all zero values in scale to 1.0
per_block_scale[per_block_scale == 0] = 1.0
# Convert to torch.float8_e4m3fn
if not keep_high_precision:
per_block_scale = per_block_scale.to(torch.float8_e4m3fn)
per_block_scale = _cast_per_block_scale_to_fp8(per_block_scale)
return per_block_scale, weights_scaling_factor_2

@classmethod
Expand Down
57 changes: 40 additions & 17 deletions tests/unit/torch/quantization/plugins/test_fused_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,27 +256,51 @@ def test_expert_index_recovery(self):
# Tests for export
# ---------------------------------------------------------------------------
class TestExportFusedExperts:
@staticmethod
def _cleanup_registry(mod_type):
if QuantModuleRegistry.get(mod_type) is not None:
QuantModuleRegistry.unregister(mod_type)

def test_export_creates_per_expert_submodules(self):
"""_export_fused_experts should create per-expert submodules with standard naming."""
import modelopt.torch.quantization as mtq
from modelopt.torch.export.moe_utils import _export_fused_experts

experts = _SyntheticFusedExperts()
expert_type = type(experts)
model = _TinyMoEModel()
expert_type = type(model.moe.experts)
self._cleanup_registry(expert_type)

# Manually register and convert
if QuantModuleRegistry.get(expert_type) is None:
QuantModuleRegistry.register({expert_type: "test.SyntheticFusedExperts"})(
_QuantFusedExperts
)
converted = QuantModuleRegistry.convert(experts)
quant_cfg = {
"quant_cfg": [
{"quantizer_name": "*", "enable": False},
{
"quantizer_name": "*gate_up_proj_input_quantizer",
"cfg": {"num_bits": 8, "axis": None},
},
{
"quantizer_name": "*down_proj_input_quantizer",
"cfg": {"num_bits": 8, "axis": None},
},
{
"quantizer_name": "*gate_up_proj_weight_quantizer",
"cfg": {"num_bits": 8, "axis": 0},
},
{
"quantizer_name": "*down_proj_weight_quantizer",
"cfg": {"num_bits": 8, "axis": 0},
},
],
"algorithm": "max",
}

# Run a forward pass to calibrate (set amaxes)
seq_len = 16
hidden_states = torch.randn(seq_len, HIDDEN_DIM)
top_k_index = torch.randint(0, NUM_EXPERTS, (seq_len, TOP_K))
top_k_weights = torch.softmax(torch.randn(seq_len, TOP_K), dim=-1)
with torch.no_grad():
converted(hidden_states, top_k_index, top_k_weights)
def forward_loop(m):
torch.manual_seed(0)
for _ in range(2):
x = torch.randn(1, 4, HIDDEN_DIM)
m(x)

mtq.quantize(model, quant_cfg, forward_loop=forward_loop)
converted = model.moe.experts

_export_fused_experts(converted, torch.float16)

Expand All @@ -297,8 +321,7 @@ def test_export_creates_per_expert_submodules(self):
assert not hasattr(converted, "down_proj")
assert not hasattr(converted, "gate_up_proj_weight_quantizers")

if QuantModuleRegistry.get(expert_type) is not None:
QuantModuleRegistry.unregister(expert_type)
self._cleanup_registry(expert_type)

def test_uncalibrated_expert_gate_up_share_amax(self, monkeypatch):
"""gate_proj and up_proj must share weight_scale_2 even when an expert
Expand Down
112 changes: 112 additions & 0 deletions tests/unit/torch/quantization/test_nvfp4_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for NVFP4QTensor per-block FP8 scale clamping (underflow + overflow)."""

from types import SimpleNamespace

import torch

from modelopt.torch.quantization.qtensor.nvfp4_tensor import (
NVFP4QTensor,
_cast_per_block_scale_to_fp8,
)

_FP8_E4M3FN_MIN = 2**-9 # 0.001953125 — smallest positive FP8 E4M3FN subnormal
_FP8_E4M3FN_MAX = 448.0


class TestNVFP4ScaleClamping:
"""Per-block weight scales outside the FP8 E4M3FN range must be clamped, not turned into 0/NaN."""

def test_no_zero_scales_for_tiny_weights(self):
"""Tiny per-block amax (<<FP8 min) must not underflow to zero after FP8 cast."""
block_size = 16
tiny_weight = torch.full((4, block_size), 1e-10)
# wsf2=1.0 → per_block_scale = amax/(6*wsf2) ≈ 1.7e-11 << 2^-9, exercises FP8-min clamp
wsf2 = torch.tensor(1.0)

per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor(tiny_weight, block_size, wsf2)
per_block_scale_f32 = per_block_scale.float()

assert (per_block_scale_f32 > 0).all(), (
f"Zero per-block scales found after FP8 cast: {per_block_scale_f32.tolist()}. "
Comment thread
Fridah-nv marked this conversation as resolved.
"FP8 scale underflow clamping likely regressed."
)
assert (per_block_scale_f32 >= _FP8_E4M3FN_MIN).all(), (
"Per-block scales with zero values found after FP8 cast "
"(below the FP8 E4M3FN subnormal minimum — clamp would have prevented this)."
)

def test_normal_weights_unaffected_by_clamp(self):
"""Weights with typical magnitudes must not be affected by the underflow clamp."""
block_size = 16
torch.manual_seed(42)
normal_weight = torch.randn(8, block_size)

per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor(normal_weight, block_size)
assert (per_block_scale.float() > 0).all(), "Normal weights produced zero scales."

def test_mixed_weight_no_zeros(self):
"""Mixed-magnitude tensor (normal + tiny blocks) must have no zero scales."""
block_size = 16
weight = torch.cat(
[
torch.randn(4, block_size),
torch.full((4, block_size), 1e-12),
],
dim=0,
)

per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor(weight, block_size)
assert (per_block_scale.float() > 0).all(), (
"Zero scales in mixed-magnitude tensor after FP8 cast."
)

def test_helper_clamps_overflow_to_max(self):
"""Values above 448 must saturate to 448, not cast to NaN (fp8_e4m3fn has no Inf)."""
oversized = torch.tensor([100.0, 448.0, 1e3, 1e6])
out = _cast_per_block_scale_to_fp8(oversized).float()
assert torch.isfinite(out).all(), f"FP8 cast produced non-finite values: {out.tolist()}"
assert (out <= _FP8_E4M3FN_MAX).all(), f"FP8 cast values exceed 448: {out.tolist()}"

def test_helper_clamps_underflow_to_min(self):
"""Values below the FP8 subnormal must clamp up, not collapse to 0."""
tiny = torch.tensor([0.0, 1e-12, 1e-6, _FP8_E4M3FN_MIN / 2])
out = _cast_per_block_scale_to_fp8(tiny).float()
assert (out > 0).all(), f"FP8 cast produced zero scales: {out.tolist()}"

def test_static_path_no_nan_when_block_amax_zero(self):
"""Static path: zero-amax block + small global_amax must clamp to 448, not cast to NaN."""
block_size = 16
# global_amax small enough that 1.0 * 448 / (global_amax/6) >> 448.
global_amax = torch.tensor(0.01)
# One block with amax=0 (triggers safety net), three normal blocks.
per_block_amax = torch.tensor([[0.0, 0.005, 0.008, 0.01]])
weight = torch.randn(1, 4 * block_size)
q = SimpleNamespace(
global_amax=global_amax,
_amax=per_block_amax,
block_sizes={-1: block_size},
)

per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor_from_quantizer(q, weight)
per_block_scale_f32 = per_block_scale.float()
assert torch.isfinite(per_block_scale_f32).all(), (
f"NaN/Inf in exported static per-block scale: {per_block_scale_f32.tolist()}"
)
assert (per_block_scale_f32 <= _FP8_E4M3FN_MAX).all(), (
f"Static per-block scale exceeds FP8 max 448: {per_block_scale_f32.tolist()}"
)
Loading