Skip to content

Commit ead81d1

Browse files
authored
[CK_TILE] Add splitk support to ck tile conv bwd data (#3353)
* add splitk support to ck tile conv bwd data * add reviewers suggestions * minor fix * removed splitkbatchoffset struct
1 parent 8b73633 commit ead81d1

2 files changed

Lines changed: 57 additions & 55 deletions

File tree

include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -542,9 +542,6 @@ struct GroupedConvolutionBackwardDataKernel
542542
static constexpr index_t MaxGroupedGemmGroupsNum =
543543
GroupedConvBwdDataKernelArgsSpecialized::MaxGroupedGemmGroupsNum;
544544

545-
// TODO: Enable this
546-
static constexpr bool IsSplitKSupported = false;
547-
548545
static constexpr auto I0 = number<0>();
549546
static constexpr auto I1 = number<1>();
550547
static constexpr auto I2 = number<2>();
@@ -623,9 +620,8 @@ struct GroupedConvolutionBackwardDataKernel
623620
CK_TILE_HOST static bool
624621
IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized& kargs)
625622
{
626-
if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
627-
is_any_of<OutDataType, fp16_t, bf16_t>::value) ||
628-
!IsSplitKSupported)
623+
if constexpr(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
624+
is_any_of<OutDataType, fp16_t, bf16_t>::value)
629625
{
630626
if(kargs.k_batch != 1)
631627
{
@@ -772,8 +768,8 @@ struct GroupedConvolutionBackwardDataKernel
772768
}();
773769

774770
const auto& c_tensor_view = [&]() {
775-
return make_tensor_view<address_space_enum::global>(c_ptr,
776-
kargs.c_grid_descs_m_n[group_id]);
771+
return make_tensor_view<address_space_enum::global, DstInMemOp>(
772+
c_ptr, kargs.c_grid_descs_m_n[group_id]);
777773
}();
778774

779775
const auto& ds_tensor_view = generate_tuple(
@@ -837,7 +833,7 @@ struct GroupedConvolutionBackwardDataKernel
837833
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
838834
const index_t i_m,
839835
const index_t i_n,
840-
const index_t i_k = 0)
836+
const index_t i_k)
841837
{
842838
const auto& a_pad_view = views.at(I0);
843839
const auto& b_pad_view = views.at(I1);
@@ -893,28 +889,32 @@ struct GroupedConvolutionBackwardDataKernel
893889
WeiDataType* c_ptr,
894890
void* smem_ptr_0,
895891
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
892+
const index_t splitted_k,
896893
const index_t block_idx_m,
897894
const index_t block_idx_n,
895+
const index_t block_idx_k,
898896
const index_t group_id)
899897
{
900898
// Create Gemm tensor views, pad views and tile windows
901899
const auto& gemm_tensor_views_tuple =
902900
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
903901
a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
904-
905902
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
906-
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
907903

908-
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(
909-
gemm_pad_views.at(I0).get_tensor_descriptor().get_length(I1)));
904+
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k));
905+
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
906+
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
907+
908+
auto gemm_tile_windows =
909+
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
910910

911911
// Run GEMM cooperatively by whole workgroup.
912912
const auto& a_block_window = gemm_tile_windows.at(I0);
913913
const auto& b_block_window = gemm_tile_windows.at(I1);
914914
const auto& d_block_window = gemm_tile_windows.at(I2);
915915

916916
const auto& c_block_tile = GemmPipeline{}.template operator()(
917-
a_block_window, b_block_window, num_loop, smem_ptr_0);
917+
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
918918

919919
// Run Epilogue Pipeline
920920
auto& c_block_window = gemm_tile_windows.at(I3);
@@ -945,27 +945,36 @@ struct GroupedConvolutionBackwardDataKernel
945945
void* __restrict__ smem_ptr_0,
946946
void* __restrict__ smem_ptr_1,
947947
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
948+
const index_t splitted_k,
948949
const index_t block_idx_m,
949950
const index_t block_idx_n,
951+
const index_t block_idx_k,
950952
const index_t group_id)
951953
{
952954
// Create Gemm tensor views, pad views and tile windows
953955
const auto& gemm_tensor_views_tuple =
954956
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
955957
a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
956958
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
957-
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
958959

959-
const index_t num_loop = amd_wave_read_first_lane(
960-
TilePartitioner::GetLoopNum(gemm_tile_windows.at(I0).get_length(I1)));
960+
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k));
961+
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
962+
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
963+
auto gemm_tile_windows =
964+
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
961965

962966
// Run GEMM cooperatively by whole workgroup.
963967
const auto& a_block_window = gemm_tile_windows.at(I0);
964968
const auto& b_block_window = gemm_tile_windows.at(I1);
965969
const auto& d_block_window = gemm_tile_windows.at(I2);
966970

967-
const auto& c_block_tile = GemmPipeline{}.template operator()(
968-
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
971+
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
972+
b_block_window,
973+
num_loop,
974+
has_hot_loop,
975+
tail_num,
976+
smem_ptr_0,
977+
smem_ptr_1);
969978

970979
// Run Epilogue Pipeline
971980
auto& c_block_window = gemm_tile_windows.at(I3);
@@ -1031,9 +1040,17 @@ struct GroupedConvolutionBackwardDataKernel
10311040
static_cast<long_index_t>(kargs.input_batch_stride);
10321041

10331042
// SplitK
1034-
// TODO: Implement SplitK support
1035-
// const index_t split_k_idx =
1036-
// __builtin_amdgcn_readfirstlane(blockIdZ - split_n_idx * kargs.k_batch);
1043+
const index_t split_k_idx =
1044+
__builtin_amdgcn_readfirstlane(blockIdZ - split_n_idx * kargs.k_batch);
1045+
1046+
const index_t gemm_k = kargs.a_grid_descs_m_k[group_id].get_length(I1);
1047+
1048+
constexpr auto K1 = TilePartitioner::KPerBlock;
1049+
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
1050+
const index_t KRead = amd_wave_read_first_lane((gemm_k + K_t - 1) / K_t * K1);
1051+
1052+
const index_t i_k = amd_wave_read_first_lane(split_k_idx * KRead);
1053+
const index_t splitted_k = amd_wave_read_first_lane(KRead);
10371054

10381055
// options
10391056
// conv_bwd_data = Out * Weight = In
@@ -1060,8 +1077,10 @@ struct GroupedConvolutionBackwardDataKernel
10601077
smem_ptr_0,
10611078
smem_ptr_1,
10621079
kargs,
1080+
splitted_k,
10631081
i_m,
10641082
i_n,
1083+
i_k,
10651084
group_id);
10661085
}
10671086
}
@@ -1071,7 +1090,17 @@ struct GroupedConvolutionBackwardDataKernel
10711090
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
10721091
is_any_of<OutDataType, fp16_t, bf16_t>::value))
10731092
{
1074-
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n, group_id);
1093+
RunGemm(a_ptr,
1094+
b_ptr,
1095+
kargs.ds_ptr,
1096+
c_ptr,
1097+
smem_ptr_0,
1098+
kargs,
1099+
splitted_k,
1100+
i_m,
1101+
i_n,
1102+
i_k,
1103+
group_id);
10751104
}
10761105
}
10771106
}

include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp

Lines changed: 5 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -505,33 +505,6 @@ struct GroupedConvolutionBackwardWeightKernel
505505
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
506506
}
507507

508-
struct SplitKBatchOffset
509-
{
510-
__device__ SplitKBatchOffset(const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
511-
const std::size_t k_id = blockIdx.z)
512-
{
513-
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{});
514-
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
515-
const index_t KRead = amd_wave_read_first_lane((kargs.GemmK + K_t - 1) / K_t * K1);
516-
517-
a_k_split_offset = amd_wave_read_first_lane(k_id * KRead);
518-
b_k_split_offset = amd_wave_read_first_lane(k_id * KRead);
519-
520-
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
521-
{
522-
splitted_k = amd_wave_read_first_lane(KRead);
523-
}
524-
else
525-
{
526-
splitted_k = amd_wave_read_first_lane(kargs.GemmK - KRead * (kargs.k_batch - 1));
527-
}
528-
}
529-
530-
index_t a_k_split_offset;
531-
index_t b_k_split_offset;
532-
index_t splitted_k;
533-
};
534-
535508
CK_TILE_HOST static bool
536509
IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
537510
{
@@ -763,20 +736,20 @@ struct GroupedConvolutionBackwardWeightKernel
763736
}
764737

765738
template <typename TensorView>
766-
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views, const index_t k_batch)
739+
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
767740
{
768741
const auto& a_pad_view = [&]() {
769742
const auto& a_tensor_view = views.at(I0);
770743
return pad_tensor_view(a_tensor_view,
771-
make_tuple(number<TilePartitioner::KPerBlock>{} * k_batch,
744+
make_tuple(number<TilePartitioner::KPerBlock>{},
772745
number<TilePartitioner::MPerBlock>{}),
773746
sequence<true, true>{});
774747
}();
775748

776749
const auto& b_pad_view = [&]() {
777750
const auto& b_tensor_view = views.at(I1);
778751
return pad_tensor_view(b_tensor_view,
779-
make_tuple(number<TilePartitioner::KPerBlock>{} * k_batch,
752+
make_tuple(number<TilePartitioner::KPerBlock>{},
780753
number<TilePartitioner::NPerBlock>{}),
781754
sequence<true, true>{});
782755
}();
@@ -882,7 +855,7 @@ struct GroupedConvolutionBackwardWeightKernel
882855
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
883856
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
884857

885-
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
858+
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
886859
auto gemm_tile_windows =
887860
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
888861

@@ -932,7 +905,7 @@ struct GroupedConvolutionBackwardWeightKernel
932905
const auto& gemm_tensor_views_tuple =
933906
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
934907
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
935-
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
908+
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
936909
auto gemm_tile_windows =
937910
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
938911

0 commit comments

Comments
 (0)