@@ -269,6 +269,8 @@ static void createGemmGemmTuningRangeGreedyPhase1(
269269 std::mt19937 rng (seed);
270270 int64_t waveSize =
271271 rock::lookupArchInfo (rock::getArchValue (gemmGemmOp)).waveSize ;
272+ bool isWmma = archInfo.isWmma (gemmGemmOp);
273+ int64_t numEUPerCU = archInfo.numEUPerCU ;
272274
273275 int64_t outputSwizzle{2 }, wavesPerEU{0 };
274276 for (uint32_t gemm0MPerBlock : params[0 ]) {
@@ -302,6 +304,13 @@ static void createGemmGemmTuningRangeGreedyPhase1(
302304 uint32_t splitKFactor =
303305 optimalSplitKFactors[rng () % optimalSplitKFactors.size ()];
304306
307+ if (isWmma) {
308+ int64_t rdnaWaves = (gemm0MPerBlock / gemmMPerWave) *
309+ (gemm0NPerBlock / gemmNPerWave);
310+ if (rdnaWaves < numEUPerCU)
311+ continue ;
312+ }
313+
305314 auto gemmGemmParams = GemmGemmParamsAttr::get (
306315 gemmGemmOp.getContext (), gemm0MPerBlock, gemm1MPerBlock,
307316 gemm0NPerBlock, gemmKPerBlock, gemmMPerWave, gemmNPerWave,
@@ -335,6 +344,7 @@ createGemmGemmTuningRangeGreedyPhase2(TuningParamSet *newSpace,
335344 bool isWmma = archInfo.isWmma (gemmGemmOp);
336345 int64_t waveSize =
337346 rock::lookupArchInfo (rock::getArchValue (gemmGemmOp)).waveSize ;
347+ int64_t numEUPerCU = archInfo.numEUPerCU ;
338348 int64_t outputSwizzle{2 }, wavesPerEU{0 };
339349 OpBuilder b (gemmGemmOp.getContext ());
340350
@@ -356,6 +366,12 @@ createGemmGemmTuningRangeGreedyPhase2(TuningParamSet *newSpace,
356366 for (uint32_t gemmKPerBlock : validRangeGemmGemmParams[2 ]) {
357367 for (uint32_t gemmMPerWave : mPerWaveRange ) {
358368 for (uint32_t gemmNPerWave : nPerWaveRange) {
369+ if (isWmma) {
370+ int64_t rdnaWaves =
371+ (gemm0MPerBlock / gemmMPerWave) * (gemm0NPerBlock / gemmNPerWave);
372+ if (rdnaWaves < numEUPerCU)
373+ continue ;
374+ }
359375 for (uint32_t gemmMnPerXdl : validRangeGemmGemmParams[3 ]) {
360376 for (uint32_t gemmKPack : validRangeGemmGemmParams[4 ]) {
361377 for (int64_t splitKFactor : optimalSplitKFactors) {
0 commit comments