Skip to content

Commit e1381d6

Browse files
[CK grouped gemm] Fix grouped gemm two stage HasMainK0BlockLoop (#3466)
* Re-enable two stage kernel * Only disable on HasMainKBlockLoop mismatch * Address PR comments
1 parent 4ce7d4c commit e1381d6

1 file changed

Lines changed: 41 additions & 16 deletions

File tree

include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#pragma once
55

6+
#include <ios>
67
#include <iostream>
78
#include <sstream>
89
#include <tuple>
@@ -677,8 +678,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
677678

678679
all_have_kbatch_gt_one = arg.K_BATCH > 1;
679680
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(
680-
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
681-
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
681+
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1));
682682
}
683683

684684
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
@@ -709,8 +709,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
709709

710710
bool not_all_have_main_k_block_loop_same =
711711
all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(
712-
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
713-
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
712+
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1));
714713
bool not_all_have_kbatch_value_same =
715714
all_have_kbatch_gt_one xor (gemm_arg.k_batch > 1);
716715

@@ -848,21 +847,47 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
848847
return false;
849848
}
850849

851-
// TODO: Fix this.
852-
// Error appears in `script/profiler_grouped_gemm.sh grouped_gemm 1 0 1 1 0 0`
853-
if(std::is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
854-
std::is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
855-
std::is_same<ELayout, tensor_layout::gemm::RowMajor>::value &&
856-
getGemmSpecializationString(GemmSpec) == "MNKPadding" && arg.K_BATCH > 2)
850+
// Check if all groups have compatible HasMainLoop values
851+
if(!arg.gemm_kernel_args_.empty())
857852
{
858-
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
853+
const auto& first_arg = arg.gemm_kernel_args_[0].karg_;
854+
const auto first_desc =
855+
GridwiseGemm64::MakeAGridDescriptor_KBatch_K0_M_K1(first_arg.M,
856+
first_arg.MPadded,
857+
first_arg.K,
858+
first_arg.StrideA,
859+
first_arg.k_batch,
860+
first_arg.K0Padded,
861+
first_arg.KPadded);
862+
const bool first_has_main_loop =
863+
GridwiseGemm64::CalculateHasMainK0BlockLoop(first_desc.GetLength(I1));
864+
865+
for(std::size_t i = 1; i < arg.gemm_kernel_args_.size(); ++i)
859866
{
860-
std::cout
861-
<< "All RowMajor layout with MNKPadding specialization and KBatch > 2 is not "
862-
"supported for all possible shapes!"
863-
<< std::endl;
867+
const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_;
868+
const auto desc =
869+
GridwiseGemm64::MakeAGridDescriptor_KBatch_K0_M_K1(gemm_arg.M,
870+
gemm_arg.MPadded,
871+
gemm_arg.K,
872+
gemm_arg.StrideA,
873+
gemm_arg.k_batch,
874+
gemm_arg.K0Padded,
875+
gemm_arg.KPadded);
876+
const bool has_main_loop =
877+
GridwiseGemm64::CalculateHasMainK0BlockLoop(desc.GetLength(I1));
878+
879+
if(first_has_main_loop != has_main_loop)
880+
{
881+
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
882+
{
883+
std::cout << std::boolalpha
884+
<< "Not all groups have compatible HasMainLoop values! "
885+
<< "Group 0: " << first_has_main_loop << ", Group " << i << ": "
886+
<< has_main_loop << std::endl;
887+
}
888+
return false;
889+
}
864890
}
865-
return false;
866891
}
867892

868893
bool supported = true;

0 commit comments

Comments
 (0)