Skip to content

Commit 2a581c7

Browse files
committed
v4.5 tag update.
1 parent aef66e5 commit 2a581c7

222 files changed

Lines changed: 36641 additions & 8070 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CHANGELOG.md

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,59 @@
22

33
# CUTLASS 4.x
44

5-
## [4.5.0](https://github.com/NVIDIA/cutlass/tree/main) (2026-03-27)
5+
## [4.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.5.0) (2026-05-01)
66

77
### CuTe DSL
8+
* New features
9+
- New Block API `block_copy()` to simplify TMA and S2T copy. Users can ignore detail about multicast and 2CTA partition for TMA by `block_copy()` and need not to invoke `tma_partition()`. And users can remove bulk of S2T initialization to simplify S2T copy.
10+
- MXF8F6F4 mixed precision supoort
11+
- BlockScaled MMA now supports MXF8*MXF4 or MXF8*MXF6
12+
- Block Scaled MMA for SM120 now works on Spark
13+
- EFC broadcast semantics support
14+
- EFC epilogue functions can now broadcast and remap tensor modes via `C.remap_modes[:, 0, 1]` subscript syntax (where `:` marks a broadcast dimension and integers select source mode indices). Covers scalar broadcast, row/column broadcast, and arbitrary mode permutations (e.g. transpose). The PyTorch reference evaluator mirrors the same transformations.
15+
- Initial linter support: Improved type hints on CuTe DSL APIs to support static type checkers like MyPy
16+
- dataclasses.dataclass is now supported for JIT compilaton and cute.compile for both plain and tvm-ffi path
17+
- cute.copy now supports user specified loop unrolling
18+
819
* Bug fixing and improvements
920
- Improved source code correlation for profiling/debugging
21+
- Fixed an aarch64 segfault issue with tvm-ffi
22+
- Re-organization for CuTe DSL examples/tutorials for better discoverability
23+
24+
* More examples of authorizing peak-performance kernels
25+
- MOE examles
26+
- A new style of grouped-gemm that aligns to torch's grouped_mm and scaled_groued_mm interface.
27+
- Expert-wise tensormap descriptor setup by a cheap helper kernel (~2us) to avoid long latency in tile switching, kernel structure is much more closer to a normal GEMM.
28+
- Compared to torch_210_cu13, very few problem has worse perf in B200.
29+
- mxfp8_2dx3d: avg 1.29 speedup;
30+
- mxfp8_2dx2d: avg 1.41 speedup;
31+
- nvfp4_2dx3d: avg 1.11 speedup;
32+
- nvfp4_2dx2d: avg 1.12 speedup (worst case 0.98)
33+
- bf16_2dx3d: avg 1.15 speedup (worst case 0.98)
34+
- bf16_2dx2d: avg 1.17 speedup (worst case 0.96)
35+
- Note: The perf is measured from torch profiler, this impl includes the helper kernel + main kernel, while torch's includes its setup kernel and cutlass_cpp main kernel.
36+
37+
* API changes
38+
- ab_dtype is deprecated in make_trivial_tiled_mma and make_blockscaled_trivial_tiled_mma from blackwell_helpers.py. Please specify a_dtype and b_dtype separately instead.
1039

1140
### CUTLASS C++
41+
* Add 2SM MMA instruction support to mixed TMA+CpAsync SM100 vanilla GEMM kernels.
42+
- Mixed TMA+CpAsync can now accept static, but non trivial cluster shapes.
43+
- Uses TMA multicast for A tile when using non-trivial cluster size along N mode.
44+
- Uses an additional barrier (mma_trampoline_barrier) to track cp.async arrivals in both CTAs.
45+
- Changes included in [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm).
46+
* Add support for 128x32xK and 128x64xK tile sizes for SM120 blockscaled MMA collective builders, yielding up to 30% performance improvement on Blackwell SM121 related kernels.
47+
* Add static load to tensor memory support, included in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
48+
* Use 64-bit adds for SM100 MMA descriptor offsets and reduce move instructions for improved code generation.
1249
* Add [example 95](https://github.com/NVIDIA/cutlass/tree/main/examples/95_blackwell_gemm_green_context) to support green context SM partition
1350
- Enables launching GEMM on stream with partial SM allocation.
1451
* Fix some kernel issues:
1552
- Fix l2_capacity=0 handling in Blackwell SM100/SM120 kernel templates
1653
- Fix CUTLASS clang build issues
54+
- Fix atomicCAS read-modify-write loop in `ConstSubbyteReference`
55+
- Replace `__nv_atomic_load_n` with `volatile` for CUDA 11.4 compatibility in subbyte reference
56+
- Remove `PipelineStorage` shadowing in SM100 complex epilogue
57+
- Fix build issue in SM90 epilogue fusion visitor TMA warpspecialized
1758
* Fix some profiler issues:
1859
- Add missing reference kernels for blockwise GEMM profiler
1960
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!

LICENSE.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,5 @@ Certain files within this repository are subject to separate licensing terms:
3030

3131
- The files located in the `python/CuTeDSL` directory are licensed under the
3232
NVIDIA End User License Agreement (EULA). Please refer to
33-
https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
33+
https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html
3434
for the full terms.

README.md

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
# CUTLASS 4.5.0
55

6-
_CUTLASS 4.5.0 - March 2026_
6+
_CUTLASS 4.5.0 - May 2026_
77

88
CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM)
99
and related computations at all levels and scales within CUDA. It incorporates strategies for
@@ -45,16 +45,57 @@ To get started quickly - please refer :
4545

4646
# What's New in CUTLASS 4.5
4747

48-
### CuTe DSL
48+
## CuTe DSL
49+
* New features
50+
- New Block API `block_copy()` to simplify TMA and S2T copy. Users can ignore detail about multicast and 2CTA partition for TMA by `block_copy()` and need not to invoke `tma_partition()`. And users can remove bulk of S2T initialization to simplify S2T copy.
51+
- MXF8F6F4 mixed precision supoort
52+
- BlockScaled MMA now supports MXF8*MXF4 or MXF8*MXF6
53+
- Block Scaled MMA for SM120 now works on Spark
54+
- EFC broadcast semantics support
55+
- EFC epilogue functions can now broadcast and remap tensor modes via `C.remap_modes[:, 0, 1]` subscript syntax (where `:` marks a broadcast dimension and integers select source mode indices). Covers scalar broadcast, row/column broadcast, and arbitrary mode permutations (e.g. transpose). The PyTorch reference evaluator mirrors the same transformations.
56+
- Initial linter support: Improved type hints on CuTe DSL APIs to support static type checkers like MyPy
57+
- dataclasses.dataclass is now supported for JIT compilaton and cute.compile for both plain and tvm-ffi path
58+
- cute.copy now supports user specified loop unrolling
59+
4960
* Bug fixing and improvements
5061
- Improved source code correlation for profiling/debugging
51-
52-
### CUTLASS C++
62+
- Fixed an aarch64 segfault issue with tvm-ffi
63+
- Re-organization for CuTe DSL examples/tutorials for better discoverability
64+
65+
* More examples of authorizing peak-performance kernels
66+
- MOE examles
67+
- A new style of grouped-gemm that aligns to torch's grouped_mm and scaled_groued_mm interface.
68+
- Expert-wise tensormap descriptor setup by a cheap helper kernel (~2us) to avoid long latency in tile switching, kernel structure is much more closer to a normal GEMM.
69+
- Compared to torch_210_cu13, very few problem has worse perf in B200.
70+
- mxfp8_2dx3d: avg 1.29 speedup;
71+
- mxfp8_2dx2d: avg 1.41 speedup;
72+
- nvfp4_2dx3d: avg 1.11 speedup;
73+
- nvfp4_2dx2d: avg 1.12 speedup (worst case 0.98)
74+
- bf16_2dx3d: avg 1.15 speedup (worst case 0.98)
75+
- bf16_2dx2d: avg 1.17 speedup (worst case 0.96)
76+
- Note: The perf is measured from torch profiler, this impl includes the helper kernel + main kernel, while torch's includes its setup kernel and cutlass_cpp main kernel.
77+
78+
* API changes
79+
- ab_dtype is deprecated in make_trivial_tiled_mma and make_blockscaled_trivial_tiled_mma from blackwell_helpers.py. Please specify a_dtype and b_dtype separately instead.
80+
81+
## CUTLASS C++
82+
* Add 2SM MMA instruction support to mixed TMA+CpAsync SM100 vanilla GEMM kernels.
83+
- Mixed TMA+CpAsync can now accept static, but non trivial cluster shapes.
84+
- Uses TMA multicast for A tile when using non-trivial cluster size along N mode.
85+
- Uses an additional barrier (mma_trampoline_barrier) to track cp.async arrivals in both CTAs.
86+
- Changes included in [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm).
87+
* Add support for 128x32xK and 128x64xK tile sizes for SM120 blockscaled MMA collective builders, yielding up to 30% performance improvement on Blackwell SM121 related kernels.
88+
* Add static load to tensor memory support, included in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
89+
* Use 64-bit adds for SM100 MMA descriptor offsets and reduce move instructions for improved code generation.
5390
* Add [example 95](https://github.com/NVIDIA/cutlass/tree/main/examples/95_blackwell_gemm_green_context) to support green context SM partition
5491
- Enables launching GEMM on stream with partial SM allocation.
5592
* Fix some kernel issues:
5693
- Fix l2_capacity=0 handling in Blackwell SM100/SM120 kernel templates
5794
- Fix CUTLASS clang build issues
95+
- Fix atomicCAS read-modify-write loop in `ConstSubbyteReference`
96+
- Replace `__nv_atomic_load_n` with `volatile` for CUDA 11.4 compatibility in subbyte reference
97+
- Remove `PipelineStorage` shadowing in SM100 complex epilogue
98+
- Fix build issue in SM90 epilogue fusion visitor TMA warpspecialized
5899
* Fix some profiler issues:
59100
- Add missing reference kernels for blockwise GEMM profiler
60101
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!

examples/13_two_tensor_op_fusion/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ foreach(FUSION_CONV_EXAMPLE
4848
fused_two_convs_f16_sm80_shmem
4949
fused_two_convs_s8_sm75_rf
5050
fused_two_convs_s8_sm75_shmem
51-
fused_two_convs_s8_sm80_rf
51+
# fused_two_convs_s8_sm80_rf # disabled: fails to build
5252
fused_two_convs_s8_sm80_shmem
5353
)
5454

examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ class GemmUniversal<
195195
}
196196
implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
197197
implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
198-
implementable &= TileScheduler::can_implement(args.scheduler);
198+
implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info);
199199
200200
return implementable;
201201
}

examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ struct Options {
216216
float alpha, beta;
217217
int iterations;
218218
int m, n, k;
219-
int preferred_cluster_m, preferred_cluster_n, fallback_cluster_m, fallback_cluster_n;
219+
int cluster_m, cluster_n;
220220
using DecompositionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
221221
using ReductionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::ReductionMode;
222222
DecompositionMode decomposition_mode;
@@ -240,10 +240,8 @@ struct Options {
240240
m(256), n(256), k(16384),
241241
alpha(1.f), beta(0.f),
242242
iterations(10),
243-
preferred_cluster_m(4),
244-
preferred_cluster_n(4),
245-
fallback_cluster_m(2),
246-
fallback_cluster_n(1),
243+
cluster_m(2),
244+
cluster_n(1),
247245
decomposition_mode(DecompositionMode::Heuristic),
248246
reduction_mode(ReductionMode::Deterministic),
249247
splits(1)
@@ -265,10 +263,8 @@ struct Options {
265263
cmd.get_cmd_line_argument("beta", beta, 0.f);
266264
cmd.get_cmd_line_argument("iterations", iterations);
267265
cmd.get_cmd_line_argument("splits", splits, 1);
268-
cmd.get_cmd_line_argument("preferred_cluster_m", preferred_cluster_m, 4);
269-
cmd.get_cmd_line_argument("preferred_cluster_n", preferred_cluster_n, 4);
270-
cmd.get_cmd_line_argument("fallback_cluster_m", fallback_cluster_m, 2);
271-
cmd.get_cmd_line_argument("fallback_cluster_n", fallback_cluster_n, 1);
266+
cmd.get_cmd_line_argument("cluster_m", cluster_m, 2);
267+
cmd.get_cmd_line_argument("cluster_n", cluster_n, 1);
272268

273269
// Parse decompsition mode
274270
std::string decomp_mode;
@@ -303,10 +299,8 @@ struct Options {
303299
<< " --k=<int> Sets the K extent of the GEMM\n"
304300
<< " --alpha=<f32> Epilogue scalar alpha\n"
305301
<< " --beta=<f32> Epilogue scalar beta\n"
306-
<< " --preferred_cluster_m=<str> Sets the M extent of preferred cluster shape\n"
307-
<< " --preferred_cluster_n=<str> Sets the N extent of preferred cluster shape\n"
308-
<< " --fallback_cluster_m=<str> Sets the M extent of fallback cluster shape\n"
309-
<< " --fallback_cluster_n=<str> Sets the N extent of fallback cluster shape\n"
302+
<< " --cluster_m=<str> Sets the M extent of the cluster shape\n"
303+
<< " --cluster_n=<str> Sets the N extent of the cluster shape\n"
310304
<< " --decomposition=<str> Mode in which the stream-K kernel should decompose the problem. Options: Heuristic (default), SplitK, StreamK, DataParallel\n"
311305
<< " --reduction=<str> Mode in which the stream-K kernel's reduction should be performed. Options: Deterministic (default), Nondeterministic\n"
312306
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
@@ -424,8 +418,8 @@ typename Gemm::Arguments args_from_options(const Options &options) {
424418
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
425419
};
426420

427-
arguments.hw_info.cluster_shape = dim3(options.preferred_cluster_m, options.preferred_cluster_n,1);
428-
arguments.hw_info.cluster_shape_fallback = dim3(options.fallback_cluster_m, options.fallback_cluster_n,1);
421+
arguments.hw_info.cluster_shape = dim3(options.cluster_m, options.cluster_n, 1);
422+
arguments.hw_info.cluster_shape_fallback = dim3(options.cluster_m, options.cluster_n, 1);
429423

430424
arguments.scheduler.splits = options.splits;
431425
arguments.scheduler.decomposition_mode = options.decomposition_mode;
@@ -498,8 +492,7 @@ int run(Options &options) {
498492

499493
std::cout << "Stream-K GEMM with"
500494
<< " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k
501-
<< " Preferred Cluster = (" << options.preferred_cluster_m << ", " << options.preferred_cluster_n << ", 1)"
502-
<< " Fallback Cluster = (" << options.fallback_cluster_m << ", " << options.fallback_cluster_n << ", 1)\n"
495+
<< " Cluster = (" << options.cluster_m << ", " << options.cluster_n << ", 1)\n"
503496
<< " Decomposition_mode=" << options.decomposition_mode_str()
504497
<< " Split_count=" << options.splits
505498
<< " Reduction_mode=" << options.reduction_mode_str()

examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,11 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
536536
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
537537

538538
// Each thread owns a single row
539+
#if defined CUTE_ARCH_TCGEN05_TMEM_STAT_ENABLED
540+
using TMEM_LOAD = SM100_TMEM_LOAD_STAT_32dp32b32x;
541+
#else
539542
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
543+
#endif
540544
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem
541545
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
542546

@@ -573,6 +577,14 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
573577
}
574578

575579
ElementQK old_row_max = row_max;
580+
#if defined CUTE_ARCH_TCGEN05_TMEM_STAT_ENABLED
581+
auto pos = tTMEM_LOADcS(0);
582+
if (!need_apply_mask || (need_apply_mask && (get<0>(pos) >= get<1>(pos) + 12) && (get<1>(pos) < get<1>(problem_shape)))) {
583+
float curr_max = tiled_tmem_load.get_max();
584+
row_max = ::fmax(row_max, curr_max);
585+
}
586+
else
587+
#endif
576588
{
577589
// compute rowmax
578590
float row_max_0 = row_max;

examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,11 @@ struct Sm100FmhaGenMainloopWarpspecialized {
540540
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
541541

542542
// Each thread owns a single row
543+
#if defined CUTE_ARCH_TCGEN05_TMEM_STAT_ENABLED
544+
using TMEM_LOAD = SM100_TMEM_LOAD_STAT_32dp32b32x;
545+
#else
543546
using TMEM_LOAD = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_LOAD_32dp32b8x, SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
547+
#endif
544548
using TMEM_STORE = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_STORE_32dp32b8x, SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
545549
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
546550

@@ -577,6 +581,14 @@ struct Sm100FmhaGenMainloopWarpspecialized {
577581
}
578582

579583
ElementQK old_row_max = row_max;
584+
#if defined CUTE_ARCH_TCGEN05_TMEM_STAT_ENABLED
585+
auto pos = tTMEM_LOADcS(0);
586+
if (!need_apply_mask || (need_apply_mask && (get<0>(pos) >= get<1>(pos) + 12) && (get<1>(pos) < get<1>(problem_shape)))) {
587+
float curr_max = tiled_tmem_load.get_max();
588+
row_max = ::fmax(row_max, curr_max);
589+
}
590+
else
591+
#endif
580592
{
581593
// compute rowmax
582594
float row_max_0 = row_max;

examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_fp4_grouped.cu

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,15 @@ auto make_iterator(T* ptr) {
213213

214214
///////////////////////////////////////////////////////////////////////////////////////////////////
215215

216-
struct ExampleRunner {
216+
template <
217217
// Type of kernel schedule to generate
218-
using MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100;
218+
class MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
219219
// Type of epilogue schedule to generate
220-
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
221-
static constexpr bool FuseQuantization = false;
220+
class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto,
221+
class ClusterShapeMNK = Shape<_1, _1, _1>,
222+
bool FuseQuantization = false
223+
>
224+
struct ExampleRunner {
222225

223226
using LayoutATag = cutlass::layout::RowMajor;
224227
using LayoutBTag = cutlass::layout::ColumnMajor;
@@ -238,10 +241,8 @@ struct ExampleRunner {
238241
using ElementCompute = float;
239242
using ElementScalar = float;
240243

241-
242-
243-
using ClusterShapeMNK = Shape<_1,_1,_1>;
244-
using MmaTileMNK = Shape<_128,_64,_256>; // use tile size of N=64 to match real use cases (N is typically very small in decoding stage)
244+
static constexpr int TileM = cute::is_base_of_v<cutlass::gemm::KernelSchedule2Sm, MainloopScheduleType> ? 256 : 128;
245+
using MmaTileMNK = Shape<Int<TileM>,_64,_256>; // use tile size of N=64 to match real use cases (N is typically very small in decoding stage)
245246

246247
static constexpr int AlignmentA = 32;
247248
static constexpr int AlignmentB = 32;
@@ -712,10 +713,34 @@ int main(int argc, char const **args) {
712713
hw_info.device_id = 0;
713714
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
714715

715-
std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl;
716+
std::cout << "Running kernel with mixed TMA+CPASYNC load, 1SM:" << std::endl;
716717
ExampleRunner runner_mixed_tma_cpasync;
717718
runner_mixed_tma_cpasync.run(options, hw_info);
718719

720+
std::cout << "\n\n\nRunning kernel with mixed TMA+CPASYNC load, 1SM, 2x2 cluster:" << std::endl;
721+
ExampleRunner<
722+
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
723+
cutlass::epilogue::collective::EpilogueScheduleAuto,
724+
Shape<_2, _2, _1>
725+
> runner_mixed_tma_cpasync_1sm_2x2;
726+
runner_mixed_tma_cpasync_1sm_2x2.run(options, hw_info);
727+
728+
std::cout << "\n\n\nRunning 2SM kernel with mixed TMA+CPASYNC load, 2x1 cluster:" << std::endl;
729+
ExampleRunner<
730+
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100,
731+
cutlass::epilogue::collective::EpilogueScheduleAuto,
732+
Shape<_2, _1, _1>
733+
> runner_mixed_tma_cpasync_2sm_2x1;
734+
runner_mixed_tma_cpasync_2sm_2x1.run(options, hw_info);
735+
736+
std::cout << "\n\n\nRunning 2SM kernel with mixed TMA+CPASYNC load, 2x4 cluster:" << std::endl;
737+
ExampleRunner<
738+
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100,
739+
cutlass::epilogue::collective::EpilogueScheduleAuto,
740+
Shape<_2, _4, _1>
741+
> runner_mixed_tma_cpasync_2sm_2x4;
742+
runner_mixed_tma_cpasync_2sm_2x4.run(options, hw_info);
743+
719744
#endif
720745

721746
return 0;

0 commit comments

Comments
 (0)