Skip to content

Commit 5ba2dea

Browse files
Add rdnaWaves occupancy filter to greedy tuning phases 1 and 2 (#2311)
Greedy tuning for attention kernels on RDNA was missing the rdnaWaves filter that exhaustive tuning already has, causing a massive search space explosion (~8k configs/problem vs ~300-600 in exhaustive). This made greedy ~13x slower than exhaustive on RDNA targets.
1 parent b235766 commit 5ba2dea

1 file changed

Lines changed: 16 additions & 0 deletions

File tree

mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,8 @@ static void createGemmGemmTuningRangeGreedyPhase1(
269269
std::mt19937 rng(seed);
270270
int64_t waveSize =
271271
rock::lookupArchInfo(rock::getArchValue(gemmGemmOp)).waveSize;
272+
bool isWmma = archInfo.isWmma(gemmGemmOp);
273+
int64_t numEUPerCU = archInfo.numEUPerCU;
272274

273275
int64_t outputSwizzle{2}, wavesPerEU{0};
274276
for (uint32_t gemm0MPerBlock : params[0]) {
@@ -302,6 +304,13 @@ static void createGemmGemmTuningRangeGreedyPhase1(
302304
uint32_t splitKFactor =
303305
optimalSplitKFactors[rng() % optimalSplitKFactors.size()];
304306

307+
if (isWmma) {
308+
int64_t rdnaWaves = (gemm0MPerBlock / gemmMPerWave) *
309+
(gemm0NPerBlock / gemmNPerWave);
310+
if (rdnaWaves < numEUPerCU)
311+
continue;
312+
}
313+
305314
auto gemmGemmParams = GemmGemmParamsAttr::get(
306315
gemmGemmOp.getContext(), gemm0MPerBlock, gemm1MPerBlock,
307316
gemm0NPerBlock, gemmKPerBlock, gemmMPerWave, gemmNPerWave,
@@ -335,6 +344,7 @@ createGemmGemmTuningRangeGreedyPhase2(TuningParamSet *newSpace,
335344
bool isWmma = archInfo.isWmma(gemmGemmOp);
336345
int64_t waveSize =
337346
rock::lookupArchInfo(rock::getArchValue(gemmGemmOp)).waveSize;
347+
int64_t numEUPerCU = archInfo.numEUPerCU;
338348
int64_t outputSwizzle{2}, wavesPerEU{0};
339349
OpBuilder b(gemmGemmOp.getContext());
340350

@@ -356,6 +366,12 @@ createGemmGemmTuningRangeGreedyPhase2(TuningParamSet *newSpace,
356366
for (uint32_t gemmKPerBlock : validRangeGemmGemmParams[2]) {
357367
for (uint32_t gemmMPerWave : mPerWaveRange) {
358368
for (uint32_t gemmNPerWave : nPerWaveRange) {
369+
if (isWmma) {
370+
int64_t rdnaWaves =
371+
(gemm0MPerBlock / gemmMPerWave) * (gemm0NPerBlock / gemmNPerWave);
372+
if (rdnaWaves < numEUPerCU)
373+
continue;
374+
}
359375
for (uint32_t gemmMnPerXdl : validRangeGemmGemmParams[3]) {
360376
for (uint32_t gemmKPack : validRangeGemmGemmParams[4]) {
361377
for (int64_t splitKFactor : optimalSplitKFactors) {

0 commit comments

Comments
 (0)