You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
- New Block API `block_copy()` to simplify TMA and S2T copy. Users can ignore detail about multicast and 2CTA partition for TMA by `block_copy()` and need not to invoke `tma_partition()`. And users can remove bulk of S2T initialization to simplify S2T copy.
10
+
- MXF8F6F4 mixed precision supoort
11
+
- BlockScaled MMA now supports MXF8*MXF4 or MXF8*MXF6
12
+
- Block Scaled MMA for SM120 now works on Spark
13
+
- EFC broadcast semantics support
14
+
- EFC epilogue functions can now broadcast and remap tensor modes via `C.remap_modes[:, 0, 1]` subscript syntax (where `:` marks a broadcast dimension and integers select source mode indices). Covers scalar broadcast, row/column broadcast, and arbitrary mode permutations (e.g. transpose). The PyTorch reference evaluator mirrors the same transformations.
15
+
- Initial linter support: Improved type hints on CuTe DSL APIs to support static type checkers like MyPy
16
+
- dataclasses.dataclass is now supported for JIT compilaton and cute.compile for both plain and tvm-ffi path
17
+
- cute.copy now supports user specified loop unrolling
18
+
8
19
* Bug fixing and improvements
9
20
- Improved source code correlation for profiling/debugging
21
+
- Fixed an aarch64 segfault issue with tvm-ffi
22
+
- Re-organization for CuTe DSL examples/tutorials for better discoverability
23
+
24
+
* More examples of authorizing peak-performance kernels
25
+
- MOE examles
26
+
- A new style of grouped-gemm that aligns to torch's grouped_mm and scaled_groued_mm interface.
27
+
- Expert-wise tensormap descriptor setup by a cheap helper kernel (~2us) to avoid long latency in tile switching, kernel structure is much more closer to a normal GEMM.
28
+
- Compared to torch_210_cu13, very few problem has worse perf in B200.
29
+
- mxfp8_2dx3d: avg 1.29 speedup;
30
+
- mxfp8_2dx2d: avg 1.41 speedup;
31
+
- nvfp4_2dx3d: avg 1.11 speedup;
32
+
- nvfp4_2dx2d: avg 1.12 speedup (worst case 0.98)
33
+
- bf16_2dx3d: avg 1.15 speedup (worst case 0.98)
34
+
- bf16_2dx2d: avg 1.17 speedup (worst case 0.96)
35
+
- Note: The perf is measured from torch profiler, this impl includes the helper kernel + main kernel, while torch's includes its setup kernel and cutlass_cpp main kernel.
36
+
37
+
* API changes
38
+
- ab_dtype is deprecated in make_trivial_tiled_mma and make_blockscaled_trivial_tiled_mma from blackwell_helpers.py. Please specify a_dtype and b_dtype separately instead.
10
39
11
40
### CUTLASS C++
41
+
* Add 2SM MMA instruction support to mixed TMA+CpAsync SM100 vanilla GEMM kernels.
42
+
- Mixed TMA+CpAsync can now accept static, but non trivial cluster shapes.
43
+
- Uses TMA multicast for A tile when using non-trivial cluster size along N mode.
44
+
- Uses an additional barrier (mma_trampoline_barrier) to track cp.async arrivals in both CTAs.
45
+
- Changes included in [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm).
46
+
* Add support for 128x32xK and 128x64xK tile sizes for SM120 blockscaled MMA collective builders, yielding up to 30% performance improvement on Blackwell SM121 related kernels.
47
+
* Add static load to tensor memory support, included in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
48
+
* Use 64-bit adds for SM100 MMA descriptor offsets and reduce move instructions for improved code generation.
12
49
* Add [example 95](https://github.com/NVIDIA/cutlass/tree/main/examples/95_blackwell_gemm_green_context) to support green context SM partition
13
50
- Enables launching GEMM on stream with partial SM allocation.
14
51
* Fix some kernel issues:
15
52
- Fix l2_capacity=0 handling in Blackwell SM100/SM120 kernel templates
16
53
- Fix CUTLASS clang build issues
54
+
- Fix atomicCAS read-modify-write loop in `ConstSubbyteReference`
55
+
- Replace `__nv_atomic_load_n` with `volatile` for CUDA 11.4 compatibility in subbyte reference
56
+
- Remove `PipelineStorage` shadowing in SM100 complex epilogue
Copy file name to clipboardExpand all lines: README.md
+45-4Lines changed: 45 additions & 4 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -3,7 +3,7 @@
3
3
4
4
# CUTLASS 4.5.0
5
5
6
-
_CUTLASS 4.5.0 - March 2026_
6
+
_CUTLASS 4.5.0 - May 2026_
7
7
8
8
CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM)
9
9
and related computations at all levels and scales within CUDA. It incorporates strategies for
@@ -45,16 +45,57 @@ To get started quickly - please refer :
45
45
46
46
# What's New in CUTLASS 4.5
47
47
48
-
### CuTe DSL
48
+
## CuTe DSL
49
+
* New features
50
+
- New Block API `block_copy()` to simplify TMA and S2T copy. Users can ignore detail about multicast and 2CTA partition for TMA by `block_copy()` and need not to invoke `tma_partition()`. And users can remove bulk of S2T initialization to simplify S2T copy.
51
+
- MXF8F6F4 mixed precision supoort
52
+
- BlockScaled MMA now supports MXF8*MXF4 or MXF8*MXF6
53
+
- Block Scaled MMA for SM120 now works on Spark
54
+
- EFC broadcast semantics support
55
+
- EFC epilogue functions can now broadcast and remap tensor modes via `C.remap_modes[:, 0, 1]` subscript syntax (where `:` marks a broadcast dimension and integers select source mode indices). Covers scalar broadcast, row/column broadcast, and arbitrary mode permutations (e.g. transpose). The PyTorch reference evaluator mirrors the same transformations.
56
+
- Initial linter support: Improved type hints on CuTe DSL APIs to support static type checkers like MyPy
57
+
- dataclasses.dataclass is now supported for JIT compilaton and cute.compile for both plain and tvm-ffi path
58
+
- cute.copy now supports user specified loop unrolling
59
+
49
60
* Bug fixing and improvements
50
61
- Improved source code correlation for profiling/debugging
51
-
52
-
### CUTLASS C++
62
+
- Fixed an aarch64 segfault issue with tvm-ffi
63
+
- Re-organization for CuTe DSL examples/tutorials for better discoverability
64
+
65
+
* More examples of authorizing peak-performance kernels
66
+
- MOE examles
67
+
- A new style of grouped-gemm that aligns to torch's grouped_mm and scaled_groued_mm interface.
68
+
- Expert-wise tensormap descriptor setup by a cheap helper kernel (~2us) to avoid long latency in tile switching, kernel structure is much more closer to a normal GEMM.
69
+
- Compared to torch_210_cu13, very few problem has worse perf in B200.
70
+
- mxfp8_2dx3d: avg 1.29 speedup;
71
+
- mxfp8_2dx2d: avg 1.41 speedup;
72
+
- nvfp4_2dx3d: avg 1.11 speedup;
73
+
- nvfp4_2dx2d: avg 1.12 speedup (worst case 0.98)
74
+
- bf16_2dx3d: avg 1.15 speedup (worst case 0.98)
75
+
- bf16_2dx2d: avg 1.17 speedup (worst case 0.96)
76
+
- Note: The perf is measured from torch profiler, this impl includes the helper kernel + main kernel, while torch's includes its setup kernel and cutlass_cpp main kernel.
77
+
78
+
* API changes
79
+
- ab_dtype is deprecated in make_trivial_tiled_mma and make_blockscaled_trivial_tiled_mma from blackwell_helpers.py. Please specify a_dtype and b_dtype separately instead.
80
+
81
+
## CUTLASS C++
82
+
* Add 2SM MMA instruction support to mixed TMA+CpAsync SM100 vanilla GEMM kernels.
83
+
- Mixed TMA+CpAsync can now accept static, but non trivial cluster shapes.
84
+
- Uses TMA multicast for A tile when using non-trivial cluster size along N mode.
85
+
- Uses an additional barrier (mma_trampoline_barrier) to track cp.async arrivals in both CTAs.
86
+
- Changes included in [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm).
87
+
* Add support for 128x32xK and 128x64xK tile sizes for SM120 blockscaled MMA collective builders, yielding up to 30% performance improvement on Blackwell SM121 related kernels.
88
+
* Add static load to tensor memory support, included in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
89
+
* Use 64-bit adds for SM100 MMA descriptor offsets and reduce move instructions for improved code generation.
53
90
* Add [example 95](https://github.com/NVIDIA/cutlass/tree/main/examples/95_blackwell_gemm_green_context) to support green context SM partition
54
91
- Enables launching GEMM on stream with partial SM allocation.
55
92
* Fix some kernel issues:
56
93
- Fix l2_capacity=0 handling in Blackwell SM100/SM120 kernel templates
57
94
- Fix CUTLASS clang build issues
95
+
- Fix atomicCAS read-modify-write loop in `ConstSubbyteReference`
96
+
- Replace `__nv_atomic_load_n` with `volatile` for CUDA 11.4 compatibility in subbyte reference
97
+
- Remove `PipelineStorage` shadowing in SM100 complex epilogue
<< " --preferred_cluster_m=<str> Sets the M extent of preferred cluster shape\n"
307
-
<< " --preferred_cluster_n=<str> Sets the N extent of preferred cluster shape\n"
308
-
<< " --fallback_cluster_m=<str> Sets the M extent of fallback cluster shape\n"
309
-
<< " --fallback_cluster_n=<str> Sets the N extent of fallback cluster shape\n"
302
+
<< " --cluster_m=<str> Sets the M extent of the cluster shape\n"
303
+
<< " --cluster_n=<str> Sets the N extent of the cluster shape\n"
310
304
<< " --decomposition=<str> Mode in which the stream-K kernel should decompose the problem. Options: Heuristic (default), SplitK, StreamK, DataParallel\n"
311
305
<< " --reduction=<str> Mode in which the stream-K kernel's reduction should be performed. Options: Deterministic (default), Nondeterministic\n"
312
306
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
0 commit comments