Skip to content

Commit 720ec27

Browse files
[PyTorch] NVFP4 RHT cast-fusion: emit GEMM-swizzled scale factors directly (#3011)
* [PyTorch] NVFP4 RHT cast-fusion: emit GEMM-swizzled scale factors directly Before this PR every NVFP4 RHT-cast-fusion quantize was followed by two standalone swizzle kernels (rowwise + columnwise) whose only job was to move scale factors into the layout cuBLAS LT consumes. The cast-fusion kernel already had a `kEnableSwizzleSFOutput` switch for that, but the framework never set the matching `with_gemm_swizzled_scales` flag on NVFP4 outputs -- it was a `false` with a TODO. This PR wires it through. Changes: * Single + grouped Hadamard cast-fusion kernels: drive `kEnableSwizzleSFOutput` from `output.with_gemm_swizzled_scales`. * NVFP4Quantizer create_tensor / convert_and_update_tensor / bulk_allocate_nvfp4_tensors: set the flag when `optimize_for_gemm && with_rht && shape eligible`, with eligibility in a new static helper NVFP4Quantizer::is_eligible_for_rht_cast_fusion (rows%64==0 && cols%128==0 && SM100/110) shared by all three sites. * Belt-and-suspenders NVTE_CHECK in quantize_with_rht_unfused_helper in case a future low-level caller bypasses the gate. The shape gate is part of this PR (not a follow-up) because LLaMA-class shapes like (8192, 11328) have K%128==64. Without the gate the framework would set the flag, dispatch would fall to the unfused path that can't emit swizzled SF, and the process would abort. With the gate, ineligible shapes silently fall back to the original code path. Numbers (GB200 SM100, bf16, rowwise+columnwise, RHT, per-quantize median, `quant + swizzle` path -- what te.Linear actually runs): (8192, 5120) 108.6 -> 81.9 us 1.33x eligible (8192, 11328) 236.3 -> 236.3 us 1.00x ineligible, gate clamped (11328, 8192) 114.4 -> 93.2 us 1.23x eligible (14336,16384) 232.1 -> 197.5 us 1.18x eligible 11/12 production-class shapes get 1.18x - 1.36x. The one ineligible shape gets 1.00x (= unchanged, no regression). `quant_only` is unchanged across all shapes -- the savings come entirely from eliminating the standalone swizzle pass, not from a faster quant kernel. Repro: benchmarks/benchmark_rht_cast_swizzle_fusion.py Tests: * new tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py: byte-equal SF / FP4 data / amax vs swizzled reference; plus 5 cases verifying the shape gate clamps correctly and that quantizer(x) on an ineligible shape does not raise. * tests/pytorch/nvfp4/test_nvfp4_group_quantize.py: added optimize_for_gemm parametrization for the legacy grouped path. * test_nvfp4_group_quantize_graph_safe.py passes unchanged (graph-safe variant already had the wiring). Signed-off-by: Cael Ling <caell@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [PyTorch] NVFP4 RHT cast-fusion: enforce group-wide quantizer config Reviewer feedback: with_gemm_swizzled_scales was derived from quantizer_cpp_list[0]->optimize_for_gemm / with_rht without checking that other quantizers in the group agreed; if any later quantizer had a different value, its tensors would be silently allocated with the wrong SF layout. Following the precedent of the split-quantize path at line 1276 (// Assume all quantizers have identical config), this commit: * adds an explicit comment block calling out the group-wide identical-config assumption and which fields this PR enforces vs. which are pre-existing; * adds an NVTE_CHECK loop enforcing identical optimize_for_gemm and with_rht across the group (the two fields the with_gemm_swizzled_scales gate depends on), with error messages that print the offending tensor index and the disagreeing values; * extracts the [0] reads into group_optimize_for_gemm and group_with_rht locals so the same value feeds both the check and the gate. Other from-[0] reads (rowwise_usage, row_scaled_nvfp4, columnwise_usage, scaling_mode, dtype) are pre-existing assumptions and remain out of scope for this PR. Signed-off-by: Cael Ling <caell@nvidia.com> * [PyTorch] NVFP4 RHT cast-fusion: address review feedback Functional fix: - `bulk_allocate_nvfp4_tensors` previously used the single-tensor RHT eligibility check (`rows % 64 == 0`), but the grouped kernel asserts `first_logical_dim % 128 == 0` at entry. Shapes with rows in {64, 192, 320, ...} would pass eligibility, set `with_gemm_swizzled_scales=True`, and then hard-abort inside the grouped kernel with an opaque NVTE_CHECK message. Adding a `for_grouped_kernel` parameter on `is_eligible_for_rht_cast_fusion` selects the correct row alignment: 64 for the single-tensor kernel, 128 for the grouped variant. Only the bulk-allocation caller passes `true`; the three single-tensor callers keep the default `false`. Refactors: - `is_eligible_for_rht_cast_fusion` now takes the full tensor shape (`std::vector<size_t>`) and flattens internally with `get_2d_dims`, so the four call sites no longer pre-flatten and duplicate the flatten rule. - `quantize_impl` delegates the shape/arch eligibility to `is_eligible_for_rht_cast_fusion` instead of inlining the same predicate, and its hand-rolled `rows = product(shape[:-1])` loop is replaced with `get_2d_dims(input.shape())`. The shape/arch eligibility now has a single source of truth. Comment cleanups: - Trimmed verbose comments in `bulk_allocate_nvfp4_tensors`, `create_tensor`, `convert_and_update_tensor`, and `quantize_with_rht_unfused_helper`. Removed cross-references to other functions/files, code narration of subsequent lines, the JAX reference in PyTorch source, and the "see X for rationale" pattern. - Doxygen on `is_eligible_for_rht_cast_fusion` reduced to a single brief sentence. Signed-off-by: Cael Ling <caell@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Sync _with_gemm_swizzled_scales onto Python NVFP4Tensor in update path Signed-off-by: Cael Ling <caell@nvidia.com> * Add license header to profile_rht_cast_swizzle_fusion.py Signed-off-by: Cael Ling <caell@nvidia.com> * [PyTorch] NVFP4 RHT cast-fusion: fix stale swizzle shape-gate test The single test_nvfp4_rht_swizzle_fusion_shape_gate conflated two checks that mainline #3076 split apart. _with_gemm_swizzled_scales only controls WHERE scale-factor swizzling happens, not WHETHER: when False, the GEMM swizzles lazily at call time; when True, the tensor is pre-swizzled and the GEMM skips it. When this test landed, ineligible shapes (rows%64!=0 or cols%128!=0) ended quantize with the flag False. #3076 then added a post-quantize inplace_swizzle_scale_for_gemm fallback that eagerly swizzles ineligible shapes and flips the flag back to True, so under optimize_for_gemm the end-to-end flag is now True for all shapes. The old False expectations encoded pre-#3076 behavior and started failing CI on (64,144), (128,144), (48,128). Split into two self-consistent tests: - shape_gate: probes make_empty() (runs create_tensor only -- no quantize, no fallback), so it observes the fused-kernel shape gate in isolation and keeps the original True/False eligibility table. - end_to_end_swizzled: quantizer(x) must never raise on ineligible shapes and must always yield _with_gemm_swizzled_scales=True (eligible via the fused cast-fusion kernel, ineligible via the #3076 swizzle fallback). Signed-off-by: Cael Ling <caell@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Cael Ling <caell@nvidia.com> Signed-off-by: cael-ling <caell@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3f64073 commit 720ec27

9 files changed

Lines changed: 759 additions & 36 deletions
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
"""Benchmark NVFP4 RHT cast-fusion with vs without fused GEMM-swizzled SF output.
6+
7+
For each shape we measure two paths and two builds:
8+
9+
* path = "quant_only": just NVFP4Quantizer(x)
10+
* path = "quant_plus_swizzle": NVFP4Quantizer(x) + tex.swizzle_scales_for_gemm_(t)
11+
(this is what te.Linear -> tex.generic_gemm does right before the
12+
cuBLAS LT NVFP4 GEMM dispatch)
13+
14+
* build = "baseline": optimize_for_gemm=False
15+
-> quant kernel emits compact SF;
16+
tex.swizzle_scales_for_gemm_ launches the standalone
17+
swizzle_{row,col}_scaling_kernel pass before GEMM.
18+
* build = "swizzle_fusion": optimize_for_gemm=True
19+
-> quant kernel emits GEMM-swizzled SF directly (via the
20+
kEnableSwizzleSFOutput compile-time switch in
21+
row_cast_col_hadamard_transform_cast_fusion.cu);
22+
tex.swizzle_scales_for_gemm_ early-returns and the standalone
23+
swizzle pass disappears from the timeline.
24+
25+
The wall-clock delta on the "quant_plus_swizzle" path is the production
26+
saving of this PR.
27+
"""
28+
29+
import argparse
30+
import torch
31+
import pandas as pd
32+
import torch.utils.benchmark as benchmark
33+
34+
import transformer_engine.pytorch as te # noqa: F401 must be first per te-python-import-order
35+
import transformer_engine_torch as tex
36+
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
37+
38+
39+
def make_quantizer(optimize_for_gemm: bool) -> NVFP4Quantizer:
40+
q = NVFP4Quantizer(
41+
fp4_dtype=tex.DType.kFloat4E2M1,
42+
rowwise=True,
43+
columnwise=True,
44+
with_amax_reduction=False,
45+
amax_reduction_group=None,
46+
with_rht=True,
47+
with_post_rht_amax=True,
48+
with_random_sign_mask=True,
49+
)
50+
q.optimize_for_gemm = optimize_for_gemm
51+
return q
52+
53+
54+
def _bench(stmt: str, globals_dict: dict, min_run_time: float) -> float:
55+
"""Returns median wall-clock per call in microseconds."""
56+
timing = benchmark.Timer(
57+
stmt=stmt,
58+
globals=globals_dict,
59+
num_threads=1,
60+
).blocked_autorange(min_run_time=min_run_time)
61+
return timing.median * 1e6
62+
63+
64+
def run_shape(shape, min_run_time: float):
65+
M, K = shape
66+
assert M % 16 == 0 and K % 16 == 0, "Shape must be divisible by 16"
67+
68+
x = torch.randn([M, K], dtype=torch.bfloat16, device="cuda")
69+
q_base = make_quantizer(optimize_for_gemm=False)
70+
q_swf = make_quantizer(optimize_for_gemm=True)
71+
72+
# quant_only path
73+
quant_only_base_us = _bench(
74+
stmt="q(x)",
75+
globals_dict={"q": q_base, "x": x},
76+
min_run_time=min_run_time,
77+
)
78+
quant_only_swf_us = _bench(
79+
stmt="q(x)",
80+
globals_dict={"q": q_swf, "x": x},
81+
min_run_time=min_run_time,
82+
)
83+
84+
# quant_plus_swizzle path (this is what te.Linear actually runs)
85+
quant_plus_swizzle_base_us = _bench(
86+
stmt="t = q(x); tex.swizzle_scales_for_gemm_(t)",
87+
globals_dict={"q": q_base, "x": x, "tex": tex},
88+
min_run_time=min_run_time,
89+
)
90+
quant_plus_swizzle_swf_us = _bench(
91+
stmt="t = q(x); tex.swizzle_scales_for_gemm_(t)",
92+
globals_dict={"q": q_swf, "x": x, "tex": tex},
93+
min_run_time=min_run_time,
94+
)
95+
96+
saved_us = quant_plus_swizzle_base_us - quant_plus_swizzle_swf_us
97+
speedup = (
98+
quant_plus_swizzle_base_us / quant_plus_swizzle_swf_us
99+
if quant_plus_swizzle_swf_us > 0
100+
else float("inf")
101+
)
102+
103+
print(
104+
f" shape={shape}: quant_only base={quant_only_base_us:.2f}us, "
105+
f"SUT={quant_only_swf_us:.2f}us | "
106+
f"quant+swizzle base={quant_plus_swizzle_base_us:.2f}us, "
107+
f"SUT={quant_plus_swizzle_swf_us:.2f}us "
108+
f"-> saved {saved_us:.2f}us ({speedup:.2f}x)"
109+
)
110+
111+
return {
112+
"shape": shape,
113+
"M": M,
114+
"K": K,
115+
"quant_only_base_us": quant_only_base_us,
116+
"quant_only_swf_us": quant_only_swf_us,
117+
"quant_plus_swizzle_base_us": quant_plus_swizzle_base_us,
118+
"quant_plus_swizzle_swf_us": quant_plus_swizzle_swf_us,
119+
"saved_us": saved_us,
120+
"speedup": speedup,
121+
}
122+
123+
124+
# Nsight Compute Profiling Command (for verifying the swizzle kernel disappears):
125+
# ncu -f -o swizzle_fusion --set=full \
126+
# --kernel-name "regex:swizzle_(row|col)_scaling_kernel|cast_col_hadamard_transform_cast_fusion" \
127+
# -s 5 -c 10 python benchmarks/benchmark_rht_cast_swizzle_fusion.py --profile
128+
129+
130+
if __name__ == "__main__":
131+
parser = argparse.ArgumentParser()
132+
parser.add_argument(
133+
"--profile",
134+
action="store_true",
135+
help="Run only one shape for use with ncu/nsys; longer min_run_time",
136+
)
137+
parser.add_argument(
138+
"--min-run-time",
139+
type=float,
140+
default=2.0,
141+
help="Minimum total measured time per cell in seconds (benchmark.Timer)",
142+
)
143+
parser.add_argument(
144+
"--csv",
145+
type=str,
146+
default="benchmark_rht_cast_swizzle_fusion.csv",
147+
help="CSV output path",
148+
)
149+
args = parser.parse_args()
150+
151+
if args.profile:
152+
print("Profiling mode enabled (single shape).")
153+
shapes = [(8192, 4096)]
154+
min_run_time = max(5.0, args.min_run_time)
155+
else:
156+
shapes = [
157+
# production-class shapes
158+
(8192, 5120),
159+
(8192, 10240),
160+
(8192, 2560),
161+
(8192, 11328),
162+
(8192, 3584),
163+
(5120, 8192),
164+
(10240, 8192),
165+
(2560, 8192),
166+
(11328, 8192),
167+
(3584, 8192),
168+
(4096, 16384),
169+
(14336, 16384),
170+
]
171+
min_run_time = args.min_run_time
172+
173+
print(
174+
"NVFP4 RHT cast-fusion: swizzle-fusion (optimize_for_gemm=True) vs baseline. "
175+
f"min_run_time={min_run_time}s per cell, BF16 input, "
176+
"rowwise+columnwise SF, RHT=True+post_rht_amax."
177+
)
178+
rows = []
179+
for shape in shapes:
180+
print(f"Running {shape} ...")
181+
rows.append(run_shape(shape, min_run_time))
182+
183+
df = pd.DataFrame(rows)
184+
pd.set_option("display.max_columns", None)
185+
pd.set_option("display.width", 200)
186+
print()
187+
print(df.to_string(index=False))
188+
df.to_csv(args.csv, index=False)
189+
print(f"\nWrote {args.csv}")
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
"""
6+
Profile that the dedicated swizzle kernels (swizzle_{row,col}_scaling_kernel
7+
in transformer_engine/common/swizzle/swizzle.cu) disappear from the timeline
8+
when NVFP4 RHT cast-fusion emits SF in the GEMM-swizzled layout directly
9+
(optimize_for_gemm=True).
10+
11+
Test setup:
12+
- NVFP4 + RHT + post-RHT amax (same as te.Linear sets up internally)
13+
- rowwise=True AND columnwise=True (covers BOTH swizzle_row_scaling_kernel
14+
and swizzle_col_scaling_kernel; this is what tex.Linear's input quantizer
15+
needs during training because the rowwise tensor is used by the fwd GEMM
16+
and the columnwise tensor is used by the dgrad GEMM)
17+
- tex.swizzle_scales_for_gemm_(t) is what te.Linear -> tex.generic_gemm
18+
calls just before the cuBLAS LT NVFP4 GEMM dispatch
19+
"""
20+
21+
import torch
22+
import transformer_engine.pytorch as te # noqa: F401 must be first
23+
import transformer_engine_torch as tex
24+
from transformer_engine.pytorch import NVFP4Quantizer
25+
26+
27+
def make_quantizer(optimize_for_gemm: bool) -> NVFP4Quantizer:
28+
q = NVFP4Quantizer(
29+
fp4_dtype=tex.DType.kFloat4E2M1,
30+
rowwise=True,
31+
columnwise=True,
32+
with_amax_reduction=False,
33+
amax_reduction_group=None,
34+
with_rht=True,
35+
with_post_rht_amax=True,
36+
with_random_sign_mask=True,
37+
)
38+
q.optimize_for_gemm = optimize_for_gemm
39+
return q
40+
41+
42+
import re
43+
44+
# Match ONLY the standalone swizzle pass kernels in
45+
# transformer_engine/common/swizzle/swizzle.cu — NOT RHT cast-fusion kernels
46+
# whose mangled name happens to contain "Swizzle" because of the
47+
# `template <..., bool kEnableSwizzleSFOutput, ...>` parameter substring.
48+
STANDALONE_SWIZZLE_RE = re.compile(
49+
r"(?:multi_tensor_(?:un)?swizzle|(?:un)?swizzle)_(?:row|col)_scaling_kernel"
50+
)
51+
52+
53+
def dump_kernel_counts(prof, label: str) -> dict:
54+
print(f"\n=== {label} ===")
55+
counts: dict[str, int] = {}
56+
for ev in prof.events():
57+
if ev.device_type != torch.autograd.DeviceType.CUDA:
58+
continue
59+
counts[ev.name] = counts.get(ev.name, 0) + 1
60+
standalone_swizzle_total = 0
61+
for name, c in sorted(counts.items(), key=lambda kv: -kv[1]):
62+
marker = ""
63+
if STANDALONE_SWIZZLE_RE.search(name):
64+
marker = " <-- STANDALONE SWIZZLE PASS"
65+
standalone_swizzle_total += c
66+
# Truncate long mangled CUTLASS names for readability
67+
short = name if len(name) <= 110 else name[:107] + "..."
68+
print(f" {c:4d} {short}{marker}")
69+
print(f" -- standalone swizzle kernel total: {standalone_swizzle_total}")
70+
return counts
71+
72+
73+
def profile_path(optimize_for_gemm: bool, x: torch.Tensor, n_iters: int = 20):
74+
q = make_quantizer(optimize_for_gemm=optimize_for_gemm)
75+
# warm-up
76+
for _ in range(3):
77+
t = q(x)
78+
tex.swizzle_scales_for_gemm_(t)
79+
torch.cuda.synchronize()
80+
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
81+
for _ in range(n_iters):
82+
t = q(x)
83+
tex.swizzle_scales_for_gemm_(t)
84+
torch.cuda.synchronize()
85+
return prof
86+
87+
88+
def main():
89+
torch.manual_seed(0)
90+
torch.cuda.manual_seed(0)
91+
device = "cuda"
92+
# Shape that hits the production RHT cast-fusion fast-path
93+
# (rows % 64 == 0, cols % 128 == 0, BF16, SM100/110).
94+
M, N = 8192, 4096
95+
x = torch.randn((M, N), dtype=torch.bfloat16, device=device)
96+
97+
print(f"Shape: M={M}, N={N}, dtype=bf16, RHT=True, post_rht_amax=True")
98+
print(f"iters: 20 (after 3 warm-up)")
99+
100+
prof_baseline = profile_path(optimize_for_gemm=False, x=x)
101+
counts_baseline = dump_kernel_counts(
102+
prof_baseline, "BASELINE: optimize_for_gemm=False (separate swizzle kernel)"
103+
)
104+
105+
prof_swf = profile_path(optimize_for_gemm=True, x=x)
106+
counts_swf = dump_kernel_counts(
107+
prof_swf, "SUT: optimize_for_gemm=True (quant emits swizzled SF directly)"
108+
)
109+
110+
print("\n=== VERDICT ===")
111+
base_swizzle = sum(c for n, c in counts_baseline.items() if STANDALONE_SWIZZLE_RE.search(n))
112+
swf_swizzle = sum(c for n, c in counts_swf.items() if STANDALONE_SWIZZLE_RE.search(n))
113+
print(f" baseline standalone swizzle kernel launches: {base_swizzle}")
114+
print(f" SUT standalone swizzle kernel launches: {swf_swizzle}")
115+
if swf_swizzle == 0 and base_swizzle > 0:
116+
print(
117+
" PASS: standalone swizzle pass disappears from timeline under optimize_for_gemm=True"
118+
)
119+
else:
120+
print(
121+
" FAIL: expected baseline > 0 and SUT == 0; check whether SUT actually "
122+
"set with_gemm_swizzled_scales=True on the output tensor"
123+
)
124+
125+
126+
if __name__ == "__main__":
127+
main()

0 commit comments

Comments
 (0)