Skip to content

Commit ba1a9dd

Browse files
Split permlane reduction into SerialPermlane and PR2-Permlane paths for Navi4x
Restructure the permlanex16_var reduction logic into two distinct paths gated by 2D thread layout awareness (mTidPerWave/nTidPerWave): - SerialPermlane (blockSize <= nrDimProd): XOR butterfly reduction in registers for power-of-2 rDimSize matching mTidPerWave. Uses LDS only for final broadcast. - PR2-Permlane (blockSize > nrDimProd): register-only cross-half-wave reduction for partialR=2 when nTidPerWave=16 (lanes 0-15 <-> 16-31), avoiding the initial LDS store + barrier. Both paths now require has2DThreadLayout and wave32. The PR2-Permlane path is moved into the blockSize > nrDimProd branch alongside DPP and LDS-Tree fallbacks. Clean up comments for brevity.
1 parent 7947246 commit ba1a9dd

1 file changed

Lines changed: 43 additions & 29 deletions

File tree

mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -712,9 +712,7 @@ struct BlockwiseReduceRewritePattern
712712
return dimProduct;
713713
}
714714

715-
// Extract per-wave thread counts from the tid slice view by looking for
716-
// "m_tid" and "n_tid" named dimensions in the Merge transform that
717-
// decomposes "tid". Works for both WMMA and MFMA architectures.
715+
// Extract m_tid and n_tid counts from the tid slice view's Merge transform.
718716
static std::pair<int64_t, int64_t>
719717
getPerWaveThreadCounts(ArrayAttr tidSliceView) {
720718
if (tidSliceView.empty())
@@ -738,9 +736,8 @@ struct BlockwiseReduceRewritePattern
738736
return {0, 0};
739737
}
740738

741-
// Register-only cross-half-wave reduction using v_permlanex16_var_b32.
742-
// Each lane exchanges its value with the corresponding lane in the other
743-
// half-wave (lane i <-> lane i+16) and reduces. Requires wave32 (RDNA).
739+
// Cross-half-wave reduction via v_permlanex16_var_b32 (wave32 only).
740+
// Lane i exchanges with lane i+16 and reduces.
744741
void permlaneX16VarReduce(ConversionPatternRewriter &rewriter, Location loc,
745742
Value partialReductionBuffer, Value tid,
746743
int64_t nrDimSize, int64_t waveSize,
@@ -1284,14 +1281,29 @@ struct BlockwiseReduceRewritePattern
12841281
StringAttr arch = rock::getArchValue(op);
12851282
int64_t waveSize = rock::lookupArchInfo(arch).waveSize;
12861283

1287-
// Permlane-reduce: register-only cross-half-wave reduction using
1288-
// v_permlanex16_var_b32 (GFX12+). Avoids the initial LDS store+barrier
1289-
// by performing reduction directly in registers before writing to LDS.
12901284
int64_t partialR = partialRegTensorShape[rDim];
1285+
1286+
// PR2-Permlane: register-only cross-half-wave reduction for partialR=2
1287+
// on wave32 when nTidPerWave=16 (lanes 0-15 <-> 16-31).
1288+
auto [mTidPerWave, nTidPerWave] =
1289+
getPerWaveThreadCounts(op.getTidSubTileSliceView());
1290+
bool has2DThreadLayout = (mTidPerWave > 0 && nTidPerWave > 0);
12911291
bool canUsePermlaneReduce =
1292-
(waveSize == 32 && partialR == 2);
1292+
(has2DThreadLayout && waveSize == 32 &&
1293+
partialR == 2 && nTidPerWave == 16);
1294+
1295+
// SerialPermlane: XOR butterfly reduction via permlanex16_var for
1296+
// blockSize <= nrDimProd on wave32. Requires power-of-2 rDimSize == mTidPerWave.
1297+
bool canUseSerialPermlane = false;
1298+
if (has2DThreadLayout && waveSize == 32 &&
1299+
blockSize <= nonReductionDimSizeProduct) {
1300+
int64_t rDimSize = partialR;
1301+
canUseSerialPermlane = (rDimSize >= 2) &&
1302+
llvm::isPowerOf2_64(rDimSize) &&
1303+
(rDimSize == mTidPerWave);
1304+
}
12931305

1294-
if (!canUsePermlaneReduce) {
1306+
if (!canUsePermlaneReduce && !canUseSerialPermlane) {
12951307
storePartialReductionstoLDS(rewriter, loc, partialReductionBuffer,
12961308
workspaceLDSBuffer, inputBlockSubTile2dView,
12971309
inputThreadSubTile2dView, tidSubTileSliceView,
@@ -1301,7 +1313,9 @@ struct BlockwiseReduceRewritePattern
13011313
// Following RAII scope will create reduction loops.
13021314
{
13031315
if (blockSize <= nonReductionDimSizeProduct) {
1304-
if (canUsePermlaneReduce) {
1316+
if (canUseSerialPermlane) {
1317+
// Butterfly reduction in registers via permlanex16_var.
1318+
// Uses LDS only for final broadcast.
13051319
int64_t nrDimSize = inputThreadSubTile2dShape[nrDim];
13061320
permlaneX16VarReduce(rewriter, loc, partialReductionBuffer, tid,
13071321
nrDimSize, waveSize, elemType, op);
@@ -1420,6 +1434,20 @@ struct BlockwiseReduceRewritePattern
14201434
/*withBarrier=*/true);
14211435
} // end NR-Large-Tree else
14221436
} else {
1437+
if (canUsePermlaneReduce) {
1438+
// Register-only reduction for partialR=2 via permlanex16_var.
1439+
int64_t nrDimSize = inputThreadSubTile2dShape[nrDim];
1440+
permlaneX16VarReduce(rewriter, loc, partialReductionBuffer, tid,
1441+
nrDimSize, waveSize, elemType, op);
1442+
storePartialReductionstoLDS(
1443+
rewriter, loc, partialReductionBuffer, workspaceLDSBuffer,
1444+
inputBlockSubTile2dView, inputThreadSubTile2dView,
1445+
tidSubTileSliceView, toFlatLDSView);
1446+
readReducedResultsFromLDS(rewriter, loc, op, workspaceLDSBuffer,
1447+
outputReg, inputViewArrayAttr, axis,
1448+
partialRegTensorShape[rDim], tid,
1449+
/*withBarrier=*/true);
1450+
} else {
14231451
// This means there are more threads than elements to be reduced.
14241452
ArrayAttr threadToTensorViewTrs =
14251453
createThreadViewforNRSmallerThanThreads(loc, partialRegTensorShape,
@@ -1437,27 +1465,15 @@ struct BlockwiseReduceRewritePattern
14371465
getMaxVectorization(threadToLDSViewed, rIterDim);
14381466
int64_t rIterVectorLen = rIterVectorRes.max;
14391467

1440-
// Use DPP-based subgroup reduction when all conditions are met:
1441-
// 1. Power-of-2 reduction threads (required by SubgroupReduceOp)
1442-
// 2. More than 1 reduction thread (at least 2 for cross-lane work)
1443-
// 3. partial_r > 2 (DPP overhead not justified for partial_r=2)
1444-
// 4. Reduction threads fit within a single wave
1445-
// 5. Exact thread packing: blockSize == clusterSize *
1446-
// nonReductionDimSizeProduct. This guarantees every thread maps to
1447-
// a valid (nrtid, rtid) pair, so LDS coordinates derived from them
1448-
// are in-bounds.
1449-
// Otherwise, fall back to LDS-based tree reduction.
1468+
// DPP subgroup reduction: power-of-2 threads, partialR>2, fits in wave.
14501469
int64_t maxActiveReductionThreads = threadViewShape[rTidDim];
14511470
int64_t clusterSize = llvm::PowerOf2Ceil(maxActiveReductionThreads);
1452-
int64_t partialR = partialRegTensorShape[rDim];
14531471
bool canUseDPP = llvm::isPowerOf2_64(maxActiveReductionThreads) &&
14541472
(maxActiveReductionThreads > 1) && (partialR > 2) &&
14551473
(maxActiveReductionThreads <= waveSize) &&
14561474
(blockSize == maxActiveReductionThreads *
14571475
nonReductionDimSizeProduct);
1458-
// DPP path: contiguous threads reduce together (rtid = tid % cluster).
1459-
// Tree path: scattered layout (rtid = tid /
1460-
// nonReductionDimSizeProduct).
1476+
// DPP: rtid = tid % cluster. Tree: rtid = tid / nrDimProd.
14611477
Value rtid, nrtid;
14621478
if (canUseDPP) {
14631479
assert(llvm::isPowerOf2_64(clusterSize) &&
@@ -1528,8 +1544,6 @@ struct BlockwiseReduceRewritePattern
15281544
}
15291545
}
15301546

1531-
// Cross-lane reduction: DPP path uses SubgroupReduceOp with
1532-
// cluster_size, tree path uses iterative LDS load/reduce/store.
15331547
if (canUseDPP) {
15341548
SmallVector<Value, 4> inits{nrtid, rtid, zeroConstantOp};
15351549
SmallVector<int64_t> bounds{1, 1, 1};
@@ -1585,7 +1599,6 @@ struct BlockwiseReduceRewritePattern
15851599
LDSBarrierOp::create(rewriter, loc);
15861600

15871601
} else {
1588-
// Tree reduction path: needs LDS for inter-thread communication
15891602
int64_t ceilPowerOf2 =
15901603
llvm::PowerOf2Ceil(maxActiveReductionThreads) / 2;
15911604
if (hasThreadwiseReduction) {
@@ -1666,6 +1679,7 @@ struct BlockwiseReduceRewritePattern
16661679
outputReg, inputViewArrayAttr, axis,
16671680
partialRegTensorShape[rDim], tid,
16681681
/*withBarrier=*/false);
1682+
}
16691683
}
16701684
rewriter.eraseOp(op);
16711685
return success();

0 commit comments

Comments
 (0)