@@ -70,99 +70,95 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
7070
7171 float ave_time{0 };
7272
73- const auto Run = [&](const auto has_hot_loop_,
74- const auto tail_number_,
75- const auto memory_operation_) {
76- constexpr bool has_hot_loop_v = has_hot_loop_.value ;
77- constexpr auto tail_number_v = tail_number_.value ;
78- constexpr auto scheduler = GemmConfig::Scheduler;
79- constexpr auto memory_operation = memory_operation_.value ;
80-
81- using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
82- BDataType,
83- AccDataType,
84- GemmShape,
85- GemmUniversalTraits,
86- scheduler,
87- has_hot_loop_v,
88- tail_number_v>;
89-
90- using GemmPipeline = typename PipelineTypeTraits<
91- GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
92- using GemmEpilogue = ck_tile::CShuffleEpilogue<
93- ck_tile::CShuffleEpilogueProblem<ADataType,
94- BDataType,
95- DsDataType,
96- AccDataType,
97- CDataType,
98- DsLayout,
99- CLayout,
100- CDEElementWise,
101- TilePartitioner::MPerBlock,
102- TilePartitioner::NPerBlock,
103- GemmConfig::M_Warp,
104- GemmConfig::N_Warp,
105- GemmConfig::M_Warp_Tile,
106- GemmConfig::N_Warp_Tile,
107- GemmConfig::K_Warp_Tile,
108- UniversalGemmProblem::TransposeC,
109- memory_operation>>;
110- using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
111- auto kargs = Kernel::MakeKargs (gemm_descs);
112- if (!Kernel::IsSupportedArgument (kargs))
113- {
114- throw std::runtime_error (" Kernel arguments not supported!" );
115- }
116-
117- const dim3 blocks = Kernel::BlockSize ();
118- const dim3 grids = Kernel::GridSize (gemm_descs);
119-
120- HIP_CHECK_ERROR (hipMemcpyWithStream (kargs_ptr,
121- kargs.data (),
122- get_workspace_size (gemm_descs),
123- hipMemcpyHostToDevice,
124- s.stream_id_ ));
125-
126- if (s.log_level_ > 0 )
127- {
128- std::cout << " Launching kernel: " << Kernel::GetName () << " with args:" << " grid: {"
129- << grids.x << " , " << grids.y << " , " << grids.z << " }" << " , blocks: {"
130- << blocks.x << " , " << blocks.y << " , " << blocks.z << " }" << std::endl;
131- }
132-
133- ave_time =
134- ck_tile::launch_kernel (s,
135- ck_tile::make_kernel<GemmConfig::kBlockPerCu >(
136- Kernel{},
137- grids,
138- blocks,
139- 0 ,
140- ck_tile::cast_pointer_to_constant_address_space (kargs_ptr),
141- gemm_descs.size ()));
142-
143- return ave_time;
144- };
73+ const auto Run =
74+ [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
75+ constexpr bool has_hot_loop_v = has_hot_loop_.value ;
76+ constexpr auto tail_number_v = tail_number_.value ;
77+ constexpr auto scheduler = GemmConfig::Scheduler;
78+ constexpr auto memory_operation = memory_operation_.value ;
79+
80+ using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
81+ BDataType,
82+ AccDataType,
83+ GemmShape,
84+ GemmUniversalTraits,
85+ scheduler,
86+ has_hot_loop_v,
87+ tail_number_v>;
88+
89+ using GemmPipeline = typename PipelineTypeTraits<
90+ GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
91+ using GemmEpilogue = ck_tile::CShuffleEpilogue<
92+ ck_tile::CShuffleEpilogueProblem<ADataType,
93+ BDataType,
94+ DsDataType,
95+ AccDataType,
96+ CDataType,
97+ DsLayout,
98+ CLayout,
99+ CDEElementWise,
100+ TilePartitioner::MPerBlock,
101+ TilePartitioner::NPerBlock,
102+ GemmConfig::M_Warp,
103+ GemmConfig::N_Warp,
104+ GemmConfig::M_Warp_Tile,
105+ GemmConfig::N_Warp_Tile,
106+ GemmConfig::K_Warp_Tile,
107+ UniversalGemmProblem::TransposeC,
108+ memory_operation>>;
109+ using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
110+ auto kargs = Kernel::MakeKargs (gemm_descs);
111+ if (!Kernel::IsSupportedArgument (kargs))
112+ {
113+ throw std::runtime_error (" Kernel arguments not supported!" );
114+ }
115+
116+ const dim3 blocks = Kernel::BlockSize ();
117+ const dim3 grids = Kernel::GridSize (gemm_descs);
118+
119+ HIP_CHECK_ERROR (hipMemcpyWithStream (kargs_ptr,
120+ kargs.data (),
121+ get_workspace_size (gemm_descs),
122+ hipMemcpyHostToDevice,
123+ s.stream_id_ ));
124+
125+ if (s.log_level_ > 0 )
126+ {
127+ std::cout << " Launching kernel: " << Kernel::GetName ()
128+ << " with args:" << " grid: {" << grids.x << " , " << grids.y << " , "
129+ << grids.z << " }" << " , blocks: {" << blocks.x << " , " << blocks.y << " , "
130+ << blocks.z << " }" << std::endl;
131+ }
132+
133+ return ave_time = ck_tile::launch_kernel (
134+ s,
135+ ck_tile::make_kernel<GemmConfig::kBlockPerCu >(
136+ Kernel{},
137+ grids,
138+ blocks,
139+ 0 ,
140+ ck_tile::cast_pointer_to_constant_address_space (kargs_ptr),
141+ gemm_descs.size ()));
142+ };
145143
146144 const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
147145 if (gemm_descs[0 ].k_batch == 1 )
148146 {
149- Run (has_hot_loop_,
150- tail_number_,
151- ck_tile::integral_constant<ck_tile::memory_operation_enum,
152- ck_tile::memory_operation_enum::set>{});
147+ return Run (has_hot_loop_,
148+ tail_number_,
149+ ck_tile::integral_constant<ck_tile::memory_operation_enum,
150+ ck_tile::memory_operation_enum::set>{});
153151 }
154152 else
155153 {
156- Run (has_hot_loop_,
157- tail_number_,
158- ck_tile::integral_constant<ck_tile::memory_operation_enum,
159- ck_tile::memory_operation_enum::atomic_add>{});
154+ return Run (has_hot_loop_,
155+ tail_number_,
156+ ck_tile::integral_constant<ck_tile::memory_operation_enum,
157+ ck_tile::memory_operation_enum::atomic_add>{});
160158 }
161159 };
162160
163- BaseGemmPipeline::TailHandler (RunSplitk, has_hot_loop, tail_num);
164-
165- return ave_time;
161+ return ave_time = BaseGemmPipeline::TailHandler (RunSplitk, has_hot_loop, tail_num);
166162}
167163
168164template <typename GemmConfig,
@@ -243,31 +239,28 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
243239 << blocks.x << " , " << blocks.y << " , " << blocks.z << " }" << std::endl;
244240 }
245241
246- ave_time =
247- ck_tile::launch_kernel (s,
248- ck_tile::make_kernel<GemmConfig::kBlockPerCu >(
249- Kernel{},
250- grids,
251- blocks,
252- 0 ,
253- ck_tile::cast_pointer_to_constant_address_space (kargs_ptr),
254- num_groups));
255-
256- return ave_time;
242+ return ave_time = ck_tile::launch_kernel (
243+ s,
244+ ck_tile::make_kernel<GemmConfig::kBlockPerCu >(
245+ Kernel{},
246+ grids,
247+ blocks,
248+ 0 ,
249+ ck_tile::cast_pointer_to_constant_address_space (kargs_ptr),
250+ num_groups));
257251 };
258252
259253 if (!splitk)
260254 {
261- Run (ck_tile::integral_constant<ck_tile::memory_operation_enum,
262- ck_tile::memory_operation_enum::set>{});
255+ return ave_time = Run (ck_tile::integral_constant<ck_tile::memory_operation_enum,
256+ ck_tile::memory_operation_enum::set>{});
263257 }
264258 else
265259 {
266- Run (ck_tile::integral_constant<ck_tile::memory_operation_enum,
267- ck_tile::memory_operation_enum::atomic_add>{});
260+ return ave_time =
261+ Run (ck_tile::integral_constant<ck_tile::memory_operation_enum,
262+ ck_tile::memory_operation_enum::atomic_add>{});
268263 }
269-
270- return ave_time;
271264}
272265
273266#include " run_grouped_gemm_example.inc"
0 commit comments