@@ -323,22 +323,38 @@ struct UniversalGemmKernel
323323
324324 struct SplitKBatchOffset
325325 {
326- __device__ SplitKBatchOffset (const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
326+ // This structure distributes work evenly among splitkk workgroups
327+ // It's based on a principle that if there is enough work to fill all workgroups,
328+ // then we can distribute the (K / K1) parts among k_batch workgroups in such a way
329+ // that each workgroup will be doing ceil((K / K1) / splitk) or ceil((K / K1) / splitk) - 1
330+ // and leave the potential tail for last(splitk - 1) indexed workgroup.
331+ __device__ SplitKBatchOffset (const KernelArgs& kargs, const index_t k_id = blockIdx.z)
327332 {
328- constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at (number<2 >{});
329- const index_t K_t = amd_wave_read_first_lane (kargs.k_batch * K1);
330- const index_t KRead = amd_wave_read_first_lane ((kargs.K + K_t - 1 ) / K_t * K1);
333+ constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at (number<2 >{});
334+ const index_t num_all = amd_wave_read_first_lane (
335+ kargs.K / K1); // num of all loops not including potential tail
336+ index_t num_full = amd_wave_read_first_lane (num_all % kargs.k_batch );
337+ num_full = num_full == 0 ? kargs.k_batch : num_full;
338+
339+ const index_t num_full_iters =
340+ amd_wave_read_first_lane (std::max (integer_divide_ceil (num_all, kargs.k_batch ), 1 ));
341+ const index_t full_k_read = num_full_iters * K1;
342+ const index_t partial_k_read = (num_full_iters - 1 ) * K1;
331343
332344 static_for<0 , NumATensor, 1 >{}([&](auto index) {
333345 using AiLayout = remove_cvref_t <std::tuple_element_t <index.value , AsLayout>>;
334346 if constexpr (std::is_same_v<tensor_layout::gemm::RowMajor, AiLayout>)
335347 {
336- as_k_split_offset[index] = amd_wave_read_first_lane (k_id * KRead);
348+ as_k_split_offset[index] =
349+ amd_wave_read_first_lane (std::min (k_id, num_full) * full_k_read +
350+ std::max (k_id - num_full, 0 ) * partial_k_read);
337351 }
338352 else if constexpr (std::is_same_v<tensor_layout::gemm::ColumnMajor, AiLayout>)
339353 {
340354 as_k_split_offset[index] =
341- amd_wave_read_first_lane (k_id * KRead * kargs.stride_As [index]);
355+ amd_wave_read_first_lane ((std::min (k_id, num_full) * full_k_read +
356+ std::max (k_id - num_full, 0 ) * partial_k_read) *
357+ kargs.stride_As [index]);
342358 }
343359 });
344360
@@ -347,21 +363,30 @@ struct UniversalGemmKernel
347363 if constexpr (std::is_same_v<tensor_layout::gemm::RowMajor, BiLayout>)
348364 {
349365 bs_k_split_offset[index] =
350- amd_wave_read_first_lane (k_id * KRead * kargs.stride_Bs [index]);
366+ amd_wave_read_first_lane ((std::min (k_id, num_full) * full_k_read +
367+ std::max (k_id - num_full, 0 ) * partial_k_read) *
368+ kargs.stride_Bs [index]);
351369 }
352370 else if constexpr (std::is_same_v<tensor_layout::gemm::ColumnMajor, BiLayout>)
353371 {
354- bs_k_split_offset[index] = amd_wave_read_first_lane (k_id * KRead);
372+ bs_k_split_offset[index] =
373+ amd_wave_read_first_lane (std::min (k_id, num_full) * full_k_read +
374+ std::max (k_id - num_full, 0 ) * partial_k_read);
355375 }
356376 });
357377
358- if (k_id < static_cast <uint32_t >(kargs.k_batch - 1 ))
378+ if (k_id == kargs.k_batch - 1 )
379+ {
380+ splitted_k = kargs.K - std::min (k_id, num_full) * full_k_read -
381+ std::max (k_id - num_full, 0 ) * partial_k_read;
382+ }
383+ else if (k_id < num_full)
359384 {
360- splitted_k = amd_wave_read_first_lane (KRead) ;
385+ splitted_k = full_k_read ;
361386 }
362387 else
363388 {
364- splitted_k = amd_wave_read_first_lane (kargs. K - KRead * (kargs. k_batch - 1 )) ;
389+ splitted_k = partial_k_read ;
365390 }
366391 }
367392
@@ -385,6 +410,15 @@ struct UniversalGemmKernel
385410 }
386411 }
387412
413+ if (kargs.K < GemmPipeline::BlockGemmShape::WarpTile::at (number<2 >{}) * kargs.k_batch )
414+ {
415+ if (ck_tile::EnvIsEnabled (CK_TILE_ENV (CK_TILE_LOGGING)))
416+ {
417+ CK_TILE_ERROR (" KBatch is too large, part of GPU wouldn't be utilized!" );
418+ }
419+ return false ;
420+ }
421+
388422 const auto vectorSizeA = is_wave32 () ? GemmPipeline::template GetVectorSizeA<true >()
389423 : GemmPipeline::template GetVectorSizeA<false >();
390424 bool AsTesnorIsValid = {true };
0 commit comments