11// SPDX-License-Identifier: MIT
2- // Copyright (c) 2024- 2025, Advanced Micro Devices, Inc. All rights reserved.
2+ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
33
44#include < hip/hip_runtime.h>
55
1616#include " ck_tile/host.hpp"
1717#include " grouped_gemm.hpp"
1818
19- template <typename ADataType,
20- typename BDataType,
21- typename DsDataType,
22- typename AccDataType,
23- typename CDataType,
24- typename ALayout,
25- typename BLayout,
26- typename DsLayout,
27- typename CLayout,
28- typename CDEElementWise = ck_tile::element_wise::PassThrough>
29- float grouped_gemm (const std::vector<grouped_gemm_kargs>& gemm_descs,
30- const ck_tile::stream_config& s,
31- void * kargs_ptr)
19+ template <typename ALayout, typename BLayout, typename CLayout>
20+ float grouped_gemm_tileloop (const ck_tile::stream_config& s,
21+ const ck_tile::index_t num_groups,
22+ void * kargs_ptr,
23+ bool splitk)
3224{
3325#if (CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
3426 // Memory friendly for Interwave scheduler
@@ -83,8 +75,6 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
8375 constexpr bool kPadN = false ;
8476 constexpr bool kPadK = false ;
8577
86- constexpr bool TransposeC = false ;
87-
8878 constexpr int kBlockPerCu = 1 ;
8979 constexpr ck_tile::index_t TileParitionerGroupNum = 8 ;
9080 constexpr ck_tile::index_t TileParitionerM01 = 4 ;
@@ -97,54 +87,41 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
9787 GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
9888
9989 using Traits = ck_tile::TileGemmTraits<kPadM , kPadN , kPadK , ALayout, BLayout, CLayout>;
100- using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM ,
101- kPadN ,
102- kPadK ,
103- DoubleSmemBuffer,
104- ALayout,
105- BLayout,
106- CLayout,
107- TransposeC>;
90+ using GemmUniversalTraits = ck_tile::PersistentTileGemmUniversalTraits<kPadM ,
91+ kPadN ,
92+ kPadK ,
93+ DoubleSmemBuffer,
94+ ALayout,
95+ BLayout,
96+ CLayout>;
10897 using GemmPipelineProblem =
10998 ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
11099
111- using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE <GemmPipelineProblem>;
112-
113- const ck_tile::index_t k_grain = gemm_descs[0 ].k_batch * K_Tile;
114- const ck_tile::index_t K_split = (gemm_descs[0 ].K + k_grain - 1 ) / k_grain * K_Tile;
115- const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum (K_split);
116- const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop (num_loop);
117- const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum (num_loop);
118-
119100 float ave_time{0 };
120101
121- const auto Run = [&](const auto has_hot_loop_,
122- const auto tail_number_,
123- const auto memory_operation_) {
124- constexpr bool has_hot_loop_v = has_hot_loop_.value ;
125- constexpr auto tail_number_v = tail_number_.value ;
102+ const auto Run = [&](const auto memory_operation_) {
126103 constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER ;
127104 constexpr auto memory_operation = memory_operation_.value ;
128105
106+ // We create the GEMM pipeline without specifying hotloop or tailnumber.
107+ // These are automatically run inside the kernel based on the given input data.
129108 using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
130109 BDataType,
131110 AccDataType,
132111 GemmShape,
133112 GemmUniversalTraits,
134- scheduler,
135- has_hot_loop_v,
136- tail_number_v>;
113+ scheduler>;
137114
138115 using GemmPipeline = GEMM_PIPELINE <UniversalGemmProblem>;
139116 using GemmEpilogue = ck_tile::CShuffleEpilogue<
140117 ck_tile::CShuffleEpilogueProblem<ADataType,
141118 BDataType,
142- DsDataType ,
119+ ck_tile::tuple<> ,
143120 AccDataType,
144121 CDataType,
145- DsLayout ,
122+ ck_tile::tuple<> ,
146123 CLayout,
147- CDEElementWise ,
124+ ck_tile::element_wise::PassThrough ,
148125 GemmPipelineProblem::kBlockSize ,
149126 TilePartitioner::MPerBlock,
150127 TilePartitioner::NPerBlock,
@@ -156,20 +133,8 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
156133 UniversalGemmProblem::TransposeC,
157134 memory_operation>>;
158135 using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
159- auto kargs = Kernel::MakeKargs (gemm_descs);
160- if (!Kernel::IsSupportedArgument (kargs))
161- {
162- throw std::runtime_error (" Kernel arguments not supported!" );
163- }
164-
165136 constexpr dim3 blocks = Kernel::BlockSize ();
166- const dim3 grids = Kernel::GridSize (gemm_descs);
167-
168- HIP_CHECK_ERROR (hipMemcpyWithStream (kargs_ptr,
169- kargs.data (),
170- get_workspace_size (gemm_descs),
171- hipMemcpyHostToDevice,
172- s.stream_id_ ));
137+ const dim3 grids = Kernel::MaxOccupancyGridSize (s);
173138
174139 if (s.log_level_ > 0 )
175140 {
@@ -186,45 +151,26 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
186151 blocks,
187152 0 ,
188153 ck_tile::cast_pointer_to_constant_address_space (kargs_ptr),
189- gemm_descs. size () ));
154+ num_groups ));
190155
191156 return ave_time;
192157 };
193158
194- const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
195- if (gemm_descs[0 ].k_batch == 1 )
196- {
197- Run (has_hot_loop_,
198- tail_number_,
199- ck_tile::integral_constant<ck_tile::memory_operation_enum,
200- ck_tile::memory_operation_enum::set>{});
201- }
202- else
203- {
204- Run (has_hot_loop_,
205- tail_number_,
206- ck_tile::integral_constant<ck_tile::memory_operation_enum,
207- ck_tile::memory_operation_enum::atomic_add>{});
208- }
209- };
210-
211- BaseGemmPipeline::TailHandler (RunSplitk, has_hot_loop, tail_num);
159+ if (!splitk)
160+ {
161+ Run (ck_tile::integral_constant<ck_tile::memory_operation_enum,
162+ ck_tile::memory_operation_enum::set>{});
163+ }
164+ else
165+ {
166+ Run (ck_tile::integral_constant<ck_tile::memory_operation_enum,
167+ ck_tile::memory_operation_enum::atomic_add>{});
168+ }
212169
213170 return ave_time;
214171}
215172
216173#include " run_grouped_gemm_example.inc"
217174
218- constexpr bool Persistent = false ;
219- int main (int argc, char * argv[])
220- {
221- try
222- {
223- return !run_grouped_gemm_example<Persistent>(argc, argv);
224- }
225- catch (const std::runtime_error& e)
226- {
227- std::cerr << " Runtime error: " << e.what () << ' \n ' ;
228- return EXIT_FAILURE ;
229- }
230- }
175+ constexpr bool Persistent = true ;
176+ int main (int argc, char * argv[]) { return !run_grouped_gemm_example<Persistent>(argc, argv); }
0 commit comments