Skip to content

Commit c0797c1

Browse files
jakpiasebartekxk
andauthored
[CK_TILE] Minor splitk bugfix for gemms and conv (#3387)
* fix for splitk if splitk < grid * add different splitk implementation * minor bugfix for streamk gemm * Add test --------- Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>
1 parent e1381d6 commit c0797c1

3 files changed

Lines changed: 80 additions & 13 deletions

File tree

include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -323,22 +323,38 @@ struct UniversalGemmKernel
323323

324324
struct SplitKBatchOffset
325325
{
326-
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
326+
// This structure distributes work evenly among splitkk workgroups
327+
// It's based on a principle that if there is enough work to fill all workgroups,
328+
// then we can distribute the (K / K1) parts among k_batch workgroups in such a way
329+
// that each workgroup will be doing ceil((K / K1) / splitk) or ceil((K / K1) / splitk) - 1
330+
// and leave the potential tail for last(splitk - 1) indexed workgroup.
331+
__device__ SplitKBatchOffset(const KernelArgs& kargs, const index_t k_id = blockIdx.z)
327332
{
328-
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{});
329-
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
330-
const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
333+
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
334+
const index_t num_all = amd_wave_read_first_lane(
335+
kargs.K / K1); // num of all loops not including potential tail
336+
index_t num_full = amd_wave_read_first_lane(num_all % kargs.k_batch);
337+
num_full = num_full == 0 ? kargs.k_batch : num_full;
338+
339+
const index_t num_full_iters =
340+
amd_wave_read_first_lane(std::max(integer_divide_ceil(num_all, kargs.k_batch), 1));
341+
const index_t full_k_read = num_full_iters * K1;
342+
const index_t partial_k_read = (num_full_iters - 1) * K1;
331343

332344
static_for<0, NumATensor, 1>{}([&](auto index) {
333345
using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
334346
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, AiLayout>)
335347
{
336-
as_k_split_offset[index] = amd_wave_read_first_lane(k_id * KRead);
348+
as_k_split_offset[index] =
349+
amd_wave_read_first_lane(std::min(k_id, num_full) * full_k_read +
350+
std::max(k_id - num_full, 0) * partial_k_read);
337351
}
338352
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, AiLayout>)
339353
{
340354
as_k_split_offset[index] =
341-
amd_wave_read_first_lane(k_id * KRead * kargs.stride_As[index]);
355+
amd_wave_read_first_lane((std::min(k_id, num_full) * full_k_read +
356+
std::max(k_id - num_full, 0) * partial_k_read) *
357+
kargs.stride_As[index]);
342358
}
343359
});
344360

@@ -347,21 +363,30 @@ struct UniversalGemmKernel
347363
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BiLayout>)
348364
{
349365
bs_k_split_offset[index] =
350-
amd_wave_read_first_lane(k_id * KRead * kargs.stride_Bs[index]);
366+
amd_wave_read_first_lane((std::min(k_id, num_full) * full_k_read +
367+
std::max(k_id - num_full, 0) * partial_k_read) *
368+
kargs.stride_Bs[index]);
351369
}
352370
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BiLayout>)
353371
{
354-
bs_k_split_offset[index] = amd_wave_read_first_lane(k_id * KRead);
372+
bs_k_split_offset[index] =
373+
amd_wave_read_first_lane(std::min(k_id, num_full) * full_k_read +
374+
std::max(k_id - num_full, 0) * partial_k_read);
355375
}
356376
});
357377

358-
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
378+
if(k_id == kargs.k_batch - 1)
379+
{
380+
splitted_k = kargs.K - std::min(k_id, num_full) * full_k_read -
381+
std::max(k_id - num_full, 0) * partial_k_read;
382+
}
383+
else if(k_id < num_full)
359384
{
360-
splitted_k = amd_wave_read_first_lane(KRead);
385+
splitted_k = full_k_read;
361386
}
362387
else
363388
{
364-
splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1));
389+
splitted_k = partial_k_read;
365390
}
366391
}
367392

@@ -385,6 +410,15 @@ struct UniversalGemmKernel
385410
}
386411
}
387412

413+
if(kargs.K < GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch)
414+
{
415+
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
416+
{
417+
CK_TILE_ERROR("KBatch is too large, part of GPU wouldn't be utilized!");
418+
}
419+
return false;
420+
}
421+
388422
const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA<true>()
389423
: GemmPipeline::template GetVectorSizeA<false>();
390424
bool AsTesnorIsValid = {true};

include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,15 @@ struct GroupedConvolutionBackwardWeightKernel
568568
}
569569
}
570570

571+
if(kargs.GemmK < TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch)
572+
{
573+
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
574+
{
575+
CK_TILE_ERROR("KBatch is too large, part of GPU wouldn't be utilized!");
576+
}
577+
return false;
578+
}
579+
571580
const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
572581
const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
573582

test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,11 @@ static GroupedConvBwdWeightHostArgs create_2d_host_args(index_t k_batch)
173173
return create_2d_host_args(2, 2, 8, 8, 3, 3, 7, 7, 1, 1, 1, 1, 1, 1, 1, 1, k_batch);
174174
}
175175

176+
static GroupedConvBwdWeightHostArgs create_large_2d_host_args(index_t k_batch)
177+
{
178+
return create_2d_host_args(2, 2, 8, 8, 3, 3, 70, 70, 1, 1, 1, 1, 1, 1, 1, 1, k_batch);
179+
}
180+
176181
class GroupedConvBwdWeightIsSupportedArgumentTest : public ::testing::Test
177182
{
178183
};
@@ -227,6 +232,25 @@ TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, AtomicAddRequiresKBatchGreat
227232
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_2));
228233
}
229234

235+
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, K0KBatchLimitation)
236+
{
237+
using Kernel = typename BuildKernel<half_t,
238+
TestConvConfig,
239+
tensor_layout::convolution::NHWGC,
240+
tensor_layout::convolution::GKYXC,
241+
tensor_layout::convolution::NHWGK>::type;
242+
243+
// k_batch = 128 should pass
244+
auto host_args_kbatch_6 = create_2d_host_args(6);
245+
auto kargs_6 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_6);
246+
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_6));
247+
248+
// k_batch = 129 should fail for half_t output
249+
auto host_args_kbatch_7 = create_2d_host_args(7);
250+
auto kargs_7 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_7);
251+
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_7));
252+
}
253+
230254
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, NonFloatDoubleOutputLimitsKBatch)
231255
{
232256
using Kernel = typename BuildKernel<half_t,
@@ -236,13 +260,13 @@ TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, NonFloatDoubleOutputLimitsKB
236260
tensor_layout::convolution::NHWGK>::type;
237261

238262
// k_batch = 128 should pass
239-
auto host_args_kbatch_128 = create_2d_host_args(128);
263+
auto host_args_kbatch_128 = create_large_2d_host_args(128);
240264
auto kargs_128 =
241265
typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_128);
242266
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_128));
243267

244268
// k_batch = 129 should fail for half_t output
245-
auto host_args_kbatch_129 = create_2d_host_args(129);
269+
auto host_args_kbatch_129 = create_large_2d_host_args(129);
246270
auto kargs_129 =
247271
typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_129);
248272
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_129));

0 commit comments

Comments
 (0)