Commit 720ec27
[PyTorch] NVFP4 RHT cast-fusion: emit GEMM-swizzled scale factors directly (#3011)
* [PyTorch] NVFP4 RHT cast-fusion: emit GEMM-swizzled scale factors directly
Before this PR every NVFP4 RHT-cast-fusion quantize was followed by two
standalone swizzle kernels (rowwise + columnwise) whose only job was to
move scale factors into the layout cuBLAS LT consumes. The cast-fusion
kernel already had a `kEnableSwizzleSFOutput` switch for that, but the
framework never set the matching `with_gemm_swizzled_scales` flag on
NVFP4 outputs -- it was a `false` with a TODO. This PR wires it through.
Changes:
* Single + grouped Hadamard cast-fusion kernels: drive
`kEnableSwizzleSFOutput` from `output.with_gemm_swizzled_scales`.
* NVFP4Quantizer create_tensor / convert_and_update_tensor /
bulk_allocate_nvfp4_tensors: set the flag when
`optimize_for_gemm && with_rht && shape eligible`, with eligibility
in a new static helper NVFP4Quantizer::is_eligible_for_rht_cast_fusion
(rows%64==0 && cols%128==0 && SM100/110) shared by all three sites.
* Belt-and-suspenders NVTE_CHECK in quantize_with_rht_unfused_helper
in case a future low-level caller bypasses the gate.
The shape gate is part of this PR (not a follow-up) because LLaMA-class
shapes like (8192, 11328) have K%128==64. Without the gate the framework
would set the flag, dispatch would fall to the unfused path that can't
emit swizzled SF, and the process would abort. With the gate, ineligible
shapes silently fall back to the original code path.
Numbers (GB200 SM100, bf16, rowwise+columnwise, RHT, per-quantize median,
`quant + swizzle` path -- what te.Linear actually runs):
(8192, 5120) 108.6 -> 81.9 us 1.33x eligible
(8192, 11328) 236.3 -> 236.3 us 1.00x ineligible, gate clamped
(11328, 8192) 114.4 -> 93.2 us 1.23x eligible
(14336,16384) 232.1 -> 197.5 us 1.18x eligible
11/12 production-class shapes get 1.18x - 1.36x. The one ineligible
shape gets 1.00x (= unchanged, no regression). `quant_only` is unchanged
across all shapes -- the savings come entirely from eliminating the
standalone swizzle pass, not from a faster quant kernel.
Repro: benchmarks/benchmark_rht_cast_swizzle_fusion.py
Tests:
* new tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py:
byte-equal SF / FP4 data / amax vs swizzled reference; plus 5 cases
verifying the shape gate clamps correctly and that quantizer(x) on an
ineligible shape does not raise.
* tests/pytorch/nvfp4/test_nvfp4_group_quantize.py: added
optimize_for_gemm parametrization for the legacy grouped path.
* test_nvfp4_group_quantize_graph_safe.py passes unchanged (graph-safe
variant already had the wiring).
Signed-off-by: Cael Ling <caell@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [PyTorch] NVFP4 RHT cast-fusion: enforce group-wide quantizer config
Reviewer feedback: with_gemm_swizzled_scales was derived from
quantizer_cpp_list[0]->optimize_for_gemm / with_rht without checking
that other quantizers in the group agreed; if any later quantizer
had a different value, its tensors would be silently allocated with
the wrong SF layout.
Following the precedent of the split-quantize path at line 1276
(// Assume all quantizers have identical config), this commit:
* adds an explicit comment block calling out the group-wide
identical-config assumption and which fields this PR enforces
vs. which are pre-existing;
* adds an NVTE_CHECK loop enforcing identical optimize_for_gemm
and with_rht across the group (the two fields the
with_gemm_swizzled_scales gate depends on), with error messages
that print the offending tensor index and the disagreeing values;
* extracts the [0] reads into group_optimize_for_gemm and
group_with_rht locals so the same value feeds both the check
and the gate.
Other from-[0] reads (rowwise_usage, row_scaled_nvfp4,
columnwise_usage, scaling_mode, dtype) are pre-existing assumptions
and remain out of scope for this PR.
Signed-off-by: Cael Ling <caell@nvidia.com>
* [PyTorch] NVFP4 RHT cast-fusion: address review feedback
Functional fix:
- `bulk_allocate_nvfp4_tensors` previously used the single-tensor RHT
eligibility check (`rows % 64 == 0`), but the grouped kernel asserts
`first_logical_dim % 128 == 0` at entry. Shapes with rows in
{64, 192, 320, ...} would pass eligibility, set
`with_gemm_swizzled_scales=True`, and then hard-abort inside the
grouped kernel with an opaque NVTE_CHECK message. Adding a
`for_grouped_kernel` parameter on `is_eligible_for_rht_cast_fusion`
selects the correct row alignment: 64 for the single-tensor kernel,
128 for the grouped variant. Only the bulk-allocation caller passes
`true`; the three single-tensor callers keep the default `false`.
Refactors:
- `is_eligible_for_rht_cast_fusion` now takes the full tensor shape
(`std::vector<size_t>`) and flattens internally with `get_2d_dims`,
so the four call sites no longer pre-flatten and duplicate the
flatten rule.
- `quantize_impl` delegates the shape/arch eligibility to
`is_eligible_for_rht_cast_fusion` instead of inlining the same
predicate, and its hand-rolled `rows = product(shape[:-1])` loop is
replaced with `get_2d_dims(input.shape())`. The shape/arch
eligibility now has a single source of truth.
Comment cleanups:
- Trimmed verbose comments in `bulk_allocate_nvfp4_tensors`,
`create_tensor`, `convert_and_update_tensor`, and
`quantize_with_rht_unfused_helper`. Removed cross-references to
other functions/files, code narration of subsequent lines, the JAX
reference in PyTorch source, and the "see X for rationale" pattern.
- Doxygen on `is_eligible_for_rht_cast_fusion` reduced to a single
brief sentence.
Signed-off-by: Cael Ling <caell@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Sync _with_gemm_swizzled_scales onto Python NVFP4Tensor in update path
Signed-off-by: Cael Ling <caell@nvidia.com>
* Add license header to profile_rht_cast_swizzle_fusion.py
Signed-off-by: Cael Ling <caell@nvidia.com>
* [PyTorch] NVFP4 RHT cast-fusion: fix stale swizzle shape-gate test
The single test_nvfp4_rht_swizzle_fusion_shape_gate conflated two checks
that mainline #3076 split apart. _with_gemm_swizzled_scales only controls
WHERE scale-factor swizzling happens, not WHETHER: when False, the GEMM
swizzles lazily at call time; when True, the tensor is pre-swizzled and the
GEMM skips it. When this test landed, ineligible shapes (rows%64!=0 or
cols%128!=0) ended quantize with the flag False. #3076 then added a
post-quantize inplace_swizzle_scale_for_gemm fallback that eagerly swizzles
ineligible shapes and flips the flag back to True, so under optimize_for_gemm
the end-to-end flag is now True for all shapes. The old False expectations
encoded pre-#3076 behavior and started failing CI on (64,144), (128,144),
(48,128).
Split into two self-consistent tests:
- shape_gate: probes make_empty() (runs create_tensor only -- no quantize,
no fallback), so it observes the fused-kernel shape gate in isolation and
keeps the original True/False eligibility table.
- end_to_end_swizzled: quantizer(x) must never raise on ineligible shapes
and must always yield _with_gemm_swizzled_scales=True (eligible via the
fused cast-fusion kernel, ineligible via the #3076 swizzle fallback).
Signed-off-by: Cael Ling <caell@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Signed-off-by: Cael Ling <caell@nvidia.com>
Signed-off-by: cael-ling <caell@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>1 parent 3f64073 commit 720ec27
9 files changed
Lines changed: 759 additions & 36 deletions
File tree
- benchmarks
- tests/pytorch/nvfp4
- transformer_engine
- common/hadamard_transform
- pytorch/csrc
- extensions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
0 commit comments