Skip to content

Commit 928ab1c

Browse files
cael-lingzhongbozhu
andcommitted
Add NVFP4 per-token GEMM, fused grouped amax, cast, tests and benches
Rewrites the grouped multi-tensor cast as a K1 fused amax + K2 fused cast pair and ships pytest correctness + sweep benches against the per-tensor RHT+SR production baseline. * common/cast/.../quantize_nvfp4_per_token_group.cu: K1+K2 fused grouped kernel, reusing the single-tensor 4-stage TMA pipeline. * common/gemm/nvfp4_per_token_post_scale.cu: row-wise post-scale kernel for the cuBLASLT NVFP4 dequantize step (maybe updated due to 2d quant of W). * pytorch/csrc/extensions/nvfp4_per_token.cpp + pybind.cpp: new C++ grouped bulk binding and per-token GEMM entry; thin pybind layer. * pytorch/custom_recipes/{gemm_nvfp4_per_token, quantization_nvfp4_per_token_group}.py: Python wrappers. * tests/pytorch/nvfp4/test_nvfp4_per_token{,_group}.py: byte-equal cast tests + bf16-close GEMM tests. * tests/pytorch/nvfp4/bench_nvfp4_per_token{,_group}.py: 6x3 sweep over M in {1024..32768} x K in {2048,4096,8192}, eager + CUDA Graphs columns, ratio against per-tensor RHT+SR baseline. Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
1 parent 80ea313 commit 928ab1c

17 files changed

Lines changed: 5157 additions & 2 deletions
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
"""Bench NVFP4 per-token K1+K2 quant vs per-tensor RHT+SR baseline.
6+
7+
Quant-only (no GEMM). Both sides time the K1 (amax) + K2 (cast) composite on
8+
activation A, rowwise+columnwise. Requires bf16 input, M % 128 == 0, K % 128 == 0.
9+
"""
10+
11+
from __future__ import annotations
12+
13+
import argparse
14+
import math
15+
import statistics
16+
import sys
17+
from dataclasses import dataclass
18+
from typing import Callable, List, Tuple
19+
20+
import torch
21+
22+
# Import transformer_engine first so libtransformer_engine.so is dlopen'd
23+
# before transformer_engine_torch tries to resolve its typeinfo symbols.
24+
import transformer_engine.pytorch as te # noqa: F401
25+
import transformer_engine_torch as tex
26+
from transformer_engine.pytorch import NVFP4Quantizer
27+
28+
29+
def cuda_time_ms(fn: Callable[[], None], *, warmup: int = 5, iters: int = 50) -> float:
30+
"""Median wall time of fn over iters invocations, in ms."""
31+
for _ in range(warmup):
32+
fn()
33+
torch.cuda.synchronize()
34+
35+
starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
36+
ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
37+
for i in range(iters):
38+
starts[i].record()
39+
fn()
40+
ends[i].record()
41+
torch.cuda.synchronize()
42+
samples = [starts[i].elapsed_time(ends[i]) for i in range(iters)]
43+
return statistics.median(samples)
44+
45+
46+
def cuda_graph_time_ms(
47+
fn: Callable[[], object], *, warmup: int = 5, iters: int = 50
48+
) -> float:
49+
"""Median g.replay() wall time of fn captured into a CUDA Graph (kernel-only floor).
50+
51+
Returns nan if capture fails.
52+
"""
53+
try:
54+
side = torch.cuda.Stream()
55+
side.wait_stream(torch.cuda.current_stream())
56+
with torch.cuda.stream(side):
57+
for _ in range(warmup):
58+
fn()
59+
torch.cuda.current_stream().wait_stream(side)
60+
torch.cuda.synchronize()
61+
62+
g = torch.cuda.CUDAGraph()
63+
with torch.cuda.graph(g):
64+
fn()
65+
except Exception as e:
66+
print(f" [graph capture skipped: {type(e).__name__}: {e}]", file=sys.stderr)
67+
return float("nan")
68+
69+
starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
70+
ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
71+
for i in range(iters):
72+
starts[i].record()
73+
g.replay()
74+
ends[i].record()
75+
torch.cuda.synchronize()
76+
samples = [starts[i].elapsed_time(ends[i]) for i in range(iters)]
77+
return statistics.median(samples)
78+
79+
80+
def _make_baseline_quantizer() -> NVFP4Quantizer:
81+
"""Per-tensor baseline quantizer: RHT + SR + random sign mask."""
82+
return NVFP4Quantizer(
83+
fp4_dtype=tex.DType.kFloat4E2M1,
84+
rowwise=True,
85+
columnwise=True,
86+
with_amax_reduction=False,
87+
amax_reduction_group=None,
88+
with_rht=True,
89+
with_post_rht_amax=True,
90+
with_2d_quantization=False,
91+
stochastic_rounding=True,
92+
with_random_sign_mask=True,
93+
)
94+
95+
96+
def _has_sm100() -> bool:
97+
if not torch.cuda.is_available():
98+
return False
99+
major, _ = torch.cuda.get_device_capability()
100+
return major >= 10
101+
102+
103+
@dataclass
104+
class ShapeBench:
105+
M: int
106+
K: int
107+
t_pt: float # per-token full K1+K2 (eager pybind, ms)
108+
t_pten: float # per-tensor full K1+K2 (eager pybind, ms)
109+
t_pt_g: float # per-token under CUDA Graphs replay (ms)
110+
t_pten_g: float # per-tensor under CUDA Graphs replay (ms)
111+
112+
113+
def _bench_shape(M: int, K: int, *, device: torch.device) -> ShapeBench:
114+
"""Time per-tensor vs per-token K1+K2 quant at one (M, K) shape."""
115+
a = torch.randn((M, K), dtype=torch.bfloat16, device=device)
116+
117+
# Per-tensor quantizer + A output tensor.
118+
quantizer = _make_baseline_quantizer()
119+
dst_a = quantizer.make_empty(a.shape, dtype=torch.bfloat16, device=device)
120+
121+
# Per-token A-side buffers: BLOCK_K=16 (1x16 e4m3 inner SF).
122+
BLOCK_K = 16
123+
ra_a = torch.empty((M,), dtype=torch.float32, device=device)
124+
ca_a = torch.empty((K,), dtype=torch.float32, device=device)
125+
q_row_a = torch.empty((M, K // 2), dtype=torch.uint8, device=device)
126+
s_dec_row_a = torch.empty((M, K // BLOCK_K), dtype=torch.uint8, device=device)
127+
q_col_a = torch.empty((K, M // 2), dtype=torch.uint8, device=device)
128+
s_dec_col_a = torch.empty((K, M // BLOCK_K), dtype=torch.uint8, device=device)
129+
130+
def _baseline_quant_fn():
131+
tex.quantize(a, quantizer, dst_a, None)
132+
133+
def _pt_full_quant_fn():
134+
tex.nvfp4_per_token_quantize(
135+
a, q_row_a, s_dec_row_a, ra_a, q_col_a, s_dec_col_a, ca_a, True, True,
136+
)
137+
138+
t_pten = cuda_time_ms(_baseline_quant_fn)
139+
t_pt = cuda_time_ms(_pt_full_quant_fn)
140+
t_pten_g = cuda_graph_time_ms(_baseline_quant_fn)
141+
t_pt_g = cuda_graph_time_ms(_pt_full_quant_fn)
142+
143+
return ShapeBench(M=M, K=K, t_pt=t_pt, t_pten=t_pten, t_pt_g=t_pt_g, t_pten_g=t_pten_g)
144+
145+
146+
# 6x3 sweep matching bench_nvfp4_per_token_group.py: M in {1024..32768}, K in {2048,4096,8192}.
147+
_M_VALUES: Tuple[int, ...] = (1024, 2048, 4096, 8192, 16384, 32768)
148+
_K_VALUES: Tuple[int, ...] = (2048, 4096, 8192)
149+
_DEFAULT_SHAPES: Tuple[Tuple[int, int], ...] = tuple(
150+
(m, k) for m in _M_VALUES for k in _K_VALUES
151+
)
152+
153+
154+
def _parse_shape(s: str) -> Tuple[int, int]:
155+
parts = s.split("x")
156+
if len(parts) != 2:
157+
raise argparse.ArgumentTypeError(f"Shape must be MxK, got '{s}'")
158+
return tuple(int(p) for p in parts) # type: ignore[return-value]
159+
160+
161+
def _ratio(num: float, den: float) -> float:
162+
if den <= 0 or math.isnan(num) or math.isnan(den):
163+
return float("nan")
164+
return num / den
165+
166+
167+
def main() -> int:
168+
parser = argparse.ArgumentParser(
169+
description="Benchmark NVFP4 per-token K1+K2 quant vs per-tensor production NVFP4."
170+
)
171+
parser.add_argument(
172+
"--shapes", type=_parse_shape, nargs="+", default=None,
173+
help="Shapes to bench, in MxK form (e.g. 4096x4096). "
174+
"Default: an internally-chosen production-shape sweep.",
175+
)
176+
args = parser.parse_args()
177+
178+
if not _has_sm100():
179+
print("SKIP: NVFP4 per-token quant requires SM100 (Blackwell).", file=sys.stderr)
180+
return 1
181+
182+
device = torch.device("cuda")
183+
shapes = list(args.shapes) if args.shapes else list(_DEFAULT_SHAPES)
184+
185+
records: List[ShapeBench] = [_bench_shape(M, K, device=device) for (M, K) in shapes]
186+
187+
header = (
188+
f"{'M':>7} {'K':>6}"
189+
f" |"
190+
f"{'per-token':>10} {'per-tensor':>11} {'ratio':>8}"
191+
f" |"
192+
f"{'per-token(Graph)':>17} {'per-tensor(Graph)':>18} {'ratio(Graph)':>13}"
193+
)
194+
print(header)
195+
print("-" * len(header))
196+
prev_M = None
197+
for rec in records:
198+
if prev_M is not None and rec.M != prev_M:
199+
print()
200+
prev_M = rec.M
201+
ratio = _ratio(rec.t_pt, rec.t_pten)
202+
ratio_g = _ratio(rec.t_pt_g, rec.t_pten_g)
203+
ratio_s = "nan" if math.isnan(ratio) else f"{ratio:.2f}x"
204+
ratio_g_s = "nan" if math.isnan(ratio_g) else f"{ratio_g:.2f}x"
205+
print(
206+
f"{rec.M:>7} {rec.K:>6}"
207+
f" |"
208+
f"{rec.t_pt:>10.4f} {rec.t_pten:>11.4f} {ratio_s:>8}"
209+
f" |"
210+
f"{rec.t_pt_g:>17.4f} {rec.t_pten_g:>18.4f} {ratio_g_s:>13}"
211+
)
212+
213+
return 0
214+
215+
216+
if __name__ == "__main__":
217+
sys.exit(main())

0 commit comments

Comments
 (0)