Skip to content

Commit a9c8ccf

Browse files
committed
[Quantization] Address review feedback round 3 on FP8 sweep
Three changes from realAsma's latest review: - nvfp4_fp8_sweep kernel: use ``scale_safe`` rather than ``scale`` in the per-candidate diff so the divisor and multiplier match. Numerically equivalent on real inputs (the only case where ``scale_safe`` differs from ``scale`` is ``global_amax == 0``, in which case ``w_abs`` is also zero so the loss is zero either way), but more consistent. - Extract ``fp8_scale_candidates`` to a triton-free module ``_fp8_scale_candidates.py`` so the calibrator's reference sweep and the Triton kernel wrapper share one definition. Removes the duplicate copy in ``NVFP4MSECalibrator._generate_candidates``. - Parity test: extend ``test_parity_random_weights`` to cover bf16 and fp16 in addition to fp32 by parametrizing on dtype, so the canonical parity grid (3 seeds × 3 num_blocks) is now exercised on every supported dtype. Folded the smaller ``test_parity_dtypes`` into this since it was a strict subset. Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
1 parent 95b8a95 commit a9c8ccf

4 files changed

Lines changed: 48 additions & 43 deletions

File tree

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Single source of truth for the NVFP4 FP8 scale-candidate set.
17+
18+
Pure PyTorch, no Triton dependency, so it can be imported from both the kernel
19+
wrapper (which is triton-gated) and the reference Python sweep in the
20+
:class:`NVFP4MSECalibrator` (which must work without triton too).
21+
"""
22+
23+
import torch
24+
25+
26+
def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor:
27+
"""Return the 126 valid finite positive FP8 E4M3 scale candidates / 448."""
28+
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
29+
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
30+
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
31+
return fp8_values[valid_mask] / 448.0

modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,12 @@
3232
import triton
3333
import triton.language as tl
3434

35+
from ._fp8_scale_candidates import fp8_scale_candidates
3536
from .nvfp4_quant import fp4_round_magnitude
3637

3738
__all__ = ["fp8_scale_candidates", "nvfp4_fp8_scale_sweep"]
3839

3940

40-
def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor:
41-
"""Return the 126 valid finite positive FP8 E4M3 scale candidates / 448."""
42-
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
43-
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
44-
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
45-
return fp8_values[valid_mask] / 448.0
46-
47-
4841
# Selected from a (BLOCKS_PER_PROGRAM, num_warps) sweep on B300:
4942
# BPP=16,nw=2: 6.06 ms BPP=32,nw=4: 6.06 ms BPP=64,nw=8: 5.08 ms
5043
# The smaller-tile entries cover cases where N_BLOCKS is small enough that BPP=64
@@ -93,11 +86,11 @@ def _fp8_scale_sweep_kernel(
9386
for k in tl.static_range(NUM_CANDIDATES):
9487
c = tl.load(candidates_ptr + k).to(tl.float32)
9588
scale = c * global_amax / 6.0
96-
# Avoid divide-by-zero when global_amax == 0; the resulting err == w_abs² is
97-
# the same for every candidate, so any best_idx is fine.
89+
# Avoid divide-by-zero when global_amax == 0; in that case w_abs is also zero
90+
# (global_amax = max|w|), so the loss is zero for every candidate either way.
9891
scale_safe = tl.where(scale == 0.0, 1.0, scale)
9992
q_mag = fp4_round_magnitude(w_abs / scale_safe)
100-
diff = w_abs - q_mag * scale
93+
diff = w_abs - q_mag * scale_safe
10194
loss = tl.sum(diff * diff, axis=1) # [BLOCKS_PER_PROGRAM]
10295
is_better = loss < best_loss
10396
best_loss = tl.where(is_better, loss, best_loss)

modelopt/torch/quantization/calib/mse.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -203,17 +203,12 @@ def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor:
203203
return torch.ones_like(self._initial_amax) * self._global_amax * candidates
204204

205205
def _generate_candidates(self, device: torch.device) -> torch.Tensor:
206-
"""Generate 126 valid FP8 E4M3 scale candidates.
206+
"""Generate the 126 valid FP8 E4M3 scale candidates."""
207+
from modelopt.torch.kernels.quantization.gemm._fp8_scale_candidates import (
208+
fp8_scale_candidates,
209+
)
207210

208-
Kept in sync with ``fp8_scale_candidates`` in
209-
``modelopt.torch.kernels.quantization.gemm.nvfp4_fp8_sweep`` — the FP8 E4M3
210-
spec is fixed, and the parity test exercises both paths against each other.
211-
"""
212-
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
213-
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
214-
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
215-
fp8_values = fp8_values[valid_mask]
216-
return fp8_values / 448.0
211+
return fp8_scale_candidates(device)
217212

218213
def _can_use_triton_fast_path(self, x: torch.Tensor) -> bool:
219214
"""Whether the Triton fast path is usable for this ``collect`` input.

tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,17 @@ def _run_triton(x, per_block_amax, global_amax):
8585

8686

8787
@requires_triton
88+
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
8889
@pytest.mark.parametrize("seed", [0, 1, 2])
8990
@pytest.mark.parametrize("num_blocks", [4, 64, 1024])
90-
def test_parity_random_weights(seed, num_blocks):
91-
"""Triton sweep must produce the exact same per-block amax as the reference."""
91+
def test_parity_random_weights(seed, num_blocks, dtype):
92+
"""Triton sweep must produce the exact same per-block amax as the reference,
93+
across every dtype supported by the NVFP4 quantizer (fp32, fp16, bf16)."""
9294
torch.manual_seed(seed)
9395
device = "cuda"
94-
x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=torch.float32)
95-
per_block_amax = x.abs().amax(dim=-1)
96+
x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=dtype)
97+
# Promote to fp32 for the per-block amax (matches what max_calibrate produces).
98+
per_block_amax = x.float().abs().amax(dim=-1)
9699
global_amax = per_block_amax.max()
97100

98101
ref = _run_reference(x, per_block_amax, global_amax)
@@ -102,29 +105,12 @@ def test_parity_random_weights(seed, num_blocks):
102105
# Both pick from the same 126-element discrete candidate set, so any disagreement
103106
# would show up as a non-zero diff (not a small float epsilon). Demand exact match.
104107
assert torch.equal(ref, tri), (
105-
f"Triton sweep diverged from reference: max |diff| = "
108+
f"Triton sweep diverged from reference (dtype={dtype}): max |diff| = "
106109
f"{(ref - tri).abs().max().item():.3e}, "
107110
f"differing blocks = {(ref != tri).sum().item()} / {num_blocks}"
108111
)
109112

110113

111-
@requires_triton
112-
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
113-
def test_parity_dtypes(dtype):
114-
"""Sweep must agree across the dtypes supported by the NVFP4 quantizer."""
115-
torch.manual_seed(42)
116-
device = "cuda"
117-
num_blocks = 256
118-
x = torch.randn(num_blocks, BLOCK_SIZE, device=device, dtype=dtype)
119-
# Promote to fp32 for the per-block amax (matches what max_calibrate produces).
120-
per_block_amax = x.float().abs().amax(dim=-1)
121-
global_amax = per_block_amax.max()
122-
123-
ref = _run_reference(x, per_block_amax, global_amax)
124-
tri = _run_triton(x, per_block_amax, global_amax)
125-
assert torch.equal(ref, tri)
126-
127-
128114
@requires_triton
129115
def test_quantized_output_matches():
130116
"""Round-tripping x through the chosen amax should give the same fake-quant result."""

0 commit comments

Comments
 (0)