-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Expand file tree
/
Copy pathdense_gemm_persistent_dynamic.py
More file actions
1915 lines (1702 loc) · 72.1 KB
/
dense_gemm_persistent_dynamic.py
File metadata and controls
1915 lines (1702 loc) · 72.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
from typing import Optional, Tuple, Type, Union
from functools import lru_cache
import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
import cutlass.utils as utils
from cutlass.utils import is_fp8_dtype, create_cute_tensor_for_fp8
import cutlass.pipeline as pipeline
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
from cutlass.cute.nvgpu import cpasync, tcgen05
"""
A high-performance cluster launch control(CLC) dynamic persistent batched dense GEMM example
for the NVIDIA Blackwell SM100 architecture using CUTE DSL.
The CLC dynamic persistent scheduling technique performs dynamic loading balancing.
It has the ability to adapt available SMs rather than a statically selected number. To support this,
a new instruction is introduced to query for a new tile to compute. This new instruction is similar
to programmatic multicast in context of clusters in that the same starting tile ID for a given cluster
is broadcasted to all threadblocks in the cluster.
See `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-try-cancel>`.
- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M")
- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K")
- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M")
This GEMM kernel supports the following features:
- Utilizes Tensor Memory Access (TMA) for efficient memory operations
- Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations (including 2cta mma instructions)
- Implements TMA multicast with cluster to reduce L2 memory traffic
- Support CLC dynamic persistent tile scheduling to have near perfect load balancing
- Support warp specialization to avoid explicit pipelining between mainloop load and mma
This GEMM works as follows:
1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations.
2. MMA warp: Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction.
3. EPILOGUE warp:
- Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
- Type convert C matrix to output type.
- Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations,
or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations.
- Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor:
e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0))
SM100 tcgen05.mma instructions operate as follows:
- Read matrix A from SMEM
- Read matrix B from SMEM
- Write accumulator to TMEM
The accumulator in TMEM must then be loaded to registers before writing back to GMEM.
Input arguments to this example is same as dense_gemm.py.
.. code-block:: bash
python examples/python/CuTeDSL/blackwell/dense_gemm_persistent_dynamic.py \
--ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \
--mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \
--mnkl 8192,8192,8192,1 \
--use_tma_store --use_2cta_instrs
To collect performance with NCU profiler:
.. code-block:: bash
ncu python examples/python/CuTeDSL/blackwell/dense_gemm_persistent_dynamic.py \
--ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \
--mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \
--mnkl 8192,8192,8192,1 \
--use_tma_store --use_2cta_instrs \
--warmup_iterations 1 --iterations 10 --skip_ref_check
Constraints are same as dense_gemm.py:
* Supported input data types: fp16, bf16, tf32, int8, uint8, fp8 (e4m3fn, e5m2),
see detailed valid dtype combinations in below PersistentDenseGemmKernel class documentation
* A/B tensor must have the same data type
* Mma tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True)
* Mma tiler N must be 32-256, step 32
* Cluster shape M/N must be positive and power of 2, total cluster size <= 16
* Cluster shape M must be multiple of 2 if use_2cta_instrs=True
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
i.e, number of elements is a multiple of 4, 8, and 16 for TFloat32,
Float16/BFloat16, and Int8/Uint8/Float8, respectively.
* OOB tiles are not allowed when TMA store is disabled
"""
def _compute_stages(
tiled_mma: cute.TiledMma,
mma_tiler_mnk: Tuple[int, int, int],
a_dtype: Type[cutlass.Numeric],
b_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
smem_capacity: int,
occupancy: int,
use_tma_store: bool,
c_smem_layout: Union[cute.Layout, None],
) -> Tuple[int, int, int]:
"""Computes the number of stages for A/B/C operands based on heuristics.
:param tiled_mma: The tiled MMA object defining the core computation.
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler.
:type mma_tiler_mnk: tuple[int, int, int]
:param a_dtype: Data type of operand A.
:type a_dtype: type[cutlass.Numeric]
:param b_dtype: Data type of operand B.
:type b_dtype: type[cutlass.Numeric]
:param c_dtype: Data type of operand C (output).
:type c_dtype: type[cutlass.Numeric]
:param smem_capacity: Total available shared memory capacity in bytes.
:type smem_capacity: int
:param occupancy: Target number of CTAs per SM (occupancy).
:type occupancy: int
:param use_tma_store: Whether TMA store is enabled.
:type use_tma_store: bool
:param c_smem_layout: Layout of C operand in shared memory, or None if not using TMA store.
:type c_smem_layout: Union[cute.Layout, None]
:return: A tuple containing the computed number of stages for:
(ACC stages, A/B operand stages, C stages)
:rtype: tuple[int, int, int]
"""
# Default ACC stages
num_acc_stage = 2
# Default C stages
num_c_stage = 2 if use_tma_store else 0
# Calculate smem layout and size for one stage of A, B, and C with 1-stage
a_smem_layout_stage_one = utils.sm100.make_smem_layout_a(
tiled_mma, mma_tiler_mnk, a_dtype, 1
)
b_smem_layout_staged_one = utils.sm100.make_smem_layout_b(
tiled_mma, mma_tiler_mnk, b_dtype, 1
)
ab_bytes_per_stage = cute.size_in_bytes(
a_dtype, a_smem_layout_stage_one
) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one)
mbar_helpers_bytes = 1024
c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout)
c_bytes = c_bytes_per_stage * num_c_stage
# Calculate A/B stages:
# Start with total smem per CTA (capacity / occupancy)
# Subtract reserved bytes and initial C stages bytes
# Divide remaining by bytes needed per A/B stage
num_ab_stage = (
smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes)
) // ab_bytes_per_stage
# Refine epilogue stages:
# Calculate remaining smem after allocating for A/B stages and reserved bytes
# Add remaining unused smem to epilogue
if use_tma_store:
num_c_stage += (
smem_capacity
- occupancy * ab_bytes_per_stage * num_ab_stage
- occupancy * (mbar_helpers_bytes + c_bytes)
) // (occupancy * c_bytes_per_stage)
return num_acc_stage, num_ab_stage, num_c_stage
class PersistentDenseGemmKernel:
"""This class implements batched matrix multiplication (C = A x B) with support for various data types
and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
:param acc_dtype: Data type for accumulation during computation
:type acc_dtype: type[cutlass.Numeric]
:param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation
:type use_2cta_instrs: bool
:param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N)
:type mma_tiler_mn: Tuple[int, int]
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
:type cluster_shape_mn: Tuple[int, int]
:param use_tma_store: Whether to use Tensor Memory Access (TMA) for storing results
:type use_tma_store: bool
:note: In current version, A and B tensor must have the same data type
- i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported
:note: Supported A/B data types:
- TFloat32
- Float16/BFloat16
- Int8/Uint8
- Float8E4M3FN/Float8E5M2
:note: Supported accumulator data types:
- Float32 (for all floating point A/B data types)
- Float16 (only for fp16 and fp8 A/B data types)
- Int32 (only for uint8/int8 A/B data types)
:note: Supported C data types:
- Float32 (for float32 and int32 accumulator data types)
- Int32 (for float32 and int32 accumulator data types)
- Float16/BFloat16 (for fp16 and fp8 accumulator data types)
- Int8/Uint8 (for uint8/int8 accumulator data types)
- Float8E4M3FN/Float8E5M2 (for float32 accumulator data types)
:note: Constraints:
- MMA tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True)
- MMA tiler N must be 32-256, step 32
- Cluster shape M must be multiple of 2 if use_2cta_instrs=True
- Cluster shape M/N must be positive and power of 2, total cluster size <= 16
**Example:**
gemm = PersistentDenseGemmKernel(
acc_dtype=cutlass.Float32,
use_2cta_instrs=True,
mma_tiler_mn=(128, 128),
cluster_shape_mn=(2, 2)
)
gemm(a, b, c, max_active_clusters, stream)
"""
def __init__(
self,
acc_dtype: Type[cutlass.Numeric],
use_2cta_instrs: bool,
mma_tiler_mn: Tuple[int, int],
cluster_shape_mn: Tuple[int, int],
use_tma_store: bool,
):
"""Initializes the configuration for a Blackwell dense GEMM kernel.
This configuration includes several key aspects:
1. MMA Instruction Settings (tcgen05):
- acc_dtype: Data types for MMA accumulator.
- mma_tiler_mn: The (M, N) shape of the MMA instruction tiler.
- use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant
with cta_group=2 should be used.
2. Cluster Shape:
- cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster.
3. Output C tensor store mode:
- use_tma_store: Boolean indicating whether to use Tensor Memory Access (TMA) for storing results.
:param acc_dtype: Data type of the accumulator.
:type acc_dtype: type[cutlass.Numeric]
:param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction.
:type mma_tiler_mn: Tuple[int, int]
:param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant.
:type use_2cta_instrs: bool
:param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster.
:type cluster_shape_mn: Tuple[int, int]
:param use_tma_store: Use Tensor Memory Access (TMA) or normal store for output C tensor.
:type use_tma_store: bool
"""
self.acc_dtype: Type[cutlass.Numeric] = acc_dtype
self.use_2cta_instrs = use_2cta_instrs
self.cluster_shape_mn = cluster_shape_mn
# K dimension is deferred in _setup_attributes
self.mma_tiler_mn = mma_tiler_mn
self.mma_tiler = (*mma_tiler_mn, 1)
self.use_tma_store = use_tma_store
self.arch = "sm_100"
self.cta_group = (
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
)
self.occupancy = 1
# Set specialized warp ids
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.sched_warp_id = 6
self.threads_per_cta = 32 * len(
(
self.mma_warp_id,
self.tma_warp_id,
self.sched_warp_id,
*self.epilogue_warp_id,
)
)
# Set barrier id for cta sync, epilogue sync and tmem ptr sync
self.epilog_sync_bar_id = 1
self.tmem_alloc_sync_bar_id = 2
self.tmem_dealloc_sync_bar_id = 3
def _create_tiled_mma(self):
return utils.sm100.make_trivial_tiled_mma(
self.a_dtype,
self.a_major_mode,
self.b_major_mode,
self.acc_dtype,
self.cta_group,
self.mma_tiler[:2],
)
def _setup_attributes(self):
"""Set up configurations that are dependent on GEMM inputs
This method configures various attributes based on the input tensor properties
(data types, leading dimensions) and kernel settings:
- Configuring tiled MMA
- Computing MMA/cluster/tile shapes
- Computing cluster layout
- Computing multicast CTAs for A/B
- Computing epilogue subtile
- Setting up A/B/C stage counts in shared memory
- Computing A/B/C shared memory layout
- Computing tensor memory allocation columns
"""
# Configure tiled mma
tiled_mma = self._create_tiled_mma()
# Compute mma/cluster/tile shapes
mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
mma_inst_tile_k = 4
self.mma_tiler = (
self.mma_tiler[0],
self.mma_tiler[1],
mma_inst_shape_k * mma_inst_tile_k,
)
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler[1],
self.mma_tiler[2],
)
# Compute cluster layout
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((*self.cluster_shape_mn, 1)),
(tiled_mma.thr_id.shape,),
)
# Compute number of multicast CTAs for A/B
self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
self.is_a_mcast = self.num_mcast_ctas_a > 1
self.is_b_mcast = self.num_mcast_ctas_b > 1
# Compute epilogue subtile
if cutlass.const_expr(self.use_tma_store):
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk,
self.use_2cta_instrs,
self.c_layout,
self.c_dtype,
)
else:
self.epi_tile = self.cta_tile_shape_mnk[:2]
c_smem_layout = None
if cutlass.const_expr(self.use_tma_store):
c_smem_layout = utils.sm100.make_smem_layout_epi(
self.c_dtype, self.c_layout, self.epi_tile, 1
)
self.smem_capacity = utils.get_smem_capacity_in_bytes()
# Setup A/B/C stage count in shared memory and ACC stage count in tensor memory
self.num_acc_stage, self.num_ab_stage, self.num_c_stage = _compute_stages(
tiled_mma,
self.mma_tiler,
self.a_dtype,
self.b_dtype,
self.c_dtype,
self.smem_capacity,
self.occupancy,
self.use_tma_store,
c_smem_layout,
)
# Setup clc stage by default
self.num_clc_stage = 1
assert self.num_clc_stage == 1, "Only single-stage CLC pipeline is supported"
# Compute A/B/C shared memory layout
self.a_smem_layout_staged = utils.sm100.make_smem_layout_a(
tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage
)
self.b_smem_layout_staged = utils.sm100.make_smem_layout_b(
tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage
)
self.c_smem_layout_staged = None
if self.use_tma_store:
self.c_smem_layout_staged = utils.sm100.make_smem_layout_epi(
self.c_dtype, self.c_layout, self.epi_tile, self.num_c_stage
)
# Compute the number of tensor memory allocation columns
self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols(
tiled_mma, self.mma_tiler, self.num_acc_stage, self.arch
)
@cute.jit
def __call__(
self,
a: cute.Tensor,
b: cute.Tensor,
c: cute.Tensor,
max_active_clusters: cutlass.Constexpr,
stream: cuda.CUstream,
epilogue_op: cutlass.Constexpr = lambda x: x,
):
"""Execute the GEMM operation in steps:
- Setup static attributes before smem/grid/tma computation
- Setup TMA load/store atoms and tensors
- Compute grid size with regard to hardware constraints
- Define shared storage for kernel
- Launch the kernel synchronously
:param a: Input tensor A
:type a: cute.Tensor
:param b: Input tensor B
:type b: cute.Tensor
:param c: Output tensor C
:type c: cute.Tensor
:param max_active_clusters: Maximum number of active clusters
:type max_active_clusters: cutlass.Constexpr
:param stream: CUDA stream for asynchronous execution
:type stream: cuda.CUstream
:param epilogue_op: Optional elementwise lambda function to apply to the output tensor
:type epilogue_op: cutlass.Constexpr
:raises TypeError: If input data types are incompatible with the MMA instruction.
:raises AssertionError: If OOB (Out-Of-Bounds) tiles are present when TMA store is disabled.
"""
# Setup static attributes before smem/grid/tma computation
self.a_dtype: Type[cutlass.Numeric] = a.element_type
self.b_dtype: Type[cutlass.Numeric] = b.element_type
self.c_dtype: Type[cutlass.Numeric] = c.element_type
self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = utils.LayoutEnum.from_tensor(c)
# Check if input data types are compatible with MMA instruction
if cutlass.const_expr(self.a_dtype != self.b_dtype):
raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}")
tiled_mma = self._create_tiled_mma()
# Setup attributes that dependent on gemm inputs
self._setup_attributes()
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
# Setup TMA load for A
a_op = utils.sm100.cluster_shape_to_tma_atom_A(
self.cluster_shape_mn, tiled_mma.thr_id
)
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
a_op,
a,
a_smem_layout,
self.mma_tiler,
tiled_mma,
self.cluster_layout_vmnk.shape,
internal_type=(
cutlass.TFloat32 if a.element_type is cutlass.Float32 else None
),
)
# Setup TMA load for B
b_op = utils.sm100.cluster_shape_to_tma_atom_B(
self.cluster_shape_mn, tiled_mma.thr_id
)
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
b_op,
b,
b_smem_layout,
self.mma_tiler,
tiled_mma,
self.cluster_layout_vmnk.shape,
internal_type=(
cutlass.TFloat32 if b.element_type is cutlass.Float32 else None
),
)
a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout)
b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout)
self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size
# Response size is 4B * 4 elements
self.num_clc_response_bytes = 16
# Setup TMA store for C
tma_atom_c = None
tma_tensor_c = None
if cutlass.const_expr(self.use_tma_store):
epi_smem_layout = cute.select(self.c_smem_layout_staged, mode=[0, 1])
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem_layout, self.epi_tile
)
# Compute grid size
self.tile_sched_params, grid = self._compute_grid(
c, self.cta_tile_shape_mnk, self.cluster_shape_mn
)
# Launch the kernel synchronously
self.kernel(
tiled_mma,
tma_atom_a,
tma_tensor_a,
tma_atom_b,
tma_tensor_b,
tma_atom_c,
tma_tensor_c if self.use_tma_store else c,
self.cluster_layout_vmnk,
self.a_smem_layout_staged,
self.b_smem_layout_staged,
self.c_smem_layout_staged,
self.epi_tile,
self.tile_sched_params,
epilogue_op,
).launch(
grid=grid,
block=[self.threads_per_cta, 1, 1],
cluster=(*self.cluster_shape_mn, 1),
stream=stream,
)
return
# GPU device kernel
@cute.kernel
def kernel(
self,
tiled_mma: cute.TiledMma,
tma_atom_a: cute.CopyAtom,
mA_mkl: cute.Tensor,
tma_atom_b: cute.CopyAtom,
mB_nkl: cute.Tensor,
tma_atom_c: Optional[cute.CopyAtom],
mC_mnl: cute.Tensor,
cluster_layout_vmnk: cute.Layout,
a_smem_layout_staged: cute.ComposedLayout,
b_smem_layout_staged: cute.ComposedLayout,
c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
epi_tile: cute.Tile,
tile_sched_params: utils.ClcDynamicPersistentTileSchedulerParams,
epilogue_op: cutlass.Constexpr,
):
"""
GPU device kernel performing the Persistent batched GEMM computation.
"""
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
#
# Prefetch tma desc
#
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
if cutlass.const_expr(self.use_tma_store):
cpasync.prefetch_descriptor(tma_atom_c)
use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
#
# Setup cta/thread coordinates
#
# Coords inside cluster
bidx, bidy, bidz = cute.arch.block_idx()
mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
is_leader_cta = mma_tile_coord_v == 0
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster()
)
is_first_cta_in_cluster = cta_rank_in_cluster == 0
block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(
cta_rank_in_cluster
)
# Coord inside cta
tidx, _, _ = cute.arch.thread_idx()
#
# Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier
#
# Define shared storage for kernel
@cute.struct
class SharedStorage:
ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
acc_full_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_acc_stage * 2
]
tmem_dealloc_mbar_ptr: cutlass.Int64
tmem_holding_buf: cutlass.Int32
clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]
clc_response: cute.struct.MemRange[cutlass.Int32, 4]
smem = utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
# Initialize mainloop ab_pipeline (barrier) and states
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, num_tma_producer
)
ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
num_stages=self.num_ab_stage,
producer_group=ab_pipeline_producer_group,
consumer_group=ab_pipeline_consumer_group,
tx_count=self.num_tma_load_bytes,
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
).make_participants()
# Initialize acc_pipeline (barrier) and states
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
num_acc_consumer_threads = len(self.epilogue_warp_id) * (
2 if use_2cta_instrs else 1
)
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, num_acc_consumer_threads
)
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
num_stages=self.num_acc_stage,
producer_group=acc_pipeline_producer_group,
consumer_group=acc_pipeline_consumer_group,
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
)
# Initialize clc_pipeline (barrier) and states
clc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
cluster_size = cute.size(self.cluster_shape_mn)
num_clc_consumer_threads = 32 * (
1 + cluster_size * (1 + len(self.epilogue_warp_id) + 1)
)
clc_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, num_clc_consumer_threads
)
clc_pipeline = pipeline.PipelineClcFetchAsync.create(
barrier_storage=storage.clc_mbar_ptr.data_ptr(),
num_stages=self.num_clc_stage,
producer_group=clc_pipeline_producer_group,
consumer_group=clc_pipeline_consumer_group,
tx_count=self.num_clc_response_bytes,
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
)
tmem_alloc_barrier = pipeline.NamedBarrier(
barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)),
)
tmem_dealloc_barrier = None
if cutlass.const_expr(not self.use_tma_store):
tmem_dealloc_barrier = pipeline.NamedBarrier(
barrier_id=self.tmem_dealloc_sync_bar_id,
num_threads=32 * len(self.epilogue_warp_id),
)
# Tensor memory dealloc barrier init
tmem = utils.TmemAllocator(
storage.tmem_holding_buf,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=self.epilogue_warp_id[0],
is_two_cta=use_2cta_instrs,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
)
# Cluster arrive after barrier init
pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=True)
# Initial clc response pointer
clc_response_ptr = storage.clc_response.data_ptr()
clc_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_clc_stage
)
#
# Setup smem tensor A/B/C
#
# (MMA, MMA_M, MMA_K, STAGE)
sA = smem.allocate_tensor(
element_type=self.a_dtype,
layout=a_smem_layout_staged.outer,
byte_alignment=128,
swizzle=a_smem_layout_staged.inner,
)
# (MMA, MMA_N, MMA_K, STAGE)
sB = smem.allocate_tensor(
element_type=self.b_dtype,
layout=b_smem_layout_staged.outer,
byte_alignment=128,
swizzle=b_smem_layout_staged.inner,
)
#
# Compute multicast mask for A/B buffer full
#
a_full_mcast_mask = None
b_full_mcast_mask = None
if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
a_full_mcast_mask = cpasync.create_tma_multicast_mask(
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
)
b_full_mcast_mask = cpasync.create_tma_multicast_mask(
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1
)
#
# Local_tile partition global tensors
#
# (bM, bK, RestM, RestK, RestL)
gA_mkl = cute.local_tile(
mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)
)
# (bN, bK, RestN, RestK, RestL)
gB_nkl = cute.local_tile(
mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)
)
# (bM, bN, RestM, RestN, RestL)
gC_mnl = cute.local_tile(
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
)
k_tile_cnt = cute.size(gA_mkl, mode=[3])
#
# Partition global tensor for TiledMMA_A/B/C
#
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
# (MMA, MMA_M, MMA_K, RestM, RestK, RestL)
tCgA = thr_mma.partition_A(gA_mkl)
# (MMA, MMA_N, MMA_K, RestN, RestK, RestL)
tCgB = thr_mma.partition_B(gB_nkl)
# (MMA, MMA_M, MMA_N, RestM, RestN, RestL)
tCgC = thr_mma.partition_C(gC_mnl)
#
# Partition global/shared tensor for TMA load A/B
#
# TMA load A partition_S/D
a_cta_layout = cute.make_layout(
cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
)
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), RestM, RestK, RestL)
tAsA, tAgA = cpasync.tma_partition(
tma_atom_a,
block_in_cluster_coord_vmnk[2],
a_cta_layout,
cute.group_modes(sA, 0, 3),
cute.group_modes(tCgA, 0, 3),
)
# TMA load B partition_S/D
b_cta_layout = cute.make_layout(
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
)
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), RestM, RestK, RestL)
tBsB, tBgB = cpasync.tma_partition(
tma_atom_b,
block_in_cluster_coord_vmnk[1],
b_cta_layout,
cute.group_modes(sB, 0, 3),
cute.group_modes(tCgB, 0, 3),
)
#
# Partition shared/tensor memory tensor for TiledMMA_A/B/C
#
# (MMA, MMA_M, MMA_K, STAGE)
tCrA = tiled_mma.make_fragment_A(sA)
# (MMA, MMA_N, MMA_K, STAGE)
tCrB = tiled_mma.make_fragment_B(sB)
# (MMA, MMA_M, MMA_N)
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
# (MMA, MMA_M, MMA_N, STAGE)
tCtAcc_fake = tiled_mma.make_fragment_C(
cute.append(acc_shape, self.num_acc_stage)
)
#
# Cluster wait before tensor memory alloc
#
pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk)
#
# Construct the scheduler
#
tile_sched = utils.ClcDynamicPersistentTileScheduler.create(
tile_sched_params,
cute.arch.block_idx(),
cute.arch.grid_dim(),
clc_response_ptr,
)
work_tile = tile_sched.initial_work_tile_info()
#
# Specialized TMA load warp
#
if warp_idx == self.tma_warp_id:
#
# Persistent tile scheduling loop
#
while work_tile.is_valid_tile:
# Get tile coord from tile scheduler
cur_tile_coord = work_tile.tile_idx
mma_tile_coord_mnl = (
cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
cur_tile_coord[1],
cur_tile_coord[2],
)
#
# Slice to per mma tile index
#
# ((atom_v, rest_v), RestK)
tAgA_slice = tAgA[
(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])
]
# ((atom_v, rest_v), RestK)
tBgB_slice = tBgB[
(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])
]
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt
ab_producer.reset()
peek_ab_empty_status = ab_producer.try_acquire()
#
# Tma load loop
#
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
# Conditionally wait for AB buffer empty
handle = ab_producer.acquire_and_advance(peek_ab_empty_status)
# TMA load A/B
cute.copy(
tma_atom_a,
tAgA_slice[(None, handle.count)],
tAsA[(None, handle.index)],
tma_bar_ptr=handle.barrier,
mcast_mask=a_full_mcast_mask,
)
cute.copy(
tma_atom_b,
tBgB_slice[(None, handle.count)],
tBsB[(None, handle.index)],
tma_bar_ptr=handle.barrier,
mcast_mask=b_full_mcast_mask,
)
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1
peek_ab_empty_status = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt:
peek_ab_empty_status = ab_producer.try_acquire()
#
# Advance to next tile
#
clc_pipeline.consumer_wait(clc_consumer_state)
work_tile = tile_sched.get_current_work()
clc_pipeline.consumer_release(clc_consumer_state)
clc_consumer_state.advance()
#
# Wait A/B buffer empty
#
ab_producer.tail()
#
# Sched warp
#
if warp_idx == self.sched_warp_id and is_first_cta_in_cluster:
#
# Persistent tile scheduling loop
#
clc_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.ProducerConsumer, self.num_clc_stage
)
while work_tile.is_valid_tile:
#
# Advance to next tile
#
clc_pipeline.producer_acquire(clc_producer_state)
mbarrier_addr = clc_pipeline.producer_get_barrier(clc_producer_state)
tile_sched.advance_to_next_work(mbarrier_addr)
clc_producer_state.advance()
clc_pipeline.consumer_wait(clc_consumer_state)
work_tile = tile_sched.get_current_work()
clc_pipeline.consumer_release(clc_consumer_state)
clc_consumer_state.advance()
clc_pipeline.producer_tail(clc_producer_state)
#
# Specialized MMA warp
#
if warp_idx == self.mma_warp_id:
#
# Retrieving tensor memory ptr and make accumulator tensor
#
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
# (MMA, MMA_M, MMA_N, STAGE)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
#
# Persistent tile scheduling loop
#
acc_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_acc_stage
)
while work_tile.is_valid_tile:
# Get tile coord from tile scheduler
cur_tile_coord = work_tile.tile_idx
mma_tile_coord_mnl = (
cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
cur_tile_coord[1],
cur_tile_coord[2],
)
# Set tensor memory buffer for current tile
# (MMA, MMA_M, MMA_N)
tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)]
# Peek (try_wait) AB buffer full for k_tile = 0
ab_consumer.reset()
peek_ab_full_status = cutlass.Boolean(1)
if is_leader_cta:
peek_ab_full_status = ab_consumer.try_wait()
#
# Wait for accumulator buffer empty
#
if is_leader_cta:
acc_pipeline.producer_acquire(acc_producer_state)
#
# Reset the ACCUMULATE field for each tile
#
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
#
# Mma mainloop
#
for k_tile in range(k_tile_cnt):
if is_leader_cta:
# Conditionally wait for AB buffer full
handle = ab_consumer.wait_and_advance(peek_ab_full_status)
# tCtAcc += tCrA * tCrB
num_kblocks = cute.size(tCrA, mode=[2])
for kblk_idx in cutlass.range(num_kblocks, unroll_full=True):
kblk_crd = (None, None, kblk_idx, handle.index)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[kblk_crd],
tCrB[kblk_crd],
tCtAcc,
)
# Enable accumulate on tCtAcc after first kblock
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
# Async arrive AB buffer empty
handle.release()
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
peek_ab_full_status = cutlass.Boolean(1)