Skip to content

Commit d184eed

Browse files
authored
[CK-Tile] Refactor base pipeline usage (#3251)
* initial poc * factor out common parts in operator() * cv4 * rest of the universal gemm pipelines * fix test * remove boilerplate from tile engine * fix example * fix example * format * fix tests build for gemm * remove base pipeline codegen from gemm instance builder * unify v3 logic with the rest of universal gemm pipelines * fix build for multi abd test * fix test gemm multi d * fix build for weight preshuffle * fix grouped gemm test * fix grouped gemm multi d test * fix grouped gemm preshuffle * fix grouped gemm example except for quant * fix gemm preshuffle * fix splitk 2 stage example * fix batched gemm example * fix multid example * fix multiabd example * fix batched gemm test * fixup * fix examples build * fix grouped gemm test build * fix smoke builder
1 parent d9d4c9c commit d184eed

37 files changed

Lines changed: 1044 additions & 1868 deletions

File tree

example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp

Lines changed: 21 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,6 @@ struct SplitKTwoStageInvoker
4646
GemmConfig::TileParitionerGroupNum,
4747
GemmConfig::TileParitionerM01>;
4848

49-
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
50-
GemmConfig::kPadN,
51-
GemmConfig::kPadK,
52-
ALayout,
53-
BLayout,
54-
ELayout,
55-
GemmConfig::NumWaveGroups>;
56-
5749
using GemmUniversalTraits =
5850
ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
5951
GemmConfig::kPadN,
@@ -67,40 +59,21 @@ struct SplitKTwoStageInvoker
6759
Persistent,
6860
GemmConfig::NumWaveGroups,
6961
GemmConfig::Preshuffle>;
70-
using GemmPipelineProblem =
71-
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
72-
73-
using BaseGemmPipeline = typename PipelineTypeTraits<
74-
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
75-
76-
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
77-
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
78-
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
79-
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
80-
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
81-
float ave_time{0};
82-
83-
const auto Run = [&](const auto has_hot_loop_,
84-
const auto tail_number_,
85-
const auto memory_operation_) {
86-
constexpr bool has_hot_loop_v = has_hot_loop_.value;
87-
constexpr auto tail_number_v = tail_number_.value;
88-
constexpr auto scheduler = GemmConfig::Scheduler;
89-
constexpr auto memory_operation = memory_operation_.value;
62+
constexpr auto scheduler = GemmConfig::Scheduler;
9063

91-
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
92-
BDataType,
93-
AccDataType,
94-
GemmShape,
95-
GemmUniversalTraits,
96-
scheduler,
97-
has_hot_loop_v,
98-
tail_number_v>;
64+
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
65+
BDataType,
66+
AccDataType,
67+
GemmShape,
68+
GemmUniversalTraits,
69+
scheduler>;
70+
using WorkspaceType = ck_tile::remove_cvref_t<typename GemmConfig::WorkspaceType>;
9971

100-
using GemmPipeline = typename PipelineTypeTraits<
101-
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
72+
using GemmPipeline = typename PipelineTypeTraits<
73+
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
10274

103-
using WorkspaceType = ck_tile::remove_cvref_t<typename GemmConfig::WorkspaceType>;
75+
const auto Run = [&](const auto memory_operation_) {
76+
constexpr auto memory_operation = memory_operation_.value;
10477

10578
using GemmEpilogue = ck_tile::CShuffleEpilogue<
10679
ck_tile::CShuffleEpilogueProblem<ADataType,
@@ -230,7 +203,7 @@ struct SplitKTwoStageInvoker
230203
preprocess = clear_gemm_output;
231204
}
232205

233-
ave_time = ck_tile::launch_kernel_time_mask(
206+
return ck_tile::launch_kernel_time_mask(
234207
s,
235208
preprocess,
236209
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
@@ -244,21 +217,15 @@ struct SplitKTwoStageInvoker
244217
ck_tile::make_tuple(args.N, 1), // Output Stride
245218
input_tensors,
246219
static_cast<CDataType*>(c_ptr)));
247-
248-
return ave_time;
249-
};
250-
251-
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
252-
if(args.k_batch == 1)
253-
{
254-
return Run(has_hot_loop_, tail_number_, MemoryOpSet{});
255-
}
256-
else
257-
{
258-
return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
259-
}
260220
};
261221

262-
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
222+
if(args.k_batch == 1)
223+
{
224+
return Run(MemoryOpSet{});
225+
}
226+
else
227+
{
228+
return Run(MemoryOpAtomicAdd{});
229+
}
263230
}
264231
};

example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp

Lines changed: 13 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,6 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
133133
GemmConfig::TileParitionerGroupNum,
134134
GemmConfig::TileParitionerM01>;
135135

136-
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
137-
GemmConfig::kPadN,
138-
GemmConfig::kPadK,
139-
ALayout,
140-
BLayout,
141-
ELayout,
142-
GemmConfig::NumWaveGroups>;
143-
144136
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
145137
GemmConfig::kPadN,
146138
GemmConfig::kPadK,
@@ -154,19 +146,6 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
154146
GemmConfig::NumWaveGroups,
155147
GemmConfig::Preshuffle>;
156148

157-
using GemmPipelineProblem =
158-
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
159-
160-
using BaseGemmPipeline = typename PipelineTypeTraits<
161-
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
162-
163-
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
164-
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
165-
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
166-
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
167-
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
168-
float ave_time{0};
169-
170149
// Create base GEMM arguments pointing to workspace instead of final output
171150
// The workspace will store partial results from each K-split
172151
ck_tile::GemmHostArgs base_args(args.a_ptr,
@@ -179,23 +158,18 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
179158
args.stride_A,
180159
args.stride_B,
181160
args.stride_E);
161+
constexpr auto scheduler = GemmConfig::Scheduler;
182162

183-
const auto Run = [&](const auto has_hot_loop_,
184-
const auto tail_number_,
185-
const auto memory_operation_) {
186-
constexpr bool has_hot_loop_v = has_hot_loop_.value;
187-
constexpr auto tail_number_v = tail_number_.value;
188-
constexpr auto scheduler = GemmConfig::Scheduler;
189-
constexpr auto memory_operation = memory_operation_.value;
163+
const auto Run = [&]() {
164+
// use SET operation since each K-split writes to separate memory
165+
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
190166

191167
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
192168
BDataType,
193169
AccDataType,
194170
GemmShape,
195171
GemmUniversalTraits,
196-
scheduler,
197-
has_hot_loop_v,
198-
tail_number_v>;
172+
scheduler>;
199173

200174
using GemmPipeline = typename PipelineTypeTraits<
201175
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
@@ -276,29 +250,20 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
276250
hipGetErrorString(hipMemsetAsync(
277251
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
278252
};
279-
return ave_time = ck_tile::launch_kernel_time_mask(
280-
s,
281-
run_flush_cache,
282-
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
283-
Kernel{}, grids, blocks, 0, kargs));
253+
return ck_tile::launch_kernel_time_mask(
254+
s,
255+
run_flush_cache,
256+
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
284257
}
285258
else
286259
{
287-
return ave_time = ck_tile::launch_kernel(s,
288-
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
289-
Kernel{}, grids, blocks, 0, kargs));
260+
return ck_tile::launch_kernel(
261+
s,
262+
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
290263
}
291264
};
292265

293-
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
294-
// For workspace mode, always use SET operation since each K-split writes to separate memory
295-
return Run(has_hot_loop_,
296-
tail_number_,
297-
ck_tile::integral_constant<ck_tile::memory_operation_enum,
298-
ck_tile::memory_operation_enum::set>{});
299-
};
300-
301-
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
266+
return Run();
302267
}
303268

304269
/**

example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp

Lines changed: 22 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,6 @@ struct WeightPreshuffleInvoker
3333
GemmConfig::TileParitionerGroupNum,
3434
GemmConfig::TileParitionerM01>;
3535

36-
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
37-
GemmConfig::kPadN,
38-
GemmConfig::kPadK,
39-
ALayout,
40-
BLayout,
41-
ELayout,
42-
GemmConfig::NumWaveGroups>;
43-
4436
using GemmUniversalTraits =
4537
ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
4638
GemmConfig::kPadN,
@@ -54,39 +46,20 @@ struct WeightPreshuffleInvoker
5446
Persistent,
5547
GemmConfig::NumWaveGroups,
5648
GemmConfig::Preshuffle>;
57-
using GemmPipelineProblem =
58-
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
59-
60-
using BaseGemmPipeline = typename PipelineTypeTraits<
61-
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
62-
63-
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
64-
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
65-
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
66-
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
67-
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
68-
float ave_time{0};
69-
70-
const auto Run = [&](const auto has_hot_loop_,
71-
const auto tail_number_,
72-
const auto memory_operation_) {
73-
constexpr bool has_hot_loop_v = has_hot_loop_.value;
74-
constexpr auto tail_number_v = tail_number_.value;
75-
constexpr auto scheduler = GemmConfig::Scheduler;
49+
constexpr auto scheduler = GemmConfig::Scheduler;
50+
51+
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
52+
BDataType,
53+
AccDataType,
54+
GemmShape,
55+
GemmUniversalTraits,
56+
scheduler>;
57+
58+
using GemmPipeline = typename PipelineTypeTraits<
59+
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
60+
const auto Run = [&](const auto memory_operation_) {
7661
constexpr auto memory_operation = memory_operation_.value;
7762

78-
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
79-
BDataType,
80-
AccDataType,
81-
GemmShape,
82-
GemmUniversalTraits,
83-
scheduler,
84-
has_hot_loop_v,
85-
tail_number_v>;
86-
87-
using GemmPipeline = typename PipelineTypeTraits<
88-
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
89-
9063
using GemmEpilogue = ck_tile::CShuffleEpilogue<
9164
ck_tile::CShuffleEpilogueProblem<ADataType,
9265
BDataType,
@@ -139,6 +112,7 @@ struct WeightPreshuffleInvoker
139112
<< "}" << ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}"
140113
<< std::endl;
141114
}
115+
float ave_time = 0.f;
142116
if(s.flush_cache_)
143117
{
144118
std::cout << "Flushing cache..." << std::endl;
@@ -183,21 +157,14 @@ struct WeightPreshuffleInvoker
183157
return ave_time;
184158
};
185159

186-
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
187-
if(args.k_batch == 1)
188-
{
189-
Run(has_hot_loop_,
190-
tail_number_,
191-
ck_tile::integral_constant<ck_tile::memory_operation_enum,
192-
ck_tile::memory_operation_enum::set>{});
193-
}
194-
else
195-
{
196-
throw std::runtime_error("split-k is not supported yet!");
197-
}
198-
};
199-
200-
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
201-
return ave_time;
160+
if(args.k_batch == 1)
161+
{
162+
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
163+
ck_tile::memory_operation_enum::set>{});
164+
}
165+
else
166+
{
167+
throw std::runtime_error("split-k is not supported yet!");
168+
}
202169
}
203170
};

example/ck_tile/03_gemm/run_gemm_example.inc

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,17 @@ void permute_tensor_b(Tensor& tensor)
6363
GemmConfig::TransposeC,
6464
GemmConfig::UseStructuredSparsity>;
6565

66-
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
67-
BDataType,
68-
AccDataType,
69-
GemmShape,
70-
GemmUniversalTraits,
71-
GemmConfig::Scheduler,
72-
true,
73-
ck_tile::TailNumber::Full>;
66+
using UniversalGemmProblem =
67+
ck_tile::UniversalGemmPipelineProblem<ADataType,
68+
BDataType,
69+
AccDataType,
70+
GemmShape,
71+
GemmUniversalTraits,
72+
GemmConfig::Scheduler,
73+
ck_tile::element_wise::PassThrough,
74+
ck_tile::element_wise::PassThrough,
75+
ADataType,
76+
true>;
7477

7578
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
7679
UniversalGemmProblem>;

0 commit comments

Comments
 (0)