Skip to content

Commit f799260

Browse files
bartekxkassistant-librarian[bot]
authored andcommitted
[rocm-libraries] ROCm/rocm-libraries#5555 (commit 1d2c4c8)
[CK][CK Tile] Fix kbatch check in grouped conv and gemm kernels (#5555) ## Motivation Fix kbatch check in grouped conv and gemm kernels, allow tails for kbatch. ## Technical Details Round up K / Kperxdl and divide it by Kbatch to allow tail for K. ## Test Plan test_grouped_convnd_bwd_weight_tile ## Test Result passed locally ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1 parent 6b69ac9 commit f799260

4 files changed

Lines changed: 10 additions & 7 deletions

File tree

include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,8 @@ struct UniversalGemmKernel
418418
}
419419
}
420420

421-
if(kargs.K < GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch)
421+
if(integer_divide_ceil(kargs.K, GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})) <
422+
kargs.k_batch)
422423
{
423424
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
424425
{

include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,9 @@ struct GroupedConvolutionBackwardWeightKernel
574574
}
575575
}
576576

577-
if(kargs.GemmK < TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch)
577+
if(integer_divide_ceil(kargs.GemmK,
578+
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})) <
579+
kargs.k_batch)
578580
{
579581
LogInfo("KBatch is too large, part of GPU wouldn't be utilized! GemmK: ",
580582
kargs.GemmK,

profiler/include/profiler/grouped_convolution_backward_weight_tile_algs.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,11 +178,11 @@ run_grouped_conv_backward_weight_tile_algs(const ckt::Args<SIGNATURE>& args,
178178
});
179179

180180
const bool valid = report.get_errors().empty();
181+
best_avg_time = std::min(best_avg_time, avg_time);
182+
best_op_name = best_avg_time < avg_time ? best_op_name : op_name;
183+
best_split_k = best_avg_time < avg_time ? best_split_k : k_batch;
181184
if(valid)
182185
{
183-
best_avg_time = std::min(best_avg_time, avg_time);
184-
best_op_name = best_avg_time < avg_time ? best_op_name : op_name;
185-
best_split_k = best_avg_time < avg_time ? best_split_k : k_batch;
186186
std::cout << "[Valid] Perf: " << std::setw(10) << avg_time << " ms," << " "
187187
<< op_name << ", SplitK " << k_batch << std::endl;
188188
}

test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,12 @@ TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, K0KBatchLimitation)
219219
tensor_layout::convolution::NHWGK>::type;
220220

221221
// k_batch = 128 should pass
222-
auto host_args_kbatch_6 = create_2d_host_args(6);
222+
auto host_args_kbatch_6 = create_2d_host_args(7);
223223
auto kargs_6 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_6);
224224
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_6));
225225

226226
// k_batch = 129 should fail for half_t output
227-
auto host_args_kbatch_7 = create_2d_host_args(7);
227+
auto host_args_kbatch_7 = create_2d_host_args(8);
228228
auto kargs_7 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_7);
229229
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_7));
230230
}

0 commit comments

Comments
 (0)