@@ -98,19 +98,23 @@ epilogue_predication(ThrMMA<Args...> const& thr_mma,
9898 }
9999}
100100
101- template <class Alpha , class TRC , class RCLayout ,
101+ template <class ... Args,
102+ class Alpha , class TRC , class RCLayout ,
102103 class Beta , class TSC , class SCLayout ,
103104 class CLoadTransformOp , class CStoreTransformOp ,
104- class SmemCopyOpC >
105+ class SmemCopyLdOpC , class SmemCopyStOpC >
105106CUTE_HOST_DEVICE
106107void
107- epilogue_no_predication (Alpha const & alpha,
108+ epilogue_no_predication (uint32_t thread_idx,
109+ ThrMMA<Args...> const & thr_mma,
110+ Alpha const & alpha,
108111 Tensor<TRC , RCLayout> & tCrC,
109112 Beta const & beta,
110- Tensor<TSC , SCLayout> & tCsC ,
113+ Tensor<TSC , SCLayout> & sC ,
111114 CLoadTransformOp const & sC_load_op , // transforms C values before use in GEMM
112115 CStoreTransformOp const & sC_store_op , // transforms results before they are stored to C
113- SmemCopyOpC const & sC_copy_op )
116+ SmemCopyLdOpC const & sC_copy_ld_op ,
117+ SmemCopyStOpC const & sC_copy_st_op )
114118{
115119 using InputTypeC = typename TSC ::value_type;
116120 using ComputeTypeC = typename TRC ::value_type;
@@ -125,18 +129,33 @@ epilogue_no_predication(Alpha const& alpha,
125129 CUTE_GCC_UNREACHABLE ;
126130 } ();
127131
128- Tensor tCrDi = make_fragment_like (tCsC);
129132 Tensor tCrD = make_fragment_like (tCrC);
133+ Tensor tCrDi = make_fragment_like<InputTypeC>(tCrD);
134+
130135 if (!isBetaZero) {
131- copy (sC_copy_op , tCsC, tCrDi);
136+ auto smem_tiled_copy_C = make_tiled_copy_C (Copy_Atom<SmemCopyLdOpC, InputTypeC>{}, thr_mma);
137+ auto smem_thr_copy_C = smem_tiled_copy_C.get_thread_slice (thread_idx);
138+ Tensor tCsC = smem_thr_copy_C.partition_S (sC );
139+ Tensor tCrDi_copy_view = smem_thr_copy_C.retile_D (tCrDi);
140+ CUTE_STATIC_ASSERT_V (size<1 >(tCsC) == size<1 >(tCrDi_copy_view)); // CPY_M
141+ CUTE_STATIC_ASSERT_V (size<2 >(tCsC) == size<2 >(tCrDi_copy_view)); // CPY_N
142+ copy (smem_tiled_copy_C, tCsC, tCrDi_copy_view);
143+
132144 // Transform C on/after load
133145 cute::transform (tCrDi, tCrD, sC_load_op );
134146 }
135147 // C = alpha * (A * B) + beta * C
136148 axpby (alpha, tCrC, beta, tCrD);
137149 // Transform C before/on store
138150 cute::transform (tCrD, tCrDi, sC_store_op );
139- copy (sC_copy_op , tCrDi, tCsC);
151+
152+ auto smem_tiled_copy_C = make_tiled_copy_C (Copy_Atom<SmemCopyStOpC, InputTypeC>{}, thr_mma);
153+ auto smem_thr_copy_C = smem_tiled_copy_C.get_thread_slice (thread_idx);
154+ Tensor tCsC = smem_thr_copy_C.partition_D (sC );
155+ Tensor tCrDi_copy_view = smem_thr_copy_C.retile_S (tCrDi);
156+ CUTE_STATIC_ASSERT_V (size<1 >(tCsC) == size<1 >(tCrDi_copy_view)); // CPY_M
157+ CUTE_STATIC_ASSERT_V (size<2 >(tCsC) == size<2 >(tCrDi_copy_view)); // CPY_N
158+ copy (smem_tiled_copy_C, tCrDi_copy_view, tCsC);
140159}
141160
142161// Predicated Cooperative GEMM
@@ -283,23 +302,23 @@ cooperative_gemm_no_predication(uint32_t thread_idx,
283302
284303 // Create register tensors for the MMA to operate on
285304 Tensor tCrA = thr_mma.partition_fragment_A (sA ); // (MMA,MMA_M,MMA_K)
305+ Tensor tCrAi = make_fragment_like<InputTypeA>(tCrA);
286306 Tensor tCrB = thr_mma.partition_fragment_B (sB ); // (MMA,MMA_N,MMA_K)
307+ Tensor tCrBi = make_fragment_like<InputTypeB>(tCrB);
287308
288309 using CopyOpAType = SmemCopyOpA;
289310 using CopyOpBType = SmemCopyOpB;
290311
291312 auto smem_tiled_copy_A = make_tiled_copy_A (Copy_Atom<CopyOpAType, InputTypeA>{}, thr_mma);
292313 auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice (thread_idx);
293314 Tensor tCsA = smem_thr_copy_A.partition_S (sA );
294- Tensor tCrAi = make_fragment_like (tCsA);
295315 Tensor tCrAi_copy_view = smem_thr_copy_A.retile_D (tCrAi);
296316 CUTE_STATIC_ASSERT_V (size<1 >(tCsA) == size<1 >(tCrAi_copy_view)); // CPY_M
297317 CUTE_STATIC_ASSERT_V (size<2 >(tCsA) == size<2 >(tCrAi_copy_view)); // CPY_K
298318
299319 auto smem_tiled_copy_B = make_tiled_copy_B (Copy_Atom<CopyOpBType, InputTypeB>{}, thr_mma);
300320 auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice (thread_idx);
301321 Tensor tCsB = smem_thr_copy_B.partition_S (sB );
302- Tensor tCrBi = make_fragment_like (tCsB);
303322 Tensor tCrBi_copy_view = smem_thr_copy_B.retile_D (tCrBi);
304323 CUTE_STATIC_ASSERT_V (size<1 >(tCsB) == size<1 >(tCrBi_copy_view)); // CPY_N
305324 CUTE_STATIC_ASSERT_V (size<2 >(tCsB) == size<2 >(tCrBi_copy_view)); // CPY_K
@@ -346,7 +365,7 @@ template <class... Args,
346365 class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
347366 class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
348367 class SmemCopyOpA = DefaultCopy, class SmemCopyOpB = DefaultCopy,
349- class SmemCopyOpC = DefaultCopy>
368+ class SmemCopyLdOpC = DefaultCopy, class SmemCopyStOpC = DefaultCopy>
350369CUTE_HOST_DEVICE
351370void
352371cooperative_gemm (uint32_t thread_idx,
@@ -356,13 +375,14 @@ cooperative_gemm(uint32_t thread_idx,
356375 Tensor<TB , BLayout> const & sB ,
357376 Beta const & beta,
358377 Tensor<TC , CLayout> & sC ,
359- ALoadTransformOp const & sA_load_op = {}, // transforms A values before use in GEMM
360- BLoadTransformOp const & sB_load_op = {}, // transforms B values before use in GEMM
361- CLoadTransformOp const & sC_load_op = {}, // transforms C values before use in GEMM
362- CStoreTransformOp const & sC_store_op = {}, // transforms results before they are stored to C
363- SmemCopyOpA const & sA_copy_op = {},
364- SmemCopyOpB const & sB_copy_op = {},
365- SmemCopyOpC const & sC_copy_op = {})
378+ ALoadTransformOp const & sA_load_op = {}, // transforms A values before use in GEMM
379+ BLoadTransformOp const & sB_load_op = {}, // transforms B values before use in GEMM
380+ CLoadTransformOp const & sC_load_op = {}, // transforms C values before use in GEMM
381+ CStoreTransformOp const & sC_store_op = {}, // transforms results before they are stored to C
382+ SmemCopyOpA const & sA_copy_op = {},
383+ SmemCopyOpB const & sB_copy_op = {},
384+ SmemCopyLdOpC const & sC_copy_ld_op = {},
385+ SmemCopyStOpC const & sC_copy_st_op = {})
366386{
367387 CUTE_STATIC_ASSERT_V (rank (sA ) == Int<2 >{});
368388 CUTE_STATIC_ASSERT_V (rank (sB ) == Int<2 >{});
@@ -394,7 +414,7 @@ cooperative_gemm(uint32_t thread_idx,
394414 thread_idx, thr_mma, sA , sB , tCrC, sA_load_op , sB_load_op , sA_copy_op , sB_copy_op
395415 );
396416 detail::epilogue_no_predication (
397- alpha, tCrC, beta, tCsC , sC_load_op , sC_store_op , sC_copy_op
417+ thread_idx, thr_mma, alpha, tCrC, beta, sC , sC_load_op , sC_store_op , sC_copy_ld_op , sC_copy_st_op
398418 );
399419 } else {
400420 detail::cooperative_gemm_predication (
@@ -466,7 +486,7 @@ template <class... Args,
466486 class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
467487 class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
468488 class SmemCopyOpA = DefaultCopy, class SmemCopyOpB = DefaultCopy,
469- class SmemCopyOpC = DefaultCopy>
489+ class SmemCopyLdOpC = DefaultCopy, class SmemCopyStOpC = DefaultCopy>
470490CUTE_HOST_DEVICE
471491void
472492cooperative_gemm (uint32_t thread_idx,
@@ -476,17 +496,18 @@ cooperative_gemm(uint32_t thread_idx,
476496 Tensor<TB , BLayout> const & sB ,
477497 Beta const & beta,
478498 Tensor<TC , CLayout> && sC ,
479- ALoadTransformOp const & sA_load_op = {}, // transforms A values before use in GEMM
480- BLoadTransformOp const & sB_load_op = {}, // transforms B values before use in GEMM
481- CLoadTransformOp const & sC_load_op = {}, // transforms C values before use in GEMM
482- CStoreTransformOp const & sC_store_op = {}, // transforms results before they are stored to C
483- SmemCopyOpA const & sA_copy_op = {},
484- SmemCopyOpB const & sB_copy_op = {},
485- SmemCopyOpC const & sC_copy_op = {})
499+ ALoadTransformOp const & sA_load_op = {}, // transforms A values before use in GEMM
500+ BLoadTransformOp const & sB_load_op = {}, // transforms B values before use in GEMM
501+ CLoadTransformOp const & sC_load_op = {}, // transforms C values before use in GEMM
502+ CStoreTransformOp const & sC_store_op = {}, // transforms results before they are stored to C
503+ SmemCopyOpA const & sA_copy_op = {},
504+ SmemCopyOpB const & sB_copy_op = {},
505+ SmemCopyLdOpC const & sC_copy_ld_op = {},
506+ SmemCopyStOpC const & sC_copy_st_op = {})
486507{
487508 cooperative_gemm (thread_idx, tiled_mma, alpha, sA , sB , beta, sC ,
488509 sA_load_op , sB_load_op , sC_load_op , sC_store_op ,
489- sA_copy_op , sB_copy_op , sC_copy_op );
510+ sA_copy_op , sB_copy_op , sC_copy_ld_op , sC_copy_st_op );
490511}
491512
492513// Legacy overload of cute::gemm for backwards-compatibility
0 commit comments