Skip to content

Commit b75755d

Browse files
committed
Merge remote-tracking branch 'origin/main' into shengliangx/local-jsonl-dataset
2 parents 9e0cf09 + 1d796f9 commit b75755d

7 files changed

Lines changed: 677 additions & 12 deletions

File tree

modelopt/torch/kernels/quantization/gemm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
# fp4_kernel works on any CUDA GPU with triton
3333
from .fp4_kernel import *
3434
from .fp8_kernel import *
35+
from .nvfp4_fp8_sweep import *
3536

3637
# fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv)
3738
if torch.cuda.get_device_capability() >= (8, 9):
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Single source of truth for the NVFP4 FP8 scale-candidate set.
17+
18+
Pure PyTorch, no Triton dependency, so it can be imported from both the kernel
19+
wrapper (which is triton-gated) and the reference Python sweep in the
20+
:class:`NVFP4MSECalibrator` (which must work without triton too).
21+
"""
22+
23+
import torch
24+
25+
26+
def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor:
27+
"""Return the 126 valid finite positive FP8 E4M3 scale candidates / 448."""
28+
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
29+
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
30+
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
31+
return fp8_values[valid_mask] / 448.0
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Fused Triton kernel for the NVFP4 weight-MSE FP8 scale sweep.
17+
18+
Replaces the 126-iteration Python sweep in :class:`NVFP4MSECalibrator` with a single
19+
kernel that, for each NVFP4 block, evaluates all 126 valid FP8 E4M3 scale candidates
20+
and emits the per-block ``best_amax`` directly.
21+
22+
The 126 candidates are constructed as ``valid_fp8_e4m3_value / 448`` (see
23+
:func:`fp8_scale_candidates`). For these specific candidates, the FP8 round-trip on
24+
the per-block scale is the identity, so the kernel can use
25+
``scale = candidate * global_amax / 6.0`` without an explicit FP8 cast — making it
26+
runnable on any CUDA GPU with Triton (no ``tl.float8e4nv`` requirement).
27+
28+
Tile shape (``BLOCKS_PER_PROGRAM``) and ``num_warps`` are autotuned per ``N_BLOCKS``.
29+
"""
30+
31+
import torch
32+
import triton
33+
import triton.language as tl
34+
35+
from ._fp8_scale_candidates import fp8_scale_candidates
36+
from .nvfp4_quant import fp4_round_magnitude
37+
38+
__all__ = ["fp8_scale_candidates", "nvfp4_fp8_scale_sweep"]
39+
40+
41+
# Selected from a (BLOCKS_PER_PROGRAM, num_warps) sweep on B300:
42+
# BPP=16,nw=2: 6.06 ms BPP=32,nw=4: 6.06 ms BPP=64,nw=8: 5.08 ms
43+
# The smaller-tile entries cover cases where N_BLOCKS is small enough that BPP=64
44+
# would underfill the SMs.
45+
_FP8_SWEEP_AUTOTUNE_CONFIGS = [
46+
triton.Config({"BLOCKS_PER_PROGRAM": 16}, num_warps=2),
47+
triton.Config({"BLOCKS_PER_PROGRAM": 32}, num_warps=4),
48+
triton.Config({"BLOCKS_PER_PROGRAM": 64}, num_warps=8),
49+
]
50+
51+
52+
@triton.autotune(configs=_FP8_SWEEP_AUTOTUNE_CONFIGS, key=["N_BLOCKS"])
53+
@triton.jit
54+
def _fp8_scale_sweep_kernel(
55+
x_ptr, # [N_BLOCKS * BLOCK_SIZE], any float dtype (loaded as fp32)
56+
candidates_ptr, # [NUM_CANDIDATES] fp32
57+
global_amax_ptr, # scalar fp32
58+
best_amax_ptr, # [N_BLOCKS] fp32 output
59+
N_BLOCKS,
60+
BLOCK_SIZE: tl.constexpr,
61+
NUM_CANDIDATES: tl.constexpr,
62+
BLOCKS_PER_PROGRAM: tl.constexpr,
63+
):
64+
pid = tl.program_id(axis=0)
65+
block_start = pid * BLOCKS_PER_PROGRAM
66+
block_idx = block_start + tl.arange(0, BLOCKS_PER_PROGRAM)
67+
block_mask = block_idx < N_BLOCKS
68+
69+
# Load weights for this tile and pre-compute their absolute values once.
70+
# The squared error is sign-invariant since FP4 quant preserves sign:
71+
# (w - w_q)^2 = (|w| - |w_q|)^2 = (|w| - q_mag * scale)^2
72+
# so we never need ``w`` itself again, dropping a tl.where + negation per element.
73+
elem_offs = block_idx[:, None] * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :]
74+
elem_mask = block_mask[:, None]
75+
w_abs = tl.abs(tl.load(x_ptr + elem_offs, mask=elem_mask, other=0.0).to(tl.float32))
76+
77+
global_amax = tl.load(global_amax_ptr).to(tl.float32)
78+
79+
best_loss = tl.full([BLOCKS_PER_PROGRAM], float("inf"), dtype=tl.float32)
80+
best_idx = tl.zeros([BLOCKS_PER_PROGRAM], dtype=tl.int32)
81+
82+
# Loop over the 126 FP8 candidates (compile-time unrolled).
83+
# Scales are guaranteed positive and finite (constructed from a positive candidate
84+
# times nonneg global_amax), so the degenerate-scale guard from nvfp4_scalar_quant is
85+
# unnecessary apart from the global_amax == 0 case handled below.
86+
for k in tl.static_range(NUM_CANDIDATES):
87+
c = tl.load(candidates_ptr + k).to(tl.float32)
88+
scale = c * global_amax / 6.0
89+
# Avoid divide-by-zero when global_amax == 0; in that case w_abs is also zero
90+
# (global_amax = max|w|), so the loss is zero for every candidate either way.
91+
scale_safe = tl.where(scale == 0.0, 1.0, scale)
92+
q_mag = fp4_round_magnitude(w_abs / scale_safe)
93+
diff = w_abs - q_mag * scale_safe
94+
loss = tl.sum(diff * diff, axis=1) # [BLOCKS_PER_PROGRAM]
95+
is_better = loss < best_loss
96+
best_loss = tl.where(is_better, loss, best_loss)
97+
best_idx = tl.where(is_better, k, best_idx)
98+
99+
# Map each block's winning candidate index back to its amax = global_amax * c[best].
100+
best_c = tl.load(candidates_ptr + best_idx, mask=block_mask, other=0.0).to(tl.float32)
101+
best_amax = global_amax * best_c
102+
tl.store(best_amax_ptr + block_idx, best_amax, mask=block_mask)
103+
104+
105+
def nvfp4_fp8_scale_sweep(
106+
x: torch.Tensor,
107+
global_amax: torch.Tensor,
108+
block_size: int = 16,
109+
) -> torch.Tensor:
110+
"""Find the per-block FP8 scale that minimizes NVFP4 quantization MSE.
111+
112+
Equivalent to the 126-step sweep in :class:`NVFP4MSECalibrator`, but fused into
113+
a single Triton kernel: every block's weight elements are loaded once, all 126
114+
candidates are evaluated in registers, and the running argmin is kept inline.
115+
116+
Args:
117+
x: Weight tensor on CUDA. Total element count must be divisible by
118+
``block_size``; layout is treated as a flat ``[N_BLOCKS, BLOCK_SIZE]``.
119+
global_amax: Scalar FP32 global amax (``= reduce_amax(per_block_amax)``).
120+
block_size: NVFP4 block size (typically 16).
121+
122+
Returns:
123+
``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``.
124+
"""
125+
if not x.is_cuda:
126+
raise ValueError("nvfp4_fp8_scale_sweep requires a CUDA tensor.")
127+
if not isinstance(block_size, int) or block_size <= 0:
128+
raise ValueError(f"block_size must be a positive int, got {block_size!r}.")
129+
if x.numel() % block_size != 0:
130+
raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).")
131+
132+
candidates = fp8_scale_candidates(x.device).to(dtype=torch.float32)
133+
134+
n_blocks = x.numel() // block_size
135+
x_flat = x.contiguous().view(-1)
136+
global_amax_f32 = global_amax.detach().to(device=x.device, dtype=torch.float32).reshape(1)
137+
best_amax = torch.empty(n_blocks, dtype=torch.float32, device=x.device)
138+
139+
grid = lambda meta: (triton.cdiv(n_blocks, meta["BLOCKS_PER_PROGRAM"]),)
140+
with torch.cuda.device(x.device):
141+
_fp8_scale_sweep_kernel[grid](
142+
x_flat,
143+
candidates,
144+
global_amax_f32,
145+
best_amax,
146+
n_blocks,
147+
BLOCK_SIZE=block_size,
148+
NUM_CANDIDATES=int(candidates.numel()),
149+
)
150+
return best_amax

modelopt/torch/quantization/calib/mse.py

Lines changed: 86 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""Calibrator that returns the MSE amax of all collected tensors."""
1717

1818
import math
19+
import os
1920
from collections.abc import Callable
2021

2122
import torch
@@ -172,7 +173,15 @@ def compute_amax(self, verbose: bool = False):
172173

173174

174175
class NVFP4MSECalibrator(MseCalibrator):
175-
"""Per-block FP8 scale sweep calibrator for NVFP4 static quantization."""
176+
"""Per-block FP8 scale sweep calibrator for NVFP4 static quantization.
177+
178+
Uses a fused Triton kernel as an internal fast path on the first ``collect`` call
179+
when (a) ``error_func is None``, (b) the input tensor is on CUDA in the standard
180+
blocked ``[n_blocks, block_size]`` layout, and (c) Triton + the kernel package are
181+
importable. Falls back to the reference 126-step Python sweep otherwise (custom
182+
``error_func`` users, multi-``collect`` activation flows, CPU inputs, or when the
183+
fast path is disabled via ``MODELOPT_NVFP4_TRITON_SWEEP=0``).
184+
"""
176185

177186
def __init__(
178187
self,
@@ -185,16 +194,86 @@ def __init__(
185194
"""Initialize NVFP4 MSE calibrator with per-block and global amax."""
186195
super().__init__(amax=amax, axis=axis, quant_func=quant_func, error_func=error_func)
187196
self._global_amax = global_amax
197+
# Set by the Triton fast path on its (one-shot) collect; consumed by compute_amax.
198+
self._best_amax_fast: torch.Tensor | None = None
188199

189200
def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor:
190201
if candidates.ndim != 0: # Called during final compute amax
191202
candidates = candidates.view_as(self._initial_amax)
192203
return torch.ones_like(self._initial_amax) * self._global_amax * candidates
193204

194205
def _generate_candidates(self, device: torch.device) -> torch.Tensor:
195-
"""Generate 126 valid FP8 E4M3 scale candidates."""
196-
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
197-
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
198-
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
199-
fp8_values = fp8_values[valid_mask]
200-
return fp8_values / 448.0
206+
"""Generate the 126 valid FP8 E4M3 scale candidates."""
207+
from modelopt.torch.kernels.quantization.gemm._fp8_scale_candidates import (
208+
fp8_scale_candidates,
209+
)
210+
211+
return fp8_scale_candidates(device)
212+
213+
def _can_use_triton_fast_path(self, x: torch.Tensor) -> bool:
214+
"""Whether the Triton fast path is usable for this ``collect`` input.
215+
216+
The kernel produces the final per-block amax in one shot, so it's only usable
217+
when the caller wants the standard squared-error sweep on a single CUDA tensor
218+
whose layout already matches the per-block amax.
219+
"""
220+
if self._error_func is not None:
221+
return False
222+
if not x.is_cuda:
223+
return False
224+
if os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") == "0":
225+
return False
226+
if self._initial_amax is None:
227+
return False
228+
if x.ndim != 2 or x.shape[0] != int(self._initial_amax.numel()):
229+
return False
230+
try:
231+
from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep # noqa: F401
232+
except ImportError:
233+
return False
234+
return True
235+
236+
@torch.no_grad()
237+
def collect(self, x: torch.Tensor):
238+
"""Collect input statistics. Uses the Triton fast path when eligible."""
239+
if self._best_amax_fast is not None:
240+
raise RuntimeError(
241+
"NVFP4MSECalibrator: the Triton fast path produced a final amax on a "
242+
"previous collect() call; multi-collect after the fast path is not "
243+
"supported. Call reset() to start a fresh cycle, set "
244+
"MODELOPT_NVFP4_TRITON_SWEEP=0, or pass a non-None error_func to force "
245+
"the reference path for activation-style accumulation."
246+
)
247+
# Fast path is eligible only on the first call, before the reference accumulator
248+
# has produced any state.
249+
if self._losses_sum is None and self._can_use_triton_fast_path(x):
250+
from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep
251+
252+
best_flat = nvfp4_fp8_scale_sweep(x.detach(), self._global_amax, block_size=x.shape[-1])
253+
# Match the original shape/dtype of the initial amax so downstream
254+
# load_calib_amax behaves identically to the reference path.
255+
self._best_amax_fast = best_flat.reshape(self._initial_amax.shape).to(
256+
self._initial_amax.dtype
257+
)
258+
return
259+
super().collect(x)
260+
261+
@torch.no_grad()
262+
def compute_amax(self, verbose: bool = False):
263+
"""Return the per-block amax — from the fast path if it ran, else from the reference sweep."""
264+
if self._best_amax_fast is not None:
265+
return self._best_amax_fast
266+
return super().compute_amax(verbose=verbose)
267+
268+
def reset(self):
269+
"""Reset per-cycle state. Keep ``_initial_amax`` so the calibrator stays reusable.
270+
271+
``MseCalibrator.reset()`` intentionally drops ``_initial_amax`` to free memory in
272+
the multi-step search, but the NVFP4 per-block amax is shape ``[num_blocks]`` —
273+
small enough to keep so a follow-up ``collect()`` can run again on the same
274+
calibrator instance.
275+
"""
276+
self._best_amax_fast = None
277+
self._losses_sum = None
278+
self._candidates = None
279+
self._amax = None

modelopt/torch/quantization/model_calib.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,9 @@ def mse_calibrate(
391391
continue
392392

393393
if fp8_scale_sweep and is_nvfp4_static:
394-
# Replace calibrator with NVFP4MSECalibrator
394+
# NVFP4MSECalibrator internally selects a fused Triton kernel for
395+
# the standard squared-error sweep; set MODELOPT_NVFP4_TRITON_SWEEP=0
396+
# to force the reference Python sweep for debugging.
395397
module._calibrator = NVFP4MSECalibrator(
396398
amax=initial_amax,
397399
axis=module._calibrator._axis,

0 commit comments

Comments
 (0)