From 0ceb2027dc40c57f4a2a177e9d364ec6eac8a420 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Sat, 14 Feb 2026 18:04:45 +0000 Subject: [PATCH 01/18] Test scheduling changes --- .../mlir/Dialect/Rock/IR/RockAttrDefs.td | 12 +- mlir/include/mlir/Dialect/Rock/Passes.td | 2 +- .../Transforms/BlockwiseGemmToThreadwise.cpp | 127 ++++---- .../BlockwiseLoadTileToThreadwise.cpp | 141 +++++---- .../Transforms/GridwiseGemmToBlockwise.cpp | 281 +++++++++++++++--- .../Transforms/ThreadwiseGemmLowering.cpp | 35 ++- 6 files changed, 440 insertions(+), 158 deletions(-) diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td index 796e6d6dae9c..1f9e0e37a7f4 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td @@ -464,13 +464,23 @@ def Rock_GemmLoadTileDirectToLDSDefault : I32EnumAttrCase<"DirectToLDSDefault", 3>; def Rock_GemmLoadTileDirectToLDSDoubleBuffer : I32EnumAttrCase<"DirectToLDSDoubleBuffer", 4>; +// Split-phase load types for V prefetch in attention kernels. +// GlobalReadOnly: Only emit the global read stage (global -> register buffer). +// LDSWriteFromRegs: Only emit the LDS write stage (register buffer -> LDS). +// Both phases share a register buffer passed via destRegisters. +def Rock_GemmLoadTileGlobalReadOnly + : I32EnumAttrCase<"GlobalReadOnly", 5>; +def Rock_GemmLoadTileLDSWriteFromRegs + : I32EnumAttrCase<"LDSWriteFromRegs", 6>; def Rock_GemmLoadTileType : Rock_I32Enum<"GemmLoadTileType", "GEMM load tile types", [Rock_GemmLoadTileBypassLDS, Rock_GemmLoadTileDefault, Rock_GemmLoadTileDoubleBuffer, Rock_GemmLoadTileDirectToLDSDefault, - Rock_GemmLoadTileDirectToLDSDoubleBuffer]> { + Rock_GemmLoadTileDirectToLDSDoubleBuffer, + Rock_GemmLoadTileGlobalReadOnly, + Rock_GemmLoadTileLDSWriteFromRegs]> { let cppNamespace = "::mlir::rock"; let genSpecializedAttr = 0; } diff --git a/mlir/include/mlir/Dialect/Rock/Passes.td b/mlir/include/mlir/Dialect/Rock/Passes.td index 1041ab12abdf..18408e58b6ba 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.td +++ b/mlir/include/mlir/Dialect/Rock/Passes.td @@ -107,7 +107,7 @@ def RockRegularizePass : Pass<"rock-regularize", "::mlir::func::FuncOp"> { def RockGridwiseGemmToBlockwisePass : Pass<"rock-gridwise-gemm-to-blockwise", "::mlir::func::FuncOp"> { let summary = "expand gridwise gemm into blockwise copy, blockwise gemm, and threadwise copy"; - let dependentDialects = ["rock::RockDialect", "affine::AffineDialect", "gpu::GPUDialect", "vector::VectorDialect", "memref::MemRefDialect", "linalg::LinalgDialect", "scf::SCFDialect"]; + let dependentDialects = ["rock::RockDialect", "affine::AffineDialect", "gpu::GPUDialect", "vector::VectorDialect", "memref::MemRefDialect", "linalg::LinalgDialect", "scf::SCFDialect", "amdgpu::AMDGPUDialect"]; } def RockLinalgAlignPass : Pass<"rock-linalg-align", "::mlir::func::FuncOp"> { diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp index a9d701e7c458..3a49c09fcb9b 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp @@ -1410,64 +1410,83 @@ struct BlockwiseReduceRewritePattern } } - // This RAII scope would do the following : - // LDS[rtid] = reduce(LDS[rtid], LDS[rtid + offset]) - // where offset is a power of 2. - // Initial it starts with power = ceil(|rtid|, power of 2) / 2 - // Then keep on reducing the power. + // Branchless reduction: each thread reads ALL rTidDim partial + // values from LDS and reduces locally in registers. This avoids + // creating conditional branches (scf.if) that split softmax into + // multiple basic blocks. Without branches, the LLVM backend + // scheduler can keep V global loads (issued before softmax) in + // the same basic block, enabling sched_barrier to prevent them + // from being sunk past softmax computation. + // + // Trade-off: every thread does rTidCount LDS reads (instead of + // log2(rTidCount) conditional reads in the tree reduction). For + // typical attention configs where rTidCount is small (e.g., 4), + // this is negligible overhead. { - int64_t ceilPowerOf2 = - llvm::PowerOf2Ceil(threadViewShape[rTidDim]) / 2; - int64_t maxActiveReductionThreads = threadViewShape[rTidDim]; - for (int64_t offset = ceilPowerOf2; offset >= 1; - offset = offset >> 1) { - Value offsetVal = - arith::ConstantIndexOp::create(rewriter, loc, offset); - Value rtidPlusOffsetVal = - arith::AddIOp::create(rewriter, loc, rtid, offsetVal); - Value maxActiveReductionThreadsVal = arith::ConstantIndexOp::create( - rewriter, loc, maxActiveReductionThreads); - maxActiveReductionThreads = - llvm::PowerOf2Ceil(maxActiveReductionThreads) >> 1; - Value isValid = arith::CmpIOp::create( - rewriter, loc, arith::CmpIPredicate::slt, rtidPlusOffsetVal, - maxActiveReductionThreadsVal); - scf::IfOp ifb = scf::IfOp::create(rewriter, loc, isValid, - /*withElseRegion=*/false); + int64_t rTidCount = threadViewShape[rTidDim]; + + // Accumulator for the full reduction. + auto accRegType = MemRefType::get( + {1}, elemType, AffineMap{}, privateMemoryAddressSpace); + Value accReg = GpuAllocOp::create(rewriter, loc, accRegType); + FillOp::create(rewriter, loc, accReg, initVal); + + // Read all rTidCount partial values from LDS and reduce. + // Every thread with the same nrtid computes the identical + // fully-reduced value. + for (int64_t i = 0; i < rTidCount; i++) { + Value iVal = arith::ConstantIndexOp::create(rewriter, loc, i); + SmallVector readInits{nrtid, iVal, zeroConstantOp}; + SmallVector bounds{1, 1, 1}; + SmallVector strides{1, 1, 1}; + + TransformingForOp readLoop = TransformingForOp::create( + rewriter, loc, ArrayRef{readInits}, + ArrayRef{threadToLDSViewTrs}, + ArrayRef(bounds), ArrayRef(strides), + /*forceUnroll=*/true, /*useIndexDiffs=*/true); { - OpBuilder thenb = ifb.getThenBodyBuilder(); - SmallVector firstInits{nrtid, rtid, zeroConstantOp}; - SmallVector secondInits{nrtid, rtidPlusOffsetVal, - zeroConstantOp}; - SmallVector bounds{1, 1, 1}; - SmallVector strides{1, 1, 1}; - - TransformingForOp reductionLoop = TransformingForOp::create( - thenb, loc, ArrayRef{firstInits, secondInits}, - ArrayRef{threadToLDSViewTrs, threadToLDSViewTrs}, - ArrayRef(bounds), ArrayRef(strides), - /*forceUnroll=*/true, /*useIndexDiffs=*/true); - { - PatternRewriter::InsertionGuard guard(thenb); - thenb.setInsertionPointToStart(reductionLoop.getBody()); - Block::BlockArgListType firstLDSLoadCoords = - reductionLoop.getLowerCoords(/*domain=*/0); - Value firstLoadVal = InBoundsLoadOp::create( - thenb, loc, elemType, workspaceLDSBuffer, - firstLDSLoadCoords); - Block::BlockArgListType secondLDSLoadCoords = - reductionLoop.getLowerCoords(/*domain=*/1); - Value secondLoadVal = InBoundsLoadOp::create( - thenb, loc, elemType, workspaceLDSBuffer, - secondLDSLoadCoords); - Value reduced = - createReducingOp(op, firstLoadVal, secondLoadVal, thenb); - InBoundsStoreOp::create(thenb, loc, reduced, workspaceLDSBuffer, - firstLDSLoadCoords); - } + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(readLoop.getBody()); + Block::BlockArgListType ldsCoords = + readLoop.getLowerCoords(/*domain=*/0); + Value ldVal = InBoundsLoadOp::create( + rewriter, loc, elemType, workspaceLDSBuffer, ldsCoords); + Value accVal = InBoundsLoadOp::create( + rewriter, loc, elemType, accReg, zeroConstantOp); + Value reduced = createReducingOp(op, ldVal, accVal, rewriter); + InBoundsStoreOp::create(rewriter, loc, reduced, accReg, + zeroConstantOp); + } + } + + // Write the fully reduced value back to LDS at [nrtid, 0]. + // All threads with the same nrtid compute the same value, + // so concurrent writes to the same location are safe. + { + Value reducedVal = InBoundsLoadOp::create( + rewriter, loc, elemType, accReg, zeroConstantOp); + SmallVector writeInits{nrtid, zeroConstantOp, + zeroConstantOp}; + SmallVector writeBounds{1, 1, 1}; + SmallVector writeStrides{1, 1, 1}; + + TransformingForOp writeLoop = TransformingForOp::create( + rewriter, loc, ArrayRef{writeInits}, + ArrayRef{threadToLDSViewTrs}, + ArrayRef(writeBounds), ArrayRef(writeStrides), + /*forceUnroll=*/true, /*useIndexDiffs=*/true); + { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(writeLoop.getBody()); + Block::BlockArgListType ldsCoords = + writeLoop.getLowerCoords(/*domain=*/0); + InBoundsStoreOp::create(rewriter, loc, reducedVal, + workspaceLDSBuffer, ldsCoords); } - LDSBarrierOp::create(rewriter, loc); } + + LDSBarrierOp::create(rewriter, loc); ArrayAttr reducedldsViewArrayAttr = createLDSWorkspaceView( loc, rewriter, inputViewArrayAttr, axis, /*makeRDimZero-*/ true, partialRegTensorShape[rDim]); diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp index 0c309db0a2c2..c2e70e99bb90 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp @@ -244,8 +244,24 @@ class LoweringBlockwiseLoadTileOp final else b.setInsertionPoint(op); + bool globalReadOnly = + loadType == GemmLoadTileType::GlobalReadOnly; + bool ldsWriteFromRegs = + loadType == GemmLoadTileType::LDSWriteFromRegs; + Value loadBuffer, storeBuffer; - if (loadType == GemmLoadTileType::BypassLDS) { + if (globalReadOnly || ldsWriteFromRegs) { + // Split-phase load: use the externally-allocated destRegisters buffer + // as the shared loadBuffer between the GlobalReadOnly and + // LDSWriteFromRegs phases. + assert(destRegisters && + "destRegisters must be set for split-phase load types"); + loadBuffer = destRegisters; + if (ldsWriteFromRegs) { + storeBuffer = gpuAlloc(b, loc, copyPerThread, elementType, + AddressSpace::Private); + } + } else if (loadType == GemmLoadTileType::BypassLDS) { auto privateMemoryAddressSpace = b.getAttr( gpu::GPUDialect::getPrivateAddressSpace()); auto accelParams = accelEmitterPtr->getParams(); @@ -270,57 +286,74 @@ class LoweringBlockwiseLoadTileOp final if (isa(parentOp)) b.setInsertionPoint(op); - auto [stageGlobalRead, stageGlobalReadNew] = - createOrGetStage(b, loc, "GlobalRead", parentOp); - { - PatternRewriter::InsertionGuard guard(b); - b.setInsertionPointToStart(&stageGlobalRead.getRegion().back()); - - FailureOr maybeBufferViews; - if (loadType == GemmLoadTileType::BypassLDS) { - // Check if the other operand uses LDS transpose load - bool otherOperandUsesLdsTranspose = - isA ? matrixParamsB.getLdsTransposeEnabled() - : matrixParamsA.getLdsTransposeEnabled(); - maybeBufferViews = accelEmitterPtr->createAccelGemmOperandTransforms( - b, loc, kIters, bidGridLengths, blockSize, vecDimInfo.inDPerThread, - dName, isKContiguousDim, false, - /*doSplitKAcrossThreadsFirst=*/false, otherOperandUsesLdsTranspose); - } else { - maybeBufferViews = getLoadRegsAsTileViews( - b, loc, source, dName, bidGridOrder, bidGridLengths, blockSize, - kPerBlock, dPerBlock, vecDimInfo.inKPerThread, - vecDimInfo.inDPerThread, isKContiguousDim, directToLDS); - } - if (failed(maybeBufferViews)) - return failure(); - - Value wrappedSource = transform(b, source, maybeBufferViews->gridSubTile); - - ThreadwiseReadIntoOp::create(b, loc, vectorOfBoolShapedLike(loadBuffer), - wrappedSource, loadBuffer, - /*dynamicValidities=*/ValueRange{}, - /*extraViews=*/b.getArrayAttr({}), - /*extraIndices=*/indices, forceUnroll, true, - /*ldsTransposeConfig=*/nullptr); - - if (rock::isGlobalPrefetchSupported(arch)) { - // add one to k_loop to prefetch next iteration - SmallVector indicesNext(indices.begin(), indices.end()); - Value one = b.createOrFold(loc, 1); - indicesNext[0] = - arith::AddIOp::create(b, loc, indicesNext[0], one).getResult(); - - // it's acceptable if the indices are out of bounds because we use - // GLOBAL_PREFETCH_B8 with Speculative Prefetch. See llvm.prefetch - // documentation in AMDGPUUsage.rst - rock::ThreadwisePrefetchOp::create(b, loc, wrappedSource, - /*extraViews=*/b.getArrayAttr({}), - /*extraIndices=*/indicesNext, - forceUnroll, true); + // ---- GlobalRead stage ---- + // Emit for all types EXCEPT LDSWriteFromRegs (which only does the write). + if (!ldsWriteFromRegs) { + // Use distinct stage name for split-phase V prefetch to avoid + // conflicting with K/Q GlobalRead stages in the same parent scope. + StringRef globalReadStageName = + globalReadOnly ? "VGlobalRead" : "GlobalRead"; + auto [stageGlobalRead, stageGlobalReadNew] = + createOrGetStage(b, loc, globalReadStageName, parentOp); + { + PatternRewriter::InsertionGuard guard(b); + b.setInsertionPointToStart(&stageGlobalRead.getRegion().back()); + + FailureOr maybeBufferViews; + if (loadType == GemmLoadTileType::BypassLDS) { + // Check if the other operand uses LDS transpose load + bool otherOperandUsesLdsTranspose = + isA ? matrixParamsB.getLdsTransposeEnabled() + : matrixParamsA.getLdsTransposeEnabled(); + maybeBufferViews = accelEmitterPtr->createAccelGemmOperandTransforms( + b, loc, kIters, bidGridLengths, blockSize, + vecDimInfo.inDPerThread, dName, isKContiguousDim, false, + /*doSplitKAcrossThreadsFirst=*/false, + otherOperandUsesLdsTranspose); + } else { + maybeBufferViews = getLoadRegsAsTileViews( + b, loc, source, dName, bidGridOrder, bidGridLengths, blockSize, + kPerBlock, dPerBlock, vecDimInfo.inKPerThread, + vecDimInfo.inDPerThread, isKContiguousDim, directToLDS); + } + if (failed(maybeBufferViews)) + return failure(); + + Value wrappedSource = + transform(b, source, maybeBufferViews->gridSubTile); + + ThreadwiseReadIntoOp::create( + b, loc, vectorOfBoolShapedLike(loadBuffer), wrappedSource, + loadBuffer, + /*dynamicValidities=*/ValueRange{}, + /*extraViews=*/b.getArrayAttr({}), + /*extraIndices=*/indices, forceUnroll, true, + /*ldsTransposeConfig=*/nullptr); + + if (!globalReadOnly && rock::isGlobalPrefetchSupported(arch)) { + // add one to k_loop to prefetch next iteration + SmallVector indicesNext(indices.begin(), indices.end()); + Value one = b.createOrFold(loc, 1); + indicesNext[0] = + arith::AddIOp::create(b, loc, indicesNext[0], one).getResult(); + + // it's acceptable if the indices are out of bounds because we use + // GLOBAL_PREFETCH_B8 with Speculative Prefetch. See llvm.prefetch + // documentation in AMDGPUUsage.rst + rock::ThreadwisePrefetchOp::create(b, loc, wrappedSource, + /*extraViews=*/b.getArrayAttr({}), + /*extraIndices=*/indicesNext, + forceUnroll, true); + } + if (stageGlobalReadNew) + rock::YieldOp::create(b, loc); } - if (stageGlobalReadNew) - rock::YieldOp::create(b, loc); + } + + // For GlobalReadOnly, we're done - skip all write stages. + if (globalReadOnly) { + b.eraseOp(op); + return success(); } if (loadType == GemmLoadTileType::BypassLDS) { @@ -378,8 +411,12 @@ class LoweringBlockwiseLoadTileOp final } } else { if (!directToLDS) { + // Use distinct stage name for split-phase V write to avoid + // conflicting with K/Q LDSWrite stages in nested loops. + StringRef ldsWriteStageName = + ldsWriteFromRegs ? "VLDSWrite" : "LDSWrite"; auto [stageLDSWrite, stageLDSWriteNew] = - createOrGetStage(b, loc, "LDSWrite", parentOp); + createOrGetStage(b, loc, ldsWriteStageName, parentOp); { PatternRewriter::InsertionGuard guard(b); b.setInsertionPointToStart(&stageLDSWrite.getRegion().back()); diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index cf5a0ce439de..aac0e36cbf94 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -34,6 +34,7 @@ #include "mlir/Dialect/Rock/utility/transformMapUtils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" @@ -2787,6 +2788,65 @@ struct GridwiseAttentionAccelRewritePattern accelEmitterPtrGemm0->computeOutputConversion( rewriter, loc, accRegBufferGemm0, gemm0OutBuffer, forceUnroll); + // ================================================================ + // V PREFETCH: Issue global reads for V tile 0 before softmax. + // ================================================================ + // By issuing V global reads here (before softmax computation), + // we overlap the ~120+ instructions of softmax work with the + // global memory access latency for V, matching CK's approach. + // + // The flow is: + // 1. Issue V global reads -> register buffer [HERE, before softmax] + // 2. Softmax computation [hides load latency] + // 3. Write V from registers -> LDS [after softmax] + // 4. GEMM1 first iteration uses V from LDS [peeled iteration] + // 5. Remaining GEMM1 iters: normal load+MMA [pipelineable loop] + // + // The split is implemented using two new GemmLoadTileType values: + // - GlobalReadOnly: emits only the GlobalRead stage + // (ThreadwiseReadIntoOp: global -> register buffer, no LDS write) + // - LDSWriteFromRegs: emits only the LDSWrite stage + // (ThreadwiseCopyOp + ThreadwiseWriteAllOp: regs -> LDS, + // no global read) + // Both phases share the same flat register buffer (vPrefetchRegs). + Value ldsByteBufferV; + Value vPrefetchRegs; + layout::GridCoordinates gridCoordsGemm1; + bool prefetchFirstVTile = op.getEnableSoftmax() && !directToLDS; + if (prefetchFirstVTile) { + ldsByteBufferV = createLDSByteBuffer( + rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV); + gridCoordsGemm1 = layout::makeGxNGridLayout( + rewriter, loc, bid, zero, gemm1NBlocks, gridSize, arch, + numChiplets, splitKVConst); + gridCoordsGemm1.m_block = zero; // First V tile (block index 0) + + // Allocate a flat register buffer shared between the GlobalReadOnly + // and LDSWriteFromRegs phases. Size must match what the lowering + // computes: copyPerThread = (kPerBlock * dPerBlock) / blockSize. + int64_t vCopyPerThread = + (gemm1KPerBlock * gemm1MPerBlock) / blockSize; + vPrefetchRegs = gpuAlloc(rewriter, loc, vCopyPerThread, elemTypeV, + gpu::AddressSpace::Private); + + // Phase 1: Issue global reads for V tile 0 into register buffer. + // Only the GlobalRead stage is emitted; LDS write is deferred. + loadAndStoreGemmInputTile( + rewriter, loc, inV, + /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, + vPrefetchRegs, GemmLoadTileType::GlobalReadOnly, "m", blockSize, + elemTypeV, elemTypeVLoad, gemm1TuningParams, featuresAttr, + matrixParamsV, matrixParamsKxQ); + + // Insert a scheduling barrier to prevent the LLVM backend scheduler + // from sinking the V global loads past the softmax computation. + // Without this barrier, the scheduler moves the V loads to after + // softmax, defeating the latency hiding optimization. + // mask = none (0x0): full barrier, no instructions may cross. + amdgpu::SchedBarrierOp::create( + rewriter, loc, amdgpu::sched_barrier_opt_enum::none); + } + int64_t prePadG0M = gemm0M; if (op.getPrePadG0M().has_value()) { prePadG0M = op.getPrePadG0M().value().getSExtValue(); @@ -2975,6 +3035,33 @@ struct GridwiseAttentionAccelRewritePattern expMaxDiffRowBuffer); } + // ================================================================ + // V PREFETCH: Complete LDS write for first V tile after softmax. + // ================================================================ + // The global reads issued before softmax should have completed + // (or be very close to completing) by now, since ~120+ instructions + // of softmax computation have executed in between. Write the V + // data from the register buffer to LDS so GEMM1 can consume it. + if (prefetchFirstVTile) { + // No scheduling barrier here — we intentionally let the scheduler + // move the V LDS writes (and the preceding s_waitcnt) earlier into + // the tail of softmax. The V global loads were issued before softmax + // and should have completed by this point, so the s_waitcnt is + // essentially free and the ds_writes overlap with remaining softmax + // work, giving us even more latency hiding. + + // Phase 2: Write V data from register buffer to LDS. + // Only the LDSWrite stage is emitted; global read was already done + // before softmax in Phase 1 (GlobalReadOnly). + loadAndStoreGemmInputTile( + rewriter, loc, inV, + /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, + vPrefetchRegs, GemmLoadTileType::LDSWriteFromRegs, "m", + blockSize, elemTypeV, elemTypeVLoad, gemm1TuningParams, + featuresAttr, matrixParamsV, matrixParamsKxQ); + LDSBarrierOp::create(rewriter, loc); + } + // Emit blockwise GEMM 1. { auto gemm0Out = @@ -3023,36 +3110,29 @@ struct GridwiseAttentionAccelRewritePattern } } - Value ldsByteBufferV = createLDSByteBuffer( - rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV); - Value endG1MLoop = - rewriter.createOrFold(loc, gemm1MBlocks); - - auto gridCoordsGemm1 = layout::makeGxNGridLayout( - rewriter, loc, bid, zero, gemm1NBlocks, gridSize, arch, numChiplets, - splitKVConst); - scf::ForOp g1MLoopOp = - createMainLoop(rewriter, loc, endG1MLoop, loadType); - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(g1MLoopOp.getBody()); - Value g1MLoopIndVar = g1MLoopOp.getInductionVar(); - - gridCoordsGemm1.m_block = g1MLoopIndVar; + // ================================================================ + // V load + GEMM1 loop: Two paths depending on V prefetch. + // ================================================================ + // For non-prefetch path: allocate V LDS buffer and grid coords + // (prefetch path already did this before softmax). + if (!prefetchFirstVTile) { + ldsByteBufferV = createLDSByteBuffer( + rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV); + gridCoordsGemm1 = layout::makeGxNGridLayout( + rewriter, loc, bid, zero, gemm1NBlocks, gridSize, arch, + numChiplets, splitKVConst); + } - loadAndStoreGemmInputTile( - rewriter, loc, inV, - /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, - preAccelRegBufferVForLoad, loadType, "m", blockSize, elemTypeV, - elemTypeVLoad, gemm1TuningParams, featuresAttr, matrixParamsV, - matrixParamsKxQ); - - // Conservative barrier: Ensure all LDS writes complete - // before MMA stage reads from LDS. RockPipelinePass will remove this - // and add optimized barriers when pipelining. - LDSBarrierOp::create(rewriter, loc); - - // Emit GEMM 1. + // ---------------------------------------------------------------- + // Helper lambda: Emit GEMM1 MMA + PostProcess for a single V tile. + // Parameterized by V block index (g1MBlockIdx) to support both + // the peeled first iteration and the remaining loop iterations. + // This avoids duplicating ~100 lines of MMA + PostProcess code. + // ---------------------------------------------------------------- + auto emitGemm1Compute = + [&](Value g1MBlockIdx, GemmLoadTileType vLoadType, + Value vRegBuf) -> LogicalResult { + // Emit GEMM 1 MMA. auto computeStage = StageOp::create(rewriter, loc, "MMA"); { PatternRewriter::InsertionGuard guard(rewriter); @@ -3065,7 +3145,7 @@ struct GridwiseAttentionAccelRewritePattern } else { if (gemm1MBlocks > 1) { matrixC = createSliceOfFirstDim(rewriter, loc, matrixC, - g1MLoopIndVar); + g1MBlockIdx); } } @@ -3116,7 +3196,7 @@ struct GridwiseAttentionAccelRewritePattern ? GemmLoadTileType::BypassLDS : GemmLoadTileType::Default; blockwiseGemmAccel( - rewriter, loc, loadType, loadTypeKxD, preAccelRegBufferV, + rewriter, loc, vLoadType, loadTypeKxD, vRegBuf, preAccelRegBufferQxK, matrixC, matrixParamsV, matrixParamsKxQ, ldsTileBufferV, gemm1LDSBufferB, /*scaleA=*/nullptr, /*scaleB=*/nullptr, @@ -3126,7 +3206,9 @@ struct GridwiseAttentionAccelRewritePattern rock::YieldOp::create(rewriter, loc); } - auto postProcessStage = StageOp::create(rewriter, loc, "PostProcess"); + // Emit GEMM 1 PostProcess. + auto postProcessStage = + StageOp::create(rewriter, loc, "PostProcess"); { PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart( @@ -3138,9 +3220,9 @@ struct GridwiseAttentionAccelRewritePattern Value matrixC = accRegBufferGemm1; if (!op.getEnableSoftmax() && gemm1MBlocks > 1) { gemm1OutBufferPerG1MBlock = createSliceOfFirstDim( - rewriter, loc, gemm1OutBuffer, g1MLoopIndVar); - matrixC = - createSliceOfFirstDim(rewriter, loc, matrixC, g1MLoopIndVar); + rewriter, loc, gemm1OutBuffer, g1MBlockIdx); + matrixC = createSliceOfFirstDim(rewriter, loc, matrixC, + g1MBlockIdx); } accelEmitterPtrGemm1->computeOutputConversion( @@ -3149,7 +3231,7 @@ struct GridwiseAttentionAccelRewritePattern Value attentionOutAccBufferPerG1MBlock = attentionOutAccBuffer; if (gemm1MBlocks > 1) { attentionOutAccBufferPerG1MBlock = createSliceOfFirstDim( - rewriter, loc, attentionOutAccBuffer, g1MLoopIndVar); + rewriter, loc, attentionOutAccBuffer, g1MBlockIdx); } FailureOr maybeInvertedGemm1threadSubTileMaps = invertTransforms(rewriter, loc, @@ -3176,10 +3258,126 @@ struct GridwiseAttentionAccelRewritePattern rock::YieldOp::create(rewriter, loc); } - // Conservative barrier: Ensure all LDS reads complete before the next - // iteration writes to LDS. RockPipelinePass will remove this and add - // optimized barriers when pipelining. - LDSBarrierOp::create(rewriter, loc); + return success(); + }; // end emitGemm1Compute lambda + + if (prefetchFirstVTile) { + // ============================================================ + // PREFETCH PATH: First V tile already loaded into LDS. + // ============================================================ + // V data for tile 0 was prefetched before softmax (global read) + // and written to LDS after softmax (LDS write + barrier). + // The first GEMM1 iteration is peeled out of the loop so the + // remaining iterations form a clean, pipelineable loop. + + // --- Peeled first iteration (g1m = 0) --- + gridCoordsGemm1.m_block = zero; + // Use Default load type for the peeled iteration because the V + // data was written to LDS by the LDSWriteFromRegs phase. There is + // no BlockwiseLoadTileOp here to create an LDSRead stage, so the + // GEMM must read V directly from LDS. + // + // When double-buffering is active, preAccelRegBufferV is rank-2 + // (e.g. memref<3x2xvector<4xf16>>) because it was allocated with + // repeats=mRepeats. However, the Default load path in + // BlockwiseGemmAccelOp reads from LDS into the buffer WITHOUT + // slicing by the m-repeat loop variable. The downstream + // generateThreadwiseViewBufferA then creates a rank-1 view, + // leading to a memref.load rank mismatch. Fix: create a separate + // rank-1 register buffer for the peeled iteration. + Value peeledVRegBuf = preAccelRegBufferV; + if (doubleBuffering) { + auto [peeledVForLoad, peeledVBuf] = + createRegInterrimBufferForAccel( + rewriter, loc, accelParamsGemm1.argTypeA, + accelParamsGemm1.kBasePerThread, + /*repeats=*/1, directToLDS); + peeledVRegBuf = peeledVBuf; + } + if (failed(emitGemm1Compute(zero, GemmLoadTileType::Default, + peeledVRegBuf))) + return failure(); + + // --- Remaining iterations (g1m = 1..gemm1MBlocks-1) --- + // These form a standard pipelineable loop with V loads. + if (gemm1MBlocks > 1) { + LDSBarrierOp::create(rewriter, loc); + + Value startG1M = + rewriter.createOrFold(loc, 1); + Value endG1MLoop = + rewriter.createOrFold(loc, gemm1MBlocks); + Value oneVal = + rewriter.createOrFold(loc, 1); + scf::ForOp g1MLoopOp = scf::ForOp::create( + rewriter, loc, startG1M, endG1MLoop, oneVal); + // Mark loop for pipelining + bool g1DoubleBuffering = + loadType == GemmLoadTileType::DoubleBuffer || + loadType == GemmLoadTileType::DirectToLDSDoubleBuffer; + int64_t g1InitiationInterval = g1DoubleBuffering ? 1 : 2; + g1MLoopOp->setAttr( + PipelineAttr::getMnemonic(), + rock::PipelineAttr::get(rewriter.getContext(), + g1InitiationInterval)); + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(g1MLoopOp.getBody()); + Value g1MLoopIndVar = g1MLoopOp.getInductionVar(); + + gridCoordsGemm1.m_block = g1MLoopIndVar; + + // Normal V tile load (global -> regs -> LDS) + loadAndStoreGemmInputTile( + rewriter, loc, inV, + /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, + preAccelRegBufferVForLoad, loadType, "m", blockSize, + elemTypeV, elemTypeVLoad, gemm1TuningParams, featuresAttr, + matrixParamsV, matrixParamsKxQ); + + // Conservative barrier before MMA + LDSBarrierOp::create(rewriter, loc); + + if (failed(emitGemm1Compute(g1MLoopIndVar, loadType, + preAccelRegBufferV))) + return failure(); + + // Conservative barrier before next iteration's LDS writes + LDSBarrierOp::create(rewriter, loc); + } + } + } else { + // ============================================================ + // ORIGINAL PATH: No V prefetch (softmax disabled). + // ============================================================ + Value endG1MLoop = + rewriter.createOrFold(loc, gemm1MBlocks); + scf::ForOp g1MLoopOp = + createMainLoop(rewriter, loc, endG1MLoop, loadType); + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(g1MLoopOp.getBody()); + Value g1MLoopIndVar = g1MLoopOp.getInductionVar(); + + gridCoordsGemm1.m_block = g1MLoopIndVar; + + loadAndStoreGemmInputTile( + rewriter, loc, inV, + /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, + preAccelRegBufferVForLoad, loadType, "m", blockSize, elemTypeV, + elemTypeVLoad, gemm1TuningParams, featuresAttr, matrixParamsV, + matrixParamsKxQ); + + // Conservative barrier before MMA + LDSBarrierOp::create(rewriter, loc); + + if (failed(emitGemm1Compute(g1MLoopIndVar, loadType, + preAccelRegBufferV))) + return failure(); + + // Conservative barrier before next iteration's LDS writes + LDSBarrierOp::create(rewriter, loc); + } } } } @@ -3725,7 +3923,8 @@ void RockGridwiseGemmToBlockwisePass::runOnOperation() { target.addLegalDialect(); + scf::SCFDialect, math::MathDialect, + amdgpu::AMDGPUDialect>(); target.addLegalOp(); RewritePatternSet patterns(ctx); diff --git a/mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp b/mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp index 5d17b2dc18b8..025bbef05334 100644 --- a/mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp @@ -818,6 +818,16 @@ LogicalResult ThreadwiseReadIntoRewritePattern::matchAndRewrite( Value validity = loadLoop.getValidity(/*domain=*/0); Value destIndex = loadLoop.getLowerCoords(/*domain=*/1)[extraIdxCount]; + // Compute the full set of coordinates needed to index into `dest`. + // Domain 1 lower coords have (extraIdxCount + 1) elements, but `dest` + // may have fewer dimensions (dstRank). The last dstRank elements of the + // domain-1 coords correspond to the dest buffer dimensions. + int64_t dstRank = dstBufferType.getRank(); + Block::BlockArgListType allDestCoords = loadLoop.getLowerCoords(/*domain=*/1); + size_t dropCount = allDestCoords.size() - dstRank; + SmallVector destCoords(allDestCoords.begin() + dropCount, + allDestCoords.end()); + for (Value dynamicValidity : adaptor.getDynamicValidities()) { Value validityHere = vector::ExtractOp::create( b, loc, b.getI1Type(), dynamicValidity, destIndex, @@ -845,7 +855,7 @@ LogicalResult ThreadwiseReadIntoRewritePattern::matchAndRewrite( Value loaded = GlobalLoadOp::create(b, loc, loadType, buffer, validity, loadLoop.getLowerCoords(/*domain=*/0), needs64BitIdx); - InBoundsStoreOp::create(b, loc, loaded, dest, destIndex); + InBoundsStoreOp::create(b, loc, loaded, dest, destCoords); } else if (isGlobalToLDS) { int64_t loadTypeByteWidth = getByteWidth(loadType); if (loadTypeByteWidth != 16 && loadTypeByteWidth != 4) @@ -924,11 +934,14 @@ LogicalResult ThreadwiseReadIntoRewritePattern::matchAndRewrite( } if (!isDstVectorBuffer && !isSrcVectorBuffer) { - InBoundsStoreOp::create(b, loc, ifb.getResult(0), dest, destIndex); + InBoundsStoreOp::create(b, loc, ifb.getResult(0), dest, destCoords); } else if (!isDstVectorBuffer && isSrcVectorBuffer) { - destIndex = arith::MulIOp::create( - b, loc, destIndex, ConstantIndexOp::create(b, loc, vectorSrcLen)); - InBoundsStoreOp::create(b, loc, ifb.getResult(0), dest, destIndex); + SmallVector scaledDestCoords(destCoords); + scaledDestCoords.back() = arith::MulIOp::create( + b, loc, scaledDestCoords.back(), + ConstantIndexOp::create(b, loc, vectorSrcLen)); + InBoundsStoreOp::create(b, loc, ifb.getResult(0), dest, + scaledDestCoords); } else { // Destination is a vector buffer Value idx = loadLoop.getLowerCoords(/*domain=*/1)[extraIdxCount]; @@ -944,7 +957,10 @@ LogicalResult ThreadwiseReadIntoRewritePattern::matchAndRewrite( idx = b.createOrFold(loc, idx, srcVecLenVal); } if (vectorSrcLen == vectorDstLen) { - memref::StoreOp::create(b, loc, ifb.getResult(0), dest, idx); + SmallVector vecDestCoords(destCoords); + vecDestCoords.back() = idx; + memref::StoreOp::create(b, loc, ifb.getResult(0), dest, + vecDestCoords); } else { // When the vector types differ, we need to find the gcd // to make it work for the both source and dest. @@ -973,15 +989,16 @@ LogicalResult ThreadwiseReadIntoRewritePattern::matchAndRewrite( Value storeVecStart = b.createOrFold( loc, elementOffset, b.createOrFold(loc, vectorDstLen)); + SmallVector vecDestCoords(destCoords); + vecDestCoords.back() = storeVecStart; Value storeVec = memref::LoadOp::create(b, loc, dstVectorType, dest, - ValueRange{storeVecStart}); + vecDestCoords); Value storeSliceStart = b.createOrFold( loc, elementOffset, b.createOrFold(loc, vectorDstLen)); Value newStoreVec = InsertSliceOp::create( b, loc, dstVectorType, value, storeVec, storeSliceStart); - memref::StoreOp::create(b, loc, newStoreVec, dest, - ValueRange{storeVecStart}); + memref::StoreOp::create(b, loc, newStoreVec, dest, vecDestCoords); } } } From e42fe55578ff1d139faed4d17d802e71100c9a76 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Mon, 23 Feb 2026 14:57:10 +0000 Subject: [PATCH 02/18] Test interleaving --- .../include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 24 +++++++++++ .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 18 +++++++- mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp | 8 ++++ .../Transforms/BlockwiseGemmToThreadwise.cpp | 32 +++++++++++--- .../Transforms/GridwiseGemmToBlockwise.cpp | 9 ++++ .../Dialect/Rock/Transforms/RockPipeline.cpp | 43 +++++++++++++++---- 6 files changed, 118 insertions(+), 16 deletions(-) diff --git a/external/llvm-project/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/external/llvm-project/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 27009cd0961e..6662a0fadb4d 100644 --- a/external/llvm-project/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/external/llvm-project/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -852,6 +852,30 @@ def AMDGPU_SchedBarrierOp : }]; } +def AMDGPU_IglpOptOp : + AMDGPU_Op<"iglp_opt">, + Arguments<(ins I32Attr:$variant)> + { + let summary = "Hint to the AMDGPU instruction scheduler to apply an IGLP strategy"; + let description = [{ + `amdgpu.iglp_opt` provides a hint to the LLVM AMDGPU backend's instruction + scheduler to apply a specific Instruction-Group-Level-Parallelism (IGLP) + scheduling strategy. The `variant` attribute selects which strategy to use: + + - 0: `MFMASmallGemmOpt` — interleaves DS_READ with MFMA for small GEMMs. + - 1: `MFMASmallGemmSingleWaveOpt` — single-wave GEMM with DS_READ, V_PERM, + DS_WRITE, and VMEM interleaving. + - 2: `MFMAExpInterleave` — interleaves MFMA with transcendental (EXP) + instructions, with complex DAG analysis. + - 3: `MFMAExpSimpleInterleave` — simple TRANS/MFMA interleaving. + + This lowers to `llvm.amdgcn.iglp_opt(i32 variant)`. + }]; + let assemblyFormat = [{ + $variant attr-dict + }]; +} + def AMDGPU_MemoryCounterWaitOp : AMDGPU_Op<"memory_counter_wait">, Arguments<(ins diff --git a/external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index e9da09b9c8f6..47f3e2570224 100644 --- a/external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -613,6 +613,21 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern { } }; +struct IglpOptOpLowering : public ConvertOpToLLVMPattern { + IglpOptOpLowering(const LLVMTypeConverter &converter, Chipset chipset) + : ConvertOpToLLVMPattern(converter), chipset(chipset) {} + + Chipset chipset; + + LogicalResult + matchAndRewrite(IglpOptOp op, IglpOptOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, static_cast(op.getVariant())); + return success(); + } +}; + } // namespace /// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL @@ -2219,7 +2234,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RawBufferOpLowering, AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering, - SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, + SchedBarrierOpLowering, IglpOptOpLowering, MFMAOpLowering, + ScaledMFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering, GatherToLDSOpLowering, diff --git a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp index d258bdc6e916..fed61f75e50c 100644 --- a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp +++ b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp @@ -240,6 +240,14 @@ void rock::buildKernelPipeline(OpPassManager &pm, funcPm.addPass(rock::createRockThreadwiseGemmLoweringPass()); funcPm.addPass(rock::createRockAnalyzeMemoryUsePass()); funcPm.addPass(rock::createRockSugarToLoopsPass()); + // Re-run the pipeline pass to coalesce LDS barriers inserted by + // ReuseLDS. This must run AFTER SugarToLoops so that + // TransformingForOps have been unrolled and the individual + // InBoundsStoreOps are visible. PushBarrierDown can then move + // barriers past LDS stores, and RemoveBackToBack removes adjacent + // barriers. This eliminates redundant s_waitcnt lgkmcnt(0) + // between independent LDS writes (e.g., in softmax reductions). + funcPm.addPass(rock::createRockPipelinePass()); funcPm.addPass(rock::createRockCleanMathPass()); math::MathExtendToSupportedTypesOptions extendToLLVMTypesOptions; extendToLLVMTypesOptions.extraTypeStrs = {"f16"}; diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp index 3a49c09fcb9b..3b0d8967e8bd 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp @@ -928,8 +928,13 @@ struct BlockwiseReduceRewritePattern } } else { if (rMethod == ReduceMethod::Sum) { + // Use -0.0 (negative zero) instead of +0.0. In IEEE 754, -0.0 is the + // true additive identity: fadd(-0.0, x) = x for ALL x (including -0.0 + // and NaN). LLVM can fold `fadd -0.0, x → x`, eliminating the + // redundant `v_add_f32 v, 0, v` that +0.0 generates via + // llvm.vector.reduce.fadd. return createConstantFloatOp(rewriter, op.getLoc(), elementType, - elementType, 0.0); + elementType, -0.0f); } else { // Op verifier gurantees this. assert(rMethod == ReduceMethod::Max); @@ -951,6 +956,10 @@ struct BlockwiseReduceRewritePattern if (!isa(acc.getType()) && isa(input.getType())) { // This means accumulator is a scalar type and input is a vector type, // therefore its a elementwise reduction between two operands. + // Pass `acc` as the accumulator to vector::ReductionOp so that the + // scalar accumulation is folded into the reduction intrinsic rather + // than emitting a separate arith::AddFOp / arith::MaximumFOp. + // This avoids redundant `fadd X, 0.0` when acc is the identity. vector::CombiningKind kind; if (rMethod == ReduceMethod::Sum) { kind = vector::CombiningKind::ADD; @@ -964,7 +973,7 @@ struct BlockwiseReduceRewritePattern kind = vector::CombiningKind::MAXNUMF; } } - input = vector::ReductionOp::create(builder, loc, kind, input); + return vector::ReductionOp::create(builder, loc, kind, input, acc); } if (rMethod == ReduceMethod::Sum) { @@ -1452,11 +1461,20 @@ struct BlockwiseReduceRewritePattern readLoop.getLowerCoords(/*domain=*/0); Value ldVal = InBoundsLoadOp::create( rewriter, loc, elemType, workspaceLDSBuffer, ldsCoords); - Value accVal = InBoundsLoadOp::create( - rewriter, loc, elemType, accReg, zeroConstantOp); - Value reduced = createReducingOp(op, ldVal, accVal, rewriter); - InBoundsStoreOp::create(rewriter, loc, reduced, accReg, - zeroConstantOp); + if (i == 0) { + // First iteration: store the loaded value directly to the + // accumulator. This avoids a redundant reduction with the + // identity element (e.g., `0.0 + x` for sum, `max(-inf, x)` + // for max). + InBoundsStoreOp::create(rewriter, loc, ldVal, accReg, + zeroConstantOp); + } else { + Value accVal = InBoundsLoadOp::create( + rewriter, loc, elemType, accReg, zeroConstantOp); + Value reduced = createReducingOp(op, ldVal, accVal, rewriter); + InBoundsStoreOp::create(rewriter, loc, reduced, accReg, + zeroConstantOp); + } } } diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index aac0e36cbf94..bea905794d07 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -2845,6 +2845,15 @@ struct GridwiseAttentionAccelRewritePattern // mask = none (0x0): full barrier, no instructions may cross. amdgpu::SchedBarrierOp::create( rewriter, loc, amdgpu::sched_barrier_opt_enum::none); + + // Enable IGLP (Instruction-Group-Level Parallelism) scheduling. + // The softmax section produces v_exp_f32 (transcendental unit) and + // the subsequent S*V GEMM produces v_mfma (matrix core unit). + // These two execution units can operate in parallel. Variant 2 + // (MFMAExpInterleave) analyzes the dependency graph and creates + // scheduling groups that interleave TRANS and MFMA instructions, + // hiding transcendental latency behind matrix computation. + amdgpu::IglpOptOp::create(rewriter, loc, /*variant=*/2); } int64_t prePadG0M = gemm0M; diff --git a/mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp b/mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp index 17e5e2534056..c2e224cac449 100644 --- a/mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp @@ -22,6 +22,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/Rock/IR/Rock.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Rock/Passes.h" #include "mlir/Dialect/Rock/Transforms/RockMultibuffer.h" #include "mlir/Dialect/Rock/utility/loweringUtils.h" @@ -165,17 +166,32 @@ struct PushBarrierDownRewritePattern if (!nextOp->getNextNode()) return failure(); + // Don't push past another barrier — RemoveBackToBack handles that. + // Without this check, two adjacent barriers would swap endlessly. + if (isa(nextOp)) + return failure(); + // We assume that operations that have a body may modify LDS if (nextOp->getNumRegions() > 0 && !dyn_cast(nextOp)) return failure(); bool moveDown = true; - // Make sure that the "nextOp" doesn't modify LDS + // Check if the operation accesses LDS. + // We can move past LDS store-only operations because independent + // writes don't need ordering between them — the next barrier will + // ensure all writes complete before any subsequent reads. + // We must stop at LDS reads. + // We recognize store ops both before SugarToLoops (InBoundsStoreOp) + // and after (memref::StoreOp, vector::TransferWriteOp, vector::StoreOp). for (Value operand : nextOp->getOperands()) { auto maybeAlloc = rock::findGpuAlloc(operand); if (succeeded(maybeAlloc) && - getAddressSpace(*maybeAlloc) == AddressSpace::Workgroup) - moveDown = false; + getAddressSpace(*maybeAlloc) == AddressSpace::Workgroup) { + // This operation touches LDS. Check if it's a write-only op. + if (!isa(nextOp)) + moveDown = false; + } } if (moveDown) { @@ -815,12 +831,23 @@ void RockPipeline::runOnOperation() { if (failed( applyPatternsGreedily(func, std::move(patternsRemoveStages)))) return signalPassFailure(); - - RewritePatternSet patternsBackToBack(&getContext()); - patternsBackToBack.add(ctx); - if (failed(applyPatternsGreedily(func, std::move(patternsBackToBack)))) - return signalPassFailure(); } } } + + // Always run barrier coalescing, even when there are no loops to pipeline. + // This handles barriers inserted by other passes (e.g., softmax reductions + // in BlockwiseGemmToThreadwise) that are only exposed after SugarToLoops + // unrolls TransformingForOps into individual memref.store/load ops. + { + RewritePatternSet patternsPushBarrier(&getContext()); + patternsPushBarrier.add(ctx); + if (failed(applyPatternsGreedily(func, std::move(patternsPushBarrier)))) + return signalPassFailure(); + + RewritePatternSet patternsBackToBack(&getContext()); + patternsBackToBack.add(ctx); + if (failed(applyPatternsGreedily(func, std::move(patternsBackToBack)))) + return signalPassFailure(); + } } From d963949eee48c557eabcd8c47d24a73918ff7cfd Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Mon, 23 Feb 2026 23:32:24 +0000 Subject: [PATCH 03/18] Remove additional barriers --- .../Transforms/BlockwiseGemmToThreadwise.cpp | 91 +++++++++++-------- 1 file changed, 55 insertions(+), 36 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp index 3b0d8967e8bd..397639e6ad00 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp @@ -1478,48 +1478,67 @@ struct BlockwiseReduceRewritePattern } } - // Write the fully reduced value back to LDS at [nrtid, 0]. - // All threads with the same nrtid compute the same value, - // so concurrent writes to the same location are safe. - { + // After the branchless reduction, every thread with the same + // nrtid holds the identical fully-reduced value in accReg. + // + // When each thread owns exactly 1 non-reduction position + // (inputThreadSubTile2dShape[nrDim] == 1), the output register + // only needs this single reduced value broadcast to all its + // elements. We can skip the LDS write-back + barrier + + // broadcast-read and instead fill the output register directly + // from the register, eliminating one s_barrier per reduction. + if (inputThreadSubTile2dShape[nrDim] == 1) { Value reducedVal = InBoundsLoadOp::create( rewriter, loc, elemType, accReg, zeroConstantOp); - SmallVector writeInits{nrtid, zeroConstantOp, - zeroConstantOp}; - SmallVector writeBounds{1, 1, 1}; - SmallVector writeStrides{1, 1, 1}; - - TransformingForOp writeLoop = TransformingForOp::create( - rewriter, loc, ArrayRef{writeInits}, - ArrayRef{threadToLDSViewTrs}, - ArrayRef(writeBounds), ArrayRef(writeStrides), - /*forceUnroll=*/true, /*useIndexDiffs=*/true); + FillOp::create(rewriter, loc, outputReg, reducedVal); + if (op.getExtraOutViewAttr()) { + FillOp::create(rewriter, loc, op.getExtraOut(), reducedVal); + } + } else { + // General case: thread has multiple non-reduction rows. + // Only nrtid's row was reduced; other rows' results live in + // LDS (written by other threads). Fall back to LDS round-trip. { - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(writeLoop.getBody()); - Block::BlockArgListType ldsCoords = - writeLoop.getLowerCoords(/*domain=*/0); - InBoundsStoreOp::create(rewriter, loc, reducedVal, - workspaceLDSBuffer, ldsCoords); + Value reducedVal = InBoundsLoadOp::create( + rewriter, loc, elemType, accReg, zeroConstantOp); + SmallVector writeInits{nrtid, zeroConstantOp, + zeroConstantOp}; + SmallVector writeBounds{1, 1, 1}; + SmallVector writeStrides{1, 1, 1}; + + TransformingForOp writeLoop = TransformingForOp::create( + rewriter, loc, ArrayRef{writeInits}, + ArrayRef{threadToLDSViewTrs}, + ArrayRef(writeBounds), + ArrayRef(writeStrides), + /*forceUnroll=*/true, /*useIndexDiffs=*/true); + { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(writeLoop.getBody()); + Block::BlockArgListType ldsCoords = + writeLoop.getLowerCoords(/*domain=*/0); + InBoundsStoreOp::create(rewriter, loc, reducedVal, + workspaceLDSBuffer, ldsCoords); + } } - } - LDSBarrierOp::create(rewriter, loc); - ArrayAttr reducedldsViewArrayAttr = createLDSWorkspaceView( - loc, rewriter, inputViewArrayAttr, axis, /*makeRDimZero-*/ true, - partialRegTensorShape[rDim]); - ThreadwiseReadIntoOp::create(rewriter, loc, workspaceLDSBuffer, - outputReg, reducedldsViewArrayAttr, - /*extraIndices=*/ValueRange{tid}, true, - false); - if (ArrayAttr outputViewArrayAttr = op.getExtraOutViewAttr()) { - ArrayAttr reducedldsViewArrayAttr2 = createLDSWorkspaceView( - loc, rewriter, outputViewArrayAttr, axis, + LDSBarrierOp::create(rewriter, loc); + ArrayAttr reducedldsViewArrayAttr = createLDSWorkspaceView( + loc, rewriter, inputViewArrayAttr, axis, /*makeRDimZero-*/ true, partialRegTensorShape[rDim]); - ThreadwiseReadIntoOp::create( - rewriter, loc, workspaceLDSBuffer, op.getExtraOut(), - reducedldsViewArrayAttr2, - /*extraIndices=*/ValueRange{tid}, true, false); + ThreadwiseReadIntoOp::create(rewriter, loc, workspaceLDSBuffer, + outputReg, reducedldsViewArrayAttr, + /*extraIndices=*/ValueRange{tid}, true, + false); + if (ArrayAttr outputViewArrayAttr = op.getExtraOutViewAttr()) { + ArrayAttr reducedldsViewArrayAttr2 = createLDSWorkspaceView( + loc, rewriter, outputViewArrayAttr, axis, + /*makeRDimZero-*/ true, partialRegTensorShape[rDim]); + ThreadwiseReadIntoOp::create( + rewriter, loc, workspaceLDSBuffer, op.getExtraOut(), + reducedldsViewArrayAttr2, + /*extraIndices=*/ValueRange{tid}, true, false); + } } } } From 004dfb3c0c51e9debccf6a6dd85eb50741b6737d Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Mon, 23 Feb 2026 23:50:44 +0000 Subject: [PATCH 04/18] Hoist V loads --- .../Transforms/GridwiseGemmToBlockwise.cpp | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index bea905794d07..42828fe5167b 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -3017,6 +3017,33 @@ struct GridwiseAttentionAccelRewritePattern gemm0MNExpThreadwiseView, gemm0MNMaxThreadwiseView, maxRowBuffer); + // ================================================================ + // V PREFETCH Phase 2 (hoisted): Write V data from regs to LDS + // before the sum reduction so that the sum reduction's internal + // LDS barrier also synchronises the V tile writes. This + // eliminates the dedicated V-tile LDS barrier that was + // previously required after the sum reduction, saving one + // s_barrier per iteration. + // + // Safety: AnnotateLiveness + ReuseLDS will see that V's live + // range (write here → read during GEMM1) overlaps with the sum + // workspace's live range, so they will NOT be aliased. The + // max-reduction workspace is already dead, so it CAN be + // aliased with V. The LDS increase is small and does not + // affect occupancy (VGPR-limited, not LDS-limited). + // ================================================================ + if (prefetchFirstVTile) { + loadAndStoreGemmInputTile( + rewriter, loc, inV, + /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, + vPrefetchRegs, GemmLoadTileType::LDSWriteFromRegs, "m", + blockSize, elemTypeV, elemTypeVLoad, gemm1TuningParams, + featuresAttr, matrixParamsV, matrixParamsKxQ); + // No LDSBarrierOp here — the barrier inside the sum + // BlockwiseBroadcastReduceOp (below) will synchronise both + // the V LDS writes and the softmax partial-sum LDS writes. + } + // Softmax sum reduction Value ldsReductionWorkspaceByteSecondBuffer = createLDSByteBuffer( rewriter, loc, reductionWorkspaceSize, elemTypeSoftmax); @@ -3044,33 +3071,6 @@ struct GridwiseAttentionAccelRewritePattern expMaxDiffRowBuffer); } - // ================================================================ - // V PREFETCH: Complete LDS write for first V tile after softmax. - // ================================================================ - // The global reads issued before softmax should have completed - // (or be very close to completing) by now, since ~120+ instructions - // of softmax computation have executed in between. Write the V - // data from the register buffer to LDS so GEMM1 can consume it. - if (prefetchFirstVTile) { - // No scheduling barrier here — we intentionally let the scheduler - // move the V LDS writes (and the preceding s_waitcnt) earlier into - // the tail of softmax. The V global loads were issued before softmax - // and should have completed by this point, so the s_waitcnt is - // essentially free and the ds_writes overlap with remaining softmax - // work, giving us even more latency hiding. - - // Phase 2: Write V data from register buffer to LDS. - // Only the LDSWrite stage is emitted; global read was already done - // before softmax in Phase 1 (GlobalReadOnly). - loadAndStoreGemmInputTile( - rewriter, loc, inV, - /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, - vPrefetchRegs, GemmLoadTileType::LDSWriteFromRegs, "m", - blockSize, elemTypeV, elemTypeVLoad, gemm1TuningParams, - featuresAttr, matrixParamsV, matrixParamsKxQ); - LDSBarrierOp::create(rewriter, loc); - } - // Emit blockwise GEMM 1. { auto gemm0Out = From 6eab4b9a14ce3f36e3db3d3184444c2597721277 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Fri, 27 Feb 2026 14:38:35 +0000 Subject: [PATCH 05/18] Fix LDS usage error --- .../Transforms/GridwiseGemmToBlockwise.cpp | 82 +++++++++++++++---- 1 file changed, 68 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index 42828fe5167b..ad7dbffb0c96 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -2813,6 +2813,48 @@ struct GridwiseAttentionAccelRewritePattern Value vPrefetchRegs; layout::GridCoordinates gridCoordsGemm1; bool prefetchFirstVTile = op.getEnableSoftmax() && !directToLDS; + + // Decide whether to hoist Phase 2 (V regs -> LDS write) before the + // sum reduction. Hoisting saves one LDS barrier by piggybacking on + // the sum reduction's internal barrier, but it makes V's LDS live + // range overlap with the sum-reduction workspace, preventing + // ReuseLDS from aliasing them. + // + // ReuseLDS uses greedy graph coloring that packs non-interfering + // buffers (like K and V) into merged color groups. When V interferes + // with sum_ws (due to hoisting), V gets displaced within the merged + // group by sum_ws's size, growing the group by exactly sumWSBytes. + // So: hoisted_total ≈ non_hoisted_peak + sumWSBytes. + // + // The non-hoisted peak is the max concurrent LDS from GEMM0 + // (Q+K buffers) or GEMM1 (V+gemm1_B buffers). We check if adding + // sumWSBytes would exceed the hardware LDS limit. + bool hoistVPhase2 = false; + if (prefetchFirstVTile) { + int64_t maxLDS = archInfo.maxSharedMemPerWG; + int64_t sumWSBytes = + getPackedByteSize(reductionWorkspaceSize, elemTypeSoftmax); + int64_t gemm0PeakBytes = + getPackedByteSize(ldsByteBufferQSize, elemTypeQ) + + getPackedByteSize(gemm0KPerBlock * gemm0MPerBlock, elemTypeK); + int64_t gemm1PeakBytes = + getPackedByteSize(gemm1KPerBlock * gemm1MPerBlock, elemTypeV) + + getPackedByteSize(gemm1LDSByteBufferBSize, elemTypeV); + // The base peak without hoisting is determined by the larger of + // GEMM0 and GEMM1 concurrent buffer sets. + int64_t nonHoistedPeak = std::max(gemm0PeakBytes, gemm1PeakBytes); + // Hoisting adds sumWSBytes on top (V displaced in merged color). + int64_t hoistedTotal = nonHoistedPeak + sumWSBytes; + hoistVPhase2 = hoistedTotal <= maxLDS; + LLVM_DEBUG(llvm::dbgs() + << "V prefetch Phase 2 hoist decision: " + << (hoistVPhase2 ? "HOIST" : "DEFER") + << " (hoistedTotal=" << hoistedTotal << ", max=" << maxLDS + << ", sumWS=" << sumWSBytes + << ", gemm0=" << gemm0PeakBytes + << ", gemm1=" << gemm1PeakBytes << ")\n"); + } + if (prefetchFirstVTile) { ldsByteBufferV = createLDSByteBuffer( rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV); @@ -3018,21 +3060,14 @@ struct GridwiseAttentionAccelRewritePattern gemm0MNMaxThreadwiseView, maxRowBuffer); // ================================================================ - // V PREFETCH Phase 2 (hoisted): Write V data from regs to LDS - // before the sum reduction so that the sum reduction's internal - // LDS barrier also synchronises the V tile writes. This - // eliminates the dedicated V-tile LDS barrier that was - // previously required after the sum reduction, saving one - // s_barrier per iteration. - // - // Safety: AnnotateLiveness + ReuseLDS will see that V's live - // range (write here → read during GEMM1) overlaps with the sum - // workspace's live range, so they will NOT be aliased. The - // max-reduction workspace is already dead, so it CAN be - // aliased with V. The LDS increase is small and does not - // affect occupancy (VGPR-limited, not LDS-limited). + // V PREFETCH Phase 2 (hoisted path): Write V data from regs to + // LDS before the sum reduction so that the sum reduction's + // internal LDS barrier also synchronises the V tile writes, + // saving one s_barrier per outer-loop iteration. + // Only used when the LDS budget can accommodate V and the + // sum-reduction workspace being live simultaneously. // ================================================================ - if (prefetchFirstVTile) { + if (prefetchFirstVTile && hoistVPhase2) { loadAndStoreGemmInputTile( rewriter, loc, inV, /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, @@ -3069,6 +3104,25 @@ struct GridwiseAttentionAccelRewritePattern updateRowSum(rewriter, loc, gemm0SumThreadwiseView, gemm0MaxThreadwiseView, sumRowBuffer, maxRowBuffer, expMaxDiffRowBuffer); + + // ================================================================ + // V PREFETCH Phase 2 (deferred path): Write V data from regs to + // LDS after the sum reduction. This avoids V's LDS live range + // overlapping with the sum-reduction workspace, allowing + // ReuseLDS to alias them and stay within the hardware LDS budget. + // Costs one extra s_barrier vs the hoisted path. + // Phase 1 (global reads -> regs, before softmax) still hides the + // global memory latency across the entire softmax computation. + // ================================================================ + if (prefetchFirstVTile && !hoistVPhase2) { + loadAndStoreGemmInputTile( + rewriter, loc, inV, + /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, + vPrefetchRegs, GemmLoadTileType::LDSWriteFromRegs, "m", + blockSize, elemTypeV, elemTypeVLoad, gemm1TuningParams, + featuresAttr, matrixParamsV, matrixParamsKxQ); + LDSBarrierOp::create(rewriter, loc); + } } // Emit blockwise GEMM 1. From b2d7ff7f745f1b6ab61664ca585a1a9598256ba2 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Fri, 27 Feb 2026 15:15:37 +0000 Subject: [PATCH 06/18] Fix barrier issue --- mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp | 9 ++---- .../Transforms/GridwiseGemmToBlockwise.cpp | 32 +++++++++++++++---- .../Dialect/Rock/Transforms/RockPipeline.cpp | 12 ++----- 3 files changed, 30 insertions(+), 23 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp index fed61f75e50c..84e7c859e577 100644 --- a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp +++ b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp @@ -240,13 +240,8 @@ void rock::buildKernelPipeline(OpPassManager &pm, funcPm.addPass(rock::createRockThreadwiseGemmLoweringPass()); funcPm.addPass(rock::createRockAnalyzeMemoryUsePass()); funcPm.addPass(rock::createRockSugarToLoopsPass()); - // Re-run the pipeline pass to coalesce LDS barriers inserted by - // ReuseLDS. This must run AFTER SugarToLoops so that - // TransformingForOps have been unrolled and the individual - // InBoundsStoreOps are visible. PushBarrierDown can then move - // barriers past LDS stores, and RemoveBackToBack removes adjacent - // barriers. This eliminates redundant s_waitcnt lgkmcnt(0) - // between independent LDS writes (e.g., in softmax reductions). + // Re-run the pipeline pass to remove back-to-back LDS barriers + // that may appear after SugarToLoops unrolls TransformingForOps. funcPm.addPass(rock::createRockPipelinePass()); funcPm.addPass(rock::createRockCleanMathPass()); math::MathExtendToSupportedTypesOptions extendToLLVMTypesOptions; diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index ad7dbffb0c96..f9490364cdb7 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -3060,12 +3060,19 @@ struct GridwiseAttentionAccelRewritePattern gemm0MNMaxThreadwiseView, maxRowBuffer); // ================================================================ - // V PREFETCH Phase 2 (hoisted path): Write V data from regs to - // LDS before the sum reduction so that the sum reduction's - // internal LDS barrier also synchronises the V tile writes, - // saving one s_barrier per outer-loop iteration. - // Only used when the LDS budget can accommodate V and the - // sum-reduction workspace being live simultaneously. + // V PREFETCH Phase 2 (hoisted): Write V data from regs to LDS + // before the sum reduction so that the sum reduction's internal + // LDS barrier also synchronises the V tile writes. This + // eliminates the dedicated V-tile LDS barrier that was + // previously required after the sum reduction, saving one + // s_barrier per iteration. + // + // Safety: AnnotateLiveness + ReuseLDS will see that V's live + // range (write here -> read during GEMM1) overlaps with the sum + // workspace's live range, so they will NOT be aliased. The + // max-reduction workspace is already dead, so it CAN be + // aliased with V. The LDS increase is small and does not + // affect occupancy (VGPR-limited, not LDS-limited). // ================================================================ if (prefetchFirstVTile && hoistVPhase2) { loadAndStoreGemmInputTile( @@ -3329,7 +3336,8 @@ struct GridwiseAttentionAccelRewritePattern // PREFETCH PATH: First V tile already loaded into LDS. // ============================================================ // V data for tile 0 was prefetched before softmax (global read) - // and written to LDS after softmax (LDS write + barrier). + // and written to LDS before the sum reduction (LDS write synced + // by sum reduction's internal barrier). // The first GEMM1 iteration is peeled out of the loop so the // remaining iterations form a clean, pipelineable loop. @@ -3357,6 +3365,16 @@ struct GridwiseAttentionAccelRewritePattern /*repeats=*/1, directToLDS); peeledVRegBuf = peeledVBuf; } + // Barrier: ensure all threads have finished writing the softmax + // exp values to LDS (storeGemmInputTile above) before GEMM1 + // reads from them. Only needed when the softmax exp actually + // goes through LDS (!doBypassLDSSecondGemm). When LDS is + // bypassed, softmax exp stays in registers and V is already + // synced by either the sum reduction's internal barrier + // (hoisted path) or the deferred V Phase 2 barrier. + if (!doBypassLDSSecondGemm) + LDSBarrierOp::create(rewriter, loc); + if (failed(emitGemm1Compute(zero, GemmLoadTileType::Default, peeledVRegBuf))) return failure(); diff --git a/mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp b/mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp index c2e224cac449..281730ce140e 100644 --- a/mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp @@ -835,16 +835,10 @@ void RockPipeline::runOnOperation() { } } - // Always run barrier coalescing, even when there are no loops to pipeline. - // This handles barriers inserted by other passes (e.g., softmax reductions - // in BlockwiseGemmToThreadwise) that are only exposed after SugarToLoops - // unrolls TransformingForOps into individual memref.store/load ops. + // Always run back-to-back barrier removal, even when there are no loops + // to pipeline. This handles barriers that become adjacent after other + // passes (e.g., after SugarToLoops unrolls TransformingForOps). { - RewritePatternSet patternsPushBarrier(&getContext()); - patternsPushBarrier.add(ctx); - if (failed(applyPatternsGreedily(func, std::move(patternsPushBarrier)))) - return signalPassFailure(); - RewritePatternSet patternsBackToBack(&getContext()); patternsBackToBack.add(ctx); if (failed(applyPatternsGreedily(func, std::move(patternsBackToBack)))) From e2ffd8f4ad9d4a90009dad912c427ce30876a0ea Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Fri, 27 Feb 2026 20:00:22 +0000 Subject: [PATCH 07/18] Reduction fix --- .../Transforms/BlockwiseGemmToThreadwise.cpp | 101 ++++++++---------- 1 file changed, 43 insertions(+), 58 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp index 397639e6ad00..ffa855fa6121 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp @@ -956,10 +956,6 @@ struct BlockwiseReduceRewritePattern if (!isa(acc.getType()) && isa(input.getType())) { // This means accumulator is a scalar type and input is a vector type, // therefore its a elementwise reduction between two operands. - // Pass `acc` as the accumulator to vector::ReductionOp so that the - // scalar accumulation is folded into the reduction intrinsic rather - // than emitting a separate arith::AddFOp / arith::MaximumFOp. - // This avoids redundant `fadd X, 0.0` when acc is the identity. vector::CombiningKind kind; if (rMethod == ReduceMethod::Sum) { kind = vector::CombiningKind::ADD; @@ -1478,67 +1474,56 @@ struct BlockwiseReduceRewritePattern } } - // After the branchless reduction, every thread with the same - // nrtid holds the identical fully-reduced value in accReg. + // Write the fully reduced value back to LDS at [nrtid, 0]. + // All threads with the same nrtid compute the same value, + // so concurrent writes to the same location are safe. // - // When each thread owns exactly 1 non-reduction position - // (inputThreadSubTile2dShape[nrDim] == 1), the output register - // only needs this single reduced value broadcast to all its - // elements. We can skip the LDS write-back + barrier + - // broadcast-read and instead fill the output register directly - // from the register, eliminating one s_barrier per reduction. - if (inputThreadSubTile2dShape[nrDim] == 1) { + // NOTE: We cannot use a FillOp shortcut here (even when + // inputThreadSubTile2dShape[nrDim] == 1) because nrtid + // (= tid % nonReduceMergeDimSize) does NOT necessarily + // correspond to the thread's actual non-reduction position + // in the MFMA layout. The ThreadwiseReadIntoOp uses the + // correct layout-aware view to read each thread's result. + { Value reducedVal = InBoundsLoadOp::create( rewriter, loc, elemType, accReg, zeroConstantOp); - FillOp::create(rewriter, loc, outputReg, reducedVal); - if (op.getExtraOutViewAttr()) { - FillOp::create(rewriter, loc, op.getExtraOut(), reducedVal); - } - } else { - // General case: thread has multiple non-reduction rows. - // Only nrtid's row was reduced; other rows' results live in - // LDS (written by other threads). Fall back to LDS round-trip. + SmallVector writeInits{nrtid, zeroConstantOp, + zeroConstantOp}; + SmallVector writeBounds{1, 1, 1}; + SmallVector writeStrides{1, 1, 1}; + + TransformingForOp writeLoop = TransformingForOp::create( + rewriter, loc, ArrayRef{writeInits}, + ArrayRef{threadToLDSViewTrs}, + ArrayRef(writeBounds), + ArrayRef(writeStrides), + /*forceUnroll=*/true, /*useIndexDiffs=*/true); { - Value reducedVal = InBoundsLoadOp::create( - rewriter, loc, elemType, accReg, zeroConstantOp); - SmallVector writeInits{nrtid, zeroConstantOp, - zeroConstantOp}; - SmallVector writeBounds{1, 1, 1}; - SmallVector writeStrides{1, 1, 1}; - - TransformingForOp writeLoop = TransformingForOp::create( - rewriter, loc, ArrayRef{writeInits}, - ArrayRef{threadToLDSViewTrs}, - ArrayRef(writeBounds), - ArrayRef(writeStrides), - /*forceUnroll=*/true, /*useIndexDiffs=*/true); - { - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(writeLoop.getBody()); - Block::BlockArgListType ldsCoords = - writeLoop.getLowerCoords(/*domain=*/0); - InBoundsStoreOp::create(rewriter, loc, reducedVal, - workspaceLDSBuffer, ldsCoords); - } + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(writeLoop.getBody()); + Block::BlockArgListType ldsCoords = + writeLoop.getLowerCoords(/*domain=*/0); + InBoundsStoreOp::create(rewriter, loc, reducedVal, + workspaceLDSBuffer, ldsCoords); } + } - LDSBarrierOp::create(rewriter, loc); - ArrayAttr reducedldsViewArrayAttr = createLDSWorkspaceView( - loc, rewriter, inputViewArrayAttr, axis, + LDSBarrierOp::create(rewriter, loc); + ArrayAttr reducedldsViewArrayAttr = createLDSWorkspaceView( + loc, rewriter, inputViewArrayAttr, axis, + /*makeRDimZero-*/ true, partialRegTensorShape[rDim]); + ThreadwiseReadIntoOp::create(rewriter, loc, workspaceLDSBuffer, + outputReg, reducedldsViewArrayAttr, + /*extraIndices=*/ValueRange{tid}, true, + false); + if (ArrayAttr outputViewArrayAttr = op.getExtraOutViewAttr()) { + ArrayAttr reducedldsViewArrayAttr2 = createLDSWorkspaceView( + loc, rewriter, outputViewArrayAttr, axis, /*makeRDimZero-*/ true, partialRegTensorShape[rDim]); - ThreadwiseReadIntoOp::create(rewriter, loc, workspaceLDSBuffer, - outputReg, reducedldsViewArrayAttr, - /*extraIndices=*/ValueRange{tid}, true, - false); - if (ArrayAttr outputViewArrayAttr = op.getExtraOutViewAttr()) { - ArrayAttr reducedldsViewArrayAttr2 = createLDSWorkspaceView( - loc, rewriter, outputViewArrayAttr, axis, - /*makeRDimZero-*/ true, partialRegTensorShape[rDim]); - ThreadwiseReadIntoOp::create( - rewriter, loc, workspaceLDSBuffer, op.getExtraOut(), - reducedldsViewArrayAttr2, - /*extraIndices=*/ValueRange{tid}, true, false); - } + ThreadwiseReadIntoOp::create( + rewriter, loc, workspaceLDSBuffer, op.getExtraOut(), + reducedldsViewArrayAttr2, + /*extraIndices=*/ValueRange{tid}, true, false); } } } From dcceeeca6de19245b9288bb9e7cf3a51a3bfc476 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Fri, 27 Feb 2026 22:10:06 +0000 Subject: [PATCH 08/18] More fixes --- .../Transforms/GridwiseGemmToBlockwise.cpp | 57 +++++++++++-------- 1 file changed, 34 insertions(+), 23 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index f9490364cdb7..de679eed1178 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -2845,7 +2845,7 @@ struct GridwiseAttentionAccelRewritePattern int64_t nonHoistedPeak = std::max(gemm0PeakBytes, gemm1PeakBytes); // Hoisting adds sumWSBytes on top (V displaced in merged color). int64_t hoistedTotal = nonHoistedPeak + sumWSBytes; - hoistVPhase2 = hoistedTotal <= maxLDS; + hoistVPhase2 = (hoistedTotal <= maxLDS); LLVM_DEBUG(llvm::dbgs() << "V prefetch Phase 2 hoist decision: " << (hoistVPhase2 ? "HOIST" : "DEFER") @@ -2856,8 +2856,7 @@ struct GridwiseAttentionAccelRewritePattern } if (prefetchFirstVTile) { - ldsByteBufferV = createLDSByteBuffer( - rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV); + // Set up grid coordinates for the first V tile. gridCoordsGemm1 = layout::makeGxNGridLayout( rewriter, loc, bid, zero, gemm1NBlocks, gridSize, arch, numChiplets, splitKVConst); @@ -2873,28 +2872,22 @@ struct GridwiseAttentionAccelRewritePattern // Phase 1: Issue global reads for V tile 0 into register buffer. // Only the GlobalRead stage is emitted; LDS write is deferred. + // A dummy LDS buffer is passed because the function signature + // requires one, but GlobalReadOnly does not write to LDS. + Value dummyLDS = createLDSByteBuffer( + rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV); loadAndStoreGemmInputTile( rewriter, loc, inV, - /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, + /*kIter=*/mLoopIV, tid, gridCoordsGemm1, dummyLDS, vPrefetchRegs, GemmLoadTileType::GlobalReadOnly, "m", blockSize, elemTypeV, elemTypeVLoad, gemm1TuningParams, featuresAttr, matrixParamsV, matrixParamsKxQ); // Insert a scheduling barrier to prevent the LLVM backend scheduler // from sinking the V global loads past the softmax computation. - // Without this barrier, the scheduler moves the V loads to after - // softmax, defeating the latency hiding optimization. - // mask = none (0x0): full barrier, no instructions may cross. amdgpu::SchedBarrierOp::create( rewriter, loc, amdgpu::sched_barrier_opt_enum::none); - // Enable IGLP (Instruction-Group-Level Parallelism) scheduling. - // The softmax section produces v_exp_f32 (transcendental unit) and - // the subsequent S*V GEMM produces v_mfma (matrix core unit). - // These two execution units can operate in parallel. Variant 2 - // (MFMAExpInterleave) analyzes the dependency graph and creates - // scheduling groups that interleave TRANS and MFMA instructions, - // hiding transcendental latency behind matrix computation. amdgpu::IglpOptOp::create(rewriter, loc, /*variant=*/2); } @@ -3075,6 +3068,10 @@ struct GridwiseAttentionAccelRewritePattern // affect occupancy (VGPR-limited, not LDS-limited). // ================================================================ if (prefetchFirstVTile && hoistVPhase2) { + // Allocate V LDS buffer early (before the sum reduction) so that + // Phase 2 can write the prefetched V data from registers into LDS. + ldsByteBufferV = createLDSByteBuffer( + rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV); loadAndStoreGemmInputTile( rewriter, loc, inV, /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, @@ -3122,6 +3119,12 @@ struct GridwiseAttentionAccelRewritePattern // global memory latency across the entire softmax computation. // ================================================================ if (prefetchFirstVTile && !hoistVPhase2) { + // Allocate V LDS buffer HERE (late) instead of before softmax. + // This makes ldsByteBufferV's live range start after the + // reduction, preventing ReuseLDS from aliasing it with + // buffers that are still being read by slow wavefronts. + ldsByteBufferV = createLDSByteBuffer( + rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV); loadAndStoreGemmInputTile( rewriter, loc, inV, /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, @@ -3392,15 +3395,23 @@ struct GridwiseAttentionAccelRewritePattern rewriter.createOrFold(loc, 1); scf::ForOp g1MLoopOp = scf::ForOp::create( rewriter, loc, startG1M, endG1MLoop, oneVal); - // Mark loop for pipelining - bool g1DoubleBuffering = - loadType == GemmLoadTileType::DoubleBuffer || - loadType == GemmLoadTileType::DirectToLDSDoubleBuffer; - int64_t g1InitiationInterval = g1DoubleBuffering ? 1 : 2; - g1MLoopOp->setAttr( - PipelineAttr::getMnemonic(), - rock::PipelineAttr::get(rewriter.getContext(), - g1InitiationInterval)); + // Mark loop for pipelining — but only when the remaining loop + // has more than 1 iteration. Pipelining a 1-iteration loop + // (gemm1MBlocks == 2 → loop from 1 to 2) provides no overlap + // benefit and the RockPipelinePass currently drops the + // inter-stage LDS barriers from the epilogue, causing a data + // race between the V LDS write (prologue) and the GEMM1 V LDS + // read (epilogue). + if (gemm1MBlocks > 2) { + bool g1DoubleBuffering = + loadType == GemmLoadTileType::DoubleBuffer || + loadType == GemmLoadTileType::DirectToLDSDoubleBuffer; + int64_t g1InitiationInterval = g1DoubleBuffering ? 1 : 2; + g1MLoopOp->setAttr( + PipelineAttr::getMnemonic(), + rock::PipelineAttr::get(rewriter.getContext(), + g1InitiationInterval)); + } { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(g1MLoopOp.getBody()); From 1bb557c1575d3aaa0ec4f11d58317189ec285915 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Tue, 3 Mar 2026 18:28:25 +0000 Subject: [PATCH 09/18] Remove iglp op --- .../include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 24 ------------------- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 16 +------------ .../Transforms/GridwiseGemmToBlockwise.cpp | 2 -- 3 files changed, 1 insertion(+), 41 deletions(-) diff --git a/external/llvm-project/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/external/llvm-project/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 6662a0fadb4d..27009cd0961e 100644 --- a/external/llvm-project/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/external/llvm-project/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -852,30 +852,6 @@ def AMDGPU_SchedBarrierOp : }]; } -def AMDGPU_IglpOptOp : - AMDGPU_Op<"iglp_opt">, - Arguments<(ins I32Attr:$variant)> - { - let summary = "Hint to the AMDGPU instruction scheduler to apply an IGLP strategy"; - let description = [{ - `amdgpu.iglp_opt` provides a hint to the LLVM AMDGPU backend's instruction - scheduler to apply a specific Instruction-Group-Level-Parallelism (IGLP) - scheduling strategy. The `variant` attribute selects which strategy to use: - - - 0: `MFMASmallGemmOpt` — interleaves DS_READ with MFMA for small GEMMs. - - 1: `MFMASmallGemmSingleWaveOpt` — single-wave GEMM with DS_READ, V_PERM, - DS_WRITE, and VMEM interleaving. - - 2: `MFMAExpInterleave` — interleaves MFMA with transcendental (EXP) - instructions, with complex DAG analysis. - - 3: `MFMAExpSimpleInterleave` — simple TRANS/MFMA interleaving. - - This lowers to `llvm.amdgcn.iglp_opt(i32 variant)`. - }]; - let assemblyFormat = [{ - $variant attr-dict - }]; -} - def AMDGPU_MemoryCounterWaitOp : AMDGPU_Op<"memory_counter_wait">, Arguments<(ins diff --git a/external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 47f3e2570224..6c82f020106e 100644 --- a/external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -613,20 +613,6 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern { } }; -struct IglpOptOpLowering : public ConvertOpToLLVMPattern { - IglpOptOpLowering(const LLVMTypeConverter &converter, Chipset chipset) - : ConvertOpToLLVMPattern(converter), chipset(chipset) {} - - Chipset chipset; - - LogicalResult - matchAndRewrite(IglpOptOp op, IglpOptOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, static_cast(op.getVariant())); - return success(); - } -}; } // namespace @@ -2234,7 +2220,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RawBufferOpLowering, AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering, - SchedBarrierOpLowering, IglpOptOpLowering, MFMAOpLowering, + SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering, diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index de679eed1178..0f3ad207dd36 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -2887,8 +2887,6 @@ struct GridwiseAttentionAccelRewritePattern // from sinking the V global loads past the softmax computation. amdgpu::SchedBarrierOp::create( rewriter, loc, amdgpu::sched_barrier_opt_enum::none); - - amdgpu::IglpOptOp::create(rewriter, loc, /*variant=*/2); } int64_t prePadG0M = gemm0M; From b3fecc6d665592d689a44fc7bbc6a5e2bab87e0e Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Tue, 3 Mar 2026 18:44:43 +0000 Subject: [PATCH 10/18] Minor fixes --- .../lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 4 +--- mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td | 4 ---- mlir/include/mlir/Dialect/Rock/Passes.td | 2 +- .../Rock/Transforms/BlockwiseGemmToThreadwise.cpp | 2 +- .../Rock/Transforms/GridwiseGemmToBlockwise.cpp | 10 +++++----- 5 files changed, 8 insertions(+), 14 deletions(-) diff --git a/external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 6c82f020106e..e9da09b9c8f6 100644 --- a/external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -613,7 +613,6 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern { } }; - } // namespace /// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL @@ -2220,8 +2219,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RawBufferOpLowering, AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering, - SchedBarrierOpLowering, MFMAOpLowering, - ScaledMFMAOpLowering, + SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering, GatherToLDSOpLowering, diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td index 1f9e0e37a7f4..793fbd0df126 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td @@ -464,10 +464,6 @@ def Rock_GemmLoadTileDirectToLDSDefault : I32EnumAttrCase<"DirectToLDSDefault", 3>; def Rock_GemmLoadTileDirectToLDSDoubleBuffer : I32EnumAttrCase<"DirectToLDSDoubleBuffer", 4>; -// Split-phase load types for V prefetch in attention kernels. -// GlobalReadOnly: Only emit the global read stage (global -> register buffer). -// LDSWriteFromRegs: Only emit the LDS write stage (register buffer -> LDS). -// Both phases share a register buffer passed via destRegisters. def Rock_GemmLoadTileGlobalReadOnly : I32EnumAttrCase<"GlobalReadOnly", 5>; def Rock_GemmLoadTileLDSWriteFromRegs diff --git a/mlir/include/mlir/Dialect/Rock/Passes.td b/mlir/include/mlir/Dialect/Rock/Passes.td index 18408e58b6ba..1041ab12abdf 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.td +++ b/mlir/include/mlir/Dialect/Rock/Passes.td @@ -107,7 +107,7 @@ def RockRegularizePass : Pass<"rock-regularize", "::mlir::func::FuncOp"> { def RockGridwiseGemmToBlockwisePass : Pass<"rock-gridwise-gemm-to-blockwise", "::mlir::func::FuncOp"> { let summary = "expand gridwise gemm into blockwise copy, blockwise gemm, and threadwise copy"; - let dependentDialects = ["rock::RockDialect", "affine::AffineDialect", "gpu::GPUDialect", "vector::VectorDialect", "memref::MemRefDialect", "linalg::LinalgDialect", "scf::SCFDialect", "amdgpu::AMDGPUDialect"]; + let dependentDialects = ["rock::RockDialect", "affine::AffineDialect", "gpu::GPUDialect", "vector::VectorDialect", "memref::MemRefDialect", "linalg::LinalgDialect", "scf::SCFDialect"]; } def RockLinalgAlignPass : Pass<"rock-linalg-align", "::mlir::func::FuncOp"> { diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp index ffa855fa6121..0171eb371113 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp @@ -1415,7 +1415,7 @@ struct BlockwiseReduceRewritePattern } } - // Branchless reduction: each thread reads ALL rTidDim partial + // Branchless reduction: each thread reads all rTidDim partial // values from LDS and reduces locally in registers. This avoids // creating conditional branches (scf.if) that split softmax into // multiple basic blocks. Without branches, the LLVM backend diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index 0f3ad207dd36..d8f7e55fb6e2 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -2422,11 +2422,11 @@ struct GridwiseAttentionAccelRewritePattern rewriter, loc, accelParamsGemm1.argTypeA, accelParamsGemm1.kBasePerThread, doubleBuffering ? accelParamsGemm1.mRepeats : 1, directToLDS); - auto [preAccelRegBufferQxKForLoad, preAccelRegBufferQxK] = - createRegInterrimBufferForAccel( - rewriter, loc, accelParamsGemm1.argTypeB, - accelParamsGemm1.kBasePerThread, - doBypassLDSSecondGemm ? accelParamsGemm1.nRepeats : 1, false); + auto preAccelRegBufferQxKPair = createRegInterrimBufferForAccel( + rewriter, loc, accelParamsGemm1.argTypeB, + accelParamsGemm1.kBasePerThread, + doBypassLDSSecondGemm ? accelParamsGemm1.nRepeats : 1, false); + Value preAccelRegBufferQxK = preAccelRegBufferQxKPair.second; Value accRegBufferGemm1; Value gemm1OutBuffer; From 7063f8d380c33d8af9ff6660eeba86fcb6a57473 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Tue, 3 Mar 2026 18:56:23 +0000 Subject: [PATCH 11/18] Add back AMDGPU dialect --- mlir/include/mlir/Dialect/Rock/Passes.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Rock/Passes.td b/mlir/include/mlir/Dialect/Rock/Passes.td index 1041ab12abdf..18408e58b6ba 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.td +++ b/mlir/include/mlir/Dialect/Rock/Passes.td @@ -107,7 +107,7 @@ def RockRegularizePass : Pass<"rock-regularize", "::mlir::func::FuncOp"> { def RockGridwiseGemmToBlockwisePass : Pass<"rock-gridwise-gemm-to-blockwise", "::mlir::func::FuncOp"> { let summary = "expand gridwise gemm into blockwise copy, blockwise gemm, and threadwise copy"; - let dependentDialects = ["rock::RockDialect", "affine::AffineDialect", "gpu::GPUDialect", "vector::VectorDialect", "memref::MemRefDialect", "linalg::LinalgDialect", "scf::SCFDialect"]; + let dependentDialects = ["rock::RockDialect", "affine::AffineDialect", "gpu::GPUDialect", "vector::VectorDialect", "memref::MemRefDialect", "linalg::LinalgDialect", "scf::SCFDialect", "amdgpu::AMDGPUDialect"]; } def RockLinalgAlignPass : Pass<"rock-linalg-align", "::mlir::func::FuncOp"> { From fedc9ebd18ee2c7b014931e8de0287f46ce25759 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Tue, 3 Mar 2026 19:12:48 +0000 Subject: [PATCH 12/18] More cleanup --- .../Transforms/BlockwiseGemmToThreadwise.cpp | 15 +- .../BlockwiseLoadTileToThreadwise.cpp | 8 +- .../Transforms/GridwiseGemmToBlockwise.cpp | 146 ++++-------------- .../Dialect/Rock/Transforms/RockPipeline.cpp | 4 +- 4 files changed, 32 insertions(+), 141 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp index 0171eb371113..19bc6696d2b7 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp @@ -1418,15 +1418,13 @@ struct BlockwiseReduceRewritePattern // Branchless reduction: each thread reads all rTidDim partial // values from LDS and reduces locally in registers. This avoids // creating conditional branches (scf.if) that split softmax into - // multiple basic blocks. Without branches, the LLVM backend - // scheduler can keep V global loads (issued before softmax) in - // the same basic block, enabling sched_barrier to prevent them - // from being sunk past softmax computation. - // + // multiple basic blocks. // Trade-off: every thread does rTidCount LDS reads (instead of // log2(rTidCount) conditional reads in the tree reduction). For // typical attention configs where rTidCount is small (e.g., 4), // this is negligible overhead. + // TODO: We may have to use a heuristic to determine whether or not to + // use this depending on the size of rTidCount. { int64_t rTidCount = threadViewShape[rTidDim]; @@ -1477,13 +1475,6 @@ struct BlockwiseReduceRewritePattern // Write the fully reduced value back to LDS at [nrtid, 0]. // All threads with the same nrtid compute the same value, // so concurrent writes to the same location are safe. - // - // NOTE: We cannot use a FillOp shortcut here (even when - // inputThreadSubTile2dShape[nrDim] == 1) because nrtid - // (= tid % nonReduceMergeDimSize) does NOT necessarily - // correspond to the thread's actual non-reduction position - // in the MFMA layout. The ThreadwiseReadIntoOp uses the - // correct layout-aware view to read each thread's result. { Value reducedVal = InBoundsLoadOp::create( rewriter, loc, elemType, accReg, zeroConstantOp); diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp index c2e70e99bb90..a301faa0a4ab 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp @@ -286,8 +286,6 @@ class LoweringBlockwiseLoadTileOp final if (isa(parentOp)) b.setInsertionPoint(op); - // ---- GlobalRead stage ---- - // Emit for all types EXCEPT LDSWriteFromRegs (which only does the write). if (!ldsWriteFromRegs) { // Use distinct stage name for split-phase V prefetch to avoid // conflicting with K/Q GlobalRead stages in the same parent scope. @@ -336,10 +334,6 @@ class LoweringBlockwiseLoadTileOp final Value one = b.createOrFold(loc, 1); indicesNext[0] = arith::AddIOp::create(b, loc, indicesNext[0], one).getResult(); - - // it's acceptable if the indices are out of bounds because we use - // GLOBAL_PREFETCH_B8 with Speculative Prefetch. See llvm.prefetch - // documentation in AMDGPUUsage.rst rock::ThreadwisePrefetchOp::create(b, loc, wrappedSource, /*extraViews=*/b.getArrayAttr({}), /*extraIndices=*/indicesNext, @@ -350,7 +344,7 @@ class LoweringBlockwiseLoadTileOp final } } - // For GlobalReadOnly, we're done - skip all write stages. + // For GlobalReadOnly there's nothing further to do. if (globalReadOnly) { b.eraseOp(op); return success(); diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index d8f7e55fb6e2..c41b58359eaf 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -2788,47 +2788,19 @@ struct GridwiseAttentionAccelRewritePattern accelEmitterPtrGemm0->computeOutputConversion( rewriter, loc, accRegBufferGemm0, gemm0OutBuffer, forceUnroll); - // ================================================================ - // V PREFETCH: Issue global reads for V tile 0 before softmax. - // ================================================================ - // By issuing V global reads here (before softmax computation), - // we overlap the ~120+ instructions of softmax work with the - // global memory access latency for V, matching CK's approach. - // - // The flow is: - // 1. Issue V global reads -> register buffer [HERE, before softmax] - // 2. Softmax computation [hides load latency] - // 3. Write V from registers -> LDS [after softmax] - // 4. GEMM1 first iteration uses V from LDS [peeled iteration] - // 5. Remaining GEMM1 iters: normal load+MMA [pipelineable loop] - // - // The split is implemented using two new GemmLoadTileType values: - // - GlobalReadOnly: emits only the GlobalRead stage - // (ThreadwiseReadIntoOp: global -> register buffer, no LDS write) - // - LDSWriteFromRegs: emits only the LDSWrite stage - // (ThreadwiseCopyOp + ThreadwiseWriteAllOp: regs -> LDS, - // no global read) - // Both phases share the same flat register buffer (vPrefetchRegs). + // V p: Issue global reads for V tile 0 before softmax + // to overlap softmax computation with V's global memory latency. + // Uses GlobalReadOnly (global -> regs) and LDSWriteFromRegs + // (regs -> LDS) to split the load across the softmax boundary. Value ldsByteBufferV; Value vPrefetchRegs; layout::GridCoordinates gridCoordsGemm1; bool prefetchFirstVTile = op.getEnableSoftmax() && !directToLDS; - // Decide whether to hoist Phase 2 (V regs -> LDS write) before the - // sum reduction. Hoisting saves one LDS barrier by piggybacking on - // the sum reduction's internal barrier, but it makes V's LDS live - // range overlap with the sum-reduction workspace, preventing - // ReuseLDS from aliasing them. - // - // ReuseLDS uses greedy graph coloring that packs non-interfering - // buffers (like K and V) into merged color groups. When V interferes - // with sum_ws (due to hoisting), V gets displaced within the merged - // group by sum_ws's size, growing the group by exactly sumWSBytes. - // So: hoisted_total ≈ non_hoisted_peak + sumWSBytes. - // - // The non-hoisted peak is the max concurrent LDS from GEMM0 - // (Q+K buffers) or GEMM1 (V+gemm1_B buffers). We check if adding - // sumWSBytes would exceed the hardware LDS limit. + // Decide whether to hoist V regs->LDS write before the sum reduction. + // Hoisting saves one LDS barrier but extends V's LDS live range to + // overlap with the sum-reduction workspace, which may increase peak + // LDS usage. Only hoist if the resulting peak fits in hardware LDS. bool hoistVPhase2 = false; if (prefetchFirstVTile) { int64_t maxLDS = archInfo.maxSharedMemPerWG; @@ -2840,6 +2812,7 @@ struct GridwiseAttentionAccelRewritePattern int64_t gemm1PeakBytes = getPackedByteSize(gemm1KPerBlock * gemm1MPerBlock, elemTypeV) + getPackedByteSize(gemm1LDSByteBufferBSize, elemTypeV); + // The base peak without hoisting is determined by the larger of // GEMM0 and GEMM1 concurrent buffer sets. int64_t nonHoistedPeak = std::max(gemm0PeakBytes, gemm1PeakBytes); @@ -3050,21 +3023,9 @@ struct GridwiseAttentionAccelRewritePattern gemm0MNExpThreadwiseView, gemm0MNMaxThreadwiseView, maxRowBuffer); - // ================================================================ - // V PREFETCH Phase 2 (hoisted): Write V data from regs to LDS - // before the sum reduction so that the sum reduction's internal - // LDS barrier also synchronises the V tile writes. This - // eliminates the dedicated V-tile LDS barrier that was - // previously required after the sum reduction, saving one - // s_barrier per iteration. - // - // Safety: AnnotateLiveness + ReuseLDS will see that V's live - // range (write here -> read during GEMM1) overlaps with the sum - // workspace's live range, so they will NOT be aliased. The - // max-reduction workspace is already dead, so it CAN be - // aliased with V. The LDS increase is small and does not - // affect occupancy (VGPR-limited, not LDS-limited). - // ================================================================ + // V prefetch phase 2 (hoisted): Write V data from regs to LDS + // before the sum reduction. The sum reduction's internal LDS + // barrier synchronises the V tile writes, saving one barrier. if (prefetchFirstVTile && hoistVPhase2) { // Allocate V LDS buffer early (before the sum reduction) so that // Phase 2 can write the prefetched V data from registers into LDS. @@ -3076,9 +3037,6 @@ struct GridwiseAttentionAccelRewritePattern vPrefetchRegs, GemmLoadTileType::LDSWriteFromRegs, "m", blockSize, elemTypeV, elemTypeVLoad, gemm1TuningParams, featuresAttr, matrixParamsV, matrixParamsKxQ); - // No LDSBarrierOp here — the barrier inside the sum - // BlockwiseBroadcastReduceOp (below) will synchronise both - // the V LDS writes and the softmax partial-sum LDS writes. } // Softmax sum reduction @@ -3107,20 +3065,10 @@ struct GridwiseAttentionAccelRewritePattern gemm0MaxThreadwiseView, sumRowBuffer, maxRowBuffer, expMaxDiffRowBuffer); - // ================================================================ - // V PREFETCH Phase 2 (deferred path): Write V data from regs to - // LDS after the sum reduction. This avoids V's LDS live range - // overlapping with the sum-reduction workspace, allowing - // ReuseLDS to alias them and stay within the hardware LDS budget. - // Costs one extra s_barrier vs the hoisted path. - // Phase 1 (global reads -> regs, before softmax) still hides the - // global memory latency across the entire softmax computation. - // ================================================================ + // V prefetch phase 2 (deferred path): Write V data from regs to + // LDS after the sum reduction to avoid overlapping with the + // sum-reduction workspace in LDS. Costs one extra barrier. if (prefetchFirstVTile && !hoistVPhase2) { - // Allocate V LDS buffer HERE (late) instead of before softmax. - // This makes ldsByteBufferV's live range start after the - // reduction, preventing ReuseLDS from aliasing it with - // buffers that are still being read by slow wavefronts. ldsByteBufferV = createLDSByteBuffer( rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV); loadAndStoreGemmInputTile( @@ -3181,11 +3129,8 @@ struct GridwiseAttentionAccelRewritePattern } } - // ================================================================ - // V load + GEMM1 loop: Two paths depending on V prefetch. - // ================================================================ - // For non-prefetch path: allocate V LDS buffer and grid coords - // (prefetch path already did this before softmax). + // V load + GEMM1 loop. For the non-prefetch path, allocate the + // V LDS buffer and grid coords here (prefetch already did this). if (!prefetchFirstVTile) { ldsByteBufferV = createLDSByteBuffer( rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV); @@ -3194,12 +3139,7 @@ struct GridwiseAttentionAccelRewritePattern numChiplets, splitKVConst); } - // ---------------------------------------------------------------- - // Helper lambda: Emit GEMM1 MMA + PostProcess for a single V tile. - // Parameterized by V block index (g1MBlockIdx) to support both - // the peeled first iteration and the remaining loop iterations. - // This avoids duplicating ~100 lines of MMA + PostProcess code. - // ---------------------------------------------------------------- + // Helper lambda: emit GEMM1 MMA + PostProcess for a single V tile. auto emitGemm1Compute = [&](Value g1MBlockIdx, GemmLoadTileType vLoadType, Value vRegBuf) -> LogicalResult { @@ -3333,30 +3273,11 @@ struct GridwiseAttentionAccelRewritePattern }; // end emitGemm1Compute lambda if (prefetchFirstVTile) { - // ============================================================ - // PREFETCH PATH: First V tile already loaded into LDS. - // ============================================================ - // V data for tile 0 was prefetched before softmax (global read) - // and written to LDS before the sum reduction (LDS write synced - // by sum reduction's internal barrier). - // The first GEMM1 iteration is peeled out of the loop so the - // remaining iterations form a clean, pipelineable loop. - - // --- Peeled first iteration (g1m = 0) --- + // Prefetch path: V tile 0 is already in LDS. Peel the first + // GEMM1 iteration and loop over the remaining tiles. gridCoordsGemm1.m_block = zero; - // Use Default load type for the peeled iteration because the V - // data was written to LDS by the LDSWriteFromRegs phase. There is - // no BlockwiseLoadTileOp here to create an LDSRead stage, so the - // GEMM must read V directly from LDS. - // - // When double-buffering is active, preAccelRegBufferV is rank-2 - // (e.g. memref<3x2xvector<4xf16>>) because it was allocated with - // repeats=mRepeats. However, the Default load path in - // BlockwiseGemmAccelOp reads from LDS into the buffer WITHOUT - // slicing by the m-repeat loop variable. The downstream - // generateThreadwiseViewBufferA then creates a rank-1 view, - // leading to a memref.load rank mismatch. Fix: create a separate - // rank-1 register buffer for the peeled iteration. + // When double-buffering, preAccelRegBufferV is rank-2; the + // Default load path expects rank-1, so allocate a separate buf. Value peeledVRegBuf = preAccelRegBufferV; if (doubleBuffering) { auto [peeledVForLoad, peeledVBuf] = @@ -3366,13 +3287,6 @@ struct GridwiseAttentionAccelRewritePattern /*repeats=*/1, directToLDS); peeledVRegBuf = peeledVBuf; } - // Barrier: ensure all threads have finished writing the softmax - // exp values to LDS (storeGemmInputTile above) before GEMM1 - // reads from them. Only needed when the softmax exp actually - // goes through LDS (!doBypassLDSSecondGemm). When LDS is - // bypassed, softmax exp stays in registers and V is already - // synced by either the sum reduction's internal barrier - // (hoisted path) or the deferred V Phase 2 barrier. if (!doBypassLDSSecondGemm) LDSBarrierOp::create(rewriter, loc); @@ -3380,8 +3294,7 @@ struct GridwiseAttentionAccelRewritePattern peeledVRegBuf))) return failure(); - // --- Remaining iterations (g1m = 1..gemm1MBlocks-1) --- - // These form a standard pipelineable loop with V loads. + // Remaining iterations (g1m = 1..gemm1MBlocks-1). if (gemm1MBlocks > 1) { LDSBarrierOp::create(rewriter, loc); @@ -3393,13 +3306,8 @@ struct GridwiseAttentionAccelRewritePattern rewriter.createOrFold(loc, 1); scf::ForOp g1MLoopOp = scf::ForOp::create( rewriter, loc, startG1M, endG1MLoop, oneVal); - // Mark loop for pipelining — but only when the remaining loop - // has more than 1 iteration. Pipelining a 1-iteration loop - // (gemm1MBlocks == 2 → loop from 1 to 2) provides no overlap - // benefit and the RockPipelinePass currently drops the - // inter-stage LDS barriers from the epilogue, causing a data - // race between the V LDS write (prologue) and the GEMM1 V LDS - // read (epilogue). + // Only pipeline when >1 iteration remains; pipelining a + // single iteration causes barrier mismatches. if (gemm1MBlocks > 2) { bool g1DoubleBuffering = loadType == GemmLoadTileType::DoubleBuffer || @@ -3437,9 +3345,7 @@ struct GridwiseAttentionAccelRewritePattern } } } else { - // ============================================================ - // ORIGINAL PATH: No V prefetch (softmax disabled). - // ============================================================ + // Non-prefetch path (softmax disabled). Value endG1MLoop = rewriter.createOrFold(loc, gemm1MBlocks); scf::ForOp g1MLoopOp = diff --git a/mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp b/mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp index 281730ce140e..64b2eb1f7197 100644 --- a/mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp @@ -166,7 +166,7 @@ struct PushBarrierDownRewritePattern if (!nextOp->getNextNode()) return failure(); - // Don't push past another barrier — RemoveBackToBack handles that. + // Don't push past another barrier, RemoveBackToBack handles that. // Without this check, two adjacent barriers would swap endlessly. if (isa(nextOp)) return failure(); @@ -178,7 +178,7 @@ struct PushBarrierDownRewritePattern bool moveDown = true; // Check if the operation accesses LDS. // We can move past LDS store-only operations because independent - // writes don't need ordering between them — the next barrier will + // writes don't need ordering between them, the next barrier will // ensure all writes complete before any subsequent reads. // We must stop at LDS reads. // We recognize store ops both before SugarToLoops (InBoundsStoreOp) From c298f92e02e4fdcf6d30ba94507cab788587dda9 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Mon, 9 Mar 2026 21:15:58 +0000 Subject: [PATCH 13/18] Clang-format --- .../mlir/Dialect/Rock/IR/RockAttrDefs.td | 6 +- mlir/include/mlir/Dialect/Rock/Passes.td | 5 +- .../Transforms/BlockwiseGemmToThreadwise.cpp | 17 ++-- .../BlockwiseLoadTileToThreadwise.cpp | 24 +++-- .../Transforms/GridwiseGemmToBlockwise.cpp | 99 +++++++++---------- .../Dialect/Rock/Transforms/RockPipeline.cpp | 2 +- .../Transforms/ThreadwiseGemmLowering.cpp | 3 +- 7 files changed, 74 insertions(+), 82 deletions(-) diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td index 793fbd0df126..fa123a57b5fb 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td @@ -464,10 +464,8 @@ def Rock_GemmLoadTileDirectToLDSDefault : I32EnumAttrCase<"DirectToLDSDefault", 3>; def Rock_GemmLoadTileDirectToLDSDoubleBuffer : I32EnumAttrCase<"DirectToLDSDoubleBuffer", 4>; -def Rock_GemmLoadTileGlobalReadOnly - : I32EnumAttrCase<"GlobalReadOnly", 5>; -def Rock_GemmLoadTileLDSWriteFromRegs - : I32EnumAttrCase<"LDSWriteFromRegs", 6>; +def Rock_GemmLoadTileGlobalReadOnly : I32EnumAttrCase<"GlobalReadOnly", 5>; +def Rock_GemmLoadTileLDSWriteFromRegs : I32EnumAttrCase<"LDSWriteFromRegs", 6>; def Rock_GemmLoadTileType : Rock_I32Enum<"GemmLoadTileType", "GEMM load tile types", diff --git a/mlir/include/mlir/Dialect/Rock/Passes.td b/mlir/include/mlir/Dialect/Rock/Passes.td index 18408e58b6ba..1774b6793e3a 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.td +++ b/mlir/include/mlir/Dialect/Rock/Passes.td @@ -107,7 +107,10 @@ def RockRegularizePass : Pass<"rock-regularize", "::mlir::func::FuncOp"> { def RockGridwiseGemmToBlockwisePass : Pass<"rock-gridwise-gemm-to-blockwise", "::mlir::func::FuncOp"> { let summary = "expand gridwise gemm into blockwise copy, blockwise gemm, and threadwise copy"; - let dependentDialects = ["rock::RockDialect", "affine::AffineDialect", "gpu::GPUDialect", "vector::VectorDialect", "memref::MemRefDialect", "linalg::LinalgDialect", "scf::SCFDialect", "amdgpu::AMDGPUDialect"]; + let dependentDialects = ["rock::RockDialect", "affine::AffineDialect", + "gpu::GPUDialect", "vector::VectorDialect", + "memref::MemRefDialect", "linalg::LinalgDialect", + "scf::SCFDialect", "amdgpu::AMDGPUDialect"]; } def RockLinalgAlignPass : Pass<"rock-linalg-align", "::mlir::func::FuncOp"> { diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp index 19bc6696d2b7..dc751ee30917 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp @@ -1418,7 +1418,7 @@ struct BlockwiseReduceRewritePattern // Branchless reduction: each thread reads all rTidDim partial // values from LDS and reduces locally in registers. This avoids // creating conditional branches (scf.if) that split softmax into - // multiple basic blocks. + // multiple basic blocks. // Trade-off: every thread does rTidCount LDS reads (instead of // log2(rTidCount) conditional reads in the tree reduction). For // typical attention configs where rTidCount is small (e.g., 4), @@ -1429,8 +1429,8 @@ struct BlockwiseReduceRewritePattern int64_t rTidCount = threadViewShape[rTidDim]; // Accumulator for the full reduction. - auto accRegType = MemRefType::get( - {1}, elemType, AffineMap{}, privateMemoryAddressSpace); + auto accRegType = MemRefType::get({1}, elemType, AffineMap{}, + privateMemoryAddressSpace); Value accReg = GpuAllocOp::create(rewriter, loc, accRegType); FillOp::create(rewriter, loc, accReg, initVal); @@ -1463,8 +1463,8 @@ struct BlockwiseReduceRewritePattern InBoundsStoreOp::create(rewriter, loc, ldVal, accReg, zeroConstantOp); } else { - Value accVal = InBoundsLoadOp::create( - rewriter, loc, elemType, accReg, zeroConstantOp); + Value accVal = InBoundsLoadOp::create(rewriter, loc, elemType, + accReg, zeroConstantOp); Value reduced = createReducingOp(op, ldVal, accVal, rewriter); InBoundsStoreOp::create(rewriter, loc, reduced, accReg, zeroConstantOp); @@ -1476,8 +1476,8 @@ struct BlockwiseReduceRewritePattern // All threads with the same nrtid compute the same value, // so concurrent writes to the same location are safe. { - Value reducedVal = InBoundsLoadOp::create( - rewriter, loc, elemType, accReg, zeroConstantOp); + Value reducedVal = InBoundsLoadOp::create(rewriter, loc, elemType, + accReg, zeroConstantOp); SmallVector writeInits{nrtid, zeroConstantOp, zeroConstantOp}; SmallVector writeBounds{1, 1, 1}; @@ -1486,8 +1486,7 @@ struct BlockwiseReduceRewritePattern TransformingForOp writeLoop = TransformingForOp::create( rewriter, loc, ArrayRef{writeInits}, ArrayRef{threadToLDSViewTrs}, - ArrayRef(writeBounds), - ArrayRef(writeStrides), + ArrayRef(writeBounds), ArrayRef(writeStrides), /*forceUnroll=*/true, /*useIndexDiffs=*/true); { PatternRewriter::InsertionGuard guard(rewriter); diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp index a301faa0a4ab..f4587e78c905 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp @@ -244,10 +244,8 @@ class LoweringBlockwiseLoadTileOp final else b.setInsertionPoint(op); - bool globalReadOnly = - loadType == GemmLoadTileType::GlobalReadOnly; - bool ldsWriteFromRegs = - loadType == GemmLoadTileType::LDSWriteFromRegs; + bool globalReadOnly = loadType == GemmLoadTileType::GlobalReadOnly; + bool ldsWriteFromRegs = loadType == GemmLoadTileType::LDSWriteFromRegs; Value loadBuffer, storeBuffer; if (globalReadOnly || ldsWriteFromRegs) { @@ -258,8 +256,8 @@ class LoweringBlockwiseLoadTileOp final "destRegisters must be set for split-phase load types"); loadBuffer = destRegisters; if (ldsWriteFromRegs) { - storeBuffer = gpuAlloc(b, loc, copyPerThread, elementType, - AddressSpace::Private); + storeBuffer = + gpuAlloc(b, loc, copyPerThread, elementType, AddressSpace::Private); } } else if (loadType == GemmLoadTileType::BypassLDS) { auto privateMemoryAddressSpace = b.getAttr( @@ -320,13 +318,13 @@ class LoweringBlockwiseLoadTileOp final Value wrappedSource = transform(b, source, maybeBufferViews->gridSubTile); - ThreadwiseReadIntoOp::create( - b, loc, vectorOfBoolShapedLike(loadBuffer), wrappedSource, - loadBuffer, - /*dynamicValidities=*/ValueRange{}, - /*extraViews=*/b.getArrayAttr({}), - /*extraIndices=*/indices, forceUnroll, true, - /*ldsTransposeConfig=*/nullptr); + ThreadwiseReadIntoOp::create(b, loc, vectorOfBoolShapedLike(loadBuffer), + wrappedSource, loadBuffer, + /*dynamicValidities=*/ValueRange{}, + /*extraViews=*/b.getArrayAttr({}), + /*extraIndices=*/indices, forceUnroll, + true, + /*ldsTransposeConfig=*/nullptr); if (!globalReadOnly && rock::isGlobalPrefetchSupported(arch)) { // add one to k_loop to prefetch next iteration diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index c41b58359eaf..15e6410530f3 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -33,8 +33,8 @@ #include "mlir/Dialect/Rock/utility/math.h" #include "mlir/Dialect/Rock/utility/transformMapUtils.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" @@ -2823,23 +2823,21 @@ struct GridwiseAttentionAccelRewritePattern << "V prefetch Phase 2 hoist decision: " << (hoistVPhase2 ? "HOIST" : "DEFER") << " (hoistedTotal=" << hoistedTotal << ", max=" << maxLDS - << ", sumWS=" << sumWSBytes - << ", gemm0=" << gemm0PeakBytes + << ", sumWS=" << sumWSBytes << ", gemm0=" << gemm0PeakBytes << ", gemm1=" << gemm1PeakBytes << ")\n"); } if (prefetchFirstVTile) { // Set up grid coordinates for the first V tile. gridCoordsGemm1 = layout::makeGxNGridLayout( - rewriter, loc, bid, zero, gemm1NBlocks, gridSize, arch, - numChiplets, splitKVConst); + rewriter, loc, bid, zero, gemm1NBlocks, gridSize, arch, numChiplets, + splitKVConst); gridCoordsGemm1.m_block = zero; // First V tile (block index 0) // Allocate a flat register buffer shared between the GlobalReadOnly // and LDSWriteFromRegs phases. Size must match what the lowering // computes: copyPerThread = (kPerBlock * dPerBlock) / blockSize. - int64_t vCopyPerThread = - (gemm1KPerBlock * gemm1MPerBlock) / blockSize; + int64_t vCopyPerThread = (gemm1KPerBlock * gemm1MPerBlock) / blockSize; vPrefetchRegs = gpuAlloc(rewriter, loc, vCopyPerThread, elemTypeV, gpu::AddressSpace::Private); @@ -2851,15 +2849,15 @@ struct GridwiseAttentionAccelRewritePattern rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV); loadAndStoreGemmInputTile( rewriter, loc, inV, - /*kIter=*/mLoopIV, tid, gridCoordsGemm1, dummyLDS, - vPrefetchRegs, GemmLoadTileType::GlobalReadOnly, "m", blockSize, - elemTypeV, elemTypeVLoad, gemm1TuningParams, featuresAttr, - matrixParamsV, matrixParamsKxQ); + /*kIter=*/mLoopIV, tid, gridCoordsGemm1, dummyLDS, vPrefetchRegs, + GemmLoadTileType::GlobalReadOnly, "m", blockSize, elemTypeV, + elemTypeVLoad, gemm1TuningParams, featuresAttr, matrixParamsV, + matrixParamsKxQ); // Insert a scheduling barrier to prevent the LLVM backend scheduler // from sinking the V global loads past the softmax computation. - amdgpu::SchedBarrierOp::create( - rewriter, loc, amdgpu::sched_barrier_opt_enum::none); + amdgpu::SchedBarrierOp::create(rewriter, loc, + amdgpu::sched_barrier_opt_enum::none); } int64_t prePadG0M = gemm0M; @@ -3034,9 +3032,9 @@ struct GridwiseAttentionAccelRewritePattern loadAndStoreGemmInputTile( rewriter, loc, inV, /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, - vPrefetchRegs, GemmLoadTileType::LDSWriteFromRegs, "m", - blockSize, elemTypeV, elemTypeVLoad, gemm1TuningParams, - featuresAttr, matrixParamsV, matrixParamsKxQ); + vPrefetchRegs, GemmLoadTileType::LDSWriteFromRegs, "m", blockSize, + elemTypeV, elemTypeVLoad, gemm1TuningParams, featuresAttr, + matrixParamsV, matrixParamsKxQ); } // Softmax sum reduction @@ -3074,9 +3072,9 @@ struct GridwiseAttentionAccelRewritePattern loadAndStoreGemmInputTile( rewriter, loc, inV, /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, - vPrefetchRegs, GemmLoadTileType::LDSWriteFromRegs, "m", - blockSize, elemTypeV, elemTypeVLoad, gemm1TuningParams, - featuresAttr, matrixParamsV, matrixParamsKxQ); + vPrefetchRegs, GemmLoadTileType::LDSWriteFromRegs, "m", blockSize, + elemTypeV, elemTypeVLoad, gemm1TuningParams, featuresAttr, + matrixParamsV, matrixParamsKxQ); LDSBarrierOp::create(rewriter, loc); } } @@ -3140,9 +3138,9 @@ struct GridwiseAttentionAccelRewritePattern } // Helper lambda: emit GEMM1 MMA + PostProcess for a single V tile. - auto emitGemm1Compute = - [&](Value g1MBlockIdx, GemmLoadTileType vLoadType, - Value vRegBuf) -> LogicalResult { + auto emitGemm1Compute = [&](Value g1MBlockIdx, + GemmLoadTileType vLoadType, + Value vRegBuf) -> LogicalResult { // Emit GEMM 1 MMA. auto computeStage = StageOp::create(rewriter, loc, "MMA"); { @@ -3155,8 +3153,8 @@ struct GridwiseAttentionAccelRewritePattern zeroAccBuffer(rewriter, loc, matrixC); } else { if (gemm1MBlocks > 1) { - matrixC = createSliceOfFirstDim(rewriter, loc, matrixC, - g1MBlockIdx); + matrixC = + createSliceOfFirstDim(rewriter, loc, matrixC, g1MBlockIdx); } } @@ -3206,20 +3204,19 @@ struct GridwiseAttentionAccelRewritePattern auto loadTypeKxD = doBypassLDSSecondGemm ? GemmLoadTileType::BypassLDS : GemmLoadTileType::Default; - blockwiseGemmAccel( - rewriter, loc, vLoadType, loadTypeKxD, vRegBuf, - preAccelRegBufferQxK, matrixC, matrixParamsV, matrixParamsKxQ, - ldsTileBufferV, gemm1LDSBufferB, - /*scaleA=*/nullptr, /*scaleB=*/nullptr, - /*bufferScaleA=*/nullptr, /*bufferScaleB=*/nullptr, - featuresAttr, op.getBlockSizeAttr(), gemm1TuningParams); + blockwiseGemmAccel(rewriter, loc, vLoadType, loadTypeKxD, vRegBuf, + preAccelRegBufferQxK, matrixC, matrixParamsV, + matrixParamsKxQ, ldsTileBufferV, gemm1LDSBufferB, + /*scaleA=*/nullptr, /*scaleB=*/nullptr, + /*bufferScaleA=*/nullptr, + /*bufferScaleB=*/nullptr, featuresAttr, + op.getBlockSizeAttr(), gemm1TuningParams); rock::YieldOp::create(rewriter, loc); } // Emit GEMM 1 PostProcess. - auto postProcessStage = - StageOp::create(rewriter, loc, "PostProcess"); + auto postProcessStage = StageOp::create(rewriter, loc, "PostProcess"); { PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart( @@ -3232,8 +3229,8 @@ struct GridwiseAttentionAccelRewritePattern if (!op.getEnableSoftmax() && gemm1MBlocks > 1) { gemm1OutBufferPerG1MBlock = createSliceOfFirstDim( rewriter, loc, gemm1OutBuffer, g1MBlockIdx); - matrixC = createSliceOfFirstDim(rewriter, loc, matrixC, - g1MBlockIdx); + matrixC = + createSliceOfFirstDim(rewriter, loc, matrixC, g1MBlockIdx); } accelEmitterPtrGemm1->computeOutputConversion( @@ -3280,11 +3277,10 @@ struct GridwiseAttentionAccelRewritePattern // Default load path expects rank-1, so allocate a separate buf. Value peeledVRegBuf = preAccelRegBufferV; if (doubleBuffering) { - auto [peeledVForLoad, peeledVBuf] = - createRegInterrimBufferForAccel( - rewriter, loc, accelParamsGemm1.argTypeA, - accelParamsGemm1.kBasePerThread, - /*repeats=*/1, directToLDS); + auto [peeledVForLoad, peeledVBuf] = createRegInterrimBufferForAccel( + rewriter, loc, accelParamsGemm1.argTypeA, + accelParamsGemm1.kBasePerThread, + /*repeats=*/1, directToLDS); peeledVRegBuf = peeledVBuf; } if (!doBypassLDSSecondGemm) @@ -3298,14 +3294,13 @@ struct GridwiseAttentionAccelRewritePattern if (gemm1MBlocks > 1) { LDSBarrierOp::create(rewriter, loc); - Value startG1M = - rewriter.createOrFold(loc, 1); + Value startG1M = rewriter.createOrFold(loc, 1); Value endG1MLoop = rewriter.createOrFold(loc, gemm1MBlocks); Value oneVal = rewriter.createOrFold(loc, 1); - scf::ForOp g1MLoopOp = scf::ForOp::create( - rewriter, loc, startG1M, endG1MLoop, oneVal); + scf::ForOp g1MLoopOp = + scf::ForOp::create(rewriter, loc, startG1M, endG1MLoop, oneVal); // Only pipeline when >1 iteration remains; pipelining a // single iteration causes barrier mismatches. if (gemm1MBlocks > 2) { @@ -3313,10 +3308,9 @@ struct GridwiseAttentionAccelRewritePattern loadType == GemmLoadTileType::DoubleBuffer || loadType == GemmLoadTileType::DirectToLDSDoubleBuffer; int64_t g1InitiationInterval = g1DoubleBuffering ? 1 : 2; - g1MLoopOp->setAttr( - PipelineAttr::getMnemonic(), - rock::PipelineAttr::get(rewriter.getContext(), - g1InitiationInterval)); + g1MLoopOp->setAttr(PipelineAttr::getMnemonic(), + rock::PipelineAttr::get(rewriter.getContext(), + g1InitiationInterval)); } { OpBuilder::InsertionGuard guard(rewriter); @@ -3916,11 +3910,10 @@ void RockGridwiseGemmToBlockwisePass::runOnOperation() { ConversionTarget target(*ctx); target.addIllegalOp(); - target.addLegalDialect(); + target.addLegalDialect< + arith::ArithDialect, rock::RockDialect, memref::MemRefDialect, + affine::AffineDialect, vector::VectorDialect, linalg::LinalgDialect, + scf::SCFDialect, math::MathDialect, amdgpu::AMDGPUDialect>(); target.addLegalOp(); RewritePatternSet patterns(ctx); diff --git a/mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp b/mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp index 64b2eb1f7197..254eea17f42e 100644 --- a/mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp @@ -22,12 +22,12 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/Rock/IR/Rock.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Rock/Passes.h" #include "mlir/Dialect/Rock/Transforms/RockMultibuffer.h" #include "mlir/Dialect/Rock/utility/loweringUtils.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Pass/PassManager.h" diff --git a/mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp b/mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp index 025bbef05334..702c8ca5f72c 100644 --- a/mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp @@ -823,7 +823,8 @@ LogicalResult ThreadwiseReadIntoRewritePattern::matchAndRewrite( // may have fewer dimensions (dstRank). The last dstRank elements of the // domain-1 coords correspond to the dest buffer dimensions. int64_t dstRank = dstBufferType.getRank(); - Block::BlockArgListType allDestCoords = loadLoop.getLowerCoords(/*domain=*/1); + Block::BlockArgListType allDestCoords = + loadLoop.getLowerCoords(/*domain=*/1); size_t dropCount = allDestCoords.size() - dstRank; SmallVector destCoords(allDestCoords.begin() + dropCount, allDestCoords.end()); From da064559a638e2755ca896651832e0afe8640e9a Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Tue, 10 Mar 2026 21:45:00 +0000 Subject: [PATCH 14/18] Fix some LIT tests --- .../gridwise_attention_accel_lowering.mlir | 97 +++++++++---------- ...gridwise_attention_accel_lowering_gqa.mlir | 6 ++ .../lowering_blockwise_broadcast_reduce.mlir | 92 +++++++++--------- .../test/Dialect/Rock/test_rock_pipeline.mlir | 8 +- .../toblockwise_attention_accel_lowering.mlir | 46 ++++----- 5 files changed, 123 insertions(+), 126 deletions(-) diff --git a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir index 119dff173b7a..da89afdd99c0 100644 --- a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir +++ b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir @@ -106,6 +106,9 @@ // CHECK-DAG: %[[gemm0ValSubMaxExp:.+]] = math.exp2 %[[gemm0ValSubMax]] // CHECK-DAG: rock.in_bounds_store %[[gemm0ValSubMaxExp]] -> %[[gemm0NormExp:.+]][ + // V prefetch: write V data to LDS + // CHECK: %[[ldsG0BStore:.+]] = rock.alloc() : memref<4096xi8, #gpu.address_space> + // CHECK: %[[ldsReductionWS2:.+]] = rock.alloc() : memref<256xi8, #gpu.address_space> // CHECK: %[[ldsReductionWS2View:.+]] = memref.view %[[ldsReductionWS2]][{{.*}}][] : memref<256xi8, #gpu.address_space> to memref<64xf32, #gpu.address_space> // CHECK: rock.blockwise_broadcast_reduce sum {{.*}} %[[gemm0NormExp]] into %[[gemm0NormExpSum:[0-9]+]] using %[[ldsReductionWS2View]] @@ -163,63 +166,51 @@ // Store to LDS G1A tile buffer // CHECK-DAG: rock.threadwise_write_all {{.*}} %[[G1AregsKpack]] -> [](%[[viewG1AStoreTr7]]) - // CHECK-DAG: %[[ldsG0BStore:.+]] = rock.alloc() : memref<4096xi8, #gpu.address_space> + // Gemm1 (unrolled) + // Iteration 0: V data already in LDS from prefetch + // CHECK: rock.stage + // CHECK-DAG: rock.fill(%[[gemm1AccBuf:.+]], %[[zeroVecF32]]) + // CHECK: memref.view %[[ldsG0BStore]][{{.*}}][] : memref<4096xi8, #gpu.address_space> to memref<1024xf32, #gpu.address_space> + // CHECK: memref.view %[[ldsG1AStore]][{{.*}}][] : memref<4096xi8, #gpu.address_space> to memref<1024xf32, #gpu.address_space> + // CHECK: rock.blockwise_gemm_accel + // CHECK: {name = "MMA"} - // Repack G1B tile regs for better LDS store vectorization - // CHECK-DAG: %[[G1Bregs:.+]] = rock.alloc() : memref<16xf32, #gpu.address_space> - // CHECK-DAG: %[[G1BregsKpack:.+]] = rock.alloc() : memref<16xf32, #gpu.address_space> + // CHECK: rock.stage + // CHECK: rock.transforming_for + // CHECK: memref.load %[[gemm1AccBuf]][ + // CHECK: rock.in_bounds_store - // Gemm1 - // CHECK: scf.for %[[g1MIter:.+]] + // CHECK: memref.subview %[[attnOutBuf]] + // Reduction corrections + // CHECK: rock.transforming_for + // CHECK: arith.mulf + // CHECK: arith.addf + // CHECK: {name = "PostProcess"} + + // Iteration 1: Load next V tile // CHECK: rock.stage - // Load G1B tile from global to regs - // CHECK-DAG: %[[VTr0:.+]] = rock.transform %[[V]] by - // CHECK-DAG: %[[VTr1:.+]] = rock.transform %[[VTr0]] by - // CHECK-DAG: rock.threadwise_read_into {{.*}}(%[[VTr1]]) {{.*}} -> %[[G1Bregs]] : + // CHECK: rock.threadwise_read_into // CHECK: {name = "GlobalRead"} - // CHECK: rock.stage - // CHECK-DAG: %[[G1BregsTr0:.+]] = rock.transform %[[G1Bregs]] by - // CHECK-DAG: %[[G1BregsTr1:.+]] = rock.transform %[[G1BregsTr0]] by - // CHECK-DAG: %[[G1BregsKpackTr0:.+]] = rock.transform %[[G1BregsKpack]] by - // CHECK-DAG: %[[G1BregsKpackTr1:.+]] = rock.transform %[[G1BregsKpackTr0:.+]] by - // CHECK-DAG: rock.threadwise_copy %[[G1BregsTr1]] -> %[[G1BregsKpackTr1]] - // Store to LDS G1B tile buffer - // CHECK-DAG: %[[viewG1BStore:.+]] = memref.view %[[ldsG0BStore]][{{.*}}][] : memref<4096xi8, #gpu.address_space> to memref<1024xf32, #gpu.address_space> - // CHECK-DAG: %[[viewG1BStoreTr0:.+]] = rock.transform %[[viewG1BStore]] - // CHECK-DAG: %[[viewG1BStoreTr1:.+]] = rock.transform %[[viewG1BStoreTr0]] - // CHECK-DAG: %[[viewG1BStoreTr2:.+]] = rock.transform %[[viewG1BStoreTr1]] - // CHECK-DAG: %[[viewG1BStoreTr3:.+]] = rock.transform %[[viewG1BStoreTr2]] - // CHECK-DAG: rock.threadwise_write_all {{.*}} %[[G1BregsKpack]] -> [](%[[viewG1BStoreTr3]]) + // CHECK: rock.stage + // CHECK: memref.view %[[ldsG0BStore]] + // CHECK: rock.threadwise_write_all // CHECK: {name = "LDSWrite"} - // Emit blockwise gemm1 - - // CHECK: rock.stage - // CHECK-DAG: rock.fill(%[[gemm1AccBuf:.+]], %[[zeroVecF32]]) - // CHECK-DAG: %[[view2G1AStore:.+]] = memref.view %[[ldsG1AStore]][{{.*}}][] : memref<4096xi8, #gpu.address_space> to memref<1024xf32, #gpu.address_space> - // CHECK-DAG: %[[view2G1BStore:.+]] = memref.view %[[ldsG0BStore]][{{.*}}][] : memref<4096xi8, #gpu.address_space> to memref<1024xf32, #gpu.address_space> - // CHECK: rock.blockwise_gemm_accel %[[gemm1AccBuf]] += %[[preAccelRegB:.+]] from %[[view2G1BStore]] * %[[preAccelRegA:.+]] from %[[view2G1AStore]] + // CHECK: rock.stage + // CHECK: rock.fill + // CHECK: memref.view %[[ldsG0BStore]] + // CHECK: memref.view %[[ldsG1AStore]] + // CHECK: rock.blockwise_gemm_accel // CHECK: {name = "MMA"} - // CHECK: rock.stage + // CHECK: rock.stage // CHECK: rock.transforming_for - // CHECK: %[[tmp1:.+]] = memref.load %[[gemm1AccBuf]][ - // CHECK: rock.in_bounds_store %[[tmp1]] -> %[[gemm1AccBufScalar:.+]][ - - // CHECK: %[[sliceAttnOutBuf:.+]] = memref.subview %[[attnOutBuf]] - // Reduction corrections + // CHECK: memref.subview %[[attnOutBuf]] // CHECK: rock.transforming_for - // CHECK-DAG: %[[maxdiffexp:.+]] = rock.in_bounds_load %[[maxdiffexpbuf]] - // CHECK-DAG: %[[attnOutVal:.+]] = rock.in_bounds_load %[[sliceAttnOutBuf]] - // CHECK-DAG: %[[gemm1Val:.+]] = rock.in_bounds_load %[[gemm1AccBufScalar]] - - // CHECK-DAG: %[[attnOutBufMul:.+]] = arith.mulf %[[attnOutVal]], %[[maxdiffexp]] - // CHECK-DAG: %[[newattnOutVal:.+]] = arith.addf %[[attnOutBufMul]], %[[gemm1Val]] - // CHECK-DAG: rock.in_bounds_store %[[newattnOutVal]] -> %[[sliceAttnOutBuf]] - // CHECK : } + // CHECK: arith.mulf + // CHECK: arith.addf // CHECK: {name = "PostProcess"} - // CHECK : } // CHECK : } // CHECK : %[[flatAttnOutBuf:.+]] = memref.collapse_shape %[[attnOutBuf]] // CHECK : rock.threadwise_write_all {{.*}} %[[flatAttnOutBuf]] -> {{.*}}(%[[O]]) @@ -255,7 +246,7 @@ func.func @gridwise_attn_kvcache(%arg0: memref<1x384x64xf32>, %arg1: memref<1x64 // CHECK: %[[num:.+]] = arith.addi %[[currSeqLenIndex]], %[[c32]] : index // CHECK-NEXT: %[[numIter:.+]] = arith.divui %[[num]], %[[c32]] : index // CHECK-NEXT: %[[lastIter:.+]] = arith.subi %[[numIter]], %[[c1]] : index - // CHECK-NEXT: scf.for %[[iterIndex:.+]] = %[[c0]] to %[[numIter]] step %[[c1]] { + // CHECK: scf.for %[[iterIndex:.+]] = %[[c0]] to %[[numIter]] step %[[c1]] { // CHECK: %[[comparison:.+]] = arith.cmpi eq, %[[iterIndex]], %[[lastIter]] : index // CHECK-NEXT: scf.if %[[comparison]] { // CHECK: rock.transforming_for {forceUnroll, useIndexDiffs} (%[[dim0:.+]], %[[dim1:.+]], %[[dim2:.+]]) = [{{.*}}]({{.*}}), ({{.*}}) = [] @@ -301,7 +292,7 @@ func.func @gridwise_attn_causal_kvcache(%arg0: memref<1x384x64xf32>, %arg1: memr // CHECK-NEXT: %[[minQPlusOne:.+]] = arith.addi %[[minQEffective]], %[[c1]] : index // CHECK-NEXT: %[[firstCausalMaskIter:.+]] = arith.divui %[[minQPlusOne]], %[[c32]] : index // CHECK: %[[lastIter:.+]] = arith.subi %[[numIter]], %[[c1]] : index - // CHECK-NEXT: scf.for %[[iterIndex:.+]] = %[[c0]] to %[[numIter]] step %[[c1]] { + // CHECK: scf.for %[[iterIndex:.+]] = %[[c0]] to %[[numIter]] step %[[c1]] { // CHECK: %[[comparison:.+]] = arith.cmpi eq, %[[iterIndex]], %[[lastIter]] : index // CHECK-NEXT: scf.if %[[comparison]] { // CHECK: rock.transforming_for {forceUnroll, useIndexDiffs} (%[[dim0:.+]], %[[dim1:.+]], %[[dim2:.+]]) = [{{.*}}]({{.*}}), ({{.*}}) = [] @@ -390,7 +381,7 @@ func.func @gridwise_attn_lse_kvcache(%arg0: memref<1x384x64xf32>, %arg1: memref< // CHECK: %[[num:.+]] = arith.addi %[[currSeqLenIndex]], %[[c32]] : index // CHECK-NEXT: %[[numIter:.+]] = arith.divui %[[num]], %[[c32]] : index // CHECK-NEXT: %[[lastIter:.+]] = arith.subi %[[numIter]], %[[c1]] : index - // CHECK-NEXT: scf.for %[[iterIndex:.+]] = %[[c0]] to %[[numIter]] step %[[c1]] { + // CHECK: scf.for %[[iterIndex:.+]] = %[[c0]] to %[[numIter]] step %[[c1]] { // CHECK: %[[comparison:.+]] = arith.cmpi eq, %[[iterIndex]], %[[lastIter]] : index // CHECK-NEXT: scf.if %[[comparison]] { // CHECK: rock.transforming_for {forceUnroll, useIndexDiffs} (%[[dim0:.+]], %[[dim1:.+]], %[[dim2:.+]]) = [{{.*}}]({{.*}}), ({{.*}}) = [] @@ -437,7 +428,7 @@ func.func @gridwise_attn_softmaxtype(%arg0: memref<1x384x64xf16>, %arg1: memref< // CHECK: %[[num:.+]] = arith.addi %[[currSeqLenIndex]], %[[c32]] : index // CHECK-NEXT: %[[numIter:.+]] = arith.divui %[[num]], %[[c32]] : index // CHECK-NEXT: %[[lastIter:.+]] = arith.subi %[[numIter]], %[[c1]] : index - // CHECK-NEXT: scf.for %[[iterIndex:.+]] = %[[c0]] to %[[numIter]] step %[[c1]] { + // CHECK: scf.for %[[iterIndex:.+]] = %[[c0]] to %[[numIter]] step %[[c1]] { // CHECK: %[[comparison:.+]] = arith.cmpi eq, %[[iterIndex]], %[[lastIter]] : index // CHECK-NEXT: scf.if %[[comparison]] { // CHECK: rock.blockwise_broadcast_reduce max {{.*}} memref<16xf32, #gpu.address_space> using memref<64xf32, #gpu.address_space> into memref<16xf32, #gpu.address_space> @@ -477,7 +468,7 @@ func.func @gridwise_attn_softmaxtype_with_scaling(%arg0: memref<1x384x64xf16>, % // CHECK: %[[num:.+]] = arith.addi %[[currSeqLenIndex]], %[[c32]] : index // CHECK-NEXT: %[[numIter:.+]] = arith.divui %[[num]], %[[c32]] : index // CHECK-NEXT: %[[lastIter:.+]] = arith.subi %[[numIter]], %[[c1]] : index - // CHECK-NEXT: scf.for %[[iterIndex:.+]] = %[[c0]] to %[[numIter]] step %[[c1]] { + // CHECK: scf.for %[[iterIndex:.+]] = %[[c0]] to %[[numIter]] step %[[c1]] { // CHECK: rock.transforming_for // CHECK: rock.in_bounds_store {{.*}} -> {{.*}} : vector<16xf16> -> memref<16xf16, #gpu.address_space>, index // CHECK: linalg.generic @@ -547,7 +538,7 @@ func.func @gridwise_attn_splitkv_lse_kvcache(%arg0: memref<1x384x64xf32>, %arg1: // CHECK-NEXT: %[[lastIter:.+]] = arith.subi %[[endIter]], %[[c1]] : index // CHECK-NEXT: %[[someWorkToDo:.+]] = arith.cmpi ugt, %[[endIter]], %[[startIter]] : index // CHECK-NEXT: scf.if %[[someWorkToDo]] - // CHECK-NEXT: scf.for %[[iterIndex:.+]] = %[[startIter]] to %[[endIter]] step %[[c1]] { + // CHECK: scf.for %[[iterIndex:.+]] = %[[startIter]] to %[[endIter]] step %[[c1]] { // CHECK: %[[comparison:.+]] = arith.cmpi eq, %[[iterIndex]], %[[lastIter]] : index // CHECK-NEXT: scf.if %[[comparison]] { // CHECK: rock.transforming_for {forceUnroll, useIndexDiffs} (%[[dim0:.+]], %[[dim1:.+]], %[[dim2:.+]]) = [{{.*}}]({{.*}}), ({{.*}}) = [] @@ -593,6 +584,8 @@ func.func @multiple_linalg_generics_in_presoftmax_ops(%arg0: memref<59136xf16>, // CHECK: rock.transforming_for // CHECK: rock.in_bounds_store // CHECK-SAME: %[[GEMM0_BUFFER_FLAT]] + // V prefetch: issue global reads before softmax + // CHECK: {name = "VGlobalRead"} // CHECK: %[[ARG2_BUFFER:.*]] = rock.alloc() : memref<16xf16, #gpu.address_space> // CHECK: rock.threadwise_read_into // CHECK-SAME: (%arg2) @@ -665,6 +658,8 @@ func.func @multiple_linalg_generics_in_presoftmax_ops_with_transforms_inbetween( // CHECK: rock.transforming_for // CHECK: rock.in_bounds_store // CHECK-SAME: %[[GEMM0_BUFFER_FLAT]] + // V prefetch: issue global reads before softmax + // CHECK: {name = "VGlobalRead"} // CHECK: %[[ARG2_BUFFER:.*]] = rock.alloc() : memref<16xf16, #gpu.address_space> // CHECK: rock.threadwise_read_into // CHECK-SAME: (%arg2) diff --git a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_gqa.mlir b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_gqa.mlir index dbc1b50fe2fd..7cd09b8b1934 100644 --- a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_gqa.mlir +++ b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_gqa.mlir @@ -31,6 +31,12 @@ func.func @gridwise_attn_causal_scale_gqa(%arg0: memref<8192xf16>, %arg1: memref // CHECK: rock.transforming_for // CHECK: rock.in_bounds_store %{{.*}} -> %[[gemmOut:.+]][{{.*}}] + // V prefetch: issue global reads before softmax + // CHECK: rock.alloc() : memref<32xf16, #gpu.address_space> + // CHECK: rock.blockwise_load_tile + // CHECK-SAME: GlobalReadOnly + // CHECK: amdgpu.sched_barrier + // fusion // CHECK: %[[loadInto:.+]] = rock.alloc() : memref<32xf16, #gpu.address_space> // CHECK: rock.threadwise_read_into diff --git a/mlir/test/Dialect/Rock/lowering_blockwise_broadcast_reduce.mlir b/mlir/test/Dialect/Rock/lowering_blockwise_broadcast_reduce.mlir index 4f678f3118d8..5d265199bb68 100644 --- a/mlir/test/Dialect/Rock/lowering_blockwise_broadcast_reduce.mlir +++ b/mlir/test/Dialect/Rock/lowering_blockwise_broadcast_reduce.mlir @@ -56,9 +56,8 @@ // CHECK: rock.transforming_for {{.*}} (%[[LD_COORD:.*]]) = [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[TID0]], %[[ZERO]], %[[ZERO]]), {{.*}}, (%[[LDS_ST_COORD:.*]]) = [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP13]], #[[TMAP12]]](%[[TID0]], %[[ZERO]], %[[ZERO]]) {{.*}} bounds [1, 1, 20] strides [1, 1, 4] { // CHECK: %[[TO_REDUCE_VAL:.*]] = rock.in_bounds_load {{.*}}[%[[LD_COORD]]] // CHECK: %[[TO_REDUCE_ACC:.*]] = rock.in_bounds_load %[[TO_REDUCE_ACC_MEMREF]][%[[ZERO]]] - // CHECK: %[[MAX_REDUCE:.*]] = vector.reduction , %[[TO_REDUCE_VAL]] : vector<4xf32> into f32 - // CHECK: %[[ACC_NEW:.*]] = arith.maxnumf %[[TO_REDUCE_ACC]], %[[MAX_REDUCE]] - // CHECK: rock.in_bounds_store %[[ACC_NEW]] -> %arg2[%[[LDS_ST_COORD]]] + // CHECK: %[[MAX_REDUCE:.*]] = vector.reduction , %[[TO_REDUCE_VAL]], %[[TO_REDUCE_ACC]] : vector<4xf32> into f32 + // CHECK: rock.in_bounds_store %[[MAX_REDUCE]] -> %[[TO_REDUCE_ACC_MEMREF]][%[[ZERO]]] // CHECK: rock.lds_barrier // CHECK: rock.threadwise_read_into {{.*}}(%arg2) {{.*}} -> %arg1 @@ -83,12 +82,14 @@ func.func @rock_blockwise_reducesum_rthreads_fix(%input_reg : memref<3xf32, #gpu // CHECK: %[[RTID:.*]] = arith.divsi %[[TID]], %c3 // CHECK: %[[NRTID:.*]] = arith.remsi %[[TID]], %c3 - // Threadwise partial reduction into LDS uses rDimPerRThread=5 + // Threadwise partial reduction uses rDimPerRThread=5 // CHECK: rock.transforming_for // CHECK-SAME: bounds [1, 1, 5] - // CHECK: %[[PLUS_ONE:.*]] = arith.addi %[[RTID]], %c1 - // CHECK: %[[BCHECK:.*]] = arith.cmpi slt, %[[PLUS_ONE]], %c2 - // CHECK: scf.if %[[BCHECK]] + // Write reduced value back to LDS + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: rock.lds_barrier + // Branchless reduction across rThreads reads from LDS // CHECK: rock.lds_barrier // CHECK: rock.threadwise_read_into rock.blockwise_broadcast_reduce sum [#inputView][#inputView_tid][#inputView_iter]%input_reg into %output_reg using %ws_lds {axis = 1 : index, blockSize = 10 : i32, nrDimPerThread = 3 : index} : memref<3xf32, #gpu.address_space> using memref<30xf32, #gpu.address_space> into memref<3xf32, #gpu.address_space> @@ -123,7 +124,7 @@ func.func @rock_blockwise_reducesum_rthreads_fix(%input_reg : memref<3xf32, #gpu // CHECK: func @rock_blockwise_reducesum_nr_threads_lt_blocksize -// CHECK-DAG: %[[ZEROFP:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[ZEROFP:.*]] = arith.constant -0.000000e+00 : f32 // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index // CHECK-DAG: %[[TID0:.*]] = rock.workitem_id : index @@ -135,50 +136,47 @@ func.func @rock_blockwise_reducesum_rthreads_fix(%input_reg : memref<3xf32, #gpu // CHECK: rock.transforming_for {{.*}} (%[[LDS_LD_COORD:.*]]) = [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[PRT_GROUP_IDX]], %[[PRT_THREAD_IDX]], %c0) {{.*}} bounds [1, 1, 4] strides [1, 1, 4] { // CHECK: %[[TO_REDUCE_VAL:.*]] = rock.in_bounds_load {{.*}}[%[[LDS_LD_COORD]]] // CHECK: %[[TO_REDUCE_ACC:.*]] = rock.in_bounds_load {{.*}}[%c0] - // CHECK: %[[SUM_REDUCE:.*]] = vector.reduction , %[[TO_REDUCE_VAL]] : vector<4xf32> into f32 - // CHECK: %[[ACC_NEW:.*]] = arith.addf %[[TO_REDUCE_ACC]], %[[SUM_REDUCE]] - // CHECK: rock.in_bounds_store %[[ACC_NEW]] -> {{.*}}[%c0] {{.*}} #gpu.address_space> + // CHECK: %[[SUM_REDUCE:.*]] = vector.reduction , %[[TO_REDUCE_VAL]], %[[TO_REDUCE_ACC]] : vector<4xf32> into f32 + // CHECK: rock.in_bounds_store %[[SUM_REDUCE]] -> {{.*}}[%c0] {{.*}} #gpu.address_space> +// Write partial result to LDS // CHECK: rock.transforming_for {{.*}}[#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[PRT_GROUP_IDX]], %[[PRT_THREAD_IDX]], %c0) {{.*}} bounds [1, 1, 1] strides [1, 1, 1] { // CHECK: rock.in_bounds_load {{.*}} : memref<1xf32, #gpu.address_space>, index -> f32 // CHECK: rock.in_bounds_store {{.*}} : f32 -> memref<80xf32, #gpu.address_space>, index // CHECK: rock.lds_barrier -// Partial threadwise reductions done now... - -// CHECK: %[[PLUS_FOUR_OFFSET:.*]] = arith.addi %[[PRT_THREAD_IDX]], %c4 -// CHECK: %[[PLUS_FOUR_BCHECK:.*]] = arith.cmpi slt, %[[PLUS_FOUR_OFFSET]], %c5 -// CHECK: scf.if %[[PLUS_FOUR_BCHECK]] { - // CHECK: rock.transforming_for - // CHECK-SAME: (%[[LDS_LD_COORD1A:.*]]) = [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[PRT_GROUP_IDX]], %[[PRT_THREAD_IDX]], %c0) - // CHECK-SAME: (%[[LDS_LD_COORD1B:.*]]) = [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[PRT_GROUP_IDX]], %[[PLUS_FOUR_OFFSET]], %c0) - // CHECK: rock.in_bounds_load {{.*}}[%[[LDS_LD_COORD1A]]] - // CHECK: rock.in_bounds_load {{.*}}[%[[LDS_LD_COORD1B]]] - // CHECK: arith.addf - // CHECK: rock.in_bounds_store {{.*}}[%[[LDS_LD_COORD1A]]] -// CHECK: rock.lds_barrier - -// CHECK: %[[PLUS_TWO_OFFSET:.*]] = arith.addi %[[PRT_THREAD_IDX]], %c2 -// CHECK: %[[PLUS_TWO_BCHECK:.*]] = arith.cmpi slt, %[[PLUS_TWO_OFFSET]], %c4 -// CHECK: scf.if %[[PLUS_TWO_BCHECK]] { - // CHECK: rock.transforming_for - // CHECK-SAME: (%[[LDS_LD_COORD1A:.*]]) = [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[PRT_GROUP_IDX]], %[[PRT_THREAD_IDX]], %c0) - // CHECK-SAME: (%[[LDS_LD_COORD1B:.*]]) = [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[PRT_GROUP_IDX]], %[[PLUS_TWO_OFFSET]], %c0) - // CHECK: rock.in_bounds_load {{.*}}[%[[LDS_LD_COORD1A]]] - // CHECK: rock.in_bounds_load {{.*}}[%[[LDS_LD_COORD1B]]] - // CHECK: arith.addf - // CHECK: rock.in_bounds_store {{.*}}[%[[LDS_LD_COORD1A]]] -// CHECK: rock.lds_barrier - -// CHECK: %[[PLUS_ONE_OFFSET:.*]] = arith.addi %[[PRT_THREAD_IDX]], %c1 -// CHECK: %[[PLUS_ONE_BCHECK:.*]] = arith.cmpi slt, %[[PLUS_ONE_OFFSET]], %c2 -// CHECK: scf.if %[[PLUS_ONE_BCHECK]] { - // CHECK: rock.transforming_for - // CHECK-SAME: (%[[LDS_LD_COORD1A:.*]]) = [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[PRT_GROUP_IDX]], %[[PRT_THREAD_IDX]], %c0) - // CHECK-SAME: (%[[LDS_LD_COORD1B:.*]]) = [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[PRT_GROUP_IDX]], %[[PLUS_ONE_OFFSET]], %c0) - // CHECK: rock.in_bounds_load {{.*}}[%[[LDS_LD_COORD1A]]] - // CHECK: rock.in_bounds_load {{.*}}[%[[LDS_LD_COORD1B]]] - // CHECK: arith.addf - // CHECK: rock.in_bounds_store {{.*}}[%[[LDS_LD_COORD1A]]] +// Branchless reduction: alloc private accumulator, each thread reads all +// rTid partial values from LDS and reduces locally (no tree reduction). +// CHECK: %[[BRLESS_ACC:.+]] = rock.alloc() : memref<1xf32, #gpu.address_space> +// Iteration 0: load rThread=0 partial from LDS into accumulator (no add) +// CHECK: rock.transforming_for {{.*}} [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[PRT_GROUP_IDX]], {{.*}}) {{.*}} bounds [1, 1, 1] + // CHECK: %[[LDS_INIT:.+]] = rock.in_bounds_load %arg2[{{.*}}] : memref<80xf32, #gpu.address_space>, index -> f32 + // CHECK: rock.in_bounds_store %[[LDS_INIT]] -> %[[BRLESS_ACC]][{{.*}}] : f32 -> memref<1xf32, #gpu.address_space>, index +// Iteration 1: load rThread=1 partial from LDS and accumulate +// CHECK: rock.transforming_for {{.*}} [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[PRT_GROUP_IDX]], {{.*}}) {{.*}} bounds [1, 1, 1] + // CHECK: %[[LDS_V1:.+]] = rock.in_bounds_load %arg2[{{.*}}] : memref<80xf32, #gpu.address_space>, index -> f32 + // CHECK: %[[ACC_V1:.+]] = rock.in_bounds_load %[[BRLESS_ACC]][{{.*}}] : memref<1xf32, #gpu.address_space>, index -> f32 + // CHECK: %[[SUM_1:.+]] = arith.addf %[[ACC_V1]], %[[LDS_V1]] : f32 + // CHECK: rock.in_bounds_store %[[SUM_1]] -> %[[BRLESS_ACC]][{{.*}}] : f32 -> memref<1xf32, #gpu.address_space>, index +// Iterations 2-4: same pattern (load from LDS, load acc, addf, store acc) +// CHECK: rock.transforming_for {{.*}} bounds [1, 1, 1] + // CHECK: rock.in_bounds_load %arg2{{.*}} : memref<80xf32, #gpu.address_space> + // CHECK: rock.in_bounds_load %[[BRLESS_ACC]] + // CHECK: arith.addf + // CHECK: rock.in_bounds_store {{.*}} -> %[[BRLESS_ACC]] +// CHECK: rock.transforming_for {{.*}} bounds [1, 1, 1] + // CHECK: rock.in_bounds_load %arg2{{.*}} : memref<80xf32, #gpu.address_space> + // CHECK: rock.in_bounds_load %[[BRLESS_ACC]] + // CHECK: arith.addf + // CHECK: rock.in_bounds_store {{.*}} -> %[[BRLESS_ACC]] +// CHECK: rock.transforming_for {{.*}} bounds [1, 1, 1] + // CHECK: rock.in_bounds_load %arg2{{.*}} : memref<80xf32, #gpu.address_space> + // CHECK: rock.in_bounds_load %[[BRLESS_ACC]] + // CHECK: arith.addf + // CHECK: rock.in_bounds_store {{.*}} -> %[[BRLESS_ACC]] +// Write fully reduced value back to LDS +// CHECK: %[[FINAL_RED:.+]] = rock.in_bounds_load %[[BRLESS_ACC]][{{.*}}] : memref<1xf32, #gpu.address_space>, index -> f32 +// CHECK: rock.transforming_for {{.*}} bounds [1, 1, 1] + // CHECK: rock.in_bounds_store %[[FINAL_RED]] -> %arg2[{{.*}}] : f32 -> memref<80xf32, #gpu.address_space>, index // CHECK: rock.lds_barrier // All reductions are done and stored for each point in joint non-reduction space. diff --git a/mlir/test/Dialect/Rock/test_rock_pipeline.mlir b/mlir/test/Dialect/Rock/test_rock_pipeline.mlir index 582bdc24dcf2..b118115d5a95 100644 --- a/mlir/test/Dialect/Rock/test_rock_pipeline.mlir +++ b/mlir/test/Dialect/Rock/test_rock_pipeline.mlir @@ -242,15 +242,11 @@ func.func @rock_pipeline_no_stages_ii_1(%input : memref<16xi8, #gpu.address_spac // CHECK: %[[rawRegA:.*]] = rock.alloc() : memref<16xi8, #gpu.address_space> // CHECK: %[[rawRegB:.*]] = rock.alloc() : memref<16xi8, #gpu.address_space> - // CHECK: %[[lds0View:.*]] = memref.view {{.*}} - // CHECK: %[[rawRegAView:.*]] = memref.view {{.*}} - // CHECK: %[[rawRegBView:.*]] = memref.view {{.*}} - // CHECK: scf.for // CHECK-SAME: %[[c0]] to %[[c16]] // CHECK-NOT: name = "__fwd_barrier__" - // CHECK: rock.extract_multibuffer(%[[lds0View]]) - // CHECK: rock.extract_multibuffer(%[[lds0View]]) + // CHECK: rock.extract_multibuffer(%[[rawRegA]]) + // CHECK: rock.extract_multibuffer(%[[lds0]]) scf.for %arg3 = %c0 to %c16 step %c1 { %a = memref.load %input[%arg3] : memref<16xi8, #gpu.address_space> memref.store %a, %regA[%arg3] : memref<16xi8, #gpu.address_space> diff --git a/mlir/test/Dialect/Rock/toblockwise_attention_accel_lowering.mlir b/mlir/test/Dialect/Rock/toblockwise_attention_accel_lowering.mlir index cc5b9879ccee..8f1b39b14a42 100644 --- a/mlir/test/Dialect/Rock/toblockwise_attention_accel_lowering.mlir +++ b/mlir/test/Dialect/Rock/toblockwise_attention_accel_lowering.mlir @@ -60,6 +60,9 @@ // CHECK-DAG: %[[gemm0ValSubMaxExp:.+]] = math.exp2 %[[gemm0ValSubMax]] // CHECK-DAG: rock.in_bounds_store %[[gemm0ValSubMaxExp]] -> %[[gemm0NormExp:.+]][ + // V prefetch: write V data to LDS + // CHECK: %[[ldsG0AStore:.+]] = rock.alloc() : memref<4096xi8, #gpu.address_space> + // CHECK: %[[ldsReductionWS2:.+]] = rock.alloc() : memref<256xi8, #gpu.address_space> // CHECK: %[[ldsReductionWS2View:.+]] = memref.view %[[ldsReductionWS2]][{{.*}}][] : memref<256xi8, #gpu.address_space> to memref<64xf32, #gpu.address_space> // CHECK: rock.blockwise_broadcast_reduce sum {{.*}} %[[gemm0NormExp]] into %[[gemm0NormExpSum:[0-9]+]] using %[[ldsReductionWS2View]] @@ -90,8 +93,9 @@ // CHECK-DAG: %[[gemm0NormExpTr3:.+]] = rock.transform %[[gemm0NormExpTr2]] // CHECK-DAG: %[[gemm0NormExpTr4:.+]] = rock.transform %[[gemm0NormExpTr3]] // CHECK-DAG: %[[gemm0NormExpTr5:.+]] = rock.transform %[[gemm0NormExpTr4]] - - // CHECK-DAG: %[[ldsG1BStore:.+]] = rock.alloc() : memref<4096xi8, #gpu.address_space> + + // G1 A buffer alloc + // CHECK: %[[ldsG1BStore:.+]] = rock.alloc() : memref<4096xi8, #gpu.address_space> // Viewing another set of register with kPack packing // CHECK: %[[G1AregsKpackTr0:.+]] = rock.transform %[[G1AregsKpack:.+]] by @@ -116,39 +120,37 @@ // Store to LDS G1A tile buffer // CHECK-DAG: rock.threadwise_write_all {{.*}} %[[G1AregsKpack]] -> [](%[[viewG1AStoreTr7]]) - - // CHECK-DAG: %[[ldsG0AStore:.+]] = rock.alloc() : memref<4096xi8, #gpu.address_space> - - // Gemm1 - // CHECK: scf.for %[[g1MIter:.+]] - // CHECK: rock.blockwise_load_tile %[[V]]{{.*}} LDS -> %[[ldsG0AStore]] -> %[[preAccelRegV:[0-9]+]] {{.*}}#rock - // Emit blockwise gemm1 - // rock.stage + // Gemm1 (unrolled) + // Iteration 0: V data already in LDS from prefetch // CHECK-DAG: rock.fill(%[[gemm1AccBuf:.+]], %[[zeroVecF32]]) // CHECK: %[[view2G1AStore:.+]] = memref.view %[[ldsG0AStore]][{{.*}}][] : memref<4096xi8, #gpu.address_space> to memref<1024xf32, #gpu.address_space> // CHECK: %[[view2G1BStore:.+]] = memref.view %[[ldsG1BStore]][{{.*}}][] : memref<4096xi8, #gpu.address_space> to memref<1024xf32, #gpu.address_space> - // CHECK: rock.blockwise_gemm_accel %[[gemm1AccBuf]] += %[[preAccelRegV]] from %[[view2G1AStore]] * %[[preAccelRegA:[0-9]+]] from %[[view2G1BStore]] + // CHECK: rock.blockwise_gemm_accel %[[gemm1AccBuf]] += {{.*}} from %[[view2G1AStore]] * {{.*}} from %[[view2G1BStore]] // CHECK: {name = "MMA"} - // rock.stage // CHECK: rock.transforming_for // CHECK: %[[tmp1:.+]] = memref.load %[[gemm1AccBuf]][ // CHECK: rock.in_bounds_store %[[tmp1]] -> %[[gemm1AccBufScalar:.+]][ - // CHECK: %[[sliceAttnOutBuf:.+]] = memref.subview %[[attnOutBuf]] + // CHECK: memref.subview %[[attnOutBuf]] // Reduction corrections // CHECK: rock.transforming_for - // CHECK-DAG: %[[maxdiffexp:.+]] = rock.in_bounds_load %[[maxdiffexpbuf]] - // CHECK-DAG: %[[attnOutVal:.+]] = rock.in_bounds_load %[[sliceAttnOutBuf]] - // CHECK-DAG: %[[gemm1Val:.+]] = rock.in_bounds_load %[[gemm1AccBufScalar]] - - // CHECK-DAG: %[[attnOutBufMul:.+]] = arith.mulf %[[attnOutVal]], %[[maxdiffexp]] - // CHECK-DAG: %[[newattnOutVal:.+]] = arith.addf %[[attnOutBufMul]], %[[gemm1Val]] - // CHECK-DAG: rock.in_bounds_store %[[newattnOutVal]] -> %[[sliceAttnOutBuf]] - // CHECK : } + // CHECK: arith.mulf + // CHECK: arith.addf + // CHECK: {name = "PostProcess"} + + // Iteration 1: Load next V tile + // CHECK: rock.blockwise_load_tile %[[V]]{{.*}} LDS -> %[[ldsG0AStore]] + // CHECK-DAG: rock.fill({{.*}}, %[[zeroVecF32]]) + // CHECK: memref.view %[[ldsG0AStore]] + // CHECK: memref.view %[[ldsG1BStore]] + // CHECK: rock.blockwise_gemm_accel + // CHECK: {name = "MMA"} + // CHECK: rock.transforming_for + // CHECK: memref.subview %[[attnOutBuf]] + // CHECK: rock.transforming_for // CHECK: {name = "PostProcess"} - // CHECK : {pipeline = #rock.pipeline<2>} // CHECK : } // CHECK : %[[flatAttnOutBuf:.+]] = memref.collapse_shape %[[attnOutBuf]] // CHECK : rock.threadwise_write_all {{.*}} %[[flatAttnOutBuf]] -> {{.*}}(%[[O]]) From 6018bd8bf5481c93ab2e17781cbfe497bf9dfcf3 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Wed, 11 Mar 2026 14:06:29 +0000 Subject: [PATCH 15/18] Update pipelines test --- mlir/test/rocmlir-driver/pipelines.mlir | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/test/rocmlir-driver/pipelines.mlir b/mlir/test/rocmlir-driver/pipelines.mlir index 70c3857d14e2..9d130db811c6 100644 --- a/mlir/test/rocmlir-driver/pipelines.mlir +++ b/mlir/test/rocmlir-driver/pipelines.mlir @@ -41,6 +41,7 @@ // GPU-NEXT:rock-threadwise-gemm-lowering, // GPU-NEXT:rock-analyze-memory-use, // GPU-NEXT:rock-sugar-to-loops, +// GPU-NEXT:rock-pipeline{rock-pipeline-remove-stages=true}, // GPU-NEXT:rock-clean-math, // GPU-NEXT:math-extend-to-supported-types{extra-types={f16} target-type=f32}, // GPU-NEXT:rock-buffer-load-merge, From b4ac474597d0c611ebe35195c875de842930cdd1 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Thu, 12 Mar 2026 17:34:36 +0000 Subject: [PATCH 16/18] Navi3X fixes --- .../mlir/Dialect/Rock/IR/RockAttrDefs.td | 4 +- mlir/include/mlir/Dialect/Rock/IR/RockOps.td | 3 + mlir/lib/Dialect/Rock/IR/RockDialect.cpp | 12 +- .../BlockwiseLoadTileToThreadwise.cpp | 27 ++-- .../Transforms/GridwiseGemmToBlockwise.cpp | 137 +++++++----------- 5 files changed, 80 insertions(+), 103 deletions(-) diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td index fa123a57b5fb..8fa3b78e01f6 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td @@ -466,6 +466,7 @@ def Rock_GemmLoadTileDirectToLDSDoubleBuffer : I32EnumAttrCase<"DirectToLDSDoubleBuffer", 4>; def Rock_GemmLoadTileGlobalReadOnly : I32EnumAttrCase<"GlobalReadOnly", 5>; def Rock_GemmLoadTileLDSWriteFromRegs : I32EnumAttrCase<"LDSWriteFromRegs", 6>; +def Rock_GemmLoadTileLDSReadOnly : I32EnumAttrCase<"LDSReadOnly", 7>; def Rock_GemmLoadTileType : Rock_I32Enum<"GemmLoadTileType", "GEMM load tile types", @@ -474,7 +475,8 @@ def Rock_GemmLoadTileType Rock_GemmLoadTileDirectToLDSDefault, Rock_GemmLoadTileDirectToLDSDoubleBuffer, Rock_GemmLoadTileGlobalReadOnly, - Rock_GemmLoadTileLDSWriteFromRegs]> { + Rock_GemmLoadTileLDSWriteFromRegs, + Rock_GemmLoadTileLDSReadOnly]> { let cppNamespace = "::mlir::rock"; let genSpecializedAttr = 0; } diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td index e4b0ee7496d6..f80303e44b12 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td @@ -1621,6 +1621,9 @@ def Rock_BlockwiseLoadTileOp - DoubleBuffer: Creates three stages, (1) load from memory, (2) write to LDS, (3) load to registers. - DirectToLDSDefault: Same as Default, but a single stage loads from memory and writes to LDS. - DirectToLDSDoubleBuffer: Same as DoubleBuffer, but a single stage loads from memory and writes to LDS. + - GlobalReadOnly: Loads from global memory into flat registers. + - LDSWriteFromRegs: Writes from flat registers into LDS. + - LDSReadOnly: Reads from LDS into registers. `isA` determines if we are loading an A matrix or B matrix. `G`, `M` and `N` are the GEMM sizes. `elementTypeA` and `elementTypeB` are used to construct AccelEmitter. They are data types for the Matrix A & B of the GEMMs. diff --git a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp index 52914cb74829..c8f749a3f869 100644 --- a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp @@ -2482,13 +2482,17 @@ void BlockwiseLoadTileOp::getEffects( loadType == GemmLoadTileType::DirectToLDSDoubleBuffer; bool singleBuffer = loadType == GemmLoadTileType::Default || loadType == GemmLoadTileType::DirectToLDSDefault; + bool ldsReadOnly = loadType == GemmLoadTileType::LDSReadOnly; - effects.emplace_back(read, &getSourceMutable()); + // LDSReadOnly does not read from global source. + if (!ldsReadOnly) + effects.emplace_back(read, &getSourceMutable()); if (loadType != GemmLoadTileType::BypassLDS) { assert(getDestLDS() != nullptr); - effects.emplace_back(write, &getDestLDSMutable()[0]); - // DoubleBuffer means we write to LDS and then, load from it - if (doubleBuffer) + // LDSReadOnly only reads from LDS, it does not write to it. + if (!ldsReadOnly) + effects.emplace_back(write, &getDestLDSMutable()[0]); + if (doubleBuffer || ldsReadOnly) effects.emplace_back(read, &getDestLDSMutable()[0]); } if (!singleBuffer) { diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp index f4587e78c905..e7755ce09369 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp @@ -203,6 +203,7 @@ class LoweringBlockwiseLoadTileOp final loadType == GemmLoadTileType::DirectToLDSDoubleBuffer; bool doubleBuffer = loadType == GemmLoadTileType::DoubleBuffer || loadType == GemmLoadTileType::DirectToLDSDoubleBuffer; + bool ldsReadOnly = loadType == GemmLoadTileType::LDSReadOnly; // Build LDS transpose config attribute if enabled // The decision was already made in GridwiseGemmToBlockwise pass @@ -248,7 +249,10 @@ class LoweringBlockwiseLoadTileOp final bool ldsWriteFromRegs = loadType == GemmLoadTileType::LDSWriteFromRegs; Value loadBuffer, storeBuffer; - if (globalReadOnly || ldsWriteFromRegs) { + if (ldsReadOnly) { + // LDSReadOnly: no load/store buffers needed, we only read from LDS + // into destRegisters via generateReadLoop. + } else if (globalReadOnly || ldsWriteFromRegs) { // Split-phase load: use the externally-allocated destRegisters buffer // as the shared loadBuffer between the GlobalReadOnly and // LDSWriteFromRegs phases. @@ -284,7 +288,7 @@ class LoweringBlockwiseLoadTileOp final if (isa(parentOp)) b.setInsertionPoint(op); - if (!ldsWriteFromRegs) { + if (!ldsWriteFromRegs && !ldsReadOnly) { // Use distinct stage name for split-phase V prefetch to avoid // conflicting with K/Q GlobalRead stages in the same parent scope. StringRef globalReadStageName = @@ -402,7 +406,7 @@ class LoweringBlockwiseLoadTileOp final rock::YieldOp::create(b, loc); } } else { - if (!directToLDS) { + if (!directToLDS && !ldsReadOnly) { // Use distinct stage name for split-phase V write to avoid // conflicting with K/Q LDSWrite stages in nested loops. StringRef ldsWriteStageName = @@ -483,15 +487,18 @@ class LoweringBlockwiseLoadTileOp final } } - if (doubleBuffer) { - // Pipeline pass will remove this if the loop uses pipelining - LDSBarrierOp::create(b, loc); + if (doubleBuffer || ldsReadOnly) { + if (doubleBuffer) { + // Pipeline pass will remove this if the loop uses pipelining + LDSBarrierOp::create(b, loc); + } - // If we are running double-buffered pipelines, it makes sense to also - // parallelize the LDSRead/MMA stages. We do this here, by splitting the - // MMA loop in two separate stages + // Split the MMA loop into LDSRead and MMA stages so they can be + // parallelized. For LDSReadOnly, use a distinct stage name to avoid + // conflicting with GEMM0's LDSRead stages in the same parent scope. + StringRef ldsReadStageName = ldsReadOnly ? "VLDSRead" : "LDSRead"; auto [stageLDSRead, stageLDSReadNew] = - createOrGetStage(b, loc, "LDSRead", parentOp); + createOrGetStage(b, loc, ldsReadStageName, parentOp); { // Read from LDS into registers PatternRewriter::InsertionGuard guard(b); diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index 15e6410530f3..c71c7dbf0a08 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -2417,11 +2417,17 @@ struct GridwiseAttentionAccelRewritePattern Value gemm0ExpOutBufferToLDS = createBufferForGemmOut(loc, elemTypeV, accelParamsGemm0, rewriter); + // When prefetching V, the peeled GEMM1 iteration uses LDSReadOnly + // (LDS -> accel regs) followed by a DoubleBuffer GEMM (reads from regs). + // Both the LDSReadOnly and DoubleBuffer paths require the register + // buffer to have mRepeats slots. + bool willPrefetchV = op.getEnableSoftmax() && !directToLDS; auto [preAccelRegBufferVForLoad, preAccelRegBufferV] = createRegInterrimBufferForAccel( rewriter, loc, accelParamsGemm1.argTypeA, accelParamsGemm1.kBasePerThread, - doubleBuffering ? accelParamsGemm1.mRepeats : 1, directToLDS); + (doubleBuffering || willPrefetchV) ? accelParamsGemm1.mRepeats : 1, + directToLDS); auto preAccelRegBufferQxKPair = createRegInterrimBufferForAccel( rewriter, loc, accelParamsGemm1.argTypeB, accelParamsGemm1.kBasePerThread, @@ -2788,71 +2794,43 @@ struct GridwiseAttentionAccelRewritePattern accelEmitterPtrGemm0->computeOutputConversion( rewriter, loc, accRegBufferGemm0, gemm0OutBuffer, forceUnroll); - // V p: Issue global reads for V tile 0 before softmax + // V prefetch: Issue global reads for V tile 0 before softmax // to overlap softmax computation with V's global memory latency. - // Uses GlobalReadOnly (global -> regs) and LDSWriteFromRegs - // (regs -> LDS) to split the load across the softmax boundary. + // Uses a three-phase split-load approach: + // Phase 1 (before softmax): GlobalReadOnly — global -> flat regs + // Phase 2 (after softmax): LDSWriteFromRegs — flat regs -> LDS + // Phase 3 (peeled iter): LDSReadOnly — LDS -> accel regs + // The final GEMM1 compute uses DoubleBuffer mode (reads from regs). + // This avoids using Default GEMM reads from LDS, which generate + // incompatible transforms on some architectures (e.g. gfx1100). Value ldsByteBufferV; Value vPrefetchRegs; layout::GridCoordinates gridCoordsGemm1; bool prefetchFirstVTile = op.getEnableSoftmax() && !directToLDS; - // Decide whether to hoist V regs->LDS write before the sum reduction. - // Hoisting saves one LDS barrier but extends V's LDS live range to - // overlap with the sum-reduction workspace, which may increase peak - // LDS usage. Only hoist if the resulting peak fits in hardware LDS. - bool hoistVPhase2 = false; - if (prefetchFirstVTile) { - int64_t maxLDS = archInfo.maxSharedMemPerWG; - int64_t sumWSBytes = - getPackedByteSize(reductionWorkspaceSize, elemTypeSoftmax); - int64_t gemm0PeakBytes = - getPackedByteSize(ldsByteBufferQSize, elemTypeQ) + - getPackedByteSize(gemm0KPerBlock * gemm0MPerBlock, elemTypeK); - int64_t gemm1PeakBytes = - getPackedByteSize(gemm1KPerBlock * gemm1MPerBlock, elemTypeV) + - getPackedByteSize(gemm1LDSByteBufferBSize, elemTypeV); - - // The base peak without hoisting is determined by the larger of - // GEMM0 and GEMM1 concurrent buffer sets. - int64_t nonHoistedPeak = std::max(gemm0PeakBytes, gemm1PeakBytes); - // Hoisting adds sumWSBytes on top (V displaced in merged color). - int64_t hoistedTotal = nonHoistedPeak + sumWSBytes; - hoistVPhase2 = (hoistedTotal <= maxLDS); - LLVM_DEBUG(llvm::dbgs() - << "V prefetch Phase 2 hoist decision: " - << (hoistVPhase2 ? "HOIST" : "DEFER") - << " (hoistedTotal=" << hoistedTotal << ", max=" << maxLDS - << ", sumWS=" << sumWSBytes << ", gemm0=" << gemm0PeakBytes - << ", gemm1=" << gemm1PeakBytes << ")\n"); - } - if (prefetchFirstVTile) { // Set up grid coordinates for the first V tile. gridCoordsGemm1 = layout::makeGxNGridLayout( rewriter, loc, bid, zero, gemm1NBlocks, gridSize, arch, numChiplets, splitKVConst); - gridCoordsGemm1.m_block = zero; // First V tile (block index 0) + gridCoordsGemm1.m_block = zero; + ldsByteBufferV = createLDSByteBuffer( + rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV); // Allocate a flat register buffer shared between the GlobalReadOnly - // and LDSWriteFromRegs phases. Size must match what the lowering - // computes: copyPerThread = (kPerBlock * dPerBlock) / blockSize. + // and LDSWriteFromRegs phases. int64_t vCopyPerThread = (gemm1KPerBlock * gemm1MPerBlock) / blockSize; vPrefetchRegs = gpuAlloc(rewriter, loc, vCopyPerThread, elemTypeV, gpu::AddressSpace::Private); // Phase 1: Issue global reads for V tile 0 into register buffer. // Only the GlobalRead stage is emitted; LDS write is deferred. - // A dummy LDS buffer is passed because the function signature - // requires one, but GlobalReadOnly does not write to LDS. - Value dummyLDS = createLDSByteBuffer( - rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV); loadAndStoreGemmInputTile( rewriter, loc, inV, - /*kIter=*/mLoopIV, tid, gridCoordsGemm1, dummyLDS, vPrefetchRegs, - GemmLoadTileType::GlobalReadOnly, "m", blockSize, elemTypeV, - elemTypeVLoad, gemm1TuningParams, featuresAttr, matrixParamsV, - matrixParamsKxQ); + /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, + vPrefetchRegs, GemmLoadTileType::GlobalReadOnly, "m", blockSize, + elemTypeV, elemTypeVLoad, gemm1TuningParams, featuresAttr, + matrixParamsV, matrixParamsKxQ); // Insert a scheduling barrier to prevent the LLVM backend scheduler // from sinking the V global loads past the softmax computation. @@ -3021,22 +2999,6 @@ struct GridwiseAttentionAccelRewritePattern gemm0MNExpThreadwiseView, gemm0MNMaxThreadwiseView, maxRowBuffer); - // V prefetch phase 2 (hoisted): Write V data from regs to LDS - // before the sum reduction. The sum reduction's internal LDS - // barrier synchronises the V tile writes, saving one barrier. - if (prefetchFirstVTile && hoistVPhase2) { - // Allocate V LDS buffer early (before the sum reduction) so that - // Phase 2 can write the prefetched V data from registers into LDS. - ldsByteBufferV = createLDSByteBuffer( - rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV); - loadAndStoreGemmInputTile( - rewriter, loc, inV, - /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, - vPrefetchRegs, GemmLoadTileType::LDSWriteFromRegs, "m", blockSize, - elemTypeV, elemTypeVLoad, gemm1TuningParams, featuresAttr, - matrixParamsV, matrixParamsKxQ); - } - // Softmax sum reduction Value ldsReductionWorkspaceByteSecondBuffer = createLDSByteBuffer( rewriter, loc, reductionWorkspaceSize, elemTypeSoftmax); @@ -3063,12 +3025,10 @@ struct GridwiseAttentionAccelRewritePattern gemm0MaxThreadwiseView, sumRowBuffer, maxRowBuffer, expMaxDiffRowBuffer); - // V prefetch phase 2 (deferred path): Write V data from regs to - // LDS after the sum reduction to avoid overlapping with the - // sum-reduction workspace in LDS. Costs one extra barrier. - if (prefetchFirstVTile && !hoistVPhase2) { - ldsByteBufferV = createLDSByteBuffer( - rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV); + // V prefetch phase 2: Write prefetched V data from regs to LDS. + // The global reads issued before softmax should have completed + // during softmax computation, so this write is latency-free. + if (prefetchFirstVTile) { loadAndStoreGemmInputTile( rewriter, loc, inV, /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, @@ -3270,24 +3230,24 @@ struct GridwiseAttentionAccelRewritePattern }; // end emitGemm1Compute lambda if (prefetchFirstVTile) { - // Prefetch path: V tile 0 is already in LDS. Peel the first - // GEMM1 iteration and loop over the remaining tiles. + // Prefetch path: V tile 0 is already in LDS from the split-phase + // GlobalReadOnly + LDSWriteFromRegs above. Read from LDS into + // accel-shaped registers using LDSReadOnly, then compute GEMM1 + // with DoubleBuffer (which reads from registers, not LDS). + GemmLoadTileType vLoadType = GemmLoadTileType::DoubleBuffer; gridCoordsGemm1.m_block = zero; - // When double-buffering, preAccelRegBufferV is rank-2; the - // Default load path expects rank-1, so allocate a separate buf. - Value peeledVRegBuf = preAccelRegBufferV; - if (doubleBuffering) { - auto [peeledVForLoad, peeledVBuf] = createRegInterrimBufferForAccel( - rewriter, loc, accelParamsGemm1.argTypeA, - accelParamsGemm1.kBasePerThread, - /*repeats=*/1, directToLDS); - peeledVRegBuf = peeledVBuf; - } - if (!doBypassLDSSecondGemm) - LDSBarrierOp::create(rewriter, loc); - if (failed(emitGemm1Compute(zero, GemmLoadTileType::Default, - peeledVRegBuf))) + // Phase 3: LDS -> accel regs via LDSReadOnly. + // Uses "VLDSRead" stage name to avoid conflicting with GEMM0's + // "LDSRead" stages in the same parent scope. + loadAndStoreGemmInputTile( + rewriter, loc, inV, + /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, + preAccelRegBufferVForLoad, GemmLoadTileType::LDSReadOnly, "m", + blockSize, elemTypeV, elemTypeVLoad, gemm1TuningParams, + featuresAttr, matrixParamsV, matrixParamsKxQ); + + if (failed(emitGemm1Compute(zero, vLoadType, preAccelRegBufferV))) return failure(); // Remaining iterations (g1m = 1..gemm1MBlocks-1). @@ -3304,9 +3264,10 @@ struct GridwiseAttentionAccelRewritePattern // Only pipeline when >1 iteration remains; pipelining a // single iteration causes barrier mismatches. if (gemm1MBlocks > 2) { + // vLoadType is always DoubleBuffer in the prefetch path. bool g1DoubleBuffering = - loadType == GemmLoadTileType::DoubleBuffer || - loadType == GemmLoadTileType::DirectToLDSDoubleBuffer; + vLoadType == GemmLoadTileType::DoubleBuffer || + vLoadType == GemmLoadTileType::DirectToLDSDoubleBuffer; int64_t g1InitiationInterval = g1DoubleBuffering ? 1 : 2; g1MLoopOp->setAttr(PipelineAttr::getMnemonic(), rock::PipelineAttr::get(rewriter.getContext(), @@ -3319,18 +3280,18 @@ struct GridwiseAttentionAccelRewritePattern gridCoordsGemm1.m_block = g1MLoopIndVar; - // Normal V tile load (global -> regs -> LDS) + // Normal V tile load (global -> LDS -> regs via DoubleBuffer) loadAndStoreGemmInputTile( rewriter, loc, inV, /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, - preAccelRegBufferVForLoad, loadType, "m", blockSize, + preAccelRegBufferVForLoad, vLoadType, "m", blockSize, elemTypeV, elemTypeVLoad, gemm1TuningParams, featuresAttr, matrixParamsV, matrixParamsKxQ); // Conservative barrier before MMA LDSBarrierOp::create(rewriter, loc); - if (failed(emitGemm1Compute(g1MLoopIndVar, loadType, + if (failed(emitGemm1Compute(g1MLoopIndVar, vLoadType, preAccelRegBufferV))) return failure(); From 5cad4fb11d8ae437d9a303f1ad367db1621a55e2 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Thu, 12 Mar 2026 21:10:17 +0000 Subject: [PATCH 17/18] Add barrier --- .../Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index c71c7dbf0a08..d3fd7d8f9c85 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -3247,6 +3247,13 @@ struct GridwiseAttentionAccelRewritePattern blockSize, elemTypeV, elemTypeVLoad, gemm1TuningParams, featuresAttr, matrixParamsV, matrixParamsKxQ); + // Barrier to synchronize the B-side (softmax output) LDS write + // from storeGemmInputTile above with the GEMM1 compute's Default + // LDS read of gemm1LDSByteBufferB. Without this, threads may + // read B-side data that other threads haven't finished writing. + if (!doBypassLDSSecondGemm) + LDSBarrierOp::create(rewriter, loc); + if (failed(emitGemm1Compute(zero, vLoadType, preAccelRegBufferV))) return failure(); From 2366068739320caa1ebae0e8e9a9805c595261e4 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Thu, 12 Mar 2026 23:04:40 +0000 Subject: [PATCH 18/18] Fix up some more LIT tests --- .../gridwise_attention_accel_lowering.mlir | 63 ++++++++----------- .../toblockwise_attention_accel_lowering.mlir | 25 +++++--- 2 files changed, 42 insertions(+), 46 deletions(-) diff --git a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir index da89afdd99c0..1ce0a2e6437c 100644 --- a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir +++ b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir @@ -88,6 +88,13 @@ // CHECK: rock.transforming_for // CHECK: %[[tmp:.+]] = memref.load %[[gemm0AccBuf]][ // CHECK: rock.in_bounds_store %[[tmp]] -> %[[gemm0AccBufScalar:.+]][ + + // V prefetch: issue global reads before softmax + // CHECK: %[[ldsG0BStore:.+]] = rock.alloc() : memref<4096xi8, #gpu.address_space> + // CHECK: rock.stage + // CHECK: rock.threadwise_read_into + // CHECK: {name = "VGlobalRead"} + // CHECK: linalg.generic {{.*}} ins(%[[gemm0AccBufScalar]] {{.*}} outs(%[[gemm0AccBufScalar]] // CHECK: %[[gemm0Scaled:.+]] = arith.mulf %in, %[[ln2Recip]] : f32 // CHECK: linalg.yield %[[gemm0Scaled]] @@ -106,9 +113,6 @@ // CHECK-DAG: %[[gemm0ValSubMaxExp:.+]] = math.exp2 %[[gemm0ValSubMax]] // CHECK-DAG: rock.in_bounds_store %[[gemm0ValSubMaxExp]] -> %[[gemm0NormExp:.+]][ - // V prefetch: write V data to LDS - // CHECK: %[[ldsG0BStore:.+]] = rock.alloc() : memref<4096xi8, #gpu.address_space> - // CHECK: %[[ldsReductionWS2:.+]] = rock.alloc() : memref<256xi8, #gpu.address_space> // CHECK: %[[ldsReductionWS2View:.+]] = memref.view %[[ldsReductionWS2]][{{.*}}][] : memref<256xi8, #gpu.address_space> to memref<64xf32, #gpu.address_space> // CHECK: rock.blockwise_broadcast_reduce sum {{.*}} %[[gemm0NormExp]] into %[[gemm0NormExpSum:[0-9]+]] using %[[ldsReductionWS2View]] @@ -132,6 +136,12 @@ // CHECK-DAG: %[[tilesumadd:.+]] = arith.addf %[[rowsummul]], %[[tilesum]] // CHECK-DAG: %[[tilesumadd]] -> %[[sumRowBuf]] + // V prefetch: write V data to LDS + // CHECK: rock.stage + // CHECK: rock.threadwise_write_all + // CHECK: {name = "VLDSWrite"} + // CHECK: rock.lds_barrier + // Viewing first gemm output as K x D // CHECK-DAG: %[[gemm0NormExpTr0:.+]] = rock.transform %[[gemm0NormExp]] // CHECK-DAG: %[[gemm0NormExpTr1:.+]] = rock.transform %[[gemm0NormExpTr0]] @@ -163,57 +173,34 @@ // CHECK-DAG: %[[viewG1AStoreTr6:.+]] = rock.transform %[[viewG1AStoreTr5]] // CHECK-DAG: %[[viewG1AStoreTr7:.+]] = rock.transform %[[viewG1AStoreTr6]] - // Store to LDS G1A tile buffer // CHECK-DAG: rock.threadwise_write_all {{.*}} %[[G1AregsKpack]] -> [](%[[viewG1AStoreTr7]]) - - // Gemm1 (unrolled) - // Iteration 0: V data already in LDS from prefetch + // CHECK: rock.stage + // CHECK: memref.view %[[ldsG0BStore]] + // CHECK: rock.threadwise_read_into + // CHECK: {name = "VLDSRead"} + // CHECK: rock.lds_barrier // CHECK: rock.stage // CHECK-DAG: rock.fill(%[[gemm1AccBuf:.+]], %[[zeroVecF32]]) - // CHECK: memref.view %[[ldsG0BStore]][{{.*}}][] : memref<4096xi8, #gpu.address_space> to memref<1024xf32, #gpu.address_space> // CHECK: memref.view %[[ldsG1AStore]][{{.*}}][] : memref<4096xi8, #gpu.address_space> to memref<1024xf32, #gpu.address_space> // CHECK: rock.blockwise_gemm_accel // CHECK: {name = "MMA"} - // CHECK: rock.stage // CHECK: rock.transforming_for // CHECK: memref.load %[[gemm1AccBuf]][ // CHECK: rock.in_bounds_store - // CHECK: memref.subview %[[attnOutBuf]] // Reduction corrections // CHECK: rock.transforming_for // CHECK: arith.mulf // CHECK: arith.addf // CHECK: {name = "PostProcess"} - - // Iteration 1: Load next V tile - // CHECK: rock.stage - // CHECK: rock.threadwise_read_into - // CHECK: {name = "GlobalRead"} - - // CHECK: rock.stage - // CHECK: memref.view %[[ldsG0BStore]] - // CHECK: rock.threadwise_write_all - // CHECK: {name = "LDSWrite"} - - // CHECK: rock.stage - // CHECK: rock.fill - // CHECK: memref.view %[[ldsG0BStore]] - // CHECK: memref.view %[[ldsG1AStore]] - // CHECK: rock.blockwise_gemm_accel - // CHECK: {name = "MMA"} - - // CHECK: rock.stage - // CHECK: rock.transforming_for - // CHECK: memref.subview %[[attnOutBuf]] - // CHECK: rock.transforming_for - // CHECK: arith.mulf - // CHECK: arith.addf - // CHECK: {name = "PostProcess"} -// CHECK : } -// CHECK : %[[flatAttnOutBuf:.+]] = memref.collapse_shape %[[attnOutBuf]] -// CHECK : rock.threadwise_write_all {{.*}} %[[flatAttnOutBuf]] -> {{.*}}(%[[O]]) + // CHECK: rock.lds_barrier + // CHECK: affine.for + // CHECK: memref.subview %[[attnOutBuf]] + // CHECK: rock.transforming_for + // CHECK: arith.divf + // CHECK: %[[flatAttnOutBuf:.+]] = memref.collapse_shape %[[attnOutBuf]] + // CHECK: rock.threadwise_write_all {{.*}} %[[flatAttnOutBuf]] -> {{.*}} memref<1x64x384xf32> func.func @gridwise_attn_simple(%arg0: memref<1x384x64xf32>, %arg1: memref<1x64x384xf32>, %arg2: memref<1x384x64xf32>, %arg3: memref<1x384x64xf32>) attributes {block_size = 64 : i32, grid_size = 24 : i32, kernel, mhal.arch = "amdgcn-amd-amdhsa:gfx908:sramecc+:xnack-"} { %0 = rock.transform %arg0 by (d0, d2, d1)> by [ ["gemmG"] at [0]>, ["gemm0K", "gemm0M"] at [2, 1]>] bounds = [1, 64, 384] -> [1, 384, 64]> : memref<1x384x64xf32> to memref<1x64x384xf32> diff --git a/mlir/test/Dialect/Rock/toblockwise_attention_accel_lowering.mlir b/mlir/test/Dialect/Rock/toblockwise_attention_accel_lowering.mlir index 8f1b39b14a42..6eff8177627c 100644 --- a/mlir/test/Dialect/Rock/toblockwise_attention_accel_lowering.mlir +++ b/mlir/test/Dialect/Rock/toblockwise_attention_accel_lowering.mlir @@ -42,6 +42,11 @@ // CHECK: rock.transforming_for // CHECK: %[[tmp:.+]] = memref.load %[[gemm0AccBuf]][ // CHECK: rock.in_bounds_store %[[tmp]] -> %[[gemm0AccBufScalar:.+]][ + + // V prefetch: issue global reads before softmax + // CHECK: %[[ldsG0AStore:.+]] = rock.alloc() : memref<4096xi8, #gpu.address_space> + // CHECK: rock.blockwise_load_tile %[[V]]{{.*}} LDS -> %[[ldsG0AStore]]{{.*}}#rock + // CHECK: linalg.generic {{.*}} ins(%[[gemm0AccBufScalar]] {{.*}} outs(%[[gemm0AccBufScalar]] // CHECK: %[[gemm0Scaled:.+]] = arith.mulf %in, %[[ln2Recip]] : f32 // CHECK: linalg.yield %[[gemm0Scaled]] @@ -60,9 +65,6 @@ // CHECK-DAG: %[[gemm0ValSubMaxExp:.+]] = math.exp2 %[[gemm0ValSubMax]] // CHECK-DAG: rock.in_bounds_store %[[gemm0ValSubMaxExp]] -> %[[gemm0NormExp:.+]][ - // V prefetch: write V data to LDS - // CHECK: %[[ldsG0AStore:.+]] = rock.alloc() : memref<4096xi8, #gpu.address_space> - // CHECK: %[[ldsReductionWS2:.+]] = rock.alloc() : memref<256xi8, #gpu.address_space> // CHECK: %[[ldsReductionWS2View:.+]] = memref.view %[[ldsReductionWS2]][{{.*}}][] : memref<256xi8, #gpu.address_space> to memref<64xf32, #gpu.address_space> // CHECK: rock.blockwise_broadcast_reduce sum {{.*}} %[[gemm0NormExp]] into %[[gemm0NormExpSum:[0-9]+]] using %[[ldsReductionWS2View]] @@ -86,6 +88,10 @@ // CHECK-DAG: %[[tilesumadd:.+]] = arith.addf %[[rowsummul]], %[[tilesum]] // CHECK-DAG: %[[tilesumadd]] -> %[[sumRowBuf]] + // V prefetch: write V data to LDS from regs + // CHECK: rock.blockwise_load_tile %[[V]]{{.*}} LDS -> %[[ldsG0AStore]]{{.*}}#rock + // CHECK: rock.lds_barrier + // Viewing first gemm output as K x D // CHECK-DAG: %[[gemm0NormExpTr0:.+]] = rock.transform %[[gemm0NormExp]] // CHECK-DAG: %[[gemm0NormExpTr1:.+]] = rock.transform %[[gemm0NormExpTr0]] @@ -121,12 +127,14 @@ // Store to LDS G1A tile buffer // CHECK-DAG: rock.threadwise_write_all {{.*}} %[[G1AregsKpack]] -> [](%[[viewG1AStoreTr7]]) + // V LDS Read + // CHECK: rock.blockwise_load_tile %[[V]]{{.*}} LDS -> %[[ldsG0AStore]]{{.*}}#rock + // Gemm1 (unrolled) - // Iteration 0: V data already in LDS from prefetch + // Iteration 0: V data already in registers from LDS read // CHECK-DAG: rock.fill(%[[gemm1AccBuf:.+]], %[[zeroVecF32]]) - // CHECK: %[[view2G1AStore:.+]] = memref.view %[[ldsG0AStore]][{{.*}}][] : memref<4096xi8, #gpu.address_space> to memref<1024xf32, #gpu.address_space> // CHECK: %[[view2G1BStore:.+]] = memref.view %[[ldsG1BStore]][{{.*}}][] : memref<4096xi8, #gpu.address_space> to memref<1024xf32, #gpu.address_space> - // CHECK: rock.blockwise_gemm_accel %[[gemm1AccBuf]] += {{.*}} from %[[view2G1AStore]] * {{.*}} from %[[view2G1BStore]] + // CHECK: rock.blockwise_gemm_accel %[[gemm1AccBuf]] += {{.*}} * {{.*}} from %[[view2G1BStore]] // CHECK: {name = "MMA"} // CHECK: rock.transforming_for @@ -139,11 +147,12 @@ // CHECK: arith.mulf // CHECK: arith.addf // CHECK: {name = "PostProcess"} + // CHECK: rock.lds_barrier // Iteration 1: Load next V tile - // CHECK: rock.blockwise_load_tile %[[V]]{{.*}} LDS -> %[[ldsG0AStore]] + // CHECK: rock.blockwise_load_tile %[[V]]{{.*}} LDS -> %[[ldsG0AStore]]{{.*}}#rock + // CHECK: rock.lds_barrier // CHECK-DAG: rock.fill({{.*}}, %[[zeroVecF32]]) - // CHECK: memref.view %[[ldsG0AStore]] // CHECK: memref.view %[[ldsG1BStore]] // CHECK: rock.blockwise_gemm_accel // CHECK: {name = "MMA"}