diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td index 796e6d6dae9c..8fa3b78e01f6 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td @@ -464,13 +464,19 @@ def Rock_GemmLoadTileDirectToLDSDefault : I32EnumAttrCase<"DirectToLDSDefault", 3>; def Rock_GemmLoadTileDirectToLDSDoubleBuffer : I32EnumAttrCase<"DirectToLDSDoubleBuffer", 4>; +def Rock_GemmLoadTileGlobalReadOnly : I32EnumAttrCase<"GlobalReadOnly", 5>; +def Rock_GemmLoadTileLDSWriteFromRegs : I32EnumAttrCase<"LDSWriteFromRegs", 6>; +def Rock_GemmLoadTileLDSReadOnly : I32EnumAttrCase<"LDSReadOnly", 7>; def Rock_GemmLoadTileType : Rock_I32Enum<"GemmLoadTileType", "GEMM load tile types", [Rock_GemmLoadTileBypassLDS, Rock_GemmLoadTileDefault, Rock_GemmLoadTileDoubleBuffer, Rock_GemmLoadTileDirectToLDSDefault, - Rock_GemmLoadTileDirectToLDSDoubleBuffer]> { + Rock_GemmLoadTileDirectToLDSDoubleBuffer, + Rock_GemmLoadTileGlobalReadOnly, + Rock_GemmLoadTileLDSWriteFromRegs, + Rock_GemmLoadTileLDSReadOnly]> { let cppNamespace = "::mlir::rock"; let genSpecializedAttr = 0; } diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td index e4b0ee7496d6..f80303e44b12 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td @@ -1621,6 +1621,9 @@ def Rock_BlockwiseLoadTileOp - DoubleBuffer: Creates three stages, (1) load from memory, (2) write to LDS, (3) load to registers. - DirectToLDSDefault: Same as Default, but a single stage loads from memory and writes to LDS. - DirectToLDSDoubleBuffer: Same as DoubleBuffer, but a single stage loads from memory and writes to LDS. + - GlobalReadOnly: Loads from global memory into flat registers. + - LDSWriteFromRegs: Writes from flat registers into LDS. + - LDSReadOnly: Reads from LDS into registers. `isA` determines if we are loading an A matrix or B matrix. `G`, `M` and `N` are the GEMM sizes. `elementTypeA` and `elementTypeB` are used to construct AccelEmitter. They are data types for the Matrix A & B of the GEMMs. diff --git a/mlir/include/mlir/Dialect/Rock/Passes.td b/mlir/include/mlir/Dialect/Rock/Passes.td index 1041ab12abdf..1774b6793e3a 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.td +++ b/mlir/include/mlir/Dialect/Rock/Passes.td @@ -107,7 +107,10 @@ def RockRegularizePass : Pass<"rock-regularize", "::mlir::func::FuncOp"> { def RockGridwiseGemmToBlockwisePass : Pass<"rock-gridwise-gemm-to-blockwise", "::mlir::func::FuncOp"> { let summary = "expand gridwise gemm into blockwise copy, blockwise gemm, and threadwise copy"; - let dependentDialects = ["rock::RockDialect", "affine::AffineDialect", "gpu::GPUDialect", "vector::VectorDialect", "memref::MemRefDialect", "linalg::LinalgDialect", "scf::SCFDialect"]; + let dependentDialects = ["rock::RockDialect", "affine::AffineDialect", + "gpu::GPUDialect", "vector::VectorDialect", + "memref::MemRefDialect", "linalg::LinalgDialect", + "scf::SCFDialect", "amdgpu::AMDGPUDialect"]; } def RockLinalgAlignPass : Pass<"rock-linalg-align", "::mlir::func::FuncOp"> { diff --git a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp index 52914cb74829..c8f749a3f869 100644 --- a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp @@ -2482,13 +2482,17 @@ void BlockwiseLoadTileOp::getEffects( loadType == GemmLoadTileType::DirectToLDSDoubleBuffer; bool singleBuffer = loadType == GemmLoadTileType::Default || loadType == GemmLoadTileType::DirectToLDSDefault; + bool ldsReadOnly = loadType == GemmLoadTileType::LDSReadOnly; - effects.emplace_back(read, &getSourceMutable()); + // LDSReadOnly does not read from global source. + if (!ldsReadOnly) + effects.emplace_back(read, &getSourceMutable()); if (loadType != GemmLoadTileType::BypassLDS) { assert(getDestLDS() != nullptr); - effects.emplace_back(write, &getDestLDSMutable()[0]); - // DoubleBuffer means we write to LDS and then, load from it - if (doubleBuffer) + // LDSReadOnly only reads from LDS, it does not write to it. + if (!ldsReadOnly) + effects.emplace_back(write, &getDestLDSMutable()[0]); + if (doubleBuffer || ldsReadOnly) effects.emplace_back(read, &getDestLDSMutable()[0]); } if (!singleBuffer) { diff --git a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp index d258bdc6e916..84e7c859e577 100644 --- a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp +++ b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp @@ -240,6 +240,9 @@ void rock::buildKernelPipeline(OpPassManager &pm, funcPm.addPass(rock::createRockThreadwiseGemmLoweringPass()); funcPm.addPass(rock::createRockAnalyzeMemoryUsePass()); funcPm.addPass(rock::createRockSugarToLoopsPass()); + // Re-run the pipeline pass to remove back-to-back LDS barriers + // that may appear after SugarToLoops unrolls TransformingForOps. + funcPm.addPass(rock::createRockPipelinePass()); funcPm.addPass(rock::createRockCleanMathPass()); math::MathExtendToSupportedTypesOptions extendToLLVMTypesOptions; extendToLLVMTypesOptions.extraTypeStrs = {"f16"}; diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp index a9d701e7c458..dc751ee30917 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp @@ -928,8 +928,13 @@ struct BlockwiseReduceRewritePattern } } else { if (rMethod == ReduceMethod::Sum) { + // Use -0.0 (negative zero) instead of +0.0. In IEEE 754, -0.0 is the + // true additive identity: fadd(-0.0, x) = x for ALL x (including -0.0 + // and NaN). LLVM can fold `fadd -0.0, x → x`, eliminating the + // redundant `v_add_f32 v, 0, v` that +0.0 generates via + // llvm.vector.reduce.fadd. return createConstantFloatOp(rewriter, op.getLoc(), elementType, - elementType, 0.0); + elementType, -0.0f); } else { // Op verifier gurantees this. assert(rMethod == ReduceMethod::Max); @@ -964,7 +969,7 @@ struct BlockwiseReduceRewritePattern kind = vector::CombiningKind::MAXNUMF; } } - input = vector::ReductionOp::create(builder, loc, kind, input); + return vector::ReductionOp::create(builder, loc, kind, input, acc); } if (rMethod == ReduceMethod::Sum) { @@ -1410,67 +1415,93 @@ struct BlockwiseReduceRewritePattern } } - // This RAII scope would do the following : - // LDS[rtid] = reduce(LDS[rtid], LDS[rtid + offset]) - // where offset is a power of 2. - // Initial it starts with power = ceil(|rtid|, power of 2) / 2 - // Then keep on reducing the power. + // Branchless reduction: each thread reads all rTidDim partial + // values from LDS and reduces locally in registers. This avoids + // creating conditional branches (scf.if) that split softmax into + // multiple basic blocks. + // Trade-off: every thread does rTidCount LDS reads (instead of + // log2(rTidCount) conditional reads in the tree reduction). For + // typical attention configs where rTidCount is small (e.g., 4), + // this is negligible overhead. + // TODO: We may have to use a heuristic to determine whether or not to + // use this depending on the size of rTidCount. { - int64_t ceilPowerOf2 = - llvm::PowerOf2Ceil(threadViewShape[rTidDim]) / 2; - int64_t maxActiveReductionThreads = threadViewShape[rTidDim]; - for (int64_t offset = ceilPowerOf2; offset >= 1; - offset = offset >> 1) { - Value offsetVal = - arith::ConstantIndexOp::create(rewriter, loc, offset); - Value rtidPlusOffsetVal = - arith::AddIOp::create(rewriter, loc, rtid, offsetVal); - Value maxActiveReductionThreadsVal = arith::ConstantIndexOp::create( - rewriter, loc, maxActiveReductionThreads); - maxActiveReductionThreads = - llvm::PowerOf2Ceil(maxActiveReductionThreads) >> 1; - Value isValid = arith::CmpIOp::create( - rewriter, loc, arith::CmpIPredicate::slt, rtidPlusOffsetVal, - maxActiveReductionThreadsVal); - scf::IfOp ifb = scf::IfOp::create(rewriter, loc, isValid, - /*withElseRegion=*/false); + int64_t rTidCount = threadViewShape[rTidDim]; + + // Accumulator for the full reduction. + auto accRegType = MemRefType::get({1}, elemType, AffineMap{}, + privateMemoryAddressSpace); + Value accReg = GpuAllocOp::create(rewriter, loc, accRegType); + FillOp::create(rewriter, loc, accReg, initVal); + + // Read all rTidCount partial values from LDS and reduce. + // Every thread with the same nrtid computes the identical + // fully-reduced value. + for (int64_t i = 0; i < rTidCount; i++) { + Value iVal = arith::ConstantIndexOp::create(rewriter, loc, i); + SmallVector readInits{nrtid, iVal, zeroConstantOp}; + SmallVector bounds{1, 1, 1}; + SmallVector strides{1, 1, 1}; + + TransformingForOp readLoop = TransformingForOp::create( + rewriter, loc, ArrayRef{readInits}, + ArrayRef{threadToLDSViewTrs}, + ArrayRef(bounds), ArrayRef(strides), + /*forceUnroll=*/true, /*useIndexDiffs=*/true); { - OpBuilder thenb = ifb.getThenBodyBuilder(); - SmallVector firstInits{nrtid, rtid, zeroConstantOp}; - SmallVector secondInits{nrtid, rtidPlusOffsetVal, - zeroConstantOp}; - SmallVector bounds{1, 1, 1}; - SmallVector strides{1, 1, 1}; - - TransformingForOp reductionLoop = TransformingForOp::create( - thenb, loc, ArrayRef{firstInits, secondInits}, - ArrayRef{threadToLDSViewTrs, threadToLDSViewTrs}, - ArrayRef(bounds), ArrayRef(strides), - /*forceUnroll=*/true, /*useIndexDiffs=*/true); - { - PatternRewriter::InsertionGuard guard(thenb); - thenb.setInsertionPointToStart(reductionLoop.getBody()); - Block::BlockArgListType firstLDSLoadCoords = - reductionLoop.getLowerCoords(/*domain=*/0); - Value firstLoadVal = InBoundsLoadOp::create( - thenb, loc, elemType, workspaceLDSBuffer, - firstLDSLoadCoords); - Block::BlockArgListType secondLDSLoadCoords = - reductionLoop.getLowerCoords(/*domain=*/1); - Value secondLoadVal = InBoundsLoadOp::create( - thenb, loc, elemType, workspaceLDSBuffer, - secondLDSLoadCoords); - Value reduced = - createReducingOp(op, firstLoadVal, secondLoadVal, thenb); - InBoundsStoreOp::create(thenb, loc, reduced, workspaceLDSBuffer, - firstLDSLoadCoords); + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(readLoop.getBody()); + Block::BlockArgListType ldsCoords = + readLoop.getLowerCoords(/*domain=*/0); + Value ldVal = InBoundsLoadOp::create( + rewriter, loc, elemType, workspaceLDSBuffer, ldsCoords); + if (i == 0) { + // First iteration: store the loaded value directly to the + // accumulator. This avoids a redundant reduction with the + // identity element (e.g., `0.0 + x` for sum, `max(-inf, x)` + // for max). + InBoundsStoreOp::create(rewriter, loc, ldVal, accReg, + zeroConstantOp); + } else { + Value accVal = InBoundsLoadOp::create(rewriter, loc, elemType, + accReg, zeroConstantOp); + Value reduced = createReducingOp(op, ldVal, accVal, rewriter); + InBoundsStoreOp::create(rewriter, loc, reduced, accReg, + zeroConstantOp); } } - LDSBarrierOp::create(rewriter, loc); } + + // Write the fully reduced value back to LDS at [nrtid, 0]. + // All threads with the same nrtid compute the same value, + // so concurrent writes to the same location are safe. + { + Value reducedVal = InBoundsLoadOp::create(rewriter, loc, elemType, + accReg, zeroConstantOp); + SmallVector writeInits{nrtid, zeroConstantOp, + zeroConstantOp}; + SmallVector writeBounds{1, 1, 1}; + SmallVector writeStrides{1, 1, 1}; + + TransformingForOp writeLoop = TransformingForOp::create( + rewriter, loc, ArrayRef{writeInits}, + ArrayRef{threadToLDSViewTrs}, + ArrayRef(writeBounds), ArrayRef(writeStrides), + /*forceUnroll=*/true, /*useIndexDiffs=*/true); + { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(writeLoop.getBody()); + Block::BlockArgListType ldsCoords = + writeLoop.getLowerCoords(/*domain=*/0); + InBoundsStoreOp::create(rewriter, loc, reducedVal, + workspaceLDSBuffer, ldsCoords); + } + } + + LDSBarrierOp::create(rewriter, loc); ArrayAttr reducedldsViewArrayAttr = createLDSWorkspaceView( - loc, rewriter, inputViewArrayAttr, axis, /*makeRDimZero-*/ true, - partialRegTensorShape[rDim]); + loc, rewriter, inputViewArrayAttr, axis, + /*makeRDimZero-*/ true, partialRegTensorShape[rDim]); ThreadwiseReadIntoOp::create(rewriter, loc, workspaceLDSBuffer, outputReg, reducedldsViewArrayAttr, /*extraIndices=*/ValueRange{tid}, true, diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp index 0c309db0a2c2..e7755ce09369 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp @@ -203,6 +203,7 @@ class LoweringBlockwiseLoadTileOp final loadType == GemmLoadTileType::DirectToLDSDoubleBuffer; bool doubleBuffer = loadType == GemmLoadTileType::DoubleBuffer || loadType == GemmLoadTileType::DirectToLDSDoubleBuffer; + bool ldsReadOnly = loadType == GemmLoadTileType::LDSReadOnly; // Build LDS transpose config attribute if enabled // The decision was already made in GridwiseGemmToBlockwise pass @@ -244,8 +245,25 @@ class LoweringBlockwiseLoadTileOp final else b.setInsertionPoint(op); + bool globalReadOnly = loadType == GemmLoadTileType::GlobalReadOnly; + bool ldsWriteFromRegs = loadType == GemmLoadTileType::LDSWriteFromRegs; + Value loadBuffer, storeBuffer; - if (loadType == GemmLoadTileType::BypassLDS) { + if (ldsReadOnly) { + // LDSReadOnly: no load/store buffers needed, we only read from LDS + // into destRegisters via generateReadLoop. + } else if (globalReadOnly || ldsWriteFromRegs) { + // Split-phase load: use the externally-allocated destRegisters buffer + // as the shared loadBuffer between the GlobalReadOnly and + // LDSWriteFromRegs phases. + assert(destRegisters && + "destRegisters must be set for split-phase load types"); + loadBuffer = destRegisters; + if (ldsWriteFromRegs) { + storeBuffer = + gpuAlloc(b, loc, copyPerThread, elementType, AddressSpace::Private); + } + } else if (loadType == GemmLoadTileType::BypassLDS) { auto privateMemoryAddressSpace = b.getAttr( gpu::GPUDialect::getPrivateAddressSpace()); auto accelParams = accelEmitterPtr->getParams(); @@ -270,57 +288,68 @@ class LoweringBlockwiseLoadTileOp final if (isa(parentOp)) b.setInsertionPoint(op); - auto [stageGlobalRead, stageGlobalReadNew] = - createOrGetStage(b, loc, "GlobalRead", parentOp); - { - PatternRewriter::InsertionGuard guard(b); - b.setInsertionPointToStart(&stageGlobalRead.getRegion().back()); - - FailureOr maybeBufferViews; - if (loadType == GemmLoadTileType::BypassLDS) { - // Check if the other operand uses LDS transpose load - bool otherOperandUsesLdsTranspose = - isA ? matrixParamsB.getLdsTransposeEnabled() - : matrixParamsA.getLdsTransposeEnabled(); - maybeBufferViews = accelEmitterPtr->createAccelGemmOperandTransforms( - b, loc, kIters, bidGridLengths, blockSize, vecDimInfo.inDPerThread, - dName, isKContiguousDim, false, - /*doSplitKAcrossThreadsFirst=*/false, otherOperandUsesLdsTranspose); - } else { - maybeBufferViews = getLoadRegsAsTileViews( - b, loc, source, dName, bidGridOrder, bidGridLengths, blockSize, - kPerBlock, dPerBlock, vecDimInfo.inKPerThread, - vecDimInfo.inDPerThread, isKContiguousDim, directToLDS); - } - if (failed(maybeBufferViews)) - return failure(); - - Value wrappedSource = transform(b, source, maybeBufferViews->gridSubTile); - - ThreadwiseReadIntoOp::create(b, loc, vectorOfBoolShapedLike(loadBuffer), - wrappedSource, loadBuffer, - /*dynamicValidities=*/ValueRange{}, - /*extraViews=*/b.getArrayAttr({}), - /*extraIndices=*/indices, forceUnroll, true, - /*ldsTransposeConfig=*/nullptr); - - if (rock::isGlobalPrefetchSupported(arch)) { - // add one to k_loop to prefetch next iteration - SmallVector indicesNext(indices.begin(), indices.end()); - Value one = b.createOrFold(loc, 1); - indicesNext[0] = - arith::AddIOp::create(b, loc, indicesNext[0], one).getResult(); - - // it's acceptable if the indices are out of bounds because we use - // GLOBAL_PREFETCH_B8 with Speculative Prefetch. See llvm.prefetch - // documentation in AMDGPUUsage.rst - rock::ThreadwisePrefetchOp::create(b, loc, wrappedSource, - /*extraViews=*/b.getArrayAttr({}), - /*extraIndices=*/indicesNext, - forceUnroll, true); + if (!ldsWriteFromRegs && !ldsReadOnly) { + // Use distinct stage name for split-phase V prefetch to avoid + // conflicting with K/Q GlobalRead stages in the same parent scope. + StringRef globalReadStageName = + globalReadOnly ? "VGlobalRead" : "GlobalRead"; + auto [stageGlobalRead, stageGlobalReadNew] = + createOrGetStage(b, loc, globalReadStageName, parentOp); + { + PatternRewriter::InsertionGuard guard(b); + b.setInsertionPointToStart(&stageGlobalRead.getRegion().back()); + + FailureOr maybeBufferViews; + if (loadType == GemmLoadTileType::BypassLDS) { + // Check if the other operand uses LDS transpose load + bool otherOperandUsesLdsTranspose = + isA ? matrixParamsB.getLdsTransposeEnabled() + : matrixParamsA.getLdsTransposeEnabled(); + maybeBufferViews = accelEmitterPtr->createAccelGemmOperandTransforms( + b, loc, kIters, bidGridLengths, blockSize, + vecDimInfo.inDPerThread, dName, isKContiguousDim, false, + /*doSplitKAcrossThreadsFirst=*/false, + otherOperandUsesLdsTranspose); + } else { + maybeBufferViews = getLoadRegsAsTileViews( + b, loc, source, dName, bidGridOrder, bidGridLengths, blockSize, + kPerBlock, dPerBlock, vecDimInfo.inKPerThread, + vecDimInfo.inDPerThread, isKContiguousDim, directToLDS); + } + if (failed(maybeBufferViews)) + return failure(); + + Value wrappedSource = + transform(b, source, maybeBufferViews->gridSubTile); + + ThreadwiseReadIntoOp::create(b, loc, vectorOfBoolShapedLike(loadBuffer), + wrappedSource, loadBuffer, + /*dynamicValidities=*/ValueRange{}, + /*extraViews=*/b.getArrayAttr({}), + /*extraIndices=*/indices, forceUnroll, + true, + /*ldsTransposeConfig=*/nullptr); + + if (!globalReadOnly && rock::isGlobalPrefetchSupported(arch)) { + // add one to k_loop to prefetch next iteration + SmallVector indicesNext(indices.begin(), indices.end()); + Value one = b.createOrFold(loc, 1); + indicesNext[0] = + arith::AddIOp::create(b, loc, indicesNext[0], one).getResult(); + rock::ThreadwisePrefetchOp::create(b, loc, wrappedSource, + /*extraViews=*/b.getArrayAttr({}), + /*extraIndices=*/indicesNext, + forceUnroll, true); + } + if (stageGlobalReadNew) + rock::YieldOp::create(b, loc); } - if (stageGlobalReadNew) - rock::YieldOp::create(b, loc); + } + + // For GlobalReadOnly there's nothing further to do. + if (globalReadOnly) { + b.eraseOp(op); + return success(); } if (loadType == GemmLoadTileType::BypassLDS) { @@ -377,9 +406,13 @@ class LoweringBlockwiseLoadTileOp final rock::YieldOp::create(b, loc); } } else { - if (!directToLDS) { + if (!directToLDS && !ldsReadOnly) { + // Use distinct stage name for split-phase V write to avoid + // conflicting with K/Q LDSWrite stages in nested loops. + StringRef ldsWriteStageName = + ldsWriteFromRegs ? "VLDSWrite" : "LDSWrite"; auto [stageLDSWrite, stageLDSWriteNew] = - createOrGetStage(b, loc, "LDSWrite", parentOp); + createOrGetStage(b, loc, ldsWriteStageName, parentOp); { PatternRewriter::InsertionGuard guard(b); b.setInsertionPointToStart(&stageLDSWrite.getRegion().back()); @@ -454,15 +487,18 @@ class LoweringBlockwiseLoadTileOp final } } - if (doubleBuffer) { - // Pipeline pass will remove this if the loop uses pipelining - LDSBarrierOp::create(b, loc); + if (doubleBuffer || ldsReadOnly) { + if (doubleBuffer) { + // Pipeline pass will remove this if the loop uses pipelining + LDSBarrierOp::create(b, loc); + } - // If we are running double-buffered pipelines, it makes sense to also - // parallelize the LDSRead/MMA stages. We do this here, by splitting the - // MMA loop in two separate stages + // Split the MMA loop into LDSRead and MMA stages so they can be + // parallelized. For LDSReadOnly, use a distinct stage name to avoid + // conflicting with GEMM0's LDSRead stages in the same parent scope. + StringRef ldsReadStageName = ldsReadOnly ? "VLDSRead" : "LDSRead"; auto [stageLDSRead, stageLDSReadNew] = - createOrGetStage(b, loc, "LDSRead", parentOp); + createOrGetStage(b, loc, ldsReadStageName, parentOp); { // Read from LDS into registers PatternRewriter::InsertionGuard guard(b); diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index cf5a0ce439de..d3fd7d8f9c85 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -33,6 +33,7 @@ #include "mlir/Dialect/Rock/utility/math.h" #include "mlir/Dialect/Rock/utility/transformMapUtils.h" +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -2416,16 +2417,22 @@ struct GridwiseAttentionAccelRewritePattern Value gemm0ExpOutBufferToLDS = createBufferForGemmOut(loc, elemTypeV, accelParamsGemm0, rewriter); + // When prefetching V, the peeled GEMM1 iteration uses LDSReadOnly + // (LDS -> accel regs) followed by a DoubleBuffer GEMM (reads from regs). + // Both the LDSReadOnly and DoubleBuffer paths require the register + // buffer to have mRepeats slots. + bool willPrefetchV = op.getEnableSoftmax() && !directToLDS; auto [preAccelRegBufferVForLoad, preAccelRegBufferV] = createRegInterrimBufferForAccel( rewriter, loc, accelParamsGemm1.argTypeA, accelParamsGemm1.kBasePerThread, - doubleBuffering ? accelParamsGemm1.mRepeats : 1, directToLDS); - auto [preAccelRegBufferQxKForLoad, preAccelRegBufferQxK] = - createRegInterrimBufferForAccel( - rewriter, loc, accelParamsGemm1.argTypeB, - accelParamsGemm1.kBasePerThread, - doBypassLDSSecondGemm ? accelParamsGemm1.nRepeats : 1, false); + (doubleBuffering || willPrefetchV) ? accelParamsGemm1.mRepeats : 1, + directToLDS); + auto preAccelRegBufferQxKPair = createRegInterrimBufferForAccel( + rewriter, loc, accelParamsGemm1.argTypeB, + accelParamsGemm1.kBasePerThread, + doBypassLDSSecondGemm ? accelParamsGemm1.nRepeats : 1, false); + Value preAccelRegBufferQxK = preAccelRegBufferQxKPair.second; Value accRegBufferGemm1; Value gemm1OutBuffer; @@ -2787,6 +2794,50 @@ struct GridwiseAttentionAccelRewritePattern accelEmitterPtrGemm0->computeOutputConversion( rewriter, loc, accRegBufferGemm0, gemm0OutBuffer, forceUnroll); + // V prefetch: Issue global reads for V tile 0 before softmax + // to overlap softmax computation with V's global memory latency. + // Uses a three-phase split-load approach: + // Phase 1 (before softmax): GlobalReadOnly — global -> flat regs + // Phase 2 (after softmax): LDSWriteFromRegs — flat regs -> LDS + // Phase 3 (peeled iter): LDSReadOnly — LDS -> accel regs + // The final GEMM1 compute uses DoubleBuffer mode (reads from regs). + // This avoids using Default GEMM reads from LDS, which generate + // incompatible transforms on some architectures (e.g. gfx1100). + Value ldsByteBufferV; + Value vPrefetchRegs; + layout::GridCoordinates gridCoordsGemm1; + bool prefetchFirstVTile = op.getEnableSoftmax() && !directToLDS; + + if (prefetchFirstVTile) { + // Set up grid coordinates for the first V tile. + gridCoordsGemm1 = layout::makeGxNGridLayout( + rewriter, loc, bid, zero, gemm1NBlocks, gridSize, arch, numChiplets, + splitKVConst); + gridCoordsGemm1.m_block = zero; + ldsByteBufferV = createLDSByteBuffer( + rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV); + + // Allocate a flat register buffer shared between the GlobalReadOnly + // and LDSWriteFromRegs phases. + int64_t vCopyPerThread = (gemm1KPerBlock * gemm1MPerBlock) / blockSize; + vPrefetchRegs = gpuAlloc(rewriter, loc, vCopyPerThread, elemTypeV, + gpu::AddressSpace::Private); + + // Phase 1: Issue global reads for V tile 0 into register buffer. + // Only the GlobalRead stage is emitted; LDS write is deferred. + loadAndStoreGemmInputTile( + rewriter, loc, inV, + /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, + vPrefetchRegs, GemmLoadTileType::GlobalReadOnly, "m", blockSize, + elemTypeV, elemTypeVLoad, gemm1TuningParams, featuresAttr, + matrixParamsV, matrixParamsKxQ); + + // Insert a scheduling barrier to prevent the LLVM backend scheduler + // from sinking the V global loads past the softmax computation. + amdgpu::SchedBarrierOp::create(rewriter, loc, + amdgpu::sched_barrier_opt_enum::none); + } + int64_t prePadG0M = gemm0M; if (op.getPrePadG0M().has_value()) { prePadG0M = op.getPrePadG0M().value().getSExtValue(); @@ -2973,6 +3024,19 @@ struct GridwiseAttentionAccelRewritePattern updateRowSum(rewriter, loc, gemm0SumThreadwiseView, gemm0MaxThreadwiseView, sumRowBuffer, maxRowBuffer, expMaxDiffRowBuffer); + + // V prefetch phase 2: Write prefetched V data from regs to LDS. + // The global reads issued before softmax should have completed + // during softmax computation, so this write is latency-free. + if (prefetchFirstVTile) { + loadAndStoreGemmInputTile( + rewriter, loc, inV, + /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, + vPrefetchRegs, GemmLoadTileType::LDSWriteFromRegs, "m", blockSize, + elemTypeV, elemTypeVLoad, gemm1TuningParams, featuresAttr, + matrixParamsV, matrixParamsKxQ); + LDSBarrierOp::create(rewriter, loc); + } } // Emit blockwise GEMM 1. @@ -3023,36 +3087,21 @@ struct GridwiseAttentionAccelRewritePattern } } - Value ldsByteBufferV = createLDSByteBuffer( - rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV); - Value endG1MLoop = - rewriter.createOrFold(loc, gemm1MBlocks); - - auto gridCoordsGemm1 = layout::makeGxNGridLayout( - rewriter, loc, bid, zero, gemm1NBlocks, gridSize, arch, numChiplets, - splitKVConst); - scf::ForOp g1MLoopOp = - createMainLoop(rewriter, loc, endG1MLoop, loadType); - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(g1MLoopOp.getBody()); - Value g1MLoopIndVar = g1MLoopOp.getInductionVar(); - - gridCoordsGemm1.m_block = g1MLoopIndVar; - - loadAndStoreGemmInputTile( - rewriter, loc, inV, - /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, - preAccelRegBufferVForLoad, loadType, "m", blockSize, elemTypeV, - elemTypeVLoad, gemm1TuningParams, featuresAttr, matrixParamsV, - matrixParamsKxQ); - - // Conservative barrier: Ensure all LDS writes complete - // before MMA stage reads from LDS. RockPipelinePass will remove this - // and add optimized barriers when pipelining. - LDSBarrierOp::create(rewriter, loc); + // V load + GEMM1 loop. For the non-prefetch path, allocate the + // V LDS buffer and grid coords here (prefetch already did this). + if (!prefetchFirstVTile) { + ldsByteBufferV = createLDSByteBuffer( + rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV); + gridCoordsGemm1 = layout::makeGxNGridLayout( + rewriter, loc, bid, zero, gemm1NBlocks, gridSize, arch, + numChiplets, splitKVConst); + } - // Emit GEMM 1. + // Helper lambda: emit GEMM1 MMA + PostProcess for a single V tile. + auto emitGemm1Compute = [&](Value g1MBlockIdx, + GemmLoadTileType vLoadType, + Value vRegBuf) -> LogicalResult { + // Emit GEMM 1 MMA. auto computeStage = StageOp::create(rewriter, loc, "MMA"); { PatternRewriter::InsertionGuard guard(rewriter); @@ -3064,8 +3113,8 @@ struct GridwiseAttentionAccelRewritePattern zeroAccBuffer(rewriter, loc, matrixC); } else { if (gemm1MBlocks > 1) { - matrixC = createSliceOfFirstDim(rewriter, loc, matrixC, - g1MLoopIndVar); + matrixC = + createSliceOfFirstDim(rewriter, loc, matrixC, g1MBlockIdx); } } @@ -3115,17 +3164,18 @@ struct GridwiseAttentionAccelRewritePattern auto loadTypeKxD = doBypassLDSSecondGemm ? GemmLoadTileType::BypassLDS : GemmLoadTileType::Default; - blockwiseGemmAccel( - rewriter, loc, loadType, loadTypeKxD, preAccelRegBufferV, - preAccelRegBufferQxK, matrixC, matrixParamsV, matrixParamsKxQ, - ldsTileBufferV, gemm1LDSBufferB, - /*scaleA=*/nullptr, /*scaleB=*/nullptr, - /*bufferScaleA=*/nullptr, /*bufferScaleB=*/nullptr, - featuresAttr, op.getBlockSizeAttr(), gemm1TuningParams); + blockwiseGemmAccel(rewriter, loc, vLoadType, loadTypeKxD, vRegBuf, + preAccelRegBufferQxK, matrixC, matrixParamsV, + matrixParamsKxQ, ldsTileBufferV, gemm1LDSBufferB, + /*scaleA=*/nullptr, /*scaleB=*/nullptr, + /*bufferScaleA=*/nullptr, + /*bufferScaleB=*/nullptr, featuresAttr, + op.getBlockSizeAttr(), gemm1TuningParams); rock::YieldOp::create(rewriter, loc); } + // Emit GEMM 1 PostProcess. auto postProcessStage = StageOp::create(rewriter, loc, "PostProcess"); { PatternRewriter::InsertionGuard guard(rewriter); @@ -3138,9 +3188,9 @@ struct GridwiseAttentionAccelRewritePattern Value matrixC = accRegBufferGemm1; if (!op.getEnableSoftmax() && gemm1MBlocks > 1) { gemm1OutBufferPerG1MBlock = createSliceOfFirstDim( - rewriter, loc, gemm1OutBuffer, g1MLoopIndVar); + rewriter, loc, gemm1OutBuffer, g1MBlockIdx); matrixC = - createSliceOfFirstDim(rewriter, loc, matrixC, g1MLoopIndVar); + createSliceOfFirstDim(rewriter, loc, matrixC, g1MBlockIdx); } accelEmitterPtrGemm1->computeOutputConversion( @@ -3149,7 +3199,7 @@ struct GridwiseAttentionAccelRewritePattern Value attentionOutAccBufferPerG1MBlock = attentionOutAccBuffer; if (gemm1MBlocks > 1) { attentionOutAccBufferPerG1MBlock = createSliceOfFirstDim( - rewriter, loc, attentionOutAccBuffer, g1MLoopIndVar); + rewriter, loc, attentionOutAccBuffer, g1MBlockIdx); } FailureOr maybeInvertedGemm1threadSubTileMaps = invertTransforms(rewriter, loc, @@ -3176,10 +3226,116 @@ struct GridwiseAttentionAccelRewritePattern rock::YieldOp::create(rewriter, loc); } - // Conservative barrier: Ensure all LDS reads complete before the next - // iteration writes to LDS. RockPipelinePass will remove this and add - // optimized barriers when pipelining. - LDSBarrierOp::create(rewriter, loc); + return success(); + }; // end emitGemm1Compute lambda + + if (prefetchFirstVTile) { + // Prefetch path: V tile 0 is already in LDS from the split-phase + // GlobalReadOnly + LDSWriteFromRegs above. Read from LDS into + // accel-shaped registers using LDSReadOnly, then compute GEMM1 + // with DoubleBuffer (which reads from registers, not LDS). + GemmLoadTileType vLoadType = GemmLoadTileType::DoubleBuffer; + gridCoordsGemm1.m_block = zero; + + // Phase 3: LDS -> accel regs via LDSReadOnly. + // Uses "VLDSRead" stage name to avoid conflicting with GEMM0's + // "LDSRead" stages in the same parent scope. + loadAndStoreGemmInputTile( + rewriter, loc, inV, + /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, + preAccelRegBufferVForLoad, GemmLoadTileType::LDSReadOnly, "m", + blockSize, elemTypeV, elemTypeVLoad, gemm1TuningParams, + featuresAttr, matrixParamsV, matrixParamsKxQ); + + // Barrier to synchronize the B-side (softmax output) LDS write + // from storeGemmInputTile above with the GEMM1 compute's Default + // LDS read of gemm1LDSByteBufferB. Without this, threads may + // read B-side data that other threads haven't finished writing. + if (!doBypassLDSSecondGemm) + LDSBarrierOp::create(rewriter, loc); + + if (failed(emitGemm1Compute(zero, vLoadType, preAccelRegBufferV))) + return failure(); + + // Remaining iterations (g1m = 1..gemm1MBlocks-1). + if (gemm1MBlocks > 1) { + LDSBarrierOp::create(rewriter, loc); + + Value startG1M = rewriter.createOrFold(loc, 1); + Value endG1MLoop = + rewriter.createOrFold(loc, gemm1MBlocks); + Value oneVal = + rewriter.createOrFold(loc, 1); + scf::ForOp g1MLoopOp = + scf::ForOp::create(rewriter, loc, startG1M, endG1MLoop, oneVal); + // Only pipeline when >1 iteration remains; pipelining a + // single iteration causes barrier mismatches. + if (gemm1MBlocks > 2) { + // vLoadType is always DoubleBuffer in the prefetch path. + bool g1DoubleBuffering = + vLoadType == GemmLoadTileType::DoubleBuffer || + vLoadType == GemmLoadTileType::DirectToLDSDoubleBuffer; + int64_t g1InitiationInterval = g1DoubleBuffering ? 1 : 2; + g1MLoopOp->setAttr(PipelineAttr::getMnemonic(), + rock::PipelineAttr::get(rewriter.getContext(), + g1InitiationInterval)); + } + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(g1MLoopOp.getBody()); + Value g1MLoopIndVar = g1MLoopOp.getInductionVar(); + + gridCoordsGemm1.m_block = g1MLoopIndVar; + + // Normal V tile load (global -> LDS -> regs via DoubleBuffer) + loadAndStoreGemmInputTile( + rewriter, loc, inV, + /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, + preAccelRegBufferVForLoad, vLoadType, "m", blockSize, + elemTypeV, elemTypeVLoad, gemm1TuningParams, featuresAttr, + matrixParamsV, matrixParamsKxQ); + + // Conservative barrier before MMA + LDSBarrierOp::create(rewriter, loc); + + if (failed(emitGemm1Compute(g1MLoopIndVar, vLoadType, + preAccelRegBufferV))) + return failure(); + + // Conservative barrier before next iteration's LDS writes + LDSBarrierOp::create(rewriter, loc); + } + } + } else { + // Non-prefetch path (softmax disabled). + Value endG1MLoop = + rewriter.createOrFold(loc, gemm1MBlocks); + scf::ForOp g1MLoopOp = + createMainLoop(rewriter, loc, endG1MLoop, loadType); + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(g1MLoopOp.getBody()); + Value g1MLoopIndVar = g1MLoopOp.getInductionVar(); + + gridCoordsGemm1.m_block = g1MLoopIndVar; + + loadAndStoreGemmInputTile( + rewriter, loc, inV, + /*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV, + preAccelRegBufferVForLoad, loadType, "m", blockSize, elemTypeV, + elemTypeVLoad, gemm1TuningParams, featuresAttr, matrixParamsV, + matrixParamsKxQ); + + // Conservative barrier before MMA + LDSBarrierOp::create(rewriter, loc); + + if (failed(emitGemm1Compute(g1MLoopIndVar, loadType, + preAccelRegBufferV))) + return failure(); + + // Conservative barrier before next iteration's LDS writes + LDSBarrierOp::create(rewriter, loc); + } } } } @@ -3722,10 +3878,10 @@ void RockGridwiseGemmToBlockwisePass::runOnOperation() { ConversionTarget target(*ctx); target.addIllegalOp(); - target.addLegalDialect(); + target.addLegalDialect< + arith::ArithDialect, rock::RockDialect, memref::MemRefDialect, + affine::AffineDialect, vector::VectorDialect, linalg::LinalgDialect, + scf::SCFDialect, math::MathDialect, amdgpu::AMDGPUDialect>(); target.addLegalOp(); RewritePatternSet patterns(ctx); diff --git a/mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp b/mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp index 17e5e2534056..254eea17f42e 100644 --- a/mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp @@ -27,6 +27,7 @@ #include "mlir/Dialect/Rock/utility/loweringUtils.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Pass/PassManager.h" @@ -165,17 +166,32 @@ struct PushBarrierDownRewritePattern if (!nextOp->getNextNode()) return failure(); + // Don't push past another barrier, RemoveBackToBack handles that. + // Without this check, two adjacent barriers would swap endlessly. + if (isa(nextOp)) + return failure(); + // We assume that operations that have a body may modify LDS if (nextOp->getNumRegions() > 0 && !dyn_cast(nextOp)) return failure(); bool moveDown = true; - // Make sure that the "nextOp" doesn't modify LDS + // Check if the operation accesses LDS. + // We can move past LDS store-only operations because independent + // writes don't need ordering between them, the next barrier will + // ensure all writes complete before any subsequent reads. + // We must stop at LDS reads. + // We recognize store ops both before SugarToLoops (InBoundsStoreOp) + // and after (memref::StoreOp, vector::TransferWriteOp, vector::StoreOp). for (Value operand : nextOp->getOperands()) { auto maybeAlloc = rock::findGpuAlloc(operand); if (succeeded(maybeAlloc) && - getAddressSpace(*maybeAlloc) == AddressSpace::Workgroup) - moveDown = false; + getAddressSpace(*maybeAlloc) == AddressSpace::Workgroup) { + // This operation touches LDS. Check if it's a write-only op. + if (!isa(nextOp)) + moveDown = false; + } } if (moveDown) { @@ -815,12 +831,17 @@ void RockPipeline::runOnOperation() { if (failed( applyPatternsGreedily(func, std::move(patternsRemoveStages)))) return signalPassFailure(); - - RewritePatternSet patternsBackToBack(&getContext()); - patternsBackToBack.add(ctx); - if (failed(applyPatternsGreedily(func, std::move(patternsBackToBack)))) - return signalPassFailure(); } } } + + // Always run back-to-back barrier removal, even when there are no loops + // to pipeline. This handles barriers that become adjacent after other + // passes (e.g., after SugarToLoops unrolls TransformingForOps). + { + RewritePatternSet patternsBackToBack(&getContext()); + patternsBackToBack.add(ctx); + if (failed(applyPatternsGreedily(func, std::move(patternsBackToBack)))) + return signalPassFailure(); + } } diff --git a/mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp b/mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp index 5d17b2dc18b8..702c8ca5f72c 100644 --- a/mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp @@ -818,6 +818,17 @@ LogicalResult ThreadwiseReadIntoRewritePattern::matchAndRewrite( Value validity = loadLoop.getValidity(/*domain=*/0); Value destIndex = loadLoop.getLowerCoords(/*domain=*/1)[extraIdxCount]; + // Compute the full set of coordinates needed to index into `dest`. + // Domain 1 lower coords have (extraIdxCount + 1) elements, but `dest` + // may have fewer dimensions (dstRank). The last dstRank elements of the + // domain-1 coords correspond to the dest buffer dimensions. + int64_t dstRank = dstBufferType.getRank(); + Block::BlockArgListType allDestCoords = + loadLoop.getLowerCoords(/*domain=*/1); + size_t dropCount = allDestCoords.size() - dstRank; + SmallVector destCoords(allDestCoords.begin() + dropCount, + allDestCoords.end()); + for (Value dynamicValidity : adaptor.getDynamicValidities()) { Value validityHere = vector::ExtractOp::create( b, loc, b.getI1Type(), dynamicValidity, destIndex, @@ -845,7 +856,7 @@ LogicalResult ThreadwiseReadIntoRewritePattern::matchAndRewrite( Value loaded = GlobalLoadOp::create(b, loc, loadType, buffer, validity, loadLoop.getLowerCoords(/*domain=*/0), needs64BitIdx); - InBoundsStoreOp::create(b, loc, loaded, dest, destIndex); + InBoundsStoreOp::create(b, loc, loaded, dest, destCoords); } else if (isGlobalToLDS) { int64_t loadTypeByteWidth = getByteWidth(loadType); if (loadTypeByteWidth != 16 && loadTypeByteWidth != 4) @@ -924,11 +935,14 @@ LogicalResult ThreadwiseReadIntoRewritePattern::matchAndRewrite( } if (!isDstVectorBuffer && !isSrcVectorBuffer) { - InBoundsStoreOp::create(b, loc, ifb.getResult(0), dest, destIndex); + InBoundsStoreOp::create(b, loc, ifb.getResult(0), dest, destCoords); } else if (!isDstVectorBuffer && isSrcVectorBuffer) { - destIndex = arith::MulIOp::create( - b, loc, destIndex, ConstantIndexOp::create(b, loc, vectorSrcLen)); - InBoundsStoreOp::create(b, loc, ifb.getResult(0), dest, destIndex); + SmallVector scaledDestCoords(destCoords); + scaledDestCoords.back() = arith::MulIOp::create( + b, loc, scaledDestCoords.back(), + ConstantIndexOp::create(b, loc, vectorSrcLen)); + InBoundsStoreOp::create(b, loc, ifb.getResult(0), dest, + scaledDestCoords); } else { // Destination is a vector buffer Value idx = loadLoop.getLowerCoords(/*domain=*/1)[extraIdxCount]; @@ -944,7 +958,10 @@ LogicalResult ThreadwiseReadIntoRewritePattern::matchAndRewrite( idx = b.createOrFold(loc, idx, srcVecLenVal); } if (vectorSrcLen == vectorDstLen) { - memref::StoreOp::create(b, loc, ifb.getResult(0), dest, idx); + SmallVector vecDestCoords(destCoords); + vecDestCoords.back() = idx; + memref::StoreOp::create(b, loc, ifb.getResult(0), dest, + vecDestCoords); } else { // When the vector types differ, we need to find the gcd // to make it work for the both source and dest. @@ -973,15 +990,16 @@ LogicalResult ThreadwiseReadIntoRewritePattern::matchAndRewrite( Value storeVecStart = b.createOrFold( loc, elementOffset, b.createOrFold(loc, vectorDstLen)); + SmallVector vecDestCoords(destCoords); + vecDestCoords.back() = storeVecStart; Value storeVec = memref::LoadOp::create(b, loc, dstVectorType, dest, - ValueRange{storeVecStart}); + vecDestCoords); Value storeSliceStart = b.createOrFold( loc, elementOffset, b.createOrFold(loc, vectorDstLen)); Value newStoreVec = InsertSliceOp::create( b, loc, dstVectorType, value, storeVec, storeSliceStart); - memref::StoreOp::create(b, loc, newStoreVec, dest, - ValueRange{storeVecStart}); + memref::StoreOp::create(b, loc, newStoreVec, dest, vecDestCoords); } } } diff --git a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir index 119dff173b7a..1ce0a2e6437c 100644 --- a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir +++ b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir @@ -88,6 +88,13 @@ // CHECK: rock.transforming_for // CHECK: %[[tmp:.+]] = memref.load %[[gemm0AccBuf]][ // CHECK: rock.in_bounds_store %[[tmp]] -> %[[gemm0AccBufScalar:.+]][ + + // V prefetch: issue global reads before softmax + // CHECK: %[[ldsG0BStore:.+]] = rock.alloc() : memref<4096xi8, #gpu.address_space> + // CHECK: rock.stage + // CHECK: rock.threadwise_read_into + // CHECK: {name = "VGlobalRead"} + // CHECK: linalg.generic {{.*}} ins(%[[gemm0AccBufScalar]] {{.*}} outs(%[[gemm0AccBufScalar]] // CHECK: %[[gemm0Scaled:.+]] = arith.mulf %in, %[[ln2Recip]] : f32 // CHECK: linalg.yield %[[gemm0Scaled]] @@ -129,6 +136,12 @@ // CHECK-DAG: %[[tilesumadd:.+]] = arith.addf %[[rowsummul]], %[[tilesum]] // CHECK-DAG: %[[tilesumadd]] -> %[[sumRowBuf]] + // V prefetch: write V data to LDS + // CHECK: rock.stage + // CHECK: rock.threadwise_write_all + // CHECK: {name = "VLDSWrite"} + // CHECK: rock.lds_barrier + // Viewing first gemm output as K x D // CHECK-DAG: %[[gemm0NormExpTr0:.+]] = rock.transform %[[gemm0NormExp]] // CHECK-DAG: %[[gemm0NormExpTr1:.+]] = rock.transform %[[gemm0NormExpTr0]] @@ -160,69 +173,34 @@ // CHECK-DAG: %[[viewG1AStoreTr6:.+]] = rock.transform %[[viewG1AStoreTr5]] // CHECK-DAG: %[[viewG1AStoreTr7:.+]] = rock.transform %[[viewG1AStoreTr6]] - // Store to LDS G1A tile buffer // CHECK-DAG: rock.threadwise_write_all {{.*}} %[[G1AregsKpack]] -> [](%[[viewG1AStoreTr7]]) - - // CHECK-DAG: %[[ldsG0BStore:.+]] = rock.alloc() : memref<4096xi8, #gpu.address_space> - - // Repack G1B tile regs for better LDS store vectorization - // CHECK-DAG: %[[G1Bregs:.+]] = rock.alloc() : memref<16xf32, #gpu.address_space> - // CHECK-DAG: %[[G1BregsKpack:.+]] = rock.alloc() : memref<16xf32, #gpu.address_space> - - // Gemm1 - // CHECK: scf.for %[[g1MIter:.+]] + // CHECK: rock.stage + // CHECK: memref.view %[[ldsG0BStore]] + // CHECK: rock.threadwise_read_into + // CHECK: {name = "VLDSRead"} + // CHECK: rock.lds_barrier // CHECK: rock.stage - // Load G1B tile from global to regs - // CHECK-DAG: %[[VTr0:.+]] = rock.transform %[[V]] by - // CHECK-DAG: %[[VTr1:.+]] = rock.transform %[[VTr0]] by - // CHECK-DAG: rock.threadwise_read_into {{.*}}(%[[VTr1]]) {{.*}} -> %[[G1Bregs]] : - // CHECK: {name = "GlobalRead"} - - // CHECK: rock.stage - // CHECK-DAG: %[[G1BregsTr0:.+]] = rock.transform %[[G1Bregs]] by - // CHECK-DAG: %[[G1BregsTr1:.+]] = rock.transform %[[G1BregsTr0]] by - // CHECK-DAG: %[[G1BregsKpackTr0:.+]] = rock.transform %[[G1BregsKpack]] by - // CHECK-DAG: %[[G1BregsKpackTr1:.+]] = rock.transform %[[G1BregsKpackTr0:.+]] by - // CHECK-DAG: rock.threadwise_copy %[[G1BregsTr1]] -> %[[G1BregsKpackTr1]] - // Store to LDS G1B tile buffer - // CHECK-DAG: %[[viewG1BStore:.+]] = memref.view %[[ldsG0BStore]][{{.*}}][] : memref<4096xi8, #gpu.address_space> to memref<1024xf32, #gpu.address_space> - // CHECK-DAG: %[[viewG1BStoreTr0:.+]] = rock.transform %[[viewG1BStore]] - // CHECK-DAG: %[[viewG1BStoreTr1:.+]] = rock.transform %[[viewG1BStoreTr0]] - // CHECK-DAG: %[[viewG1BStoreTr2:.+]] = rock.transform %[[viewG1BStoreTr1]] - // CHECK-DAG: %[[viewG1BStoreTr3:.+]] = rock.transform %[[viewG1BStoreTr2]] - // CHECK-DAG: rock.threadwise_write_all {{.*}} %[[G1BregsKpack]] -> [](%[[viewG1BStoreTr3]]) - // CHECK: {name = "LDSWrite"} - - // Emit blockwise gemm1 - - // CHECK: rock.stage // CHECK-DAG: rock.fill(%[[gemm1AccBuf:.+]], %[[zeroVecF32]]) - // CHECK-DAG: %[[view2G1AStore:.+]] = memref.view %[[ldsG1AStore]][{{.*}}][] : memref<4096xi8, #gpu.address_space> to memref<1024xf32, #gpu.address_space> - // CHECK-DAG: %[[view2G1BStore:.+]] = memref.view %[[ldsG0BStore]][{{.*}}][] : memref<4096xi8, #gpu.address_space> to memref<1024xf32, #gpu.address_space> - // CHECK: rock.blockwise_gemm_accel %[[gemm1AccBuf]] += %[[preAccelRegB:.+]] from %[[view2G1BStore]] * %[[preAccelRegA:.+]] from %[[view2G1AStore]] + // CHECK: memref.view %[[ldsG1AStore]][{{.*}}][] : memref<4096xi8, #gpu.address_space> to memref<1024xf32, #gpu.address_space> + // CHECK: rock.blockwise_gemm_accel // CHECK: {name = "MMA"} - - // CHECK: rock.stage + // CHECK: rock.stage // CHECK: rock.transforming_for - // CHECK: %[[tmp1:.+]] = memref.load %[[gemm1AccBuf]][ - // CHECK: rock.in_bounds_store %[[tmp1]] -> %[[gemm1AccBufScalar:.+]][ - - // CHECK: %[[sliceAttnOutBuf:.+]] = memref.subview %[[attnOutBuf]] + // CHECK: memref.load %[[gemm1AccBuf]][ + // CHECK: rock.in_bounds_store + // CHECK: memref.subview %[[attnOutBuf]] // Reduction corrections // CHECK: rock.transforming_for - // CHECK-DAG: %[[maxdiffexp:.+]] = rock.in_bounds_load %[[maxdiffexpbuf]] - // CHECK-DAG: %[[attnOutVal:.+]] = rock.in_bounds_load %[[sliceAttnOutBuf]] - // CHECK-DAG: %[[gemm1Val:.+]] = rock.in_bounds_load %[[gemm1AccBufScalar]] - - // CHECK-DAG: %[[attnOutBufMul:.+]] = arith.mulf %[[attnOutVal]], %[[maxdiffexp]] - // CHECK-DAG: %[[newattnOutVal:.+]] = arith.addf %[[attnOutBufMul]], %[[gemm1Val]] - // CHECK-DAG: rock.in_bounds_store %[[newattnOutVal]] -> %[[sliceAttnOutBuf]] - // CHECK : } + // CHECK: arith.mulf + // CHECK: arith.addf // CHECK: {name = "PostProcess"} - // CHECK : } -// CHECK : } -// CHECK : %[[flatAttnOutBuf:.+]] = memref.collapse_shape %[[attnOutBuf]] -// CHECK : rock.threadwise_write_all {{.*}} %[[flatAttnOutBuf]] -> {{.*}}(%[[O]]) + // CHECK: rock.lds_barrier + // CHECK: affine.for + // CHECK: memref.subview %[[attnOutBuf]] + // CHECK: rock.transforming_for + // CHECK: arith.divf + // CHECK: %[[flatAttnOutBuf:.+]] = memref.collapse_shape %[[attnOutBuf]] + // CHECK: rock.threadwise_write_all {{.*}} %[[flatAttnOutBuf]] -> {{.*}} memref<1x64x384xf32> func.func @gridwise_attn_simple(%arg0: memref<1x384x64xf32>, %arg1: memref<1x64x384xf32>, %arg2: memref<1x384x64xf32>, %arg3: memref<1x384x64xf32>) attributes {block_size = 64 : i32, grid_size = 24 : i32, kernel, mhal.arch = "amdgcn-amd-amdhsa:gfx908:sramecc+:xnack-"} { %0 = rock.transform %arg0 by (d0, d2, d1)> by [ ["gemmG"] at [0]>, ["gemm0K", "gemm0M"] at [2, 1]>] bounds = [1, 64, 384] -> [1, 384, 64]> : memref<1x384x64xf32> to memref<1x64x384xf32> @@ -255,7 +233,7 @@ func.func @gridwise_attn_kvcache(%arg0: memref<1x384x64xf32>, %arg1: memref<1x64 // CHECK: %[[num:.+]] = arith.addi %[[currSeqLenIndex]], %[[c32]] : index // CHECK-NEXT: %[[numIter:.+]] = arith.divui %[[num]], %[[c32]] : index // CHECK-NEXT: %[[lastIter:.+]] = arith.subi %[[numIter]], %[[c1]] : index - // CHECK-NEXT: scf.for %[[iterIndex:.+]] = %[[c0]] to %[[numIter]] step %[[c1]] { + // CHECK: scf.for %[[iterIndex:.+]] = %[[c0]] to %[[numIter]] step %[[c1]] { // CHECK: %[[comparison:.+]] = arith.cmpi eq, %[[iterIndex]], %[[lastIter]] : index // CHECK-NEXT: scf.if %[[comparison]] { // CHECK: rock.transforming_for {forceUnroll, useIndexDiffs} (%[[dim0:.+]], %[[dim1:.+]], %[[dim2:.+]]) = [{{.*}}]({{.*}}), ({{.*}}) = [] @@ -301,7 +279,7 @@ func.func @gridwise_attn_causal_kvcache(%arg0: memref<1x384x64xf32>, %arg1: memr // CHECK-NEXT: %[[minQPlusOne:.+]] = arith.addi %[[minQEffective]], %[[c1]] : index // CHECK-NEXT: %[[firstCausalMaskIter:.+]] = arith.divui %[[minQPlusOne]], %[[c32]] : index // CHECK: %[[lastIter:.+]] = arith.subi %[[numIter]], %[[c1]] : index - // CHECK-NEXT: scf.for %[[iterIndex:.+]] = %[[c0]] to %[[numIter]] step %[[c1]] { + // CHECK: scf.for %[[iterIndex:.+]] = %[[c0]] to %[[numIter]] step %[[c1]] { // CHECK: %[[comparison:.+]] = arith.cmpi eq, %[[iterIndex]], %[[lastIter]] : index // CHECK-NEXT: scf.if %[[comparison]] { // CHECK: rock.transforming_for {forceUnroll, useIndexDiffs} (%[[dim0:.+]], %[[dim1:.+]], %[[dim2:.+]]) = [{{.*}}]({{.*}}), ({{.*}}) = [] @@ -390,7 +368,7 @@ func.func @gridwise_attn_lse_kvcache(%arg0: memref<1x384x64xf32>, %arg1: memref< // CHECK: %[[num:.+]] = arith.addi %[[currSeqLenIndex]], %[[c32]] : index // CHECK-NEXT: %[[numIter:.+]] = arith.divui %[[num]], %[[c32]] : index // CHECK-NEXT: %[[lastIter:.+]] = arith.subi %[[numIter]], %[[c1]] : index - // CHECK-NEXT: scf.for %[[iterIndex:.+]] = %[[c0]] to %[[numIter]] step %[[c1]] { + // CHECK: scf.for %[[iterIndex:.+]] = %[[c0]] to %[[numIter]] step %[[c1]] { // CHECK: %[[comparison:.+]] = arith.cmpi eq, %[[iterIndex]], %[[lastIter]] : index // CHECK-NEXT: scf.if %[[comparison]] { // CHECK: rock.transforming_for {forceUnroll, useIndexDiffs} (%[[dim0:.+]], %[[dim1:.+]], %[[dim2:.+]]) = [{{.*}}]({{.*}}), ({{.*}}) = [] @@ -437,7 +415,7 @@ func.func @gridwise_attn_softmaxtype(%arg0: memref<1x384x64xf16>, %arg1: memref< // CHECK: %[[num:.+]] = arith.addi %[[currSeqLenIndex]], %[[c32]] : index // CHECK-NEXT: %[[numIter:.+]] = arith.divui %[[num]], %[[c32]] : index // CHECK-NEXT: %[[lastIter:.+]] = arith.subi %[[numIter]], %[[c1]] : index - // CHECK-NEXT: scf.for %[[iterIndex:.+]] = %[[c0]] to %[[numIter]] step %[[c1]] { + // CHECK: scf.for %[[iterIndex:.+]] = %[[c0]] to %[[numIter]] step %[[c1]] { // CHECK: %[[comparison:.+]] = arith.cmpi eq, %[[iterIndex]], %[[lastIter]] : index // CHECK-NEXT: scf.if %[[comparison]] { // CHECK: rock.blockwise_broadcast_reduce max {{.*}} memref<16xf32, #gpu.address_space> using memref<64xf32, #gpu.address_space> into memref<16xf32, #gpu.address_space> @@ -477,7 +455,7 @@ func.func @gridwise_attn_softmaxtype_with_scaling(%arg0: memref<1x384x64xf16>, % // CHECK: %[[num:.+]] = arith.addi %[[currSeqLenIndex]], %[[c32]] : index // CHECK-NEXT: %[[numIter:.+]] = arith.divui %[[num]], %[[c32]] : index // CHECK-NEXT: %[[lastIter:.+]] = arith.subi %[[numIter]], %[[c1]] : index - // CHECK-NEXT: scf.for %[[iterIndex:.+]] = %[[c0]] to %[[numIter]] step %[[c1]] { + // CHECK: scf.for %[[iterIndex:.+]] = %[[c0]] to %[[numIter]] step %[[c1]] { // CHECK: rock.transforming_for // CHECK: rock.in_bounds_store {{.*}} -> {{.*}} : vector<16xf16> -> memref<16xf16, #gpu.address_space>, index // CHECK: linalg.generic @@ -547,7 +525,7 @@ func.func @gridwise_attn_splitkv_lse_kvcache(%arg0: memref<1x384x64xf32>, %arg1: // CHECK-NEXT: %[[lastIter:.+]] = arith.subi %[[endIter]], %[[c1]] : index // CHECK-NEXT: %[[someWorkToDo:.+]] = arith.cmpi ugt, %[[endIter]], %[[startIter]] : index // CHECK-NEXT: scf.if %[[someWorkToDo]] - // CHECK-NEXT: scf.for %[[iterIndex:.+]] = %[[startIter]] to %[[endIter]] step %[[c1]] { + // CHECK: scf.for %[[iterIndex:.+]] = %[[startIter]] to %[[endIter]] step %[[c1]] { // CHECK: %[[comparison:.+]] = arith.cmpi eq, %[[iterIndex]], %[[lastIter]] : index // CHECK-NEXT: scf.if %[[comparison]] { // CHECK: rock.transforming_for {forceUnroll, useIndexDiffs} (%[[dim0:.+]], %[[dim1:.+]], %[[dim2:.+]]) = [{{.*}}]({{.*}}), ({{.*}}) = [] @@ -593,6 +571,8 @@ func.func @multiple_linalg_generics_in_presoftmax_ops(%arg0: memref<59136xf16>, // CHECK: rock.transforming_for // CHECK: rock.in_bounds_store // CHECK-SAME: %[[GEMM0_BUFFER_FLAT]] + // V prefetch: issue global reads before softmax + // CHECK: {name = "VGlobalRead"} // CHECK: %[[ARG2_BUFFER:.*]] = rock.alloc() : memref<16xf16, #gpu.address_space> // CHECK: rock.threadwise_read_into // CHECK-SAME: (%arg2) @@ -665,6 +645,8 @@ func.func @multiple_linalg_generics_in_presoftmax_ops_with_transforms_inbetween( // CHECK: rock.transforming_for // CHECK: rock.in_bounds_store // CHECK-SAME: %[[GEMM0_BUFFER_FLAT]] + // V prefetch: issue global reads before softmax + // CHECK: {name = "VGlobalRead"} // CHECK: %[[ARG2_BUFFER:.*]] = rock.alloc() : memref<16xf16, #gpu.address_space> // CHECK: rock.threadwise_read_into // CHECK-SAME: (%arg2) diff --git a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_gqa.mlir b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_gqa.mlir index dbc1b50fe2fd..7cd09b8b1934 100644 --- a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_gqa.mlir +++ b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_gqa.mlir @@ -31,6 +31,12 @@ func.func @gridwise_attn_causal_scale_gqa(%arg0: memref<8192xf16>, %arg1: memref // CHECK: rock.transforming_for // CHECK: rock.in_bounds_store %{{.*}} -> %[[gemmOut:.+]][{{.*}}] + // V prefetch: issue global reads before softmax + // CHECK: rock.alloc() : memref<32xf16, #gpu.address_space> + // CHECK: rock.blockwise_load_tile + // CHECK-SAME: GlobalReadOnly + // CHECK: amdgpu.sched_barrier + // fusion // CHECK: %[[loadInto:.+]] = rock.alloc() : memref<32xf16, #gpu.address_space> // CHECK: rock.threadwise_read_into diff --git a/mlir/test/Dialect/Rock/lowering_blockwise_broadcast_reduce.mlir b/mlir/test/Dialect/Rock/lowering_blockwise_broadcast_reduce.mlir index 4f678f3118d8..5d265199bb68 100644 --- a/mlir/test/Dialect/Rock/lowering_blockwise_broadcast_reduce.mlir +++ b/mlir/test/Dialect/Rock/lowering_blockwise_broadcast_reduce.mlir @@ -56,9 +56,8 @@ // CHECK: rock.transforming_for {{.*}} (%[[LD_COORD:.*]]) = [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[TID0]], %[[ZERO]], %[[ZERO]]), {{.*}}, (%[[LDS_ST_COORD:.*]]) = [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP13]], #[[TMAP12]]](%[[TID0]], %[[ZERO]], %[[ZERO]]) {{.*}} bounds [1, 1, 20] strides [1, 1, 4] { // CHECK: %[[TO_REDUCE_VAL:.*]] = rock.in_bounds_load {{.*}}[%[[LD_COORD]]] // CHECK: %[[TO_REDUCE_ACC:.*]] = rock.in_bounds_load %[[TO_REDUCE_ACC_MEMREF]][%[[ZERO]]] - // CHECK: %[[MAX_REDUCE:.*]] = vector.reduction , %[[TO_REDUCE_VAL]] : vector<4xf32> into f32 - // CHECK: %[[ACC_NEW:.*]] = arith.maxnumf %[[TO_REDUCE_ACC]], %[[MAX_REDUCE]] - // CHECK: rock.in_bounds_store %[[ACC_NEW]] -> %arg2[%[[LDS_ST_COORD]]] + // CHECK: %[[MAX_REDUCE:.*]] = vector.reduction , %[[TO_REDUCE_VAL]], %[[TO_REDUCE_ACC]] : vector<4xf32> into f32 + // CHECK: rock.in_bounds_store %[[MAX_REDUCE]] -> %[[TO_REDUCE_ACC_MEMREF]][%[[ZERO]]] // CHECK: rock.lds_barrier // CHECK: rock.threadwise_read_into {{.*}}(%arg2) {{.*}} -> %arg1 @@ -83,12 +82,14 @@ func.func @rock_blockwise_reducesum_rthreads_fix(%input_reg : memref<3xf32, #gpu // CHECK: %[[RTID:.*]] = arith.divsi %[[TID]], %c3 // CHECK: %[[NRTID:.*]] = arith.remsi %[[TID]], %c3 - // Threadwise partial reduction into LDS uses rDimPerRThread=5 + // Threadwise partial reduction uses rDimPerRThread=5 // CHECK: rock.transforming_for // CHECK-SAME: bounds [1, 1, 5] - // CHECK: %[[PLUS_ONE:.*]] = arith.addi %[[RTID]], %c1 - // CHECK: %[[BCHECK:.*]] = arith.cmpi slt, %[[PLUS_ONE]], %c2 - // CHECK: scf.if %[[BCHECK]] + // Write reduced value back to LDS + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: rock.lds_barrier + // Branchless reduction across rThreads reads from LDS // CHECK: rock.lds_barrier // CHECK: rock.threadwise_read_into rock.blockwise_broadcast_reduce sum [#inputView][#inputView_tid][#inputView_iter]%input_reg into %output_reg using %ws_lds {axis = 1 : index, blockSize = 10 : i32, nrDimPerThread = 3 : index} : memref<3xf32, #gpu.address_space> using memref<30xf32, #gpu.address_space> into memref<3xf32, #gpu.address_space> @@ -123,7 +124,7 @@ func.func @rock_blockwise_reducesum_rthreads_fix(%input_reg : memref<3xf32, #gpu // CHECK: func @rock_blockwise_reducesum_nr_threads_lt_blocksize -// CHECK-DAG: %[[ZEROFP:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[ZEROFP:.*]] = arith.constant -0.000000e+00 : f32 // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index // CHECK-DAG: %[[TID0:.*]] = rock.workitem_id : index @@ -135,50 +136,47 @@ func.func @rock_blockwise_reducesum_rthreads_fix(%input_reg : memref<3xf32, #gpu // CHECK: rock.transforming_for {{.*}} (%[[LDS_LD_COORD:.*]]) = [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[PRT_GROUP_IDX]], %[[PRT_THREAD_IDX]], %c0) {{.*}} bounds [1, 1, 4] strides [1, 1, 4] { // CHECK: %[[TO_REDUCE_VAL:.*]] = rock.in_bounds_load {{.*}}[%[[LDS_LD_COORD]]] // CHECK: %[[TO_REDUCE_ACC:.*]] = rock.in_bounds_load {{.*}}[%c0] - // CHECK: %[[SUM_REDUCE:.*]] = vector.reduction , %[[TO_REDUCE_VAL]] : vector<4xf32> into f32 - // CHECK: %[[ACC_NEW:.*]] = arith.addf %[[TO_REDUCE_ACC]], %[[SUM_REDUCE]] - // CHECK: rock.in_bounds_store %[[ACC_NEW]] -> {{.*}}[%c0] {{.*}} #gpu.address_space> + // CHECK: %[[SUM_REDUCE:.*]] = vector.reduction , %[[TO_REDUCE_VAL]], %[[TO_REDUCE_ACC]] : vector<4xf32> into f32 + // CHECK: rock.in_bounds_store %[[SUM_REDUCE]] -> {{.*}}[%c0] {{.*}} #gpu.address_space> +// Write partial result to LDS // CHECK: rock.transforming_for {{.*}}[#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[PRT_GROUP_IDX]], %[[PRT_THREAD_IDX]], %c0) {{.*}} bounds [1, 1, 1] strides [1, 1, 1] { // CHECK: rock.in_bounds_load {{.*}} : memref<1xf32, #gpu.address_space>, index -> f32 // CHECK: rock.in_bounds_store {{.*}} : f32 -> memref<80xf32, #gpu.address_space>, index // CHECK: rock.lds_barrier -// Partial threadwise reductions done now... - -// CHECK: %[[PLUS_FOUR_OFFSET:.*]] = arith.addi %[[PRT_THREAD_IDX]], %c4 -// CHECK: %[[PLUS_FOUR_BCHECK:.*]] = arith.cmpi slt, %[[PLUS_FOUR_OFFSET]], %c5 -// CHECK: scf.if %[[PLUS_FOUR_BCHECK]] { - // CHECK: rock.transforming_for - // CHECK-SAME: (%[[LDS_LD_COORD1A:.*]]) = [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[PRT_GROUP_IDX]], %[[PRT_THREAD_IDX]], %c0) - // CHECK-SAME: (%[[LDS_LD_COORD1B:.*]]) = [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[PRT_GROUP_IDX]], %[[PLUS_FOUR_OFFSET]], %c0) - // CHECK: rock.in_bounds_load {{.*}}[%[[LDS_LD_COORD1A]]] - // CHECK: rock.in_bounds_load {{.*}}[%[[LDS_LD_COORD1B]]] - // CHECK: arith.addf - // CHECK: rock.in_bounds_store {{.*}}[%[[LDS_LD_COORD1A]]] -// CHECK: rock.lds_barrier - -// CHECK: %[[PLUS_TWO_OFFSET:.*]] = arith.addi %[[PRT_THREAD_IDX]], %c2 -// CHECK: %[[PLUS_TWO_BCHECK:.*]] = arith.cmpi slt, %[[PLUS_TWO_OFFSET]], %c4 -// CHECK: scf.if %[[PLUS_TWO_BCHECK]] { - // CHECK: rock.transforming_for - // CHECK-SAME: (%[[LDS_LD_COORD1A:.*]]) = [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[PRT_GROUP_IDX]], %[[PRT_THREAD_IDX]], %c0) - // CHECK-SAME: (%[[LDS_LD_COORD1B:.*]]) = [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[PRT_GROUP_IDX]], %[[PLUS_TWO_OFFSET]], %c0) - // CHECK: rock.in_bounds_load {{.*}}[%[[LDS_LD_COORD1A]]] - // CHECK: rock.in_bounds_load {{.*}}[%[[LDS_LD_COORD1B]]] - // CHECK: arith.addf - // CHECK: rock.in_bounds_store {{.*}}[%[[LDS_LD_COORD1A]]] -// CHECK: rock.lds_barrier - -// CHECK: %[[PLUS_ONE_OFFSET:.*]] = arith.addi %[[PRT_THREAD_IDX]], %c1 -// CHECK: %[[PLUS_ONE_BCHECK:.*]] = arith.cmpi slt, %[[PLUS_ONE_OFFSET]], %c2 -// CHECK: scf.if %[[PLUS_ONE_BCHECK]] { - // CHECK: rock.transforming_for - // CHECK-SAME: (%[[LDS_LD_COORD1A:.*]]) = [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[PRT_GROUP_IDX]], %[[PRT_THREAD_IDX]], %c0) - // CHECK-SAME: (%[[LDS_LD_COORD1B:.*]]) = [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[PRT_GROUP_IDX]], %[[PLUS_ONE_OFFSET]], %c0) - // CHECK: rock.in_bounds_load {{.*}}[%[[LDS_LD_COORD1A]]] - // CHECK: rock.in_bounds_load {{.*}}[%[[LDS_LD_COORD1B]]] - // CHECK: arith.addf - // CHECK: rock.in_bounds_store {{.*}}[%[[LDS_LD_COORD1A]]] +// Branchless reduction: alloc private accumulator, each thread reads all +// rTid partial values from LDS and reduces locally (no tree reduction). +// CHECK: %[[BRLESS_ACC:.+]] = rock.alloc() : memref<1xf32, #gpu.address_space> +// Iteration 0: load rThread=0 partial from LDS into accumulator (no add) +// CHECK: rock.transforming_for {{.*}} [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[PRT_GROUP_IDX]], {{.*}}) {{.*}} bounds [1, 1, 1] + // CHECK: %[[LDS_INIT:.+]] = rock.in_bounds_load %arg2[{{.*}}] : memref<80xf32, #gpu.address_space>, index -> f32 + // CHECK: rock.in_bounds_store %[[LDS_INIT]] -> %[[BRLESS_ACC]][{{.*}}] : f32 -> memref<1xf32, #gpu.address_space>, index +// Iteration 1: load rThread=1 partial from LDS and accumulate +// CHECK: rock.transforming_for {{.*}} [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[PRT_GROUP_IDX]], {{.*}}) {{.*}} bounds [1, 1, 1] + // CHECK: %[[LDS_V1:.+]] = rock.in_bounds_load %arg2[{{.*}}] : memref<80xf32, #gpu.address_space>, index -> f32 + // CHECK: %[[ACC_V1:.+]] = rock.in_bounds_load %[[BRLESS_ACC]][{{.*}}] : memref<1xf32, #gpu.address_space>, index -> f32 + // CHECK: %[[SUM_1:.+]] = arith.addf %[[ACC_V1]], %[[LDS_V1]] : f32 + // CHECK: rock.in_bounds_store %[[SUM_1]] -> %[[BRLESS_ACC]][{{.*}}] : f32 -> memref<1xf32, #gpu.address_space>, index +// Iterations 2-4: same pattern (load from LDS, load acc, addf, store acc) +// CHECK: rock.transforming_for {{.*}} bounds [1, 1, 1] + // CHECK: rock.in_bounds_load %arg2{{.*}} : memref<80xf32, #gpu.address_space> + // CHECK: rock.in_bounds_load %[[BRLESS_ACC]] + // CHECK: arith.addf + // CHECK: rock.in_bounds_store {{.*}} -> %[[BRLESS_ACC]] +// CHECK: rock.transforming_for {{.*}} bounds [1, 1, 1] + // CHECK: rock.in_bounds_load %arg2{{.*}} : memref<80xf32, #gpu.address_space> + // CHECK: rock.in_bounds_load %[[BRLESS_ACC]] + // CHECK: arith.addf + // CHECK: rock.in_bounds_store {{.*}} -> %[[BRLESS_ACC]] +// CHECK: rock.transforming_for {{.*}} bounds [1, 1, 1] + // CHECK: rock.in_bounds_load %arg2{{.*}} : memref<80xf32, #gpu.address_space> + // CHECK: rock.in_bounds_load %[[BRLESS_ACC]] + // CHECK: arith.addf + // CHECK: rock.in_bounds_store {{.*}} -> %[[BRLESS_ACC]] +// Write fully reduced value back to LDS +// CHECK: %[[FINAL_RED:.+]] = rock.in_bounds_load %[[BRLESS_ACC]][{{.*}}] : memref<1xf32, #gpu.address_space>, index -> f32 +// CHECK: rock.transforming_for {{.*}} bounds [1, 1, 1] + // CHECK: rock.in_bounds_store %[[FINAL_RED]] -> %arg2[{{.*}}] : f32 -> memref<80xf32, #gpu.address_space>, index // CHECK: rock.lds_barrier // All reductions are done and stored for each point in joint non-reduction space. diff --git a/mlir/test/Dialect/Rock/test_rock_pipeline.mlir b/mlir/test/Dialect/Rock/test_rock_pipeline.mlir index 582bdc24dcf2..b118115d5a95 100644 --- a/mlir/test/Dialect/Rock/test_rock_pipeline.mlir +++ b/mlir/test/Dialect/Rock/test_rock_pipeline.mlir @@ -242,15 +242,11 @@ func.func @rock_pipeline_no_stages_ii_1(%input : memref<16xi8, #gpu.address_spac // CHECK: %[[rawRegA:.*]] = rock.alloc() : memref<16xi8, #gpu.address_space> // CHECK: %[[rawRegB:.*]] = rock.alloc() : memref<16xi8, #gpu.address_space> - // CHECK: %[[lds0View:.*]] = memref.view {{.*}} - // CHECK: %[[rawRegAView:.*]] = memref.view {{.*}} - // CHECK: %[[rawRegBView:.*]] = memref.view {{.*}} - // CHECK: scf.for // CHECK-SAME: %[[c0]] to %[[c16]] // CHECK-NOT: name = "__fwd_barrier__" - // CHECK: rock.extract_multibuffer(%[[lds0View]]) - // CHECK: rock.extract_multibuffer(%[[lds0View]]) + // CHECK: rock.extract_multibuffer(%[[rawRegA]]) + // CHECK: rock.extract_multibuffer(%[[lds0]]) scf.for %arg3 = %c0 to %c16 step %c1 { %a = memref.load %input[%arg3] : memref<16xi8, #gpu.address_space> memref.store %a, %regA[%arg3] : memref<16xi8, #gpu.address_space> diff --git a/mlir/test/Dialect/Rock/toblockwise_attention_accel_lowering.mlir b/mlir/test/Dialect/Rock/toblockwise_attention_accel_lowering.mlir index cc5b9879ccee..6eff8177627c 100644 --- a/mlir/test/Dialect/Rock/toblockwise_attention_accel_lowering.mlir +++ b/mlir/test/Dialect/Rock/toblockwise_attention_accel_lowering.mlir @@ -42,6 +42,11 @@ // CHECK: rock.transforming_for // CHECK: %[[tmp:.+]] = memref.load %[[gemm0AccBuf]][ // CHECK: rock.in_bounds_store %[[tmp]] -> %[[gemm0AccBufScalar:.+]][ + + // V prefetch: issue global reads before softmax + // CHECK: %[[ldsG0AStore:.+]] = rock.alloc() : memref<4096xi8, #gpu.address_space> + // CHECK: rock.blockwise_load_tile %[[V]]{{.*}} LDS -> %[[ldsG0AStore]]{{.*}}#rock + // CHECK: linalg.generic {{.*}} ins(%[[gemm0AccBufScalar]] {{.*}} outs(%[[gemm0AccBufScalar]] // CHECK: %[[gemm0Scaled:.+]] = arith.mulf %in, %[[ln2Recip]] : f32 // CHECK: linalg.yield %[[gemm0Scaled]] @@ -83,6 +88,10 @@ // CHECK-DAG: %[[tilesumadd:.+]] = arith.addf %[[rowsummul]], %[[tilesum]] // CHECK-DAG: %[[tilesumadd]] -> %[[sumRowBuf]] + // V prefetch: write V data to LDS from regs + // CHECK: rock.blockwise_load_tile %[[V]]{{.*}} LDS -> %[[ldsG0AStore]]{{.*}}#rock + // CHECK: rock.lds_barrier + // Viewing first gemm output as K x D // CHECK-DAG: %[[gemm0NormExpTr0:.+]] = rock.transform %[[gemm0NormExp]] // CHECK-DAG: %[[gemm0NormExpTr1:.+]] = rock.transform %[[gemm0NormExpTr0]] @@ -90,8 +99,9 @@ // CHECK-DAG: %[[gemm0NormExpTr3:.+]] = rock.transform %[[gemm0NormExpTr2]] // CHECK-DAG: %[[gemm0NormExpTr4:.+]] = rock.transform %[[gemm0NormExpTr3]] // CHECK-DAG: %[[gemm0NormExpTr5:.+]] = rock.transform %[[gemm0NormExpTr4]] - - // CHECK-DAG: %[[ldsG1BStore:.+]] = rock.alloc() : memref<4096xi8, #gpu.address_space> + + // G1 A buffer alloc + // CHECK: %[[ldsG1BStore:.+]] = rock.alloc() : memref<4096xi8, #gpu.address_space> // Viewing another set of register with kPack packing // CHECK: %[[G1AregsKpackTr0:.+]] = rock.transform %[[G1AregsKpack:.+]] by @@ -116,39 +126,40 @@ // Store to LDS G1A tile buffer // CHECK-DAG: rock.threadwise_write_all {{.*}} %[[G1AregsKpack]] -> [](%[[viewG1AStoreTr7]]) - - // CHECK-DAG: %[[ldsG0AStore:.+]] = rock.alloc() : memref<4096xi8, #gpu.address_space> - // Gemm1 - // CHECK: scf.for %[[g1MIter:.+]] - // CHECK: rock.blockwise_load_tile %[[V]]{{.*}} LDS -> %[[ldsG0AStore]] -> %[[preAccelRegV:[0-9]+]] {{.*}}#rock + // V LDS Read + // CHECK: rock.blockwise_load_tile %[[V]]{{.*}} LDS -> %[[ldsG0AStore]]{{.*}}#rock - // Emit blockwise gemm1 - // rock.stage + // Gemm1 (unrolled) + // Iteration 0: V data already in registers from LDS read // CHECK-DAG: rock.fill(%[[gemm1AccBuf:.+]], %[[zeroVecF32]]) - // CHECK: %[[view2G1AStore:.+]] = memref.view %[[ldsG0AStore]][{{.*}}][] : memref<4096xi8, #gpu.address_space> to memref<1024xf32, #gpu.address_space> // CHECK: %[[view2G1BStore:.+]] = memref.view %[[ldsG1BStore]][{{.*}}][] : memref<4096xi8, #gpu.address_space> to memref<1024xf32, #gpu.address_space> - // CHECK: rock.blockwise_gemm_accel %[[gemm1AccBuf]] += %[[preAccelRegV]] from %[[view2G1AStore]] * %[[preAccelRegA:[0-9]+]] from %[[view2G1BStore]] + // CHECK: rock.blockwise_gemm_accel %[[gemm1AccBuf]] += {{.*}} * {{.*}} from %[[view2G1BStore]] // CHECK: {name = "MMA"} - // rock.stage // CHECK: rock.transforming_for // CHECK: %[[tmp1:.+]] = memref.load %[[gemm1AccBuf]][ // CHECK: rock.in_bounds_store %[[tmp1]] -> %[[gemm1AccBufScalar:.+]][ - // CHECK: %[[sliceAttnOutBuf:.+]] = memref.subview %[[attnOutBuf]] + // CHECK: memref.subview %[[attnOutBuf]] // Reduction corrections // CHECK: rock.transforming_for - // CHECK-DAG: %[[maxdiffexp:.+]] = rock.in_bounds_load %[[maxdiffexpbuf]] - // CHECK-DAG: %[[attnOutVal:.+]] = rock.in_bounds_load %[[sliceAttnOutBuf]] - // CHECK-DAG: %[[gemm1Val:.+]] = rock.in_bounds_load %[[gemm1AccBufScalar]] - - // CHECK-DAG: %[[attnOutBufMul:.+]] = arith.mulf %[[attnOutVal]], %[[maxdiffexp]] - // CHECK-DAG: %[[newattnOutVal:.+]] = arith.addf %[[attnOutBufMul]], %[[gemm1Val]] - // CHECK-DAG: rock.in_bounds_store %[[newattnOutVal]] -> %[[sliceAttnOutBuf]] - // CHECK : } + // CHECK: arith.mulf + // CHECK: arith.addf + // CHECK: {name = "PostProcess"} + // CHECK: rock.lds_barrier + + // Iteration 1: Load next V tile + // CHECK: rock.blockwise_load_tile %[[V]]{{.*}} LDS -> %[[ldsG0AStore]]{{.*}}#rock + // CHECK: rock.lds_barrier + // CHECK-DAG: rock.fill({{.*}}, %[[zeroVecF32]]) + // CHECK: memref.view %[[ldsG1BStore]] + // CHECK: rock.blockwise_gemm_accel + // CHECK: {name = "MMA"} + // CHECK: rock.transforming_for + // CHECK: memref.subview %[[attnOutBuf]] + // CHECK: rock.transforming_for // CHECK: {name = "PostProcess"} - // CHECK : {pipeline = #rock.pipeline<2>} // CHECK : } // CHECK : %[[flatAttnOutBuf:.+]] = memref.collapse_shape %[[attnOutBuf]] // CHECK : rock.threadwise_write_all {{.*}} %[[flatAttnOutBuf]] -> {{.*}}(%[[O]]) diff --git a/mlir/test/rocmlir-driver/pipelines.mlir b/mlir/test/rocmlir-driver/pipelines.mlir index 70c3857d14e2..9d130db811c6 100644 --- a/mlir/test/rocmlir-driver/pipelines.mlir +++ b/mlir/test/rocmlir-driver/pipelines.mlir @@ -41,6 +41,7 @@ // GPU-NEXT:rock-threadwise-gemm-lowering, // GPU-NEXT:rock-analyze-memory-use, // GPU-NEXT:rock-sugar-to-loops, +// GPU-NEXT:rock-pipeline{rock-pipeline-remove-stages=true}, // GPU-NEXT:rock-clean-math, // GPU-NEXT:math-extend-to-supported-types{extra-types={f16} target-type=f32}, // GPU-NEXT:rock-buffer-load-merge,