Skip to content

Commit 98c5c85

Browse files
committed
More cleanup
1 parent 288db18 commit 98c5c85

4 files changed

Lines changed: 32 additions & 141 deletions

File tree

mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,15 +1418,13 @@ struct BlockwiseReduceRewritePattern
14181418
// Branchless reduction: each thread reads all rTidDim partial
14191419
// values from LDS and reduces locally in registers. This avoids
14201420
// creating conditional branches (scf.if) that split softmax into
1421-
// multiple basic blocks. Without branches, the LLVM backend
1422-
// scheduler can keep V global loads (issued before softmax) in
1423-
// the same basic block, enabling sched_barrier to prevent them
1424-
// from being sunk past softmax computation.
1425-
//
1421+
// multiple basic blocks.
14261422
// Trade-off: every thread does rTidCount LDS reads (instead of
14271423
// log2(rTidCount) conditional reads in the tree reduction). For
14281424
// typical attention configs where rTidCount is small (e.g., 4),
14291425
// this is negligible overhead.
1426+
// TODO: We may have to use a heuristic to determine whether or not to
1427+
// use this depending on the size of rTidCount.
14301428
{
14311429
int64_t rTidCount = threadViewShape[rTidDim];
14321430

@@ -1477,13 +1475,6 @@ struct BlockwiseReduceRewritePattern
14771475
// Write the fully reduced value back to LDS at [nrtid, 0].
14781476
// All threads with the same nrtid compute the same value,
14791477
// so concurrent writes to the same location are safe.
1480-
//
1481-
// NOTE: We cannot use a FillOp shortcut here (even when
1482-
// inputThreadSubTile2dShape[nrDim] == 1) because nrtid
1483-
// (= tid % nonReduceMergeDimSize) does NOT necessarily
1484-
// correspond to the thread's actual non-reduction position
1485-
// in the MFMA layout. The ThreadwiseReadIntoOp uses the
1486-
// correct layout-aware view to read each thread's result.
14871478
{
14881479
Value reducedVal = InBoundsLoadOp::create(
14891480
rewriter, loc, elemType, accReg, zeroConstantOp);

mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,6 @@ class LoweringBlockwiseLoadTileOp final
286286
if (isa<LoopLikeOpInterface>(parentOp))
287287
b.setInsertionPoint(op);
288288

289-
// ---- GlobalRead stage ----
290-
// Emit for all types EXCEPT LDSWriteFromRegs (which only does the write).
291289
if (!ldsWriteFromRegs) {
292290
// Use distinct stage name for split-phase V prefetch to avoid
293291
// conflicting with K/Q GlobalRead stages in the same parent scope.
@@ -336,10 +334,6 @@ class LoweringBlockwiseLoadTileOp final
336334
Value one = b.createOrFold<arith::ConstantIndexOp>(loc, 1);
337335
indicesNext[0] =
338336
arith::AddIOp::create(b, loc, indicesNext[0], one).getResult();
339-
340-
// it's acceptable if the indices are out of bounds because we use
341-
// GLOBAL_PREFETCH_B8 with Speculative Prefetch. See llvm.prefetch
342-
// documentation in AMDGPUUsage.rst
343337
rock::ThreadwisePrefetchOp::create(b, loc, wrappedSource,
344338
/*extraViews=*/b.getArrayAttr({}),
345339
/*extraIndices=*/indicesNext,
@@ -350,7 +344,7 @@ class LoweringBlockwiseLoadTileOp final
350344
}
351345
}
352346

353-
// For GlobalReadOnly, we're done - skip all write stages.
347+
// For GlobalReadOnly there's nothing further to do.
354348
if (globalReadOnly) {
355349
b.eraseOp(op);
356350
return success();

mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp

Lines changed: 26 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -2788,47 +2788,19 @@ struct GridwiseAttentionAccelRewritePattern
27882788
accelEmitterPtrGemm0->computeOutputConversion(
27892789
rewriter, loc, accRegBufferGemm0, gemm0OutBuffer, forceUnroll);
27902790

2791-
// ================================================================
2792-
// V PREFETCH: Issue global reads for V tile 0 before softmax.
2793-
// ================================================================
2794-
// By issuing V global reads here (before softmax computation),
2795-
// we overlap the ~120+ instructions of softmax work with the
2796-
// global memory access latency for V, matching CK's approach.
2797-
//
2798-
// The flow is:
2799-
// 1. Issue V global reads -> register buffer [HERE, before softmax]
2800-
// 2. Softmax computation [hides load latency]
2801-
// 3. Write V from registers -> LDS [after softmax]
2802-
// 4. GEMM1 first iteration uses V from LDS [peeled iteration]
2803-
// 5. Remaining GEMM1 iters: normal load+MMA [pipelineable loop]
2804-
//
2805-
// The split is implemented using two new GemmLoadTileType values:
2806-
// - GlobalReadOnly: emits only the GlobalRead stage
2807-
// (ThreadwiseReadIntoOp: global -> register buffer, no LDS write)
2808-
// - LDSWriteFromRegs: emits only the LDSWrite stage
2809-
// (ThreadwiseCopyOp + ThreadwiseWriteAllOp: regs -> LDS,
2810-
// no global read)
2811-
// Both phases share the same flat register buffer (vPrefetchRegs).
2791+
// V p: Issue global reads for V tile 0 before softmax
2792+
// to overlap softmax computation with V's global memory latency.
2793+
// Uses GlobalReadOnly (global -> regs) and LDSWriteFromRegs
2794+
// (regs -> LDS) to split the load across the softmax boundary.
28122795
Value ldsByteBufferV;
28132796
Value vPrefetchRegs;
28142797
layout::GridCoordinates gridCoordsGemm1;
28152798
bool prefetchFirstVTile = op.getEnableSoftmax() && !directToLDS;
28162799

2817-
// Decide whether to hoist Phase 2 (V regs -> LDS write) before the
2818-
// sum reduction. Hoisting saves one LDS barrier by piggybacking on
2819-
// the sum reduction's internal barrier, but it makes V's LDS live
2820-
// range overlap with the sum-reduction workspace, preventing
2821-
// ReuseLDS from aliasing them.
2822-
//
2823-
// ReuseLDS uses greedy graph coloring that packs non-interfering
2824-
// buffers (like K and V) into merged color groups. When V interferes
2825-
// with sum_ws (due to hoisting), V gets displaced within the merged
2826-
// group by sum_ws's size, growing the group by exactly sumWSBytes.
2827-
// So: hoisted_total ≈ non_hoisted_peak + sumWSBytes.
2828-
//
2829-
// The non-hoisted peak is the max concurrent LDS from GEMM0
2830-
// (Q+K buffers) or GEMM1 (V+gemm1_B buffers). We check if adding
2831-
// sumWSBytes would exceed the hardware LDS limit.
2800+
// Decide whether to hoist V regs->LDS write before the sum reduction.
2801+
// Hoisting saves one LDS barrier but extends V's LDS live range to
2802+
// overlap with the sum-reduction workspace, which may increase peak
2803+
// LDS usage. Only hoist if the resulting peak fits in hardware LDS.
28322804
bool hoistVPhase2 = false;
28332805
if (prefetchFirstVTile) {
28342806
int64_t maxLDS = archInfo.maxSharedMemPerWG;
@@ -2840,6 +2812,7 @@ struct GridwiseAttentionAccelRewritePattern
28402812
int64_t gemm1PeakBytes =
28412813
getPackedByteSize(gemm1KPerBlock * gemm1MPerBlock, elemTypeV) +
28422814
getPackedByteSize(gemm1LDSByteBufferBSize, elemTypeV);
2815+
28432816
// The base peak without hoisting is determined by the larger of
28442817
// GEMM0 and GEMM1 concurrent buffer sets.
28452818
int64_t nonHoistedPeak = std::max(gemm0PeakBytes, gemm1PeakBytes);
@@ -3050,21 +3023,9 @@ struct GridwiseAttentionAccelRewritePattern
30503023
gemm0MNExpThreadwiseView,
30513024
gemm0MNMaxThreadwiseView, maxRowBuffer);
30523025

3053-
// ================================================================
3054-
// V PREFETCH Phase 2 (hoisted): Write V data from regs to LDS
3055-
// before the sum reduction so that the sum reduction's internal
3056-
// LDS barrier also synchronises the V tile writes. This
3057-
// eliminates the dedicated V-tile LDS barrier that was
3058-
// previously required after the sum reduction, saving one
3059-
// s_barrier per iteration.
3060-
//
3061-
// Safety: AnnotateLiveness + ReuseLDS will see that V's live
3062-
// range (write here -> read during GEMM1) overlaps with the sum
3063-
// workspace's live range, so they will NOT be aliased. The
3064-
// max-reduction workspace is already dead, so it CAN be
3065-
// aliased with V. The LDS increase is small and does not
3066-
// affect occupancy (VGPR-limited, not LDS-limited).
3067-
// ================================================================
3026+
// V prefetch phase 2 (hoisted): Write V data from regs to LDS
3027+
// before the sum reduction. The sum reduction's internal LDS
3028+
// barrier synchronises the V tile writes, saving one barrier.
30683029
if (prefetchFirstVTile && hoistVPhase2) {
30693030
// Allocate V LDS buffer early (before the sum reduction) so that
30703031
// Phase 2 can write the prefetched V data from registers into LDS.
@@ -3076,9 +3037,6 @@ struct GridwiseAttentionAccelRewritePattern
30763037
vPrefetchRegs, GemmLoadTileType::LDSWriteFromRegs, "m",
30773038
blockSize, elemTypeV, elemTypeVLoad, gemm1TuningParams,
30783039
featuresAttr, matrixParamsV, matrixParamsKxQ);
3079-
// No LDSBarrierOp here — the barrier inside the sum
3080-
// BlockwiseBroadcastReduceOp (below) will synchronise both
3081-
// the V LDS writes and the softmax partial-sum LDS writes.
30823040
}
30833041

30843042
// Softmax sum reduction
@@ -3107,20 +3065,10 @@ struct GridwiseAttentionAccelRewritePattern
31073065
gemm0MaxThreadwiseView, sumRowBuffer, maxRowBuffer,
31083066
expMaxDiffRowBuffer);
31093067

3110-
// ================================================================
3111-
// V PREFETCH Phase 2 (deferred path): Write V data from regs to
3112-
// LDS after the sum reduction. This avoids V's LDS live range
3113-
// overlapping with the sum-reduction workspace, allowing
3114-
// ReuseLDS to alias them and stay within the hardware LDS budget.
3115-
// Costs one extra s_barrier vs the hoisted path.
3116-
// Phase 1 (global reads -> regs, before softmax) still hides the
3117-
// global memory latency across the entire softmax computation.
3118-
// ================================================================
3068+
// V prefetch phase 2 (deferred path): Write V data from regs to
3069+
// LDS after the sum reduction to avoid overlapping with the
3070+
// sum-reduction workspace in LDS. Costs one extra barrier.
31193071
if (prefetchFirstVTile && !hoistVPhase2) {
3120-
// Allocate V LDS buffer HERE (late) instead of before softmax.
3121-
// This makes ldsByteBufferV's live range start after the
3122-
// reduction, preventing ReuseLDS from aliasing it with
3123-
// buffers that are still being read by slow wavefronts.
31243072
ldsByteBufferV = createLDSByteBuffer(
31253073
rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV);
31263074
loadAndStoreGemmInputTile(
@@ -3181,11 +3129,8 @@ struct GridwiseAttentionAccelRewritePattern
31813129
}
31823130
}
31833131

3184-
// ================================================================
3185-
// V load + GEMM1 loop: Two paths depending on V prefetch.
3186-
// ================================================================
3187-
// For non-prefetch path: allocate V LDS buffer and grid coords
3188-
// (prefetch path already did this before softmax).
3132+
// V load + GEMM1 loop. For the non-prefetch path, allocate the
3133+
// V LDS buffer and grid coords here (prefetch already did this).
31893134
if (!prefetchFirstVTile) {
31903135
ldsByteBufferV = createLDSByteBuffer(
31913136
rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV);
@@ -3194,12 +3139,7 @@ struct GridwiseAttentionAccelRewritePattern
31943139
numChiplets, splitKVConst);
31953140
}
31963141

3197-
// ----------------------------------------------------------------
3198-
// Helper lambda: Emit GEMM1 MMA + PostProcess for a single V tile.
3199-
// Parameterized by V block index (g1MBlockIdx) to support both
3200-
// the peeled first iteration and the remaining loop iterations.
3201-
// This avoids duplicating ~100 lines of MMA + PostProcess code.
3202-
// ----------------------------------------------------------------
3142+
// Helper lambda: emit GEMM1 MMA + PostProcess for a single V tile.
32033143
auto emitGemm1Compute =
32043144
[&](Value g1MBlockIdx, GemmLoadTileType vLoadType,
32053145
Value vRegBuf) -> LogicalResult {
@@ -3333,30 +3273,11 @@ struct GridwiseAttentionAccelRewritePattern
33333273
}; // end emitGemm1Compute lambda
33343274

33353275
if (prefetchFirstVTile) {
3336-
// ============================================================
3337-
// PREFETCH PATH: First V tile already loaded into LDS.
3338-
// ============================================================
3339-
// V data for tile 0 was prefetched before softmax (global read)
3340-
// and written to LDS before the sum reduction (LDS write synced
3341-
// by sum reduction's internal barrier).
3342-
// The first GEMM1 iteration is peeled out of the loop so the
3343-
// remaining iterations form a clean, pipelineable loop.
3344-
3345-
// --- Peeled first iteration (g1m = 0) ---
3276+
// Prefetch path: V tile 0 is already in LDS. Peel the first
3277+
// GEMM1 iteration and loop over the remaining tiles.
33463278
gridCoordsGemm1.m_block = zero;
3347-
// Use Default load type for the peeled iteration because the V
3348-
// data was written to LDS by the LDSWriteFromRegs phase. There is
3349-
// no BlockwiseLoadTileOp here to create an LDSRead stage, so the
3350-
// GEMM must read V directly from LDS.
3351-
//
3352-
// When double-buffering is active, preAccelRegBufferV is rank-2
3353-
// (e.g. memref<3x2xvector<4xf16>>) because it was allocated with
3354-
// repeats=mRepeats. However, the Default load path in
3355-
// BlockwiseGemmAccelOp reads from LDS into the buffer WITHOUT
3356-
// slicing by the m-repeat loop variable. The downstream
3357-
// generateThreadwiseViewBufferA then creates a rank-1 view,
3358-
// leading to a memref.load rank mismatch. Fix: create a separate
3359-
// rank-1 register buffer for the peeled iteration.
3279+
// When double-buffering, preAccelRegBufferV is rank-2; the
3280+
// Default load path expects rank-1, so allocate a separate buf.
33603281
Value peeledVRegBuf = preAccelRegBufferV;
33613282
if (doubleBuffering) {
33623283
auto [peeledVForLoad, peeledVBuf] =
@@ -3366,22 +3287,14 @@ struct GridwiseAttentionAccelRewritePattern
33663287
/*repeats=*/1, directToLDS);
33673288
peeledVRegBuf = peeledVBuf;
33683289
}
3369-
// Barrier: ensure all threads have finished writing the softmax
3370-
// exp values to LDS (storeGemmInputTile above) before GEMM1
3371-
// reads from them. Only needed when the softmax exp actually
3372-
// goes through LDS (!doBypassLDSSecondGemm). When LDS is
3373-
// bypassed, softmax exp stays in registers and V is already
3374-
// synced by either the sum reduction's internal barrier
3375-
// (hoisted path) or the deferred V Phase 2 barrier.
33763290
if (!doBypassLDSSecondGemm)
33773291
LDSBarrierOp::create(rewriter, loc);
33783292

33793293
if (failed(emitGemm1Compute(zero, GemmLoadTileType::Default,
33803294
peeledVRegBuf)))
33813295
return failure();
33823296

3383-
// --- Remaining iterations (g1m = 1..gemm1MBlocks-1) ---
3384-
// These form a standard pipelineable loop with V loads.
3297+
// Remaining iterations (g1m = 1..gemm1MBlocks-1).
33853298
if (gemm1MBlocks > 1) {
33863299
LDSBarrierOp::create(rewriter, loc);
33873300

@@ -3393,13 +3306,8 @@ struct GridwiseAttentionAccelRewritePattern
33933306
rewriter.createOrFold<arith::ConstantIndexOp>(loc, 1);
33943307
scf::ForOp g1MLoopOp = scf::ForOp::create(
33953308
rewriter, loc, startG1M, endG1MLoop, oneVal);
3396-
// Mark loop for pipelining — but only when the remaining loop
3397-
// has more than 1 iteration. Pipelining a 1-iteration loop
3398-
// (gemm1MBlocks == 2 → loop from 1 to 2) provides no overlap
3399-
// benefit and the RockPipelinePass currently drops the
3400-
// inter-stage LDS barriers from the epilogue, causing a data
3401-
// race between the V LDS write (prologue) and the GEMM1 V LDS
3402-
// read (epilogue).
3309+
// Only pipeline when >1 iteration remains; pipelining a
3310+
// single iteration causes barrier mismatches.
34033311
if (gemm1MBlocks > 2) {
34043312
bool g1DoubleBuffering =
34053313
loadType == GemmLoadTileType::DoubleBuffer ||
@@ -3437,9 +3345,7 @@ struct GridwiseAttentionAccelRewritePattern
34373345
}
34383346
}
34393347
} else {
3440-
// ============================================================
3441-
// ORIGINAL PATH: No V prefetch (softmax disabled).
3442-
// ============================================================
3348+
// Non-prefetch path (softmax disabled).
34433349
Value endG1MLoop =
34443350
rewriter.createOrFold<ConstantIndexOp>(loc, gemm1MBlocks);
34453351
scf::ForOp g1MLoopOp =

mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ struct PushBarrierDownRewritePattern
166166
if (!nextOp->getNextNode())
167167
return failure();
168168

169-
// Don't push past another barrier RemoveBackToBack handles that.
169+
// Don't push past another barrier, RemoveBackToBack handles that.
170170
// Without this check, two adjacent barriers would swap endlessly.
171171
if (isa<rock::LDSBarrierOp>(nextOp))
172172
return failure();
@@ -178,7 +178,7 @@ struct PushBarrierDownRewritePattern
178178
bool moveDown = true;
179179
// Check if the operation accesses LDS.
180180
// We can move past LDS store-only operations because independent
181-
// writes don't need ordering between them the next barrier will
181+
// writes don't need ordering between them, the next barrier will
182182
// ensure all writes complete before any subsequent reads.
183183
// We must stop at LDS reads.
184184
// We recognize store ops both before SugarToLoops (InBoundsStoreOp)

0 commit comments

Comments
 (0)