Skip to content

Commit 19415d0

Browse files
fix: nil performance results for gemm examples (#2950)
1 parent d4761d7 commit 19415d0

7 files changed

Lines changed: 220 additions & 238 deletions

File tree

example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,15 +252,14 @@ struct SplitKTwoStageInvoker
252252
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
253253
if(args.k_batch == 1)
254254
{
255-
Run(has_hot_loop_, tail_number_, MemoryOpSet{});
255+
return Run(has_hot_loop_, tail_number_, MemoryOpSet{});
256256
}
257257
else
258258
{
259-
Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
259+
return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
260260
}
261261
};
262262

263-
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
264-
return ave_time;
263+
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
265264
}
266265
};

example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -275,30 +275,29 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
275275
hipGetErrorString(hipMemsetAsync(
276276
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
277277
};
278-
ave_time = ck_tile::launch_kernel_time_mask(
279-
s,
280-
run_flush_cache,
281-
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
278+
return ave_time = ck_tile::launch_kernel_time_mask(
279+
s,
280+
run_flush_cache,
281+
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
282+
Kernel{}, grids, blocks, 0, kargs));
282283
}
283284
else
284285
{
285-
ave_time = ck_tile::launch_kernel(
286-
s,
287-
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
286+
return ave_time = ck_tile::launch_kernel(s,
287+
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
288+
Kernel{}, grids, blocks, 0, kargs));
288289
}
289-
return ave_time;
290290
};
291291

292292
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
293293
// For workspace mode, always use SET operation since each K-split writes to separate memory
294-
Run(has_hot_loop_,
295-
tail_number_,
296-
ck_tile::integral_constant<ck_tile::memory_operation_enum,
297-
ck_tile::memory_operation_enum::set>{});
294+
return Run(has_hot_loop_,
295+
tail_number_,
296+
ck_tile::integral_constant<ck_tile::memory_operation_enum,
297+
ck_tile::memory_operation_enum::set>{});
298298
};
299299

300-
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
301-
return ave_time;
300+
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
302301
}
303302

304303
/**

example/ck_tile/03_gemm/universal_gemm_invoker.hpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,15 +185,14 @@ struct UniversalInvoker
185185
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
186186
if(args.k_batch == 1)
187187
{
188-
Run(has_hot_loop_, tail_number_, MemoryOpSet{});
188+
return Run(has_hot_loop_, tail_number_, MemoryOpSet{});
189189
}
190190
else
191191
{
192-
Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
192+
return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
193193
}
194194
};
195195

196-
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
197-
return ave_time;
196+
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
198197
}
199198
};

example/ck_tile/17_grouped_gemm/grouped_gemm.cpp

Lines changed: 93 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -70,99 +70,95 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
7070

7171
float ave_time{0};
7272

73-
const auto Run = [&](const auto has_hot_loop_,
74-
const auto tail_number_,
75-
const auto memory_operation_) {
76-
constexpr bool has_hot_loop_v = has_hot_loop_.value;
77-
constexpr auto tail_number_v = tail_number_.value;
78-
constexpr auto scheduler = GemmConfig::Scheduler;
79-
constexpr auto memory_operation = memory_operation_.value;
80-
81-
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
82-
BDataType,
83-
AccDataType,
84-
GemmShape,
85-
GemmUniversalTraits,
86-
scheduler,
87-
has_hot_loop_v,
88-
tail_number_v>;
89-
90-
using GemmPipeline = typename PipelineTypeTraits<
91-
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
92-
using GemmEpilogue = ck_tile::CShuffleEpilogue<
93-
ck_tile::CShuffleEpilogueProblem<ADataType,
94-
BDataType,
95-
DsDataType,
96-
AccDataType,
97-
CDataType,
98-
DsLayout,
99-
CLayout,
100-
CDEElementWise,
101-
TilePartitioner::MPerBlock,
102-
TilePartitioner::NPerBlock,
103-
GemmConfig::M_Warp,
104-
GemmConfig::N_Warp,
105-
GemmConfig::M_Warp_Tile,
106-
GemmConfig::N_Warp_Tile,
107-
GemmConfig::K_Warp_Tile,
108-
UniversalGemmProblem::TransposeC,
109-
memory_operation>>;
110-
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
111-
auto kargs = Kernel::MakeKargs(gemm_descs);
112-
if(!Kernel::IsSupportedArgument(kargs))
113-
{
114-
throw std::runtime_error("Kernel arguments not supported!");
115-
}
116-
117-
const dim3 blocks = Kernel::BlockSize();
118-
const dim3 grids = Kernel::GridSize(gemm_descs);
119-
120-
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
121-
kargs.data(),
122-
get_workspace_size(gemm_descs),
123-
hipMemcpyHostToDevice,
124-
s.stream_id_));
125-
126-
if(s.log_level_ > 0)
127-
{
128-
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
129-
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
130-
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
131-
}
132-
133-
ave_time =
134-
ck_tile::launch_kernel(s,
135-
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
136-
Kernel{},
137-
grids,
138-
blocks,
139-
0,
140-
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
141-
gemm_descs.size()));
142-
143-
return ave_time;
144-
};
73+
const auto Run =
74+
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
75+
constexpr bool has_hot_loop_v = has_hot_loop_.value;
76+
constexpr auto tail_number_v = tail_number_.value;
77+
constexpr auto scheduler = GemmConfig::Scheduler;
78+
constexpr auto memory_operation = memory_operation_.value;
79+
80+
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
81+
BDataType,
82+
AccDataType,
83+
GemmShape,
84+
GemmUniversalTraits,
85+
scheduler,
86+
has_hot_loop_v,
87+
tail_number_v>;
88+
89+
using GemmPipeline = typename PipelineTypeTraits<
90+
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
91+
using GemmEpilogue = ck_tile::CShuffleEpilogue<
92+
ck_tile::CShuffleEpilogueProblem<ADataType,
93+
BDataType,
94+
DsDataType,
95+
AccDataType,
96+
CDataType,
97+
DsLayout,
98+
CLayout,
99+
CDEElementWise,
100+
TilePartitioner::MPerBlock,
101+
TilePartitioner::NPerBlock,
102+
GemmConfig::M_Warp,
103+
GemmConfig::N_Warp,
104+
GemmConfig::M_Warp_Tile,
105+
GemmConfig::N_Warp_Tile,
106+
GemmConfig::K_Warp_Tile,
107+
UniversalGemmProblem::TransposeC,
108+
memory_operation>>;
109+
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
110+
auto kargs = Kernel::MakeKargs(gemm_descs);
111+
if(!Kernel::IsSupportedArgument(kargs))
112+
{
113+
throw std::runtime_error("Kernel arguments not supported!");
114+
}
115+
116+
const dim3 blocks = Kernel::BlockSize();
117+
const dim3 grids = Kernel::GridSize(gemm_descs);
118+
119+
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
120+
kargs.data(),
121+
get_workspace_size(gemm_descs),
122+
hipMemcpyHostToDevice,
123+
s.stream_id_));
124+
125+
if(s.log_level_ > 0)
126+
{
127+
std::cout << "Launching kernel: " << Kernel::GetName()
128+
<< " with args:" << " grid: {" << grids.x << ", " << grids.y << ", "
129+
<< grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
130+
<< blocks.z << "}" << std::endl;
131+
}
132+
133+
return ave_time = ck_tile::launch_kernel(
134+
s,
135+
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
136+
Kernel{},
137+
grids,
138+
blocks,
139+
0,
140+
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
141+
gemm_descs.size()));
142+
};
145143

146144
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
147145
if(gemm_descs[0].k_batch == 1)
148146
{
149-
Run(has_hot_loop_,
150-
tail_number_,
151-
ck_tile::integral_constant<ck_tile::memory_operation_enum,
152-
ck_tile::memory_operation_enum::set>{});
147+
return Run(has_hot_loop_,
148+
tail_number_,
149+
ck_tile::integral_constant<ck_tile::memory_operation_enum,
150+
ck_tile::memory_operation_enum::set>{});
153151
}
154152
else
155153
{
156-
Run(has_hot_loop_,
157-
tail_number_,
158-
ck_tile::integral_constant<ck_tile::memory_operation_enum,
159-
ck_tile::memory_operation_enum::atomic_add>{});
154+
return Run(has_hot_loop_,
155+
tail_number_,
156+
ck_tile::integral_constant<ck_tile::memory_operation_enum,
157+
ck_tile::memory_operation_enum::atomic_add>{});
160158
}
161159
};
162160

163-
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
164-
165-
return ave_time;
161+
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
166162
}
167163

168164
template <typename GemmConfig,
@@ -243,31 +239,28 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
243239
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
244240
}
245241

246-
ave_time =
247-
ck_tile::launch_kernel(s,
248-
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
249-
Kernel{},
250-
grids,
251-
blocks,
252-
0,
253-
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
254-
num_groups));
255-
256-
return ave_time;
242+
return ave_time = ck_tile::launch_kernel(
243+
s,
244+
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
245+
Kernel{},
246+
grids,
247+
blocks,
248+
0,
249+
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
250+
num_groups));
257251
};
258252

259253
if(!splitk)
260254
{
261-
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
262-
ck_tile::memory_operation_enum::set>{});
255+
return ave_time = Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
256+
ck_tile::memory_operation_enum::set>{});
263257
}
264258
else
265259
{
266-
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
267-
ck_tile::memory_operation_enum::atomic_add>{});
260+
return ave_time =
261+
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
262+
ck_tile::memory_operation_enum::atomic_add>{});
268263
}
269-
270-
return ave_time;
271264
}
272265

273266
#include "run_grouped_gemm_example.inc"

0 commit comments

Comments
 (0)