|
| 1 | +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# |
| 3 | +# See LICENSE for license information. |
| 4 | + |
| 5 | +"""Bench NVFP4 per-token K1+K2 quant vs per-tensor RHT+SR baseline. |
| 6 | +
|
| 7 | +Quant-only (no GEMM). Both sides time the K1 (amax) + K2 (cast) composite on |
| 8 | +activation A, rowwise+columnwise. Requires bf16 input, M % 128 == 0, K % 128 == 0. |
| 9 | +""" |
| 10 | + |
| 11 | +from __future__ import annotations |
| 12 | + |
| 13 | +import argparse |
| 14 | +import math |
| 15 | +import statistics |
| 16 | +import sys |
| 17 | +from dataclasses import dataclass |
| 18 | +from typing import Callable, List, Tuple |
| 19 | + |
| 20 | +import torch |
| 21 | + |
| 22 | +# Import transformer_engine first so libtransformer_engine.so is dlopen'd |
| 23 | +# before transformer_engine_torch tries to resolve its typeinfo symbols. |
| 24 | +import transformer_engine.pytorch as te # noqa: F401 |
| 25 | +import transformer_engine_torch as tex |
| 26 | +from transformer_engine.pytorch import NVFP4Quantizer |
| 27 | + |
| 28 | + |
| 29 | +def cuda_time_ms(fn: Callable[[], None], *, warmup: int = 5, iters: int = 50) -> float: |
| 30 | + """Median wall time of fn over iters invocations, in ms.""" |
| 31 | + for _ in range(warmup): |
| 32 | + fn() |
| 33 | + torch.cuda.synchronize() |
| 34 | + |
| 35 | + starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] |
| 36 | + ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] |
| 37 | + for i in range(iters): |
| 38 | + starts[i].record() |
| 39 | + fn() |
| 40 | + ends[i].record() |
| 41 | + torch.cuda.synchronize() |
| 42 | + samples = [starts[i].elapsed_time(ends[i]) for i in range(iters)] |
| 43 | + return statistics.median(samples) |
| 44 | + |
| 45 | + |
| 46 | +def cuda_graph_time_ms( |
| 47 | + fn: Callable[[], object], *, warmup: int = 5, iters: int = 50 |
| 48 | +) -> float: |
| 49 | + """Median g.replay() wall time of fn captured into a CUDA Graph (kernel-only floor). |
| 50 | +
|
| 51 | + Returns nan if capture fails. |
| 52 | + """ |
| 53 | + try: |
| 54 | + side = torch.cuda.Stream() |
| 55 | + side.wait_stream(torch.cuda.current_stream()) |
| 56 | + with torch.cuda.stream(side): |
| 57 | + for _ in range(warmup): |
| 58 | + fn() |
| 59 | + torch.cuda.current_stream().wait_stream(side) |
| 60 | + torch.cuda.synchronize() |
| 61 | + |
| 62 | + g = torch.cuda.CUDAGraph() |
| 63 | + with torch.cuda.graph(g): |
| 64 | + fn() |
| 65 | + except Exception as e: |
| 66 | + print(f" [graph capture skipped: {type(e).__name__}: {e}]", file=sys.stderr) |
| 67 | + return float("nan") |
| 68 | + |
| 69 | + starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] |
| 70 | + ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] |
| 71 | + for i in range(iters): |
| 72 | + starts[i].record() |
| 73 | + g.replay() |
| 74 | + ends[i].record() |
| 75 | + torch.cuda.synchronize() |
| 76 | + samples = [starts[i].elapsed_time(ends[i]) for i in range(iters)] |
| 77 | + return statistics.median(samples) |
| 78 | + |
| 79 | + |
| 80 | +def _make_baseline_quantizer() -> NVFP4Quantizer: |
| 81 | + """Per-tensor baseline quantizer: RHT + SR + random sign mask.""" |
| 82 | + return NVFP4Quantizer( |
| 83 | + fp4_dtype=tex.DType.kFloat4E2M1, |
| 84 | + rowwise=True, |
| 85 | + columnwise=True, |
| 86 | + with_amax_reduction=False, |
| 87 | + amax_reduction_group=None, |
| 88 | + with_rht=True, |
| 89 | + with_post_rht_amax=True, |
| 90 | + with_2d_quantization=False, |
| 91 | + stochastic_rounding=True, |
| 92 | + with_random_sign_mask=True, |
| 93 | + ) |
| 94 | + |
| 95 | + |
| 96 | +def _has_sm100() -> bool: |
| 97 | + if not torch.cuda.is_available(): |
| 98 | + return False |
| 99 | + major, _ = torch.cuda.get_device_capability() |
| 100 | + return major >= 10 |
| 101 | + |
| 102 | + |
| 103 | +@dataclass |
| 104 | +class ShapeBench: |
| 105 | + M: int |
| 106 | + K: int |
| 107 | + t_pt: float # per-token full K1+K2 (eager pybind, ms) |
| 108 | + t_pten: float # per-tensor full K1+K2 (eager pybind, ms) |
| 109 | + t_pt_g: float # per-token under CUDA Graphs replay (ms) |
| 110 | + t_pten_g: float # per-tensor under CUDA Graphs replay (ms) |
| 111 | + |
| 112 | + |
| 113 | +def _bench_shape(M: int, K: int, *, device: torch.device) -> ShapeBench: |
| 114 | + """Time per-tensor vs per-token K1+K2 quant at one (M, K) shape.""" |
| 115 | + a = torch.randn((M, K), dtype=torch.bfloat16, device=device) |
| 116 | + |
| 117 | + # Per-tensor quantizer + A output tensor. |
| 118 | + quantizer = _make_baseline_quantizer() |
| 119 | + dst_a = quantizer.make_empty(a.shape, dtype=torch.bfloat16, device=device) |
| 120 | + |
| 121 | + # Per-token A-side buffers: BLOCK_K=16 (1x16 e4m3 inner SF). |
| 122 | + BLOCK_K = 16 |
| 123 | + ra_a = torch.empty((M,), dtype=torch.float32, device=device) |
| 124 | + ca_a = torch.empty((K,), dtype=torch.float32, device=device) |
| 125 | + q_row_a = torch.empty((M, K // 2), dtype=torch.uint8, device=device) |
| 126 | + s_dec_row_a = torch.empty((M, K // BLOCK_K), dtype=torch.uint8, device=device) |
| 127 | + q_col_a = torch.empty((K, M // 2), dtype=torch.uint8, device=device) |
| 128 | + s_dec_col_a = torch.empty((K, M // BLOCK_K), dtype=torch.uint8, device=device) |
| 129 | + |
| 130 | + def _baseline_quant_fn(): |
| 131 | + tex.quantize(a, quantizer, dst_a, None) |
| 132 | + |
| 133 | + def _pt_full_quant_fn(): |
| 134 | + tex.nvfp4_per_token_quantize( |
| 135 | + a, q_row_a, s_dec_row_a, ra_a, q_col_a, s_dec_col_a, ca_a, True, True, |
| 136 | + ) |
| 137 | + |
| 138 | + t_pten = cuda_time_ms(_baseline_quant_fn) |
| 139 | + t_pt = cuda_time_ms(_pt_full_quant_fn) |
| 140 | + t_pten_g = cuda_graph_time_ms(_baseline_quant_fn) |
| 141 | + t_pt_g = cuda_graph_time_ms(_pt_full_quant_fn) |
| 142 | + |
| 143 | + return ShapeBench(M=M, K=K, t_pt=t_pt, t_pten=t_pten, t_pt_g=t_pt_g, t_pten_g=t_pten_g) |
| 144 | + |
| 145 | + |
| 146 | +# 6x3 sweep matching bench_nvfp4_per_token_group.py: M in {1024..32768}, K in {2048,4096,8192}. |
| 147 | +_M_VALUES: Tuple[int, ...] = (1024, 2048, 4096, 8192, 16384, 32768) |
| 148 | +_K_VALUES: Tuple[int, ...] = (2048, 4096, 8192) |
| 149 | +_DEFAULT_SHAPES: Tuple[Tuple[int, int], ...] = tuple( |
| 150 | + (m, k) for m in _M_VALUES for k in _K_VALUES |
| 151 | +) |
| 152 | + |
| 153 | + |
| 154 | +def _parse_shape(s: str) -> Tuple[int, int]: |
| 155 | + parts = s.split("x") |
| 156 | + if len(parts) != 2: |
| 157 | + raise argparse.ArgumentTypeError(f"Shape must be MxK, got '{s}'") |
| 158 | + return tuple(int(p) for p in parts) # type: ignore[return-value] |
| 159 | + |
| 160 | + |
| 161 | +def _ratio(num: float, den: float) -> float: |
| 162 | + if den <= 0 or math.isnan(num) or math.isnan(den): |
| 163 | + return float("nan") |
| 164 | + return num / den |
| 165 | + |
| 166 | + |
| 167 | +def main() -> int: |
| 168 | + parser = argparse.ArgumentParser( |
| 169 | + description="Benchmark NVFP4 per-token K1+K2 quant vs per-tensor production NVFP4." |
| 170 | + ) |
| 171 | + parser.add_argument( |
| 172 | + "--shapes", type=_parse_shape, nargs="+", default=None, |
| 173 | + help="Shapes to bench, in MxK form (e.g. 4096x4096). " |
| 174 | + "Default: an internally-chosen production-shape sweep.", |
| 175 | + ) |
| 176 | + args = parser.parse_args() |
| 177 | + |
| 178 | + if not _has_sm100(): |
| 179 | + print("SKIP: NVFP4 per-token quant requires SM100 (Blackwell).", file=sys.stderr) |
| 180 | + return 1 |
| 181 | + |
| 182 | + device = torch.device("cuda") |
| 183 | + shapes = list(args.shapes) if args.shapes else list(_DEFAULT_SHAPES) |
| 184 | + |
| 185 | + records: List[ShapeBench] = [_bench_shape(M, K, device=device) for (M, K) in shapes] |
| 186 | + |
| 187 | + header = ( |
| 188 | + f"{'M':>7} {'K':>6}" |
| 189 | + f" |" |
| 190 | + f"{'per-token':>10} {'per-tensor':>11} {'ratio':>8}" |
| 191 | + f" |" |
| 192 | + f"{'per-token(Graph)':>17} {'per-tensor(Graph)':>18} {'ratio(Graph)':>13}" |
| 193 | + ) |
| 194 | + print(header) |
| 195 | + print("-" * len(header)) |
| 196 | + prev_M = None |
| 197 | + for rec in records: |
| 198 | + if prev_M is not None and rec.M != prev_M: |
| 199 | + print() |
| 200 | + prev_M = rec.M |
| 201 | + ratio = _ratio(rec.t_pt, rec.t_pten) |
| 202 | + ratio_g = _ratio(rec.t_pt_g, rec.t_pten_g) |
| 203 | + ratio_s = "nan" if math.isnan(ratio) else f"{ratio:.2f}x" |
| 204 | + ratio_g_s = "nan" if math.isnan(ratio_g) else f"{ratio_g:.2f}x" |
| 205 | + print( |
| 206 | + f"{rec.M:>7} {rec.K:>6}" |
| 207 | + f" |" |
| 208 | + f"{rec.t_pt:>10.4f} {rec.t_pten:>11.4f} {ratio_s:>8}" |
| 209 | + f" |" |
| 210 | + f"{rec.t_pt_g:>17.4f} {rec.t_pten_g:>18.4f} {ratio_g_s:>13}" |
| 211 | + ) |
| 212 | + |
| 213 | + return 0 |
| 214 | + |
| 215 | + |
| 216 | +if __name__ == "__main__": |
| 217 | + sys.exit(main()) |
0 commit comments