Skip to content

Commit cbc8335

Browse files
authored
Improve XDL to WMMA porting for grouped conv fwd (#3456)
Refactors the way the number of XDL (matrix multiply-accumulate) instructions per wave is calculated and used in the grouped convolution forward implementations, especially to better support WMMA (Wave Matrix Multiply-Accumulate) instructions and 16x16 tiles. The changes use MXdlPerWave instead of NXdlPerWave to increase number of waves per M dim.
1 parent 2d9c962 commit cbc8335

13 files changed

Lines changed: 226 additions & 133 deletions

experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ TEST(FwdConvInstances,
3333
constexpr auto FwdConvAlgorithm =
3434
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
3535
.with_thread_block(FwdThreadBlock_64_64x32x32)
36-
.with_gemm_config(FwdGemmParams_Xdl_2x2_per_wave)
36+
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
3737
.with_transfer(FwdTransfer_4x16x1)
3838
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
3939
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);

experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ TEST(FwdConvInstances,
2828
constexpr auto FwdConvAlgorithm =
2929
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
3030
.with_thread_block(FwdThreadBlock_256_128x128x32)
31-
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
31+
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
3232
.with_transfer(FwdTransfer_4x64x1)
3333
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
3434
GemmSpecialization::MNKPadding)

experimental/builder/test/test_conv_description.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ struct DefaultAlgorithm
111111
.bk1 = 8,
112112
.m_per_xdl = 16,
113113
.n_per_xdl = 16,
114-
.m_xdl_per_wave = 4,
115-
.n_xdl_per_wave = 4};
114+
.m_xdl_per_wave = 8,
115+
.n_xdl_per_wave = 8};
116116

117117
ckb::test::TransferABC transfer{
118118
.a =
@@ -188,7 +188,7 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
188188
" ├─ Pipeline scheduler: INTRAWAVE\n"
189189
" ├─ Warp Gemm parameters: \n"
190190
" │ ├─ subtile size: 16×16\n"
191-
" │ └─ Number of warp gemm iterations: 4×4\n"
191+
" │ └─ Number of warp gemm iterations: 8×8\n"
192192
" └─ Memory access:\n"
193193
" ├─ A Tile transfer: \n"
194194
" │ ├─ Tile dimensions: 4×256×8×\n"

experimental/builder/test/utils/ckb_conv_test_configs.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ constexpr TransferABC FwdTransfer_4x64x1{
6868
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
6969
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
7070
.n_per_wave_per_shuffle = 1,
71-
.scalar_per_vector = 8},
71+
.scalar_per_vector = 4},
7272
},
7373
};
7474

include/ck/tensor_operation/gpu/device/device_base.hpp

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ template <index_t BlockSize_,
6060
index_t NPerXDL_,
6161
index_t MXdlPerWave_,
6262
bool IsWave64>
63-
static constexpr auto GetNXdlPerWave2()
63+
static constexpr auto GetXdlPerWave2()
6464
{
6565
constexpr index_t Waves = IsWave64 ? BlockSize_ / 64 : BlockSize_ / 32;
6666
constexpr index_t MWaves = MPerBlock_ / (MXdlPerWave_ * MPerXDL_);
@@ -84,17 +84,33 @@ static constexpr auto GetNXdlPerWave2()
8484
}
8585
}
8686

87-
#define GET_NXDL_PER_WAVE_IMPL \
88-
template <bool IsWave64> \
89-
static constexpr auto GetNXdlPerWave() \
90-
{ \
91-
return GetNXdlPerWave2<BlockSize, \
92-
MPerBlock, \
93-
NPerBlock, \
94-
MPerXDL, \
95-
NPerXDL, \
96-
MXdlPerWave, \
97-
IsWave64>(); \
87+
#define GET_NXDL_PER_WAVE_IMPL \
88+
template <bool IsWave64> \
89+
static constexpr auto GetNXdlPerWave() \
90+
{ \
91+
return GetXdlPerWave2<BlockSize, \
92+
MPerBlock, \
93+
NPerBlock, \
94+
MPerXDL, \
95+
NPerXDL, \
96+
MXdlPerWave, \
97+
IsWave64>(); \
98+
}
99+
100+
#define GET_MXDL_PER_WAVE_IMPL \
101+
template <bool IsWave64, \
102+
index_t MPerXDLAligned = MPerXDL, \
103+
index_t NPerXDLAligned = NPerXDL, \
104+
index_t NXdlPerWaveAligned = NXdlPerWave> \
105+
static constexpr auto GetMXdlPerWave() \
106+
{ \
107+
return GetXdlPerWave2<BlockSize, \
108+
NPerBlock, \
109+
MPerBlock, \
110+
NPerXDLAligned, \
111+
MPerXDLAligned, \
112+
NXdlPerWaveAligned, \
113+
IsWave64>(); \
98114
}
99115

100116
template <index_t BlockSize_,
@@ -114,14 +130,14 @@ static constexpr auto GetWarpTileConfig()
114130

115131
constexpr auto NXdlPerWave =
116132
IsWave64
117-
? GetNXdlPerWave2<BlockSize_,
118-
MPerBlock_,
119-
NPerBlock_,
120-
MPerXDL_,
121-
NPerXDL_,
122-
MXdlPerWave_,
123-
true>()
124-
: GetNXdlPerWave2<BlockSize_, MPerBlock_, NPerBlock_, 16, 16, MXdlPerWave32, false>();
133+
? GetXdlPerWave2<BlockSize_,
134+
MPerBlock_,
135+
NPerBlock_,
136+
MPerXDL_,
137+
NPerXDL_,
138+
MXdlPerWave_,
139+
true>()
140+
: GetXdlPerWave2<BlockSize_, MPerBlock_, NPerBlock_, 16, 16, MXdlPerWave32, false>();
125141

126142
if constexpr(IsWave64 == false && NXdlPerWave != 0)
127143
{

include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
190190
using DeviceOp = DeviceBatchedGemmGemm_Xdl_CShuffle;
191191

192192
static constexpr auto MXdlPerWave64 =
193-
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
193+
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
194194
static constexpr auto MXdlPerWave32 =
195-
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
195+
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
196196
static constexpr auto I0 = Number<0>{};
197197
static constexpr auto I1 = Number<1>{};
198198
static constexpr auto I2 = Number<2>{};

include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -235,20 +235,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
235235
{
236236
using DeviceOp = DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle;
237237

238-
static constexpr auto Gemm0MXdlPerWave64 = GetNXdlPerWave2<BlockSize,
239-
Gemm0NPerBlock,
240-
Gemm0MPerBlock,
241-
Gemm0NPerXdl,
242-
Gemm0MPerXdl,
243-
Gemm0NXdlPerWave,
244-
true>();
245-
static constexpr auto Gemm0MXdlPerWave32 = GetNXdlPerWave2<BlockSize,
246-
Gemm0NPerBlock,
247-
Gemm0MPerBlock,
248-
Gemm0NPerXdl,
249-
Gemm0MPerXdl,
250-
Gemm0NXdlPerWave,
251-
false>();
238+
static constexpr auto Gemm0MXdlPerWave64 = GetXdlPerWave2<BlockSize,
239+
Gemm0NPerBlock,
240+
Gemm0MPerBlock,
241+
Gemm0NPerXdl,
242+
Gemm0MPerXdl,
243+
Gemm0NXdlPerWave,
244+
true>();
245+
static constexpr auto Gemm0MXdlPerWave32 = GetXdlPerWave2<BlockSize,
246+
Gemm0NPerBlock,
247+
Gemm0MPerBlock,
248+
Gemm0NPerXdl,
249+
Gemm0MPerXdl,
250+
Gemm0NXdlPerWave,
251+
false>();
252252

253253
static constexpr index_t NumD0Tensor = D0sDataType::Size();
254254
static constexpr index_t NumD1Tensor = D1sDataType::Size();

include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,9 +223,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
223223
MaskingSpec>
224224
{
225225
static constexpr auto MXdlPerWave64 =
226-
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
226+
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
227227
static constexpr auto MXdlPerWave32 =
228-
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
228+
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
229229

230230
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
231231
"Number of dimension must be greater than 0");

include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
211211

212212
using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle;
213213
static constexpr auto MXdlPerWave64 =
214-
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
214+
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
215215
static constexpr auto MXdlPerWave32 =
216-
GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
216+
GetXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
217217

218218
static constexpr auto I0 = Number<0>{};
219219
static constexpr auto I1 = Number<1>{};

0 commit comments

Comments
 (0)