Skip to content

Commit 9b206aa

Browse files
Merge branch 'develop' into dpp-refactor-blockwise-reduce
2 parents 7c20cba + 5ba2dea commit 9b206aa

3 files changed

Lines changed: 16 additions & 21 deletions

File tree

mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -423,21 +423,6 @@ def MIGraphX_PoolingOp :
423423
}];
424424
}
425425

426-
def MIGraphX_FlattenOp :
427-
MIGraphX_Op<"flatten">,
428-
Arguments<(ins AnyMIXRShaped:$input,
429-
I64Attr:$axis
430-
)>,
431-
Results<(outs AnyMIXRShaped:$output)> {
432-
let summary = "Flatten tensor";
433-
let description = [{
434-
The `migraphx.flatten` op.
435-
}];
436-
let assemblyFormat = [{
437-
$input attr-dict `:` type($input) `->` type($output)
438-
}];
439-
}
440-
441426
def MIGraphX_TransposeOp :
442427
MIGraphX_Op<"transpose">,
443428
Arguments<(ins AnyMIXRShaped:$input,

mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,8 @@ static void createGemmGemmTuningRangeGreedyPhase1(
269269
std::mt19937 rng(seed);
270270
int64_t waveSize =
271271
rock::lookupArchInfo(rock::getArchValue(gemmGemmOp)).waveSize;
272+
bool isWmma = archInfo.isWmma(gemmGemmOp);
273+
int64_t numEUPerCU = archInfo.numEUPerCU;
272274

273275
int64_t outputSwizzle{2}, wavesPerEU{0};
274276
for (uint32_t gemm0MPerBlock : params[0]) {
@@ -302,6 +304,13 @@ static void createGemmGemmTuningRangeGreedyPhase1(
302304
uint32_t splitKFactor =
303305
optimalSplitKFactors[rng() % optimalSplitKFactors.size()];
304306

307+
if (isWmma) {
308+
int64_t rdnaWaves = (gemm0MPerBlock / gemmMPerWave) *
309+
(gemm0NPerBlock / gemmNPerWave);
310+
if (rdnaWaves < numEUPerCU)
311+
continue;
312+
}
313+
305314
auto gemmGemmParams = GemmGemmParamsAttr::get(
306315
gemmGemmOp.getContext(), gemm0MPerBlock, gemm1MPerBlock,
307316
gemm0NPerBlock, gemmKPerBlock, gemmMPerWave, gemmNPerWave,
@@ -335,6 +344,7 @@ createGemmGemmTuningRangeGreedyPhase2(TuningParamSet *newSpace,
335344
bool isWmma = archInfo.isWmma(gemmGemmOp);
336345
int64_t waveSize =
337346
rock::lookupArchInfo(rock::getArchValue(gemmGemmOp)).waveSize;
347+
int64_t numEUPerCU = archInfo.numEUPerCU;
338348
int64_t outputSwizzle{2}, wavesPerEU{0};
339349
OpBuilder b(gemmGemmOp.getContext());
340350

@@ -356,6 +366,12 @@ createGemmGemmTuningRangeGreedyPhase2(TuningParamSet *newSpace,
356366
for (uint32_t gemmKPerBlock : validRangeGemmGemmParams[2]) {
357367
for (uint32_t gemmMPerWave : mPerWaveRange) {
358368
for (uint32_t gemmNPerWave : nPerWaveRange) {
369+
if (isWmma) {
370+
int64_t rdnaWaves =
371+
(gemm0MPerBlock / gemmMPerWave) * (gemm0NPerBlock / gemmNPerWave);
372+
if (rdnaWaves < numEUPerCU)
373+
continue;
374+
}
359375
for (uint32_t gemmMnPerXdl : validRangeGemmGemmParams[3]) {
360376
for (uint32_t gemmKPack : validRangeGemmGemmParams[4]) {
361377
for (int64_t splitKFactor : optimalSplitKFactors) {

mlir/test/Conversion/MIGraphXToLinalg/migraphx-to-linalg-not-implemented.mlir

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,6 @@ func.func @func_pooling(%arg0: !migraphx.shaped<1x1xf32, 1x1>, %arg1: !migraphx.
7373
func.return
7474
}
7575

76-
func.func @func_flatten(%arg0: !migraphx.shaped<1x1xf32, 1x1>, %arg1: !migraphx.shaped<1x1xf32, 1x1>) {
77-
// expected-error @+1{{failed to legalize operation 'migraphx.flatten'}}
78-
migraphx.flatten %arg0 {axis = 0 : i64}: <1x1xf32, 1x1> -> <1xf32, 1>
79-
func.return
80-
}
81-
8276
func.func @func_slice(%arg0: !migraphx.shaped<1x1xf32, 1x1>, %arg1: !migraphx.shaped<1x1xf32, 1x1>) {
8377
// expected-error @+1{{failed to legalize operation 'migraphx.slice'}}
8478
migraphx.slice %arg0 {axes = [0], ends = [1], starts = [0]}: <1x1xf32, 1x1> -> <1x1xf32, 1x1>

0 commit comments

Comments
 (0)