|
3 | 3 |
|
4 | 4 | #pragma once |
5 | 5 |
|
| 6 | +#include <ios> |
6 | 7 | #include <iostream> |
7 | 8 | #include <sstream> |
8 | 9 | #include <tuple> |
@@ -677,8 +678,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage |
677 | 678 |
|
678 | 679 | all_have_kbatch_gt_one = arg.K_BATCH > 1; |
679 | 680 | all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop( |
680 | | - a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * |
681 | | - a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); |
| 681 | + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1)); |
682 | 682 | } |
683 | 683 |
|
684 | 684 | for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) |
@@ -709,8 +709,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage |
709 | 709 |
|
710 | 710 | bool not_all_have_main_k_block_loop_same = |
711 | 711 | all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop( |
712 | | - a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * |
713 | | - a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); |
| 712 | + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1)); |
714 | 713 | bool not_all_have_kbatch_value_same = |
715 | 714 | all_have_kbatch_gt_one xor (gemm_arg.k_batch > 1); |
716 | 715 |
|
@@ -848,21 +847,47 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage |
848 | 847 | return false; |
849 | 848 | } |
850 | 849 |
|
851 | | - // TODO: Fix this. |
852 | | - // Error appears in `script/profiler_grouped_gemm.sh grouped_gemm 1 0 1 1 0 0` |
853 | | - if(std::is_same<ALayout, tensor_layout::gemm::RowMajor>::value && |
854 | | - std::is_same<BLayout, tensor_layout::gemm::RowMajor>::value && |
855 | | - std::is_same<ELayout, tensor_layout::gemm::RowMajor>::value && |
856 | | - getGemmSpecializationString(GemmSpec) == "MNKPadding" && arg.K_BATCH > 2) |
| 850 | + // Check if all groups have compatible HasMainLoop values |
| 851 | + if(!arg.gemm_kernel_args_.empty()) |
857 | 852 | { |
858 | | - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
| 853 | + const auto& first_arg = arg.gemm_kernel_args_[0].karg_; |
| 854 | + const auto first_desc = |
| 855 | + GridwiseGemm64::MakeAGridDescriptor_KBatch_K0_M_K1(first_arg.M, |
| 856 | + first_arg.MPadded, |
| 857 | + first_arg.K, |
| 858 | + first_arg.StrideA, |
| 859 | + first_arg.k_batch, |
| 860 | + first_arg.K0Padded, |
| 861 | + first_arg.KPadded); |
| 862 | + const bool first_has_main_loop = |
| 863 | + GridwiseGemm64::CalculateHasMainK0BlockLoop(first_desc.GetLength(I1)); |
| 864 | + |
| 865 | + for(std::size_t i = 1; i < arg.gemm_kernel_args_.size(); ++i) |
859 | 866 | { |
860 | | - std::cout |
861 | | - << "All RowMajor layout with MNKPadding specialization and KBatch > 2 is not " |
862 | | - "supported for all possible shapes!" |
863 | | - << std::endl; |
| 867 | + const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_; |
| 868 | + const auto desc = |
| 869 | + GridwiseGemm64::MakeAGridDescriptor_KBatch_K0_M_K1(gemm_arg.M, |
| 870 | + gemm_arg.MPadded, |
| 871 | + gemm_arg.K, |
| 872 | + gemm_arg.StrideA, |
| 873 | + gemm_arg.k_batch, |
| 874 | + gemm_arg.K0Padded, |
| 875 | + gemm_arg.KPadded); |
| 876 | + const bool has_main_loop = |
| 877 | + GridwiseGemm64::CalculateHasMainK0BlockLoop(desc.GetLength(I1)); |
| 878 | + |
| 879 | + if(first_has_main_loop != has_main_loop) |
| 880 | + { |
| 881 | + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) |
| 882 | + { |
| 883 | + std::cout << std::boolalpha |
| 884 | + << "Not all groups have compatible HasMainLoop values! " |
| 885 | + << "Group 0: " << first_has_main_loop << ", Group " << i << ": " |
| 886 | + << has_main_loop << std::endl; |
| 887 | + } |
| 888 | + return false; |
| 889 | + } |
864 | 890 | } |
865 | | - return false; |
866 | 891 | } |
867 | 892 |
|
868 | 893 | bool supported = true; |
|
0 commit comments