Skip to content

Commit 79dd5db

Browse files
committed
Clang-format
1 parent 98c5c85 commit 79dd5db

7 files changed

Lines changed: 74 additions & 82 deletions

File tree

mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -464,10 +464,8 @@ def Rock_GemmLoadTileDirectToLDSDefault
464464
: I32EnumAttrCase<"DirectToLDSDefault", 3>;
465465
def Rock_GemmLoadTileDirectToLDSDoubleBuffer
466466
: I32EnumAttrCase<"DirectToLDSDoubleBuffer", 4>;
467-
def Rock_GemmLoadTileGlobalReadOnly
468-
: I32EnumAttrCase<"GlobalReadOnly", 5>;
469-
def Rock_GemmLoadTileLDSWriteFromRegs
470-
: I32EnumAttrCase<"LDSWriteFromRegs", 6>;
467+
def Rock_GemmLoadTileGlobalReadOnly : I32EnumAttrCase<"GlobalReadOnly", 5>;
468+
def Rock_GemmLoadTileLDSWriteFromRegs : I32EnumAttrCase<"LDSWriteFromRegs", 6>;
471469

472470
def Rock_GemmLoadTileType
473471
: Rock_I32Enum<"GemmLoadTileType", "GEMM load tile types",

mlir/include/mlir/Dialect/Rock/Passes.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,10 @@ def RockRegularizePass : Pass<"rock-regularize", "::mlir::func::FuncOp"> {
107107

108108
def RockGridwiseGemmToBlockwisePass : Pass<"rock-gridwise-gemm-to-blockwise", "::mlir::func::FuncOp"> {
109109
let summary = "expand gridwise gemm into blockwise copy, blockwise gemm, and threadwise copy";
110-
let dependentDialects = ["rock::RockDialect", "affine::AffineDialect", "gpu::GPUDialect", "vector::VectorDialect", "memref::MemRefDialect", "linalg::LinalgDialect", "scf::SCFDialect", "amdgpu::AMDGPUDialect"];
110+
let dependentDialects = ["rock::RockDialect", "affine::AffineDialect",
111+
"gpu::GPUDialect", "vector::VectorDialect",
112+
"memref::MemRefDialect", "linalg::LinalgDialect",
113+
"scf::SCFDialect", "amdgpu::AMDGPUDialect"];
111114
}
112115

113116
def RockLinalgAlignPass : Pass<"rock-linalg-align", "::mlir::func::FuncOp"> {

mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,7 +1418,7 @@ struct BlockwiseReduceRewritePattern
14181418
// Branchless reduction: each thread reads all rTidDim partial
14191419
// values from LDS and reduces locally in registers. This avoids
14201420
// creating conditional branches (scf.if) that split softmax into
1421-
// multiple basic blocks.
1421+
// multiple basic blocks.
14221422
// Trade-off: every thread does rTidCount LDS reads (instead of
14231423
// log2(rTidCount) conditional reads in the tree reduction). For
14241424
// typical attention configs where rTidCount is small (e.g., 4),
@@ -1429,8 +1429,8 @@ struct BlockwiseReduceRewritePattern
14291429
int64_t rTidCount = threadViewShape[rTidDim];
14301430

14311431
// Accumulator for the full reduction.
1432-
auto accRegType = MemRefType::get(
1433-
{1}, elemType, AffineMap{}, privateMemoryAddressSpace);
1432+
auto accRegType = MemRefType::get({1}, elemType, AffineMap{},
1433+
privateMemoryAddressSpace);
14341434
Value accReg = GpuAllocOp::create(rewriter, loc, accRegType);
14351435
FillOp::create(rewriter, loc, accReg, initVal);
14361436

@@ -1463,8 +1463,8 @@ struct BlockwiseReduceRewritePattern
14631463
InBoundsStoreOp::create(rewriter, loc, ldVal, accReg,
14641464
zeroConstantOp);
14651465
} else {
1466-
Value accVal = InBoundsLoadOp::create(
1467-
rewriter, loc, elemType, accReg, zeroConstantOp);
1466+
Value accVal = InBoundsLoadOp::create(rewriter, loc, elemType,
1467+
accReg, zeroConstantOp);
14681468
Value reduced = createReducingOp(op, ldVal, accVal, rewriter);
14691469
InBoundsStoreOp::create(rewriter, loc, reduced, accReg,
14701470
zeroConstantOp);
@@ -1476,8 +1476,8 @@ struct BlockwiseReduceRewritePattern
14761476
// All threads with the same nrtid compute the same value,
14771477
// so concurrent writes to the same location are safe.
14781478
{
1479-
Value reducedVal = InBoundsLoadOp::create(
1480-
rewriter, loc, elemType, accReg, zeroConstantOp);
1479+
Value reducedVal = InBoundsLoadOp::create(rewriter, loc, elemType,
1480+
accReg, zeroConstantOp);
14811481
SmallVector<Value, 3> writeInits{nrtid, zeroConstantOp,
14821482
zeroConstantOp};
14831483
SmallVector<int64_t> writeBounds{1, 1, 1};
@@ -1486,8 +1486,7 @@ struct BlockwiseReduceRewritePattern
14861486
TransformingForOp writeLoop = TransformingForOp::create(
14871487
rewriter, loc, ArrayRef<ValueRange>{writeInits},
14881488
ArrayRef<Attribute>{threadToLDSViewTrs},
1489-
ArrayRef<int64_t>(writeBounds),
1490-
ArrayRef<int64_t>(writeStrides),
1489+
ArrayRef<int64_t>(writeBounds), ArrayRef<int64_t>(writeStrides),
14911490
/*forceUnroll=*/true, /*useIndexDiffs=*/true);
14921491
{
14931492
PatternRewriter::InsertionGuard guard(rewriter);

mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,8 @@ class LoweringBlockwiseLoadTileOp final
244244
else
245245
b.setInsertionPoint(op);
246246

247-
bool globalReadOnly =
248-
loadType == GemmLoadTileType::GlobalReadOnly;
249-
bool ldsWriteFromRegs =
250-
loadType == GemmLoadTileType::LDSWriteFromRegs;
247+
bool globalReadOnly = loadType == GemmLoadTileType::GlobalReadOnly;
248+
bool ldsWriteFromRegs = loadType == GemmLoadTileType::LDSWriteFromRegs;
251249

252250
Value loadBuffer, storeBuffer;
253251
if (globalReadOnly || ldsWriteFromRegs) {
@@ -258,8 +256,8 @@ class LoweringBlockwiseLoadTileOp final
258256
"destRegisters must be set for split-phase load types");
259257
loadBuffer = destRegisters;
260258
if (ldsWriteFromRegs) {
261-
storeBuffer = gpuAlloc(b, loc, copyPerThread, elementType,
262-
AddressSpace::Private);
259+
storeBuffer =
260+
gpuAlloc(b, loc, copyPerThread, elementType, AddressSpace::Private);
263261
}
264262
} else if (loadType == GemmLoadTileType::BypassLDS) {
265263
auto privateMemoryAddressSpace = b.getAttr<gpu::AddressSpaceAttr>(
@@ -320,13 +318,13 @@ class LoweringBlockwiseLoadTileOp final
320318
Value wrappedSource =
321319
transform(b, source, maybeBufferViews->gridSubTile);
322320

323-
ThreadwiseReadIntoOp::create(
324-
b, loc, vectorOfBoolShapedLike(loadBuffer), wrappedSource,
325-
loadBuffer,
326-
/*dynamicValidities=*/ValueRange{},
327-
/*extraViews=*/b.getArrayAttr({}),
328-
/*extraIndices=*/indices, forceUnroll, true,
329-
/*ldsTransposeConfig=*/nullptr);
321+
ThreadwiseReadIntoOp::create(b, loc, vectorOfBoolShapedLike(loadBuffer),
322+
wrappedSource, loadBuffer,
323+
/*dynamicValidities=*/ValueRange{},
324+
/*extraViews=*/b.getArrayAttr({}),
325+
/*extraIndices=*/indices, forceUnroll,
326+
true,
327+
/*ldsTransposeConfig=*/nullptr);
330328

331329
if (!globalReadOnly && rock::isGlobalPrefetchSupported(arch)) {
332330
// add one to k_loop to prefetch next iteration

mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp

Lines changed: 46 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
#include "mlir/Dialect/Rock/utility/math.h"
3434
#include "mlir/Dialect/Rock/utility/transformMapUtils.h"
3535

36-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
3736
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
37+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
3838
#include "mlir/Dialect/Arith/IR/Arith.h"
3939
#include "mlir/Dialect/Func/IR/FuncOps.h"
4040
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -2823,23 +2823,21 @@ struct GridwiseAttentionAccelRewritePattern
28232823
<< "V prefetch Phase 2 hoist decision: "
28242824
<< (hoistVPhase2 ? "HOIST" : "DEFER")
28252825
<< " (hoistedTotal=" << hoistedTotal << ", max=" << maxLDS
2826-
<< ", sumWS=" << sumWSBytes
2827-
<< ", gemm0=" << gemm0PeakBytes
2826+
<< ", sumWS=" << sumWSBytes << ", gemm0=" << gemm0PeakBytes
28282827
<< ", gemm1=" << gemm1PeakBytes << ")\n");
28292828
}
28302829

28312830
if (prefetchFirstVTile) {
28322831
// Set up grid coordinates for the first V tile.
28332832
gridCoordsGemm1 = layout::makeGxNGridLayout(
2834-
rewriter, loc, bid, zero, gemm1NBlocks, gridSize, arch,
2835-
numChiplets, splitKVConst);
2833+
rewriter, loc, bid, zero, gemm1NBlocks, gridSize, arch, numChiplets,
2834+
splitKVConst);
28362835
gridCoordsGemm1.m_block = zero; // First V tile (block index 0)
28372836

28382837
// Allocate a flat register buffer shared between the GlobalReadOnly
28392838
// and LDSWriteFromRegs phases. Size must match what the lowering
28402839
// computes: copyPerThread = (kPerBlock * dPerBlock) / blockSize.
2841-
int64_t vCopyPerThread =
2842-
(gemm1KPerBlock * gemm1MPerBlock) / blockSize;
2840+
int64_t vCopyPerThread = (gemm1KPerBlock * gemm1MPerBlock) / blockSize;
28432841
vPrefetchRegs = gpuAlloc(rewriter, loc, vCopyPerThread, elemTypeV,
28442842
gpu::AddressSpace::Private);
28452843

@@ -2851,15 +2849,15 @@ struct GridwiseAttentionAccelRewritePattern
28512849
rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV);
28522850
loadAndStoreGemmInputTile(
28532851
rewriter, loc, inV,
2854-
/*kIter=*/mLoopIV, tid, gridCoordsGemm1, dummyLDS,
2855-
vPrefetchRegs, GemmLoadTileType::GlobalReadOnly, "m", blockSize,
2856-
elemTypeV, elemTypeVLoad, gemm1TuningParams, featuresAttr,
2857-
matrixParamsV, matrixParamsKxQ);
2852+
/*kIter=*/mLoopIV, tid, gridCoordsGemm1, dummyLDS, vPrefetchRegs,
2853+
GemmLoadTileType::GlobalReadOnly, "m", blockSize, elemTypeV,
2854+
elemTypeVLoad, gemm1TuningParams, featuresAttr, matrixParamsV,
2855+
matrixParamsKxQ);
28582856

28592857
// Insert a scheduling barrier to prevent the LLVM backend scheduler
28602858
// from sinking the V global loads past the softmax computation.
2861-
amdgpu::SchedBarrierOp::create(
2862-
rewriter, loc, amdgpu::sched_barrier_opt_enum::none);
2859+
amdgpu::SchedBarrierOp::create(rewriter, loc,
2860+
amdgpu::sched_barrier_opt_enum::none);
28632861
}
28642862

28652863
int64_t prePadG0M = gemm0M;
@@ -3034,9 +3032,9 @@ struct GridwiseAttentionAccelRewritePattern
30343032
loadAndStoreGemmInputTile(
30353033
rewriter, loc, inV,
30363034
/*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV,
3037-
vPrefetchRegs, GemmLoadTileType::LDSWriteFromRegs, "m",
3038-
blockSize, elemTypeV, elemTypeVLoad, gemm1TuningParams,
3039-
featuresAttr, matrixParamsV, matrixParamsKxQ);
3035+
vPrefetchRegs, GemmLoadTileType::LDSWriteFromRegs, "m", blockSize,
3036+
elemTypeV, elemTypeVLoad, gemm1TuningParams, featuresAttr,
3037+
matrixParamsV, matrixParamsKxQ);
30403038
}
30413039

30423040
// Softmax sum reduction
@@ -3074,9 +3072,9 @@ struct GridwiseAttentionAccelRewritePattern
30743072
loadAndStoreGemmInputTile(
30753073
rewriter, loc, inV,
30763074
/*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV,
3077-
vPrefetchRegs, GemmLoadTileType::LDSWriteFromRegs, "m",
3078-
blockSize, elemTypeV, elemTypeVLoad, gemm1TuningParams,
3079-
featuresAttr, matrixParamsV, matrixParamsKxQ);
3075+
vPrefetchRegs, GemmLoadTileType::LDSWriteFromRegs, "m", blockSize,
3076+
elemTypeV, elemTypeVLoad, gemm1TuningParams, featuresAttr,
3077+
matrixParamsV, matrixParamsKxQ);
30803078
LDSBarrierOp::create(rewriter, loc);
30813079
}
30823080
}
@@ -3140,9 +3138,9 @@ struct GridwiseAttentionAccelRewritePattern
31403138
}
31413139

31423140
// Helper lambda: emit GEMM1 MMA + PostProcess for a single V tile.
3143-
auto emitGemm1Compute =
3144-
[&](Value g1MBlockIdx, GemmLoadTileType vLoadType,
3145-
Value vRegBuf) -> LogicalResult {
3141+
auto emitGemm1Compute = [&](Value g1MBlockIdx,
3142+
GemmLoadTileType vLoadType,
3143+
Value vRegBuf) -> LogicalResult {
31463144
// Emit GEMM 1 MMA.
31473145
auto computeStage = StageOp::create(rewriter, loc, "MMA");
31483146
{
@@ -3155,8 +3153,8 @@ struct GridwiseAttentionAccelRewritePattern
31553153
zeroAccBuffer(rewriter, loc, matrixC);
31563154
} else {
31573155
if (gemm1MBlocks > 1) {
3158-
matrixC = createSliceOfFirstDim(rewriter, loc, matrixC,
3159-
g1MBlockIdx);
3156+
matrixC =
3157+
createSliceOfFirstDim(rewriter, loc, matrixC, g1MBlockIdx);
31603158
}
31613159
}
31623160

@@ -3206,20 +3204,19 @@ struct GridwiseAttentionAccelRewritePattern
32063204
auto loadTypeKxD = doBypassLDSSecondGemm
32073205
? GemmLoadTileType::BypassLDS
32083206
: GemmLoadTileType::Default;
3209-
blockwiseGemmAccel(
3210-
rewriter, loc, vLoadType, loadTypeKxD, vRegBuf,
3211-
preAccelRegBufferQxK, matrixC, matrixParamsV, matrixParamsKxQ,
3212-
ldsTileBufferV, gemm1LDSBufferB,
3213-
/*scaleA=*/nullptr, /*scaleB=*/nullptr,
3214-
/*bufferScaleA=*/nullptr, /*bufferScaleB=*/nullptr,
3215-
featuresAttr, op.getBlockSizeAttr(), gemm1TuningParams);
3207+
blockwiseGemmAccel(rewriter, loc, vLoadType, loadTypeKxD, vRegBuf,
3208+
preAccelRegBufferQxK, matrixC, matrixParamsV,
3209+
matrixParamsKxQ, ldsTileBufferV, gemm1LDSBufferB,
3210+
/*scaleA=*/nullptr, /*scaleB=*/nullptr,
3211+
/*bufferScaleA=*/nullptr,
3212+
/*bufferScaleB=*/nullptr, featuresAttr,
3213+
op.getBlockSizeAttr(), gemm1TuningParams);
32163214

32173215
rock::YieldOp::create(rewriter, loc);
32183216
}
32193217

32203218
// Emit GEMM 1 PostProcess.
3221-
auto postProcessStage =
3222-
StageOp::create(rewriter, loc, "PostProcess");
3219+
auto postProcessStage = StageOp::create(rewriter, loc, "PostProcess");
32233220
{
32243221
PatternRewriter::InsertionGuard guard(rewriter);
32253222
rewriter.setInsertionPointToStart(
@@ -3232,8 +3229,8 @@ struct GridwiseAttentionAccelRewritePattern
32323229
if (!op.getEnableSoftmax() && gemm1MBlocks > 1) {
32333230
gemm1OutBufferPerG1MBlock = createSliceOfFirstDim(
32343231
rewriter, loc, gemm1OutBuffer, g1MBlockIdx);
3235-
matrixC = createSliceOfFirstDim(rewriter, loc, matrixC,
3236-
g1MBlockIdx);
3232+
matrixC =
3233+
createSliceOfFirstDim(rewriter, loc, matrixC, g1MBlockIdx);
32373234
}
32383235

32393236
accelEmitterPtrGemm1->computeOutputConversion(
@@ -3280,11 +3277,10 @@ struct GridwiseAttentionAccelRewritePattern
32803277
// Default load path expects rank-1, so allocate a separate buf.
32813278
Value peeledVRegBuf = preAccelRegBufferV;
32823279
if (doubleBuffering) {
3283-
auto [peeledVForLoad, peeledVBuf] =
3284-
createRegInterrimBufferForAccel(
3285-
rewriter, loc, accelParamsGemm1.argTypeA,
3286-
accelParamsGemm1.kBasePerThread,
3287-
/*repeats=*/1, directToLDS);
3280+
auto [peeledVForLoad, peeledVBuf] = createRegInterrimBufferForAccel(
3281+
rewriter, loc, accelParamsGemm1.argTypeA,
3282+
accelParamsGemm1.kBasePerThread,
3283+
/*repeats=*/1, directToLDS);
32883284
peeledVRegBuf = peeledVBuf;
32893285
}
32903286
if (!doBypassLDSSecondGemm)
@@ -3298,25 +3294,23 @@ struct GridwiseAttentionAccelRewritePattern
32983294
if (gemm1MBlocks > 1) {
32993295
LDSBarrierOp::create(rewriter, loc);
33003296

3301-
Value startG1M =
3302-
rewriter.createOrFold<ConstantIndexOp>(loc, 1);
3297+
Value startG1M = rewriter.createOrFold<ConstantIndexOp>(loc, 1);
33033298
Value endG1MLoop =
33043299
rewriter.createOrFold<ConstantIndexOp>(loc, gemm1MBlocks);
33053300
Value oneVal =
33063301
rewriter.createOrFold<arith::ConstantIndexOp>(loc, 1);
3307-
scf::ForOp g1MLoopOp = scf::ForOp::create(
3308-
rewriter, loc, startG1M, endG1MLoop, oneVal);
3302+
scf::ForOp g1MLoopOp =
3303+
scf::ForOp::create(rewriter, loc, startG1M, endG1MLoop, oneVal);
33093304
// Only pipeline when >1 iteration remains; pipelining a
33103305
// single iteration causes barrier mismatches.
33113306
if (gemm1MBlocks > 2) {
33123307
bool g1DoubleBuffering =
33133308
loadType == GemmLoadTileType::DoubleBuffer ||
33143309
loadType == GemmLoadTileType::DirectToLDSDoubleBuffer;
33153310
int64_t g1InitiationInterval = g1DoubleBuffering ? 1 : 2;
3316-
g1MLoopOp->setAttr(
3317-
PipelineAttr::getMnemonic(),
3318-
rock::PipelineAttr::get(rewriter.getContext(),
3319-
g1InitiationInterval));
3311+
g1MLoopOp->setAttr(PipelineAttr::getMnemonic(),
3312+
rock::PipelineAttr::get(rewriter.getContext(),
3313+
g1InitiationInterval));
33203314
}
33213315
{
33223316
OpBuilder::InsertionGuard guard(rewriter);
@@ -3916,11 +3910,10 @@ void RockGridwiseGemmToBlockwisePass::runOnOperation() {
39163910
ConversionTarget target(*ctx);
39173911
target.addIllegalOp<rock::GridwiseGemmOp, rock::GridwiseGemmAccelOp,
39183912
GridwiseAttentionAccelOp>();
3919-
target.addLegalDialect<arith::ArithDialect, rock::RockDialect,
3920-
memref::MemRefDialect, affine::AffineDialect,
3921-
vector::VectorDialect, linalg::LinalgDialect,
3922-
scf::SCFDialect, math::MathDialect,
3923-
amdgpu::AMDGPUDialect>();
3913+
target.addLegalDialect<
3914+
arith::ArithDialect, rock::RockDialect, memref::MemRefDialect,
3915+
affine::AffineDialect, vector::VectorDialect, linalg::LinalgDialect,
3916+
scf::SCFDialect, math::MathDialect, amdgpu::AMDGPUDialect>();
39243917
target.addLegalOp<gpu::PrintfOp>();
39253918

39263919
RewritePatternSet patterns(ctx);

mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2323
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
2424
#include "mlir/Dialect/Rock/IR/Rock.h"
25-
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2625
#include "mlir/Dialect/Rock/Passes.h"
2726
#include "mlir/Dialect/Rock/Transforms/RockMultibuffer.h"
2827
#include "mlir/Dialect/Rock/utility/loweringUtils.h"
2928
#include "mlir/Dialect/SCF/IR/SCF.h"
3029
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
30+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
3131
#include "mlir/Interfaces/LoopLikeInterface.h"
3232
#include "mlir/Interfaces/ViewLikeInterface.h"
3333
#include "mlir/Pass/PassManager.h"

mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -823,7 +823,8 @@ LogicalResult ThreadwiseReadIntoRewritePattern::matchAndRewrite(
823823
// may have fewer dimensions (dstRank). The last dstRank elements of the
824824
// domain-1 coords correspond to the dest buffer dimensions.
825825
int64_t dstRank = dstBufferType.getRank();
826-
Block::BlockArgListType allDestCoords = loadLoop.getLowerCoords(/*domain=*/1);
826+
Block::BlockArgListType allDestCoords =
827+
loadLoop.getLowerCoords(/*domain=*/1);
827828
size_t dropCount = allDestCoords.size() - dstRank;
828829
SmallVector<Value> destCoords(allDestCoords.begin() + dropCount,
829830
allDestCoords.end());

0 commit comments

Comments
 (0)