Skip to content

Commit 693ff3b

Browse files
authored
Add support for direct store in epilogue and padding support for wave transfer without transpose (#3465)
- Add support for direct store in epilogue instead of cshuffle - Add padding support for wave transfer without transpose - Add wave transfer with interleaved layout to support direct store - Enable new functionalities on GEMMs - Add optional new functionality support for grouped convolution fwd - Add some fast instances for grouped convolution fwd with new functionalities (proper tuning needed)
1 parent 5102747 commit 693ff3b

20 files changed

Lines changed: 950 additions & 157 deletions

include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
6060
const long_index_t c_batch_offset =
6161
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
6262

63-
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
64-
typename GridwiseGemm::EpilogueCShuffle>();
63+
using EpilogueType =
64+
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
65+
GridwiseGemm::UseDirectStore,
66+
typename GridwiseGemm::EpilogueDirectStore,
67+
typename GridwiseGemm::EpilogueCShuffle>::type;
68+
69+
constexpr index_t LDS_size =
70+
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
6571
__shared__ char p_shared[LDS_size];
6672

6773
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
@@ -84,7 +90,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
8490
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
8591
});
8692

87-
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
93+
auto epilogue_args = EpilogueType{};
8894

8995
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
9096
p_as_grid_shift,

include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
4646
std::is_same_v<c_data_type, ck::bhalf_t>)))
4747
{
4848
#endif
49-
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
50-
typename GridwiseGemm::EpilogueCShuffle>();
49+
using EpilogueType =
50+
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
51+
GridwiseGemm::UseDirectStore,
52+
typename GridwiseGemm::EpilogueDirectStore,
53+
typename GridwiseGemm::EpilogueCShuffle>::type;
54+
55+
constexpr index_t LDS_size =
56+
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
5157
// The normal approach to batching would be to increase the grid size by just stretching out
5258
// the grid Z dimension (which is the outermost dimension), but this depends on lower level
5359
// functions not directly using the Z dimension for other calculations. As it turns out, k
@@ -86,7 +92,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
8692
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
8793
});
8894

89-
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
95+
auto epilogue_args = EpilogueType{};
9096

9197
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
9298
p_as_grid_shift,

include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,10 @@ struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3
188188
ComputeTypeA,
189189
ComputeTypeB,
190190
PermuteA,
191-
PermuteB>;
191+
PermuteB,
192+
false, // IsBPreShuffled
193+
false, // ForceThreadTileTransfer
194+
true>; // IsFusedKernel
192195

193196
using ReduceTrait = ReduceTrait_<ReduceAccDataType,
194197
ReducePtrsGlobal,

include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,10 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3
273273
ComputeTypeA,
274274
ComputeTypeB,
275275
PermuteA,
276-
PermuteB>;
276+
PermuteB,
277+
false,
278+
false,
279+
true>;
277280

278281
// Welford 2nd part kernel
279282
template <typename DoPads, index_t MPerTile, index_t NPerTile>

include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,10 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera
187187
ComputeTypeA,
188188
ComputeTypeB,
189189
PermuteA,
190-
PermuteB>;
190+
PermuteB,
191+
false,
192+
false,
193+
true>;
191194

192195
using ReduceTrait = ReduceTrait_<ReduceAccDataType,
193196
ReducePtrsGlobal,

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp

Lines changed: 78 additions & 83 deletions
Large diffs are not rendered by default.

include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
6666
const CDEElementwiseOperation cde_element_op)
6767
{
6868
#if(defined(__gfx11__) || defined(__gfx12__))
69-
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
70-
typename GridwiseGemm::EpilogueCShuffle>();
69+
using EpilogueType = typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
70+
GridwiseGemm::UseDirectStore,
71+
typename GridwiseGemm::EpilogueDirectStore,
72+
typename GridwiseGemm::EpilogueCShuffle>::type;
73+
74+
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
7175
__shared__ uint8_t p_shared[LDS_size];
7276

7377
const auto gemm_desc_ptr =
@@ -150,7 +154,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
150154
gemm_desc_ptr[group_id].StrideE,
151155
1);
152156

153-
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
157+
auto epilogue_args = EpilogueType{};
154158
constexpr TailNumber TailNum = TailNumber::Full;
155159

156160
if(has_main_k_block_loop)

include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
4141
const index_t group_count)
4242
{
4343
#if(defined(__gfx11__) || defined(__gfx12__))
44-
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
45-
typename GridwiseGemm::EpilogueCShuffle>();
44+
using EpilogueType = typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
45+
GridwiseGemm::UseDirectStore,
46+
typename GridwiseGemm::EpilogueDirectStore,
47+
typename GridwiseGemm::EpilogueCShuffle>::type;
48+
49+
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
4650
__shared__ char p_shared[LDS_size];
4751

4852
const index_t block_id = get_block_1d_id();
@@ -89,13 +93,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
8993

9094
auto splitk_batch_offset =
9195
typename GridwiseGemm::SplitKBatchOffset(karg, tile_index[Number<0>{}]);
92-
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
96+
auto epilogue_args = EpilogueType{};
9397

9498
GridwiseGemm::template Run<HasMainKBlockLoop,
9599
CGlobalMemoryDataOperation,
96100
TailNum,
97101
Block2CTileMap,
98-
typename GridwiseGemm::EpilogueCShuffle,
102+
EpilogueType,
99103
1, // Block2CTileMap MBlock index
100104
2 // Block2CTileMap NBlock index
101105
>(static_cast<void*>(p_shared),

include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ struct EpilogueCShuffleBase
5959
1,
6060
CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma>>;
6161

62+
__device__ static constexpr bool IsLDSNeeded() { return true; }
63+
6264
// *Caution Here repeat is shuffle repeat
6365
__device__ static constexpr auto
6466
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#pragma once
5+
6+
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
7+
8+
namespace ck {
9+
10+
template <typename DsDataType,
11+
typename EDataType,
12+
typename AccDataType,
13+
index_t MRepeat,
14+
index_t NRepeat,
15+
typename CDEElementwiseOperation,
16+
typename BlockwiseGemmPipe>
17+
struct EpilogueDirectStore
18+
{
19+
static constexpr auto I0 = Number<0>{};
20+
static constexpr auto I1 = Number<1>{};
21+
static constexpr auto I2 = Number<2>{};
22+
static constexpr auto I3 = Number<3>{};
23+
static constexpr auto I4 = Number<4>{};
24+
static constexpr auto I5 = Number<5>{};
25+
static constexpr auto I6 = Number<6>{};
26+
27+
__device__ static constexpr bool IsLDSNeeded() { return false; }
28+
29+
template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
30+
typename CThreadBuf,
31+
typename DsGridPointer,
32+
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
33+
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
34+
__device__ static void Run(CThreadBuf& c_thread_buf,
35+
DsGridPointer,
36+
EDataType* p_e_grid,
37+
void*,
38+
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&,
39+
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
40+
e_grid_desc_mblock_mperblock_nblock_nperblock,
41+
CDEElementwiseOperation& cde_element_op,
42+
const index_t& block_m_id,
43+
const index_t& block_n_id)
44+
{
45+
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
46+
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
47+
48+
// C mapping in single thread.
49+
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
50+
BlockwiseGemmPipe::
51+
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
52+
53+
// C mapping in single block
54+
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
55+
BlockwiseGemmPipe::
56+
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
57+
58+
constexpr auto MWave =
59+
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
60+
.GetLength(I1);
61+
constexpr auto MSubGroup =
62+
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
63+
.GetLength(I2);
64+
constexpr auto NWave =
65+
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
66+
.GetLength(I4);
67+
constexpr auto NThreadPerSubGroup =
68+
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
69+
.GetLength(I5);
70+
constexpr auto MAccVgprs =
71+
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
72+
.GetLength(I6);
73+
74+
// origin
75+
const auto c_thread_mtx_on_block =
76+
BlockwiseGemmPipe::CalculateCThreadOriginDataIndex(I0, I0);
77+
78+
const auto m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
79+
make_single_stage_tensor_adaptor(
80+
make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
81+
make_tuple(Sequence<0, 1, 2, 3>{}),
82+
make_tuple(Sequence<0>{}));
83+
84+
const auto m_thread_data_on_grid_idx =
85+
m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex(
86+
make_multi_index(c_thread_mtx_on_block[I0]));
87+
88+
const auto n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
89+
make_single_stage_tensor_adaptor(
90+
make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
91+
make_tuple(Sequence<0, 1, 2>{}),
92+
make_tuple(Sequence<0>{}));
93+
94+
const auto n_thread_data_on_grid_idx =
95+
n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
96+
make_multi_index(c_thread_mtx_on_block[I1]));
97+
98+
// E grid descriptor
99+
const auto c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
100+
transform_tensor_descriptor(
101+
e_grid_desc_mblock_mperblock_nblock_nperblock,
102+
make_tuple(make_freeze_transform(block_m_id),
103+
make_unmerge_transform(make_tuple(Number<MRepeat>{},
104+
Number<MWave>{},
105+
Number<MSubGroup>{},
106+
Number<MAccVgprs>{})),
107+
make_freeze_transform(block_n_id),
108+
make_unmerge_transform(make_tuple(
109+
Number<NWave>{}, Number<NThreadPerSubGroup>{}, Number<NRepeat>{}))),
110+
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
111+
make_tuple(
112+
Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<4, 5, 3>{}));
113+
114+
auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
115+
AccDataType,
116+
EDataType,
117+
decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
118+
decltype(c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
119+
CDEElementwiseOperation,
120+
Sequence<MRepeat, I1, I1, NRepeat, I1, I1, MAccVgprs>,
121+
Sequence<0, 1, 2, 3, 4, 5, 6>,
122+
3,
123+
NRepeat, // VectorSize
124+
EGlobalMemoryDataOperation,
125+
1,
126+
false>{c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
127+
make_multi_index(m_thread_data_on_grid_idx[I0],
128+
m_thread_data_on_grid_idx[I1],
129+
m_thread_data_on_grid_idx[I2],
130+
n_thread_data_on_grid_idx[I0],
131+
n_thread_data_on_grid_idx[I1],
132+
n_thread_data_on_grid_idx[I2],
133+
m_thread_data_on_grid_idx[I3]),
134+
cde_element_op};
135+
136+
c_thread_copy.Run(
137+
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
138+
make_tuple(I0, I0, I0, I0, I0, I0, I0),
139+
c_thread_buf,
140+
c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
141+
e_grid_buf);
142+
}
143+
};
144+
145+
} // namespace ck

0 commit comments

Comments
 (0)