@@ -542,9 +542,6 @@ struct GroupedConvolutionBackwardDataKernel
542542 static constexpr index_t MaxGroupedGemmGroupsNum =
543543 GroupedConvBwdDataKernelArgsSpecialized::MaxGroupedGemmGroupsNum;
544544
545- // TODO: Enable this
546- static constexpr bool IsSplitKSupported = false ;
547-
548545 static constexpr auto I0 = number<0 >();
549546 static constexpr auto I1 = number<1 >();
550547 static constexpr auto I2 = number<2 >();
@@ -623,9 +620,8 @@ struct GroupedConvolutionBackwardDataKernel
623620 CK_TILE_HOST static bool
624621 IsSupportedArgument (const GroupedConvBwdDataKernelArgsSpecialized& kargs)
625622 {
626- if constexpr ((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
627- is_any_of<OutDataType, fp16_t , bf16_t >::value) ||
628- !IsSplitKSupported)
623+ if constexpr (GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
624+ is_any_of<OutDataType, fp16_t , bf16_t >::value)
629625 {
630626 if (kargs.k_batch != 1 )
631627 {
@@ -772,8 +768,8 @@ struct GroupedConvolutionBackwardDataKernel
772768 }();
773769
774770 const auto & c_tensor_view = [&]() {
775- return make_tensor_view<address_space_enum::global>(c_ptr,
776- kargs.c_grid_descs_m_n [group_id]);
771+ return make_tensor_view<address_space_enum::global, DstInMemOp>(
772+ c_ptr, kargs.c_grid_descs_m_n [group_id]);
777773 }();
778774
779775 const auto & ds_tensor_view = generate_tuple (
@@ -837,7 +833,7 @@ struct GroupedConvolutionBackwardDataKernel
837833 CK_TILE_DEVICE static auto MakeGemmTileWindows (const PadView& views,
838834 const index_t i_m,
839835 const index_t i_n,
840- const index_t i_k = 0 )
836+ const index_t i_k)
841837 {
842838 const auto & a_pad_view = views.at (I0 );
843839 const auto & b_pad_view = views.at (I1 );
@@ -893,28 +889,32 @@ struct GroupedConvolutionBackwardDataKernel
893889 WeiDataType* c_ptr,
894890 void * smem_ptr_0,
895891 const GroupedConvBwdDataKernelArgsSpecialized& kargs,
892+ const index_t splitted_k,
896893 const index_t block_idx_m,
897894 const index_t block_idx_n,
895+ const index_t block_idx_k,
898896 const index_t group_id)
899897 {
900898 // Create Gemm tensor views, pad views and tile windows
901899 const auto & gemm_tensor_views_tuple =
902900 MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
903901 a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
904-
905902 const auto & gemm_pad_views = MakeGemmPadViews (gemm_tensor_views_tuple);
906- auto gemm_tile_windows = MakeGemmTileWindows (gemm_pad_views, block_idx_m, block_idx_n);
907903
908- const index_t num_loop = amd_wave_read_first_lane (TilePartitioner::GetLoopNum (
909- gemm_pad_views.at (I0 ).get_tensor_descriptor ().get_length (I1 )));
904+ const index_t num_loop = amd_wave_read_first_lane (TilePartitioner::GetLoopNum (splitted_k));
905+ const bool has_hot_loop = GemmPipeline::BlockHasHotloop (num_loop);
906+ const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum (num_loop);
907+
908+ auto gemm_tile_windows =
909+ MakeGemmTileWindows (gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
910910
911911 // Run GEMM cooperatively by whole workgroup.
912912 const auto & a_block_window = gemm_tile_windows.at (I0 );
913913 const auto & b_block_window = gemm_tile_windows.at (I1 );
914914 const auto & d_block_window = gemm_tile_windows.at (I2 );
915915
916916 const auto & c_block_tile = GemmPipeline{}.template operator ()(
917- a_block_window, b_block_window, num_loop, smem_ptr_0);
917+ a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
918918
919919 // Run Epilogue Pipeline
920920 auto & c_block_window = gemm_tile_windows.at (I3 );
@@ -945,27 +945,36 @@ struct GroupedConvolutionBackwardDataKernel
945945 void * __restrict__ smem_ptr_0,
946946 void * __restrict__ smem_ptr_1,
947947 const GroupedConvBwdDataKernelArgsSpecialized& kargs,
948+ const index_t splitted_k,
948949 const index_t block_idx_m,
949950 const index_t block_idx_n,
951+ const index_t block_idx_k,
950952 const index_t group_id)
951953 {
952954 // Create Gemm tensor views, pad views and tile windows
953955 const auto & gemm_tensor_views_tuple =
954956 MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
955957 a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
956958 const auto & gemm_pad_views = MakeGemmPadViews (gemm_tensor_views_tuple);
957- auto gemm_tile_windows = MakeGemmTileWindows (gemm_pad_views, block_idx_m, block_idx_n);
958959
959- const index_t num_loop = amd_wave_read_first_lane (
960- TilePartitioner::GetLoopNum (gemm_tile_windows.at (I0 ).get_length (I1 )));
960+ const index_t num_loop = amd_wave_read_first_lane (TilePartitioner::GetLoopNum (splitted_k));
961+ const bool has_hot_loop = GemmPipeline::BlockHasHotloop (num_loop);
962+ const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum (num_loop);
963+ auto gemm_tile_windows =
964+ MakeGemmTileWindows (gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
961965
962966 // Run GEMM cooperatively by whole workgroup.
963967 const auto & a_block_window = gemm_tile_windows.at (I0 );
964968 const auto & b_block_window = gemm_tile_windows.at (I1 );
965969 const auto & d_block_window = gemm_tile_windows.at (I2 );
966970
967- const auto & c_block_tile = GemmPipeline{}.template operator ()(
968- a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
971+ const auto & c_block_tile = GemmPipeline{}.template operator ()(a_block_window,
972+ b_block_window,
973+ num_loop,
974+ has_hot_loop,
975+ tail_num,
976+ smem_ptr_0,
977+ smem_ptr_1);
969978
970979 // Run Epilogue Pipeline
971980 auto & c_block_window = gemm_tile_windows.at (I3 );
@@ -1031,9 +1040,17 @@ struct GroupedConvolutionBackwardDataKernel
10311040 static_cast <long_index_t >(kargs.input_batch_stride );
10321041
10331042 // SplitK
1034- // TODO: Implement SplitK support
1035- // const index_t split_k_idx =
1036- // __builtin_amdgcn_readfirstlane(blockIdZ - split_n_idx * kargs.k_batch);
1043+ const index_t split_k_idx =
1044+ __builtin_amdgcn_readfirstlane (blockIdZ - split_n_idx * kargs.k_batch );
1045+
1046+ const index_t gemm_k = kargs.a_grid_descs_m_k [group_id].get_length (I1 );
1047+
1048+ constexpr auto K1 = TilePartitioner::KPerBlock;
1049+ const index_t K_t = amd_wave_read_first_lane (kargs.k_batch * K1 );
1050+ const index_t KRead = amd_wave_read_first_lane ((gemm_k + K_t - 1 ) / K_t * K1 );
1051+
1052+ const index_t i_k = amd_wave_read_first_lane (split_k_idx * KRead);
1053+ const index_t splitted_k = amd_wave_read_first_lane (KRead);
10371054
10381055 // options
10391056 // conv_bwd_data = Out * Weight = In
@@ -1060,8 +1077,10 @@ struct GroupedConvolutionBackwardDataKernel
10601077 smem_ptr_0,
10611078 smem_ptr_1,
10621079 kargs,
1080+ splitted_k,
10631081 i_m,
10641082 i_n,
1083+ i_k,
10651084 group_id);
10661085 }
10671086 }
@@ -1071,7 +1090,17 @@ struct GroupedConvolutionBackwardDataKernel
10711090 GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
10721091 is_any_of<OutDataType, fp16_t , bf16_t >::value))
10731092 {
1074- RunGemm (a_ptr, b_ptr, kargs.ds_ptr , c_ptr, smem_ptr_0, kargs, i_m, i_n, group_id);
1093+ RunGemm (a_ptr,
1094+ b_ptr,
1095+ kargs.ds_ptr ,
1096+ c_ptr,
1097+ smem_ptr_0,
1098+ kargs,
1099+ splitted_k,
1100+ i_m,
1101+ i_n,
1102+ i_k,
1103+ group_id);
10751104 }
10761105 }
10771106 }
0 commit comments