Skip to content

Commit 5348b57

Browse files
[rocm-libraries] ROCm/rocm-libraries#5863 (commit 31d9247)
[CK_TILE] Separate PermuteN epilogue from CShuffle epilogue into standalone file (#5863) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation The PermuteN epilogue was previously embedded within cshuffle_epilogue.hpp, despite having fundamentally different behaviour. Coupling these two independent strategies in one file introduced unnecessary complexity, SFINAE guards, and a dual operator() overload selected at compile time via TiledMMAPermuteN_ template parameter. This PR separates PermuteN into its own standalone file(pertmuten_epilogue.hpp), simplifying both implementations and making the codebase easier to maintain and extend independently. ## Technical Details **New file: permuten_epilogue.hpp:** contains PermuteNEpilogueProblem and PermuteNEpilogue, extracted from the permuteN code path in cshuffle_epilogue.hpp. **Cleanup of cshuffle_epilogue.hpp:** - Removed the TiledMMAPermuteN_ template parameter from [CShuffleEpilogueProblem] - Removed the SFINAE-guarded permuteN operator() overload - Removed the EnablePermuateN_ SFINAE alias - CShuffle now only contains CShuffle logic; EightWave support (independent feature) is retained **Consumer migration :** All consumer files now use compile-time epilogue selection via [std::conditional_t] `using GemmEpilogue = std::conditional_t< TiledMMAPermuteN, PermuteNEpilogue<PermuteNEpilogueProblem<...>>, CShuffleEpilogue<CShuffleEpilogueProblem<...>>>;` **Files modified:** - flatmm_basic.cpp, moe_flatmm.cpp, a16w4_moe_flatmm.cpp, mixed_prec_flatmm.cpp, mx_flatmm_instance.hpp — flatmm examples - run_gemm_quant_example.inc — block-scale GEMM example - gemm_weight_preshuffle_invoker.hpp — weight preshuffle invoker - test_gemm_quant_fixtures.hpp, test_gemm_persistent_async_input.cpp, test_gemm_pipeline_util.hpp — test utilities - universal_gemm_invoker.hpp — universal GEMM invoker - epilogue.hpp — add header updated to include permuten_epilogue.hpp ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1 parent 5d2fce8 commit 5348b57

14 files changed

Lines changed: 728 additions & 333 deletions

File tree

example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -58,27 +58,45 @@ struct WeightPreshuffleInvoker
5858
using GemmPipeline = typename PipelineTypeTraits<
5959
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
6060

61-
using GemmEpilogue = ck_tile::CShuffleEpilogue<
62-
ck_tile::CShuffleEpilogueProblem<ADataType,
63-
BDataType,
64-
DsDataType,
65-
AccDataType,
66-
CDataType,
67-
DsLayout,
68-
ELayout,
69-
CDEElementWise,
70-
TilePartitioner::MPerBlock,
71-
TilePartitioner::NPerBlock,
72-
GemmConfig::M_Warp,
73-
GemmConfig::N_Warp,
74-
GemmConfig::M_Warp_Tile,
75-
GemmConfig::N_Warp_Tile,
76-
GemmConfig::K_Warp_Tile,
77-
UniversalGemmProblem::TransposeC,
78-
GemmConfig::NumWaveGroups,
79-
false,
80-
1,
81-
GemmConfig::TiledMMAPermuteN>>;
61+
using GemmEpilogue = std::conditional_t<
62+
GemmConfig::TiledMMAPermuteN,
63+
ck_tile::PermuteNEpilogue<
64+
ck_tile::PermuteNEpilogueProblem<ADataType,
65+
BDataType,
66+
DsDataType,
67+
AccDataType,
68+
CDataType,
69+
DsLayout,
70+
ELayout,
71+
CDEElementWise,
72+
TilePartitioner::MPerBlock,
73+
TilePartitioner::NPerBlock,
74+
GemmConfig::M_Warp,
75+
GemmConfig::N_Warp,
76+
GemmConfig::M_Warp_Tile,
77+
GemmConfig::N_Warp_Tile,
78+
GemmConfig::K_Warp_Tile,
79+
UniversalGemmProblem::TransposeC,
80+
false,
81+
1>>,
82+
ck_tile::CShuffleEpilogue<
83+
ck_tile::CShuffleEpilogueProblem<ADataType,
84+
BDataType,
85+
DsDataType,
86+
AccDataType,
87+
CDataType,
88+
DsLayout,
89+
ELayout,
90+
CDEElementWise,
91+
TilePartitioner::MPerBlock,
92+
TilePartitioner::NPerBlock,
93+
GemmConfig::M_Warp,
94+
GemmConfig::N_Warp,
95+
GemmConfig::M_Warp_Tile,
96+
GemmConfig::N_Warp_Tile,
97+
GemmConfig::K_Warp_Tile,
98+
UniversalGemmProblem::TransposeC,
99+
GemmConfig::NumWaveGroups>>>;
82100
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
83101
auto kargs = Kernel::MakeKernelArgs(args);
84102

example/ck_tile/03_gemm/universal_gemm_invoker.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ struct UniversalInvoker
8484
GemmConfig::NumWaveGroups,
8585
false, /*FixedVectorSize_*/
8686
1, /*VectorSizeC_*/
87-
false, /*TiledMMAPermuteN_*/
8887
1, /*BlockedXDLN_PerWarp_*/
8988
GemmConfig::DoubleSmemBuffer /*DoubleSmemBuffer*/>>;
9089

@@ -228,7 +227,6 @@ struct UniversalInvoker
228227
GemmConfig::NumWaveGroups,
229228
false, /*FixedVectorSize_*/
230229
1, /*VectorSizeC_*/
231-
false, /*TiledMMAPermuteN_*/
232230
1, /*BlockedXDLN_PerWarp_*/
233231
GemmConfig::DoubleSmemBuffer>>;
234232

example/ck_tile/18_flatmm/flatmm_basic.cpp

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -188,27 +188,45 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
188188
using CodegenFlatmmPipeline =
189189
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
190190

191-
using GemmEpilogue = ck_tile::CShuffleEpilogue<
192-
ck_tile::CShuffleEpilogueProblem<ADataType,
193-
BDataType,
194-
DsDatatype,
195-
AccDataType,
196-
CDataType,
197-
DsLayout,
198-
ELayout,
199-
CDEElementWise,
200-
TilePartitioner::MPerBlock,
201-
TilePartitioner::NPerBlock,
202-
FlatmmConfig::M_Warp,
203-
FlatmmConfig::N_Warp,
204-
FlatmmConfig::M_Warp_Tile,
205-
FlatmmConfig::N_Warp_Tile,
206-
FlatmmConfig::K_Warp_Tile,
207-
CodegenPipelineProblem::TransposeC,
208-
FlatmmConfig::NumWaveGroups,
209-
false,
210-
1,
211-
FlatmmConfig::TiledMMAPermuteN>>;
191+
using GemmEpilogue = std::conditional_t<
192+
FlatmmConfig::TiledMMAPermuteN,
193+
ck_tile::PermuteNEpilogue<
194+
ck_tile::PermuteNEpilogueProblem<ADataType,
195+
BDataType,
196+
DsDatatype,
197+
AccDataType,
198+
CDataType,
199+
DsLayout,
200+
ELayout,
201+
CDEElementWise,
202+
TilePartitioner::MPerBlock,
203+
TilePartitioner::NPerBlock,
204+
FlatmmConfig::M_Warp,
205+
FlatmmConfig::N_Warp,
206+
FlatmmConfig::M_Warp_Tile,
207+
FlatmmConfig::N_Warp_Tile,
208+
FlatmmConfig::K_Warp_Tile,
209+
CodegenPipelineProblem::TransposeC,
210+
false,
211+
1>>,
212+
ck_tile::CShuffleEpilogue<
213+
ck_tile::CShuffleEpilogueProblem<ADataType,
214+
BDataType,
215+
DsDatatype,
216+
AccDataType,
217+
CDataType,
218+
DsLayout,
219+
ELayout,
220+
CDEElementWise,
221+
TilePartitioner::MPerBlock,
222+
TilePartitioner::NPerBlock,
223+
FlatmmConfig::M_Warp,
224+
FlatmmConfig::N_Warp,
225+
FlatmmConfig::M_Warp_Tile,
226+
FlatmmConfig::N_Warp_Tile,
227+
FlatmmConfig::K_Warp_Tile,
228+
CodegenPipelineProblem::TransposeC,
229+
FlatmmConfig::NumWaveGroups>>>;
212230

213231
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
214232
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
@@ -230,6 +248,7 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
230248
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
231249
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
232250
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
251+
<< "epilogue: " << GemmEpilogue::GetName() << "\n"
233252
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
234253
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
235254
<< std::endl;

example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -139,28 +139,48 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
139139

140140
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
141141

142-
using GemmEpilogue = ck_tile::CShuffleEpilogue<
143-
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
144-
ComputeDataType,
145-
DsDatatype,
146-
AccDataType,
147-
CDataType,
148-
DsLayout,
149-
ELayout,
150-
CDEElementWise,
151-
TilePartitioner::MPerBlock,
152-
TilePartitioner::NPerBlock,
153-
FlatmmConfig::M_Warp,
154-
FlatmmConfig::N_Warp,
155-
FlatmmConfig::M_Warp_Tile,
156-
FlatmmConfig::N_Warp_Tile,
157-
FlatmmConfig::K_Warp_Tile,
158-
CodegenPipelineProblem::TransposeC,
159-
FlatmmConfig::NumWaveGroups,
160-
false,
161-
1,
162-
FlatmmConfig::TiledMMAPermuteN,
163-
BlockedXDLN_PerWarp>>;
142+
using GemmEpilogue = std::conditional_t<
143+
FlatmmConfig::TiledMMAPermuteN,
144+
ck_tile::PermuteNEpilogue<
145+
ck_tile::PermuteNEpilogueProblem<ComputeDataType,
146+
ComputeDataType,
147+
DsDatatype,
148+
AccDataType,
149+
CDataType,
150+
DsLayout,
151+
ELayout,
152+
CDEElementWise,
153+
TilePartitioner::MPerBlock,
154+
TilePartitioner::NPerBlock,
155+
FlatmmConfig::M_Warp,
156+
FlatmmConfig::N_Warp,
157+
FlatmmConfig::M_Warp_Tile,
158+
FlatmmConfig::N_Warp_Tile,
159+
FlatmmConfig::K_Warp_Tile,
160+
CodegenPipelineProblem::TransposeC,
161+
false,
162+
1>>,
163+
ck_tile::CShuffleEpilogue<
164+
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
165+
ComputeDataType,
166+
DsDatatype,
167+
AccDataType,
168+
CDataType,
169+
DsLayout,
170+
ELayout,
171+
CDEElementWise,
172+
TilePartitioner::MPerBlock,
173+
TilePartitioner::NPerBlock,
174+
FlatmmConfig::M_Warp,
175+
FlatmmConfig::N_Warp,
176+
FlatmmConfig::M_Warp_Tile,
177+
FlatmmConfig::N_Warp_Tile,
178+
FlatmmConfig::K_Warp_Tile,
179+
CodegenPipelineProblem::TransposeC,
180+
FlatmmConfig::NumWaveGroups,
181+
false,
182+
1,
183+
BlockedXDLN_PerWarp>>>;
164184

165185
using CodegenFlatmmPipeline = std::conditional_t<
166186
MXFP4_Pipeline,

example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -108,28 +108,48 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
108108
using CodegenFlatmmPipeline =
109109
ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
110110

111-
using GemmEpilogue = ck_tile::CShuffleEpilogue<
112-
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
113-
ComputeDataType,
114-
DsDatatype,
115-
AccDataType,
116-
CDataType,
117-
DsLayout,
118-
ELayout,
119-
CDEElementWise,
120-
TilePartitioner::MPerBlock,
121-
TilePartitioner::NPerBlock,
122-
FlatmmConfig::M_Warp,
123-
FlatmmConfig::N_Warp,
124-
FlatmmConfig::M_Warp_Tile,
125-
FlatmmConfig::N_Warp_Tile,
126-
FlatmmConfig::K_Warp_Tile,
127-
CodegenPipelineProblem::TransposeC,
128-
FlatmmConfig::NumWaveGroups,
129-
false, // FixedVectorSize
130-
1, // VectorSizeC
131-
FlatmmConfig::TiledMMAPermuteN,
132-
BlockedXDLN_PerWarp>>;
111+
using GemmEpilogue = std::conditional_t<
112+
FlatmmConfig::TiledMMAPermuteN,
113+
ck_tile::PermuteNEpilogue<
114+
ck_tile::PermuteNEpilogueProblem<ComputeDataType,
115+
ComputeDataType,
116+
DsDatatype,
117+
AccDataType,
118+
CDataType,
119+
DsLayout,
120+
ELayout,
121+
CDEElementWise,
122+
TilePartitioner::MPerBlock,
123+
TilePartitioner::NPerBlock,
124+
FlatmmConfig::M_Warp,
125+
FlatmmConfig::N_Warp,
126+
FlatmmConfig::M_Warp_Tile,
127+
FlatmmConfig::N_Warp_Tile,
128+
FlatmmConfig::K_Warp_Tile,
129+
CodegenPipelineProblem::TransposeC,
130+
false, // FixedVectorSize
131+
1>>, // VectorSizeC
132+
ck_tile::CShuffleEpilogue<
133+
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
134+
ComputeDataType,
135+
DsDatatype,
136+
AccDataType,
137+
CDataType,
138+
DsLayout,
139+
ELayout,
140+
CDEElementWise,
141+
TilePartitioner::MPerBlock,
142+
TilePartitioner::NPerBlock,
143+
FlatmmConfig::M_Warp,
144+
FlatmmConfig::N_Warp,
145+
FlatmmConfig::M_Warp_Tile,
146+
FlatmmConfig::N_Warp_Tile,
147+
FlatmmConfig::K_Warp_Tile,
148+
CodegenPipelineProblem::TransposeC,
149+
FlatmmConfig::NumWaveGroups,
150+
false, // FixedVectorSize
151+
1, // VectorSizeC
152+
BlockedXDLN_PerWarp>>>;
133153

134154
using Kernel =
135155
ck_tile::F16xMXF4FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;

0 commit comments

Comments
 (0)