3333#include " mlir/Dialect/Rock/utility/math.h"
3434#include " mlir/Dialect/Rock/utility/transformMapUtils.h"
3535
36- #include " mlir/Dialect/Affine/IR/AffineOps.h"
3736#include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
37+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
3838#include " mlir/Dialect/Arith/IR/Arith.h"
3939#include " mlir/Dialect/Func/IR/FuncOps.h"
4040#include " mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -2823,23 +2823,21 @@ struct GridwiseAttentionAccelRewritePattern
28232823 << " V prefetch Phase 2 hoist decision: "
28242824 << (hoistVPhase2 ? " HOIST" : " DEFER" )
28252825 << " (hoistedTotal=" << hoistedTotal << " , max=" << maxLDS
2826- << " , sumWS=" << sumWSBytes
2827- << " , gemm0=" << gemm0PeakBytes
2826+ << " , sumWS=" << sumWSBytes << " , gemm0=" << gemm0PeakBytes
28282827 << " , gemm1=" << gemm1PeakBytes << " )\n " );
28292828 }
28302829
28312830 if (prefetchFirstVTile) {
28322831 // Set up grid coordinates for the first V tile.
28332832 gridCoordsGemm1 = layout::makeGxNGridLayout (
2834- rewriter, loc, bid, zero, gemm1NBlocks, gridSize, arch,
2835- numChiplets, splitKVConst);
2833+ rewriter, loc, bid, zero, gemm1NBlocks, gridSize, arch, numChiplets,
2834+ splitKVConst);
28362835 gridCoordsGemm1.m_block = zero; // First V tile (block index 0)
28372836
28382837 // Allocate a flat register buffer shared between the GlobalReadOnly
28392838 // and LDSWriteFromRegs phases. Size must match what the lowering
28402839 // computes: copyPerThread = (kPerBlock * dPerBlock) / blockSize.
2841- int64_t vCopyPerThread =
2842- (gemm1KPerBlock * gemm1MPerBlock) / blockSize;
2840+ int64_t vCopyPerThread = (gemm1KPerBlock * gemm1MPerBlock) / blockSize;
28432841 vPrefetchRegs = gpuAlloc (rewriter, loc, vCopyPerThread, elemTypeV,
28442842 gpu::AddressSpace::Private);
28452843
@@ -2851,15 +2849,15 @@ struct GridwiseAttentionAccelRewritePattern
28512849 rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV);
28522850 loadAndStoreGemmInputTile (
28532851 rewriter, loc, inV,
2854- /* kIter=*/ mLoopIV , tid, gridCoordsGemm1, dummyLDS,
2855- vPrefetchRegs, GemmLoadTileType::GlobalReadOnly, " m" , blockSize,
2856- elemTypeV, elemTypeVLoad, gemm1TuningParams, featuresAttr,
2857- matrixParamsV, matrixParamsKxQ);
2852+ /* kIter=*/ mLoopIV , tid, gridCoordsGemm1, dummyLDS, vPrefetchRegs,
2853+ GemmLoadTileType::GlobalReadOnly, " m" , blockSize, elemTypeV ,
2854+ elemTypeVLoad, gemm1TuningParams, featuresAttr, matrixParamsV ,
2855+ matrixParamsKxQ);
28582856
28592857 // Insert a scheduling barrier to prevent the LLVM backend scheduler
28602858 // from sinking the V global loads past the softmax computation.
2861- amdgpu::SchedBarrierOp::create (
2862- rewriter, loc, amdgpu::sched_barrier_opt_enum::none);
2859+ amdgpu::SchedBarrierOp::create (rewriter, loc,
2860+ amdgpu::sched_barrier_opt_enum::none);
28632861 }
28642862
28652863 int64_t prePadG0M = gemm0M;
@@ -3034,9 +3032,9 @@ struct GridwiseAttentionAccelRewritePattern
30343032 loadAndStoreGemmInputTile (
30353033 rewriter, loc, inV,
30363034 /* kIter=*/ mLoopIV , tid, gridCoordsGemm1, ldsByteBufferV,
3037- vPrefetchRegs, GemmLoadTileType::LDSWriteFromRegs, " m" ,
3038- blockSize, elemTypeV, elemTypeVLoad, gemm1TuningParams,
3039- featuresAttr, matrixParamsV, matrixParamsKxQ);
3035+ vPrefetchRegs, GemmLoadTileType::LDSWriteFromRegs, " m" , blockSize,
3036+ elemTypeV, elemTypeVLoad, gemm1TuningParams, featuresAttr ,
3037+ matrixParamsV, matrixParamsKxQ);
30403038 }
30413039
30423040 // Softmax sum reduction
@@ -3074,9 +3072,9 @@ struct GridwiseAttentionAccelRewritePattern
30743072 loadAndStoreGemmInputTile (
30753073 rewriter, loc, inV,
30763074 /* kIter=*/ mLoopIV , tid, gridCoordsGemm1, ldsByteBufferV,
3077- vPrefetchRegs, GemmLoadTileType::LDSWriteFromRegs, " m" ,
3078- blockSize, elemTypeV, elemTypeVLoad, gemm1TuningParams,
3079- featuresAttr, matrixParamsV, matrixParamsKxQ);
3075+ vPrefetchRegs, GemmLoadTileType::LDSWriteFromRegs, " m" , blockSize,
3076+ elemTypeV, elemTypeVLoad, gemm1TuningParams, featuresAttr ,
3077+ matrixParamsV, matrixParamsKxQ);
30803078 LDSBarrierOp::create (rewriter, loc);
30813079 }
30823080 }
@@ -3140,9 +3138,9 @@ struct GridwiseAttentionAccelRewritePattern
31403138 }
31413139
31423140 // Helper lambda: emit GEMM1 MMA + PostProcess for a single V tile.
3143- auto emitGemm1Compute =
3144- [&](Value g1MBlockIdx, GemmLoadTileType vLoadType,
3145- Value vRegBuf) -> LogicalResult {
3141+ auto emitGemm1Compute = [&](Value g1MBlockIdx,
3142+ GemmLoadTileType vLoadType,
3143+ Value vRegBuf) -> LogicalResult {
31463144 // Emit GEMM 1 MMA.
31473145 auto computeStage = StageOp::create (rewriter, loc, " MMA" );
31483146 {
@@ -3155,8 +3153,8 @@ struct GridwiseAttentionAccelRewritePattern
31553153 zeroAccBuffer (rewriter, loc, matrixC);
31563154 } else {
31573155 if (gemm1MBlocks > 1 ) {
3158- matrixC = createSliceOfFirstDim (rewriter, loc, matrixC,
3159- g1MBlockIdx);
3156+ matrixC =
3157+ createSliceOfFirstDim (rewriter, loc, matrixC, g1MBlockIdx);
31603158 }
31613159 }
31623160
@@ -3206,20 +3204,19 @@ struct GridwiseAttentionAccelRewritePattern
32063204 auto loadTypeKxD = doBypassLDSSecondGemm
32073205 ? GemmLoadTileType::BypassLDS
32083206 : GemmLoadTileType::Default;
3209- blockwiseGemmAccel (
3210- rewriter, loc, vLoadType, loadTypeKxD, vRegBuf ,
3211- preAccelRegBufferQxK, matrixC, matrixParamsV, matrixParamsKxQ ,
3212- ldsTileBufferV, gemm1LDSBufferB ,
3213- /* scaleA= */ nullptr , /* scaleB =*/ nullptr ,
3214- /* bufferScaleA= */ nullptr , /* bufferScaleB=*/ nullptr ,
3215- featuresAttr, op.getBlockSizeAttr (), gemm1TuningParams);
3207+ blockwiseGemmAccel (rewriter, loc, vLoadType, loadTypeKxD, vRegBuf,
3208+ preAccelRegBufferQxK, matrixC, matrixParamsV ,
3209+ matrixParamsKxQ, ldsTileBufferV, gemm1LDSBufferB ,
3210+ /* scaleA= */ nullptr , /* scaleB= */ nullptr ,
3211+ /* bufferScaleA =*/ nullptr ,
3212+ /* bufferScaleB=*/ nullptr , featuresAttr ,
3213+ op.getBlockSizeAttr (), gemm1TuningParams);
32163214
32173215 rock::YieldOp::create (rewriter, loc);
32183216 }
32193217
32203218 // Emit GEMM 1 PostProcess.
3221- auto postProcessStage =
3222- StageOp::create (rewriter, loc, " PostProcess" );
3219+ auto postProcessStage = StageOp::create (rewriter, loc, " PostProcess" );
32233220 {
32243221 PatternRewriter::InsertionGuard guard (rewriter);
32253222 rewriter.setInsertionPointToStart (
@@ -3232,8 +3229,8 @@ struct GridwiseAttentionAccelRewritePattern
32323229 if (!op.getEnableSoftmax () && gemm1MBlocks > 1 ) {
32333230 gemm1OutBufferPerG1MBlock = createSliceOfFirstDim (
32343231 rewriter, loc, gemm1OutBuffer, g1MBlockIdx);
3235- matrixC = createSliceOfFirstDim (rewriter, loc, matrixC,
3236- g1MBlockIdx);
3232+ matrixC =
3233+ createSliceOfFirstDim (rewriter, loc, matrixC, g1MBlockIdx);
32373234 }
32383235
32393236 accelEmitterPtrGemm1->computeOutputConversion (
@@ -3280,11 +3277,10 @@ struct GridwiseAttentionAccelRewritePattern
32803277 // Default load path expects rank-1, so allocate a separate buf.
32813278 Value peeledVRegBuf = preAccelRegBufferV;
32823279 if (doubleBuffering) {
3283- auto [peeledVForLoad, peeledVBuf] =
3284- createRegInterrimBufferForAccel (
3285- rewriter, loc, accelParamsGemm1.argTypeA ,
3286- accelParamsGemm1.kBasePerThread ,
3287- /* repeats=*/ 1 , directToLDS);
3280+ auto [peeledVForLoad, peeledVBuf] = createRegInterrimBufferForAccel (
3281+ rewriter, loc, accelParamsGemm1.argTypeA ,
3282+ accelParamsGemm1.kBasePerThread ,
3283+ /* repeats=*/ 1 , directToLDS);
32883284 peeledVRegBuf = peeledVBuf;
32893285 }
32903286 if (!doBypassLDSSecondGemm)
@@ -3298,25 +3294,23 @@ struct GridwiseAttentionAccelRewritePattern
32983294 if (gemm1MBlocks > 1 ) {
32993295 LDSBarrierOp::create (rewriter, loc);
33003296
3301- Value startG1M =
3302- rewriter.createOrFold <ConstantIndexOp>(loc, 1 );
3297+ Value startG1M = rewriter.createOrFold <ConstantIndexOp>(loc, 1 );
33033298 Value endG1MLoop =
33043299 rewriter.createOrFold <ConstantIndexOp>(loc, gemm1MBlocks);
33053300 Value oneVal =
33063301 rewriter.createOrFold <arith::ConstantIndexOp>(loc, 1 );
3307- scf::ForOp g1MLoopOp = scf::ForOp::create (
3308- rewriter, loc, startG1M, endG1MLoop, oneVal);
3302+ scf::ForOp g1MLoopOp =
3303+ scf::ForOp::create ( rewriter, loc, startG1M, endG1MLoop, oneVal);
33093304 // Only pipeline when >1 iteration remains; pipelining a
33103305 // single iteration causes barrier mismatches.
33113306 if (gemm1MBlocks > 2 ) {
33123307 bool g1DoubleBuffering =
33133308 loadType == GemmLoadTileType::DoubleBuffer ||
33143309 loadType == GemmLoadTileType::DirectToLDSDoubleBuffer;
33153310 int64_t g1InitiationInterval = g1DoubleBuffering ? 1 : 2 ;
3316- g1MLoopOp->setAttr (
3317- PipelineAttr::getMnemonic (),
3318- rock::PipelineAttr::get (rewriter.getContext (),
3319- g1InitiationInterval));
3311+ g1MLoopOp->setAttr (PipelineAttr::getMnemonic (),
3312+ rock::PipelineAttr::get (rewriter.getContext (),
3313+ g1InitiationInterval));
33203314 }
33213315 {
33223316 OpBuilder::InsertionGuard guard (rewriter);
@@ -3916,11 +3910,10 @@ void RockGridwiseGemmToBlockwisePass::runOnOperation() {
39163910 ConversionTarget target (*ctx);
39173911 target.addIllegalOp <rock::GridwiseGemmOp, rock::GridwiseGemmAccelOp,
39183912 GridwiseAttentionAccelOp>();
3919- target.addLegalDialect <arith::ArithDialect, rock::RockDialect,
3920- memref::MemRefDialect, affine::AffineDialect,
3921- vector::VectorDialect, linalg::LinalgDialect,
3922- scf::SCFDialect, math::MathDialect,
3923- amdgpu::AMDGPUDialect>();
3913+ target.addLegalDialect <
3914+ arith::ArithDialect, rock::RockDialect, memref::MemRefDialect,
3915+ affine::AffineDialect, vector::VectorDialect, linalg::LinalgDialect,
3916+ scf::SCFDialect, math::MathDialect, amdgpu::AMDGPUDialect>();
39243917 target.addLegalOp <gpu::PrintfOp>();
39253918
39263919 RewritePatternSet patterns (ctx);
0 commit comments