@@ -2788,47 +2788,19 @@ struct GridwiseAttentionAccelRewritePattern
27882788 accelEmitterPtrGemm0->computeOutputConversion (
27892789 rewriter, loc, accRegBufferGemm0, gemm0OutBuffer, forceUnroll);
27902790
2791- // ================================================================
2792- // V PREFETCH: Issue global reads for V tile 0 before softmax.
2793- // ================================================================
2794- // By issuing V global reads here (before softmax computation),
2795- // we overlap the ~120+ instructions of softmax work with the
2796- // global memory access latency for V, matching CK's approach.
2797- //
2798- // The flow is:
2799- // 1. Issue V global reads -> register buffer [HERE, before softmax]
2800- // 2. Softmax computation [hides load latency]
2801- // 3. Write V from registers -> LDS [after softmax]
2802- // 4. GEMM1 first iteration uses V from LDS [peeled iteration]
2803- // 5. Remaining GEMM1 iters: normal load+MMA [pipelineable loop]
2804- //
2805- // The split is implemented using two new GemmLoadTileType values:
2806- // - GlobalReadOnly: emits only the GlobalRead stage
2807- // (ThreadwiseReadIntoOp: global -> register buffer, no LDS write)
2808- // - LDSWriteFromRegs: emits only the LDSWrite stage
2809- // (ThreadwiseCopyOp + ThreadwiseWriteAllOp: regs -> LDS,
2810- // no global read)
2811- // Both phases share the same flat register buffer (vPrefetchRegs).
2791+ // V p: Issue global reads for V tile 0 before softmax
2792+ // to overlap softmax computation with V's global memory latency.
2793+ // Uses GlobalReadOnly (global -> regs) and LDSWriteFromRegs
2794+ // (regs -> LDS) to split the load across the softmax boundary.
28122795 Value ldsByteBufferV;
28132796 Value vPrefetchRegs;
28142797 layout::GridCoordinates gridCoordsGemm1;
28152798 bool prefetchFirstVTile = op.getEnableSoftmax () && !directToLDS;
28162799
2817- // Decide whether to hoist Phase 2 (V regs -> LDS write) before the
2818- // sum reduction. Hoisting saves one LDS barrier by piggybacking on
2819- // the sum reduction's internal barrier, but it makes V's LDS live
2820- // range overlap with the sum-reduction workspace, preventing
2821- // ReuseLDS from aliasing them.
2822- //
2823- // ReuseLDS uses greedy graph coloring that packs non-interfering
2824- // buffers (like K and V) into merged color groups. When V interferes
2825- // with sum_ws (due to hoisting), V gets displaced within the merged
2826- // group by sum_ws's size, growing the group by exactly sumWSBytes.
2827- // So: hoisted_total ≈ non_hoisted_peak + sumWSBytes.
2828- //
2829- // The non-hoisted peak is the max concurrent LDS from GEMM0
2830- // (Q+K buffers) or GEMM1 (V+gemm1_B buffers). We check if adding
2831- // sumWSBytes would exceed the hardware LDS limit.
2800+ // Decide whether to hoist V regs->LDS write before the sum reduction.
2801+ // Hoisting saves one LDS barrier but extends V's LDS live range to
2802+ // overlap with the sum-reduction workspace, which may increase peak
2803+ // LDS usage. Only hoist if the resulting peak fits in hardware LDS.
28322804 bool hoistVPhase2 = false ;
28332805 if (prefetchFirstVTile) {
28342806 int64_t maxLDS = archInfo.maxSharedMemPerWG ;
@@ -2840,6 +2812,7 @@ struct GridwiseAttentionAccelRewritePattern
28402812 int64_t gemm1PeakBytes =
28412813 getPackedByteSize (gemm1KPerBlock * gemm1MPerBlock, elemTypeV) +
28422814 getPackedByteSize (gemm1LDSByteBufferBSize, elemTypeV);
2815+
28432816 // The base peak without hoisting is determined by the larger of
28442817 // GEMM0 and GEMM1 concurrent buffer sets.
28452818 int64_t nonHoistedPeak = std::max (gemm0PeakBytes, gemm1PeakBytes);
@@ -3050,21 +3023,9 @@ struct GridwiseAttentionAccelRewritePattern
30503023 gemm0MNExpThreadwiseView,
30513024 gemm0MNMaxThreadwiseView, maxRowBuffer);
30523025
3053- // ================================================================
3054- // V PREFETCH Phase 2 (hoisted): Write V data from regs to LDS
3055- // before the sum reduction so that the sum reduction's internal
3056- // LDS barrier also synchronises the V tile writes. This
3057- // eliminates the dedicated V-tile LDS barrier that was
3058- // previously required after the sum reduction, saving one
3059- // s_barrier per iteration.
3060- //
3061- // Safety: AnnotateLiveness + ReuseLDS will see that V's live
3062- // range (write here -> read during GEMM1) overlaps with the sum
3063- // workspace's live range, so they will NOT be aliased. The
3064- // max-reduction workspace is already dead, so it CAN be
3065- // aliased with V. The LDS increase is small and does not
3066- // affect occupancy (VGPR-limited, not LDS-limited).
3067- // ================================================================
3026+ // V prefetch phase 2 (hoisted): Write V data from regs to LDS
3027+ // before the sum reduction. The sum reduction's internal LDS
3028+ // barrier synchronises the V tile writes, saving one barrier.
30683029 if (prefetchFirstVTile && hoistVPhase2) {
30693030 // Allocate V LDS buffer early (before the sum reduction) so that
30703031 // Phase 2 can write the prefetched V data from registers into LDS.
@@ -3076,9 +3037,6 @@ struct GridwiseAttentionAccelRewritePattern
30763037 vPrefetchRegs, GemmLoadTileType::LDSWriteFromRegs, " m" ,
30773038 blockSize, elemTypeV, elemTypeVLoad, gemm1TuningParams,
30783039 featuresAttr, matrixParamsV, matrixParamsKxQ);
3079- // No LDSBarrierOp here — the barrier inside the sum
3080- // BlockwiseBroadcastReduceOp (below) will synchronise both
3081- // the V LDS writes and the softmax partial-sum LDS writes.
30823040 }
30833041
30843042 // Softmax sum reduction
@@ -3107,20 +3065,10 @@ struct GridwiseAttentionAccelRewritePattern
31073065 gemm0MaxThreadwiseView, sumRowBuffer, maxRowBuffer,
31083066 expMaxDiffRowBuffer);
31093067
3110- // ================================================================
3111- // V PREFETCH Phase 2 (deferred path): Write V data from regs to
3112- // LDS after the sum reduction. This avoids V's LDS live range
3113- // overlapping with the sum-reduction workspace, allowing
3114- // ReuseLDS to alias them and stay within the hardware LDS budget.
3115- // Costs one extra s_barrier vs the hoisted path.
3116- // Phase 1 (global reads -> regs, before softmax) still hides the
3117- // global memory latency across the entire softmax computation.
3118- // ================================================================
3068+ // V prefetch phase 2 (deferred path): Write V data from regs to
3069+ // LDS after the sum reduction to avoid overlapping with the
3070+ // sum-reduction workspace in LDS. Costs one extra barrier.
31193071 if (prefetchFirstVTile && !hoistVPhase2) {
3120- // Allocate V LDS buffer HERE (late) instead of before softmax.
3121- // This makes ldsByteBufferV's live range start after the
3122- // reduction, preventing ReuseLDS from aliasing it with
3123- // buffers that are still being read by slow wavefronts.
31243072 ldsByteBufferV = createLDSByteBuffer (
31253073 rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV);
31263074 loadAndStoreGemmInputTile (
@@ -3181,11 +3129,8 @@ struct GridwiseAttentionAccelRewritePattern
31813129 }
31823130 }
31833131
3184- // ================================================================
3185- // V load + GEMM1 loop: Two paths depending on V prefetch.
3186- // ================================================================
3187- // For non-prefetch path: allocate V LDS buffer and grid coords
3188- // (prefetch path already did this before softmax).
3132+ // V load + GEMM1 loop. For the non-prefetch path, allocate the
3133+ // V LDS buffer and grid coords here (prefetch already did this).
31893134 if (!prefetchFirstVTile) {
31903135 ldsByteBufferV = createLDSByteBuffer (
31913136 rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV);
@@ -3194,12 +3139,7 @@ struct GridwiseAttentionAccelRewritePattern
31943139 numChiplets, splitKVConst);
31953140 }
31963141
3197- // ----------------------------------------------------------------
3198- // Helper lambda: Emit GEMM1 MMA + PostProcess for a single V tile.
3199- // Parameterized by V block index (g1MBlockIdx) to support both
3200- // the peeled first iteration and the remaining loop iterations.
3201- // This avoids duplicating ~100 lines of MMA + PostProcess code.
3202- // ----------------------------------------------------------------
3142+ // Helper lambda: emit GEMM1 MMA + PostProcess for a single V tile.
32033143 auto emitGemm1Compute =
32043144 [&](Value g1MBlockIdx, GemmLoadTileType vLoadType,
32053145 Value vRegBuf) -> LogicalResult {
@@ -3333,30 +3273,11 @@ struct GridwiseAttentionAccelRewritePattern
33333273 }; // end emitGemm1Compute lambda
33343274
33353275 if (prefetchFirstVTile) {
3336- // ============================================================
3337- // PREFETCH PATH: First V tile already loaded into LDS.
3338- // ============================================================
3339- // V data for tile 0 was prefetched before softmax (global read)
3340- // and written to LDS before the sum reduction (LDS write synced
3341- // by sum reduction's internal barrier).
3342- // The first GEMM1 iteration is peeled out of the loop so the
3343- // remaining iterations form a clean, pipelineable loop.
3344-
3345- // --- Peeled first iteration (g1m = 0) ---
3276+ // Prefetch path: V tile 0 is already in LDS. Peel the first
3277+ // GEMM1 iteration and loop over the remaining tiles.
33463278 gridCoordsGemm1.m_block = zero;
3347- // Use Default load type for the peeled iteration because the V
3348- // data was written to LDS by the LDSWriteFromRegs phase. There is
3349- // no BlockwiseLoadTileOp here to create an LDSRead stage, so the
3350- // GEMM must read V directly from LDS.
3351- //
3352- // When double-buffering is active, preAccelRegBufferV is rank-2
3353- // (e.g. memref<3x2xvector<4xf16>>) because it was allocated with
3354- // repeats=mRepeats. However, the Default load path in
3355- // BlockwiseGemmAccelOp reads from LDS into the buffer WITHOUT
3356- // slicing by the m-repeat loop variable. The downstream
3357- // generateThreadwiseViewBufferA then creates a rank-1 view,
3358- // leading to a memref.load rank mismatch. Fix: create a separate
3359- // rank-1 register buffer for the peeled iteration.
3279+ // When double-buffering, preAccelRegBufferV is rank-2; the
3280+ // Default load path expects rank-1, so allocate a separate buf.
33603281 Value peeledVRegBuf = preAccelRegBufferV;
33613282 if (doubleBuffering) {
33623283 auto [peeledVForLoad, peeledVBuf] =
@@ -3366,22 +3287,14 @@ struct GridwiseAttentionAccelRewritePattern
33663287 /* repeats=*/ 1 , directToLDS);
33673288 peeledVRegBuf = peeledVBuf;
33683289 }
3369- // Barrier: ensure all threads have finished writing the softmax
3370- // exp values to LDS (storeGemmInputTile above) before GEMM1
3371- // reads from them. Only needed when the softmax exp actually
3372- // goes through LDS (!doBypassLDSSecondGemm). When LDS is
3373- // bypassed, softmax exp stays in registers and V is already
3374- // synced by either the sum reduction's internal barrier
3375- // (hoisted path) or the deferred V Phase 2 barrier.
33763290 if (!doBypassLDSSecondGemm)
33773291 LDSBarrierOp::create (rewriter, loc);
33783292
33793293 if (failed (emitGemm1Compute (zero, GemmLoadTileType::Default,
33803294 peeledVRegBuf)))
33813295 return failure ();
33823296
3383- // --- Remaining iterations (g1m = 1..gemm1MBlocks-1) ---
3384- // These form a standard pipelineable loop with V loads.
3297+ // Remaining iterations (g1m = 1..gemm1MBlocks-1).
33853298 if (gemm1MBlocks > 1 ) {
33863299 LDSBarrierOp::create (rewriter, loc);
33873300
@@ -3393,13 +3306,8 @@ struct GridwiseAttentionAccelRewritePattern
33933306 rewriter.createOrFold <arith::ConstantIndexOp>(loc, 1 );
33943307 scf::ForOp g1MLoopOp = scf::ForOp::create (
33953308 rewriter, loc, startG1M, endG1MLoop, oneVal);
3396- // Mark loop for pipelining — but only when the remaining loop
3397- // has more than 1 iteration. Pipelining a 1-iteration loop
3398- // (gemm1MBlocks == 2 → loop from 1 to 2) provides no overlap
3399- // benefit and the RockPipelinePass currently drops the
3400- // inter-stage LDS barriers from the epilogue, causing a data
3401- // race between the V LDS write (prologue) and the GEMM1 V LDS
3402- // read (epilogue).
3309+ // Only pipeline when >1 iteration remains; pipelining a
3310+ // single iteration causes barrier mismatches.
34033311 if (gemm1MBlocks > 2 ) {
34043312 bool g1DoubleBuffering =
34053313 loadType == GemmLoadTileType::DoubleBuffer ||
@@ -3437,9 +3345,7 @@ struct GridwiseAttentionAccelRewritePattern
34373345 }
34383346 }
34393347 } else {
3440- // ============================================================
3441- // ORIGINAL PATH: No V prefetch (softmax disabled).
3442- // ============================================================
3348+ // Non-prefetch path (softmax disabled).
34433349 Value endG1MLoop =
34443350 rewriter.createOrFold <ConstantIndexOp>(loc, gemm1MBlocks);
34453351 scf::ForOp g1MLoopOp =
0 commit comments