@@ -712,9 +712,7 @@ struct BlockwiseReduceRewritePattern
712712 return dimProduct;
713713 }
714714
715- // Extract per-wave thread counts from the tid slice view by looking for
716- // "m_tid" and "n_tid" named dimensions in the Merge transform that
717- // decomposes "tid". Works for both WMMA and MFMA architectures.
715+ // Extract m_tid and n_tid counts from the tid slice view's Merge transform.
718716 static std::pair<int64_t , int64_t >
719717 getPerWaveThreadCounts (ArrayAttr tidSliceView) {
720718 if (tidSliceView.empty ())
@@ -738,9 +736,8 @@ struct BlockwiseReduceRewritePattern
738736 return {0 , 0 };
739737 }
740738
741- // Register-only cross-half-wave reduction using v_permlanex16_var_b32.
742- // Each lane exchanges its value with the corresponding lane in the other
743- // half-wave (lane i <-> lane i+16) and reduces. Requires wave32 (RDNA).
739+ // Cross-half-wave reduction via v_permlanex16_var_b32 (wave32 only).
740+ // Lane i exchanges with lane i+16 and reduces.
744741 void permlaneX16VarReduce (ConversionPatternRewriter &rewriter, Location loc,
745742 Value partialReductionBuffer, Value tid,
746743 int64_t nrDimSize, int64_t waveSize,
@@ -1284,14 +1281,29 @@ struct BlockwiseReduceRewritePattern
12841281 StringAttr arch = rock::getArchValue (op);
12851282 int64_t waveSize = rock::lookupArchInfo (arch).waveSize ;
12861283
1287- // Permlane-reduce: register-only cross-half-wave reduction using
1288- // v_permlanex16_var_b32 (GFX12+). Avoids the initial LDS store+barrier
1289- // by performing reduction directly in registers before writing to LDS.
12901284 int64_t partialR = partialRegTensorShape[rDim];
1285+
1286+ // PR2-Permlane: register-only cross-half-wave reduction for partialR=2
1287+ // on wave32 when nTidPerWave=16 (lanes 0-15 <-> 16-31).
1288+ auto [mTidPerWave , nTidPerWave] =
1289+ getPerWaveThreadCounts (op.getTidSubTileSliceView ());
1290+ bool has2DThreadLayout = (mTidPerWave > 0 && nTidPerWave > 0 );
12911291 bool canUsePermlaneReduce =
1292- (waveSize == 32 && partialR == 2 );
1292+ (has2DThreadLayout && waveSize == 32 &&
1293+ partialR == 2 && nTidPerWave == 16 );
1294+
1295+ // SerialPermlane: XOR butterfly reduction via permlanex16_var for
1296+ // blockSize <= nrDimProd on wave32. Requires power-of-2 rDimSize == mTidPerWave.
1297+ bool canUseSerialPermlane = false ;
1298+ if (has2DThreadLayout && waveSize == 32 &&
1299+ blockSize <= nonReductionDimSizeProduct) {
1300+ int64_t rDimSize = partialR;
1301+ canUseSerialPermlane = (rDimSize >= 2 ) &&
1302+ llvm::isPowerOf2_64 (rDimSize) &&
1303+ (rDimSize == mTidPerWave );
1304+ }
12931305
1294- if (!canUsePermlaneReduce) {
1306+ if (!canUsePermlaneReduce && !canUseSerialPermlane ) {
12951307 storePartialReductionstoLDS (rewriter, loc, partialReductionBuffer,
12961308 workspaceLDSBuffer, inputBlockSubTile2dView,
12971309 inputThreadSubTile2dView, tidSubTileSliceView,
@@ -1301,7 +1313,9 @@ struct BlockwiseReduceRewritePattern
13011313 // Following RAII scope will create reduction loops.
13021314 {
13031315 if (blockSize <= nonReductionDimSizeProduct) {
1304- if (canUsePermlaneReduce) {
1316+ if (canUseSerialPermlane) {
1317+ // Butterfly reduction in registers via permlanex16_var.
1318+ // Uses LDS only for final broadcast.
13051319 int64_t nrDimSize = inputThreadSubTile2dShape[nrDim];
13061320 permlaneX16VarReduce (rewriter, loc, partialReductionBuffer, tid,
13071321 nrDimSize, waveSize, elemType, op);
@@ -1420,6 +1434,20 @@ struct BlockwiseReduceRewritePattern
14201434 /* withBarrier=*/ true );
14211435 } // end NR-Large-Tree else
14221436 } else {
1437+ if (canUsePermlaneReduce) {
1438+ // Register-only reduction for partialR=2 via permlanex16_var.
1439+ int64_t nrDimSize = inputThreadSubTile2dShape[nrDim];
1440+ permlaneX16VarReduce (rewriter, loc, partialReductionBuffer, tid,
1441+ nrDimSize, waveSize, elemType, op);
1442+ storePartialReductionstoLDS (
1443+ rewriter, loc, partialReductionBuffer, workspaceLDSBuffer,
1444+ inputBlockSubTile2dView, inputThreadSubTile2dView,
1445+ tidSubTileSliceView, toFlatLDSView);
1446+ readReducedResultsFromLDS (rewriter, loc, op, workspaceLDSBuffer,
1447+ outputReg, inputViewArrayAttr, axis,
1448+ partialRegTensorShape[rDim], tid,
1449+ /* withBarrier=*/ true );
1450+ } else {
14231451 // This means there are more threads than elements to be reduced.
14241452 ArrayAttr threadToTensorViewTrs =
14251453 createThreadViewforNRSmallerThanThreads (loc, partialRegTensorShape,
@@ -1437,27 +1465,15 @@ struct BlockwiseReduceRewritePattern
14371465 getMaxVectorization (threadToLDSViewed, rIterDim);
14381466 int64_t rIterVectorLen = rIterVectorRes.max ;
14391467
1440- // Use DPP-based subgroup reduction when all conditions are met:
1441- // 1. Power-of-2 reduction threads (required by SubgroupReduceOp)
1442- // 2. More than 1 reduction thread (at least 2 for cross-lane work)
1443- // 3. partial_r > 2 (DPP overhead not justified for partial_r=2)
1444- // 4. Reduction threads fit within a single wave
1445- // 5. Exact thread packing: blockSize == clusterSize *
1446- // nonReductionDimSizeProduct. This guarantees every thread maps to
1447- // a valid (nrtid, rtid) pair, so LDS coordinates derived from them
1448- // are in-bounds.
1449- // Otherwise, fall back to LDS-based tree reduction.
1468+ // DPP subgroup reduction: power-of-2 threads, partialR>2, fits in wave.
14501469 int64_t maxActiveReductionThreads = threadViewShape[rTidDim];
14511470 int64_t clusterSize = llvm::PowerOf2Ceil (maxActiveReductionThreads);
1452- int64_t partialR = partialRegTensorShape[rDim];
14531471 bool canUseDPP = llvm::isPowerOf2_64 (maxActiveReductionThreads) &&
14541472 (maxActiveReductionThreads > 1 ) && (partialR > 2 ) &&
14551473 (maxActiveReductionThreads <= waveSize) &&
14561474 (blockSize == maxActiveReductionThreads *
14571475 nonReductionDimSizeProduct);
1458- // DPP path: contiguous threads reduce together (rtid = tid % cluster).
1459- // Tree path: scattered layout (rtid = tid /
1460- // nonReductionDimSizeProduct).
1476+ // DPP: rtid = tid % cluster. Tree: rtid = tid / nrDimProd.
14611477 Value rtid, nrtid;
14621478 if (canUseDPP) {
14631479 assert (llvm::isPowerOf2_64 (clusterSize) &&
@@ -1528,8 +1544,6 @@ struct BlockwiseReduceRewritePattern
15281544 }
15291545 }
15301546
1531- // Cross-lane reduction: DPP path uses SubgroupReduceOp with
1532- // cluster_size, tree path uses iterative LDS load/reduce/store.
15331547 if (canUseDPP) {
15341548 SmallVector<Value, 4 > inits{nrtid, rtid, zeroConstantOp};
15351549 SmallVector<int64_t > bounds{1 , 1 , 1 };
@@ -1585,7 +1599,6 @@ struct BlockwiseReduceRewritePattern
15851599 LDSBarrierOp::create (rewriter, loc);
15861600
15871601 } else {
1588- // Tree reduction path: needs LDS for inter-thread communication
15891602 int64_t ceilPowerOf2 =
15901603 llvm::PowerOf2Ceil (maxActiveReductionThreads) / 2 ;
15911604 if (hasThreadwiseReduction) {
@@ -1666,6 +1679,7 @@ struct BlockwiseReduceRewritePattern
16661679 outputReg, inputViewArrayAttr, axis,
16671680 partialRegTensorShape[rDim], tid,
16681681 /* withBarrier=*/ false );
1682+ }
16691683 }
16701684 rewriter.eraseOp (op);
16711685 return success ();
0 commit comments