Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -464,13 +464,19 @@ def Rock_GemmLoadTileDirectToLDSDefault
: I32EnumAttrCase<"DirectToLDSDefault", 3>;
def Rock_GemmLoadTileDirectToLDSDoubleBuffer
: I32EnumAttrCase<"DirectToLDSDoubleBuffer", 4>;
def Rock_GemmLoadTileGlobalReadOnly : I32EnumAttrCase<"GlobalReadOnly", 5>;
def Rock_GemmLoadTileLDSWriteFromRegs : I32EnumAttrCase<"LDSWriteFromRegs", 6>;
def Rock_GemmLoadTileLDSReadOnly : I32EnumAttrCase<"LDSReadOnly", 7>;

def Rock_GemmLoadTileType
: Rock_I32Enum<"GemmLoadTileType", "GEMM load tile types",
[Rock_GemmLoadTileBypassLDS, Rock_GemmLoadTileDefault,
Rock_GemmLoadTileDoubleBuffer,
Rock_GemmLoadTileDirectToLDSDefault,
Rock_GemmLoadTileDirectToLDSDoubleBuffer]> {
Rock_GemmLoadTileDirectToLDSDoubleBuffer,
Rock_GemmLoadTileGlobalReadOnly,
Rock_GemmLoadTileLDSWriteFromRegs,
Rock_GemmLoadTileLDSReadOnly]> {
let cppNamespace = "::mlir::rock";
let genSpecializedAttr = 0;
}
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Rock/IR/RockOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1621,6 +1621,9 @@ def Rock_BlockwiseLoadTileOp
- DoubleBuffer: Creates three stages, (1) load from memory, (2) write to LDS, (3) load to registers.
- DirectToLDSDefault: Same as Default, but a single stage loads from memory and writes to LDS.
- DirectToLDSDoubleBuffer: Same as DoubleBuffer, but a single stage loads from memory and writes to LDS.
- GlobalReadOnly: Loads from global memory into flat registers.
- LDSWriteFromRegs: Writes from flat registers into LDS.
- LDSReadOnly: Reads from LDS into registers.

`isA` determines if we are loading an A matrix or B matrix. `G`, `M` and `N` are the GEMM sizes.
`elementTypeA` and `elementTypeB` are used to construct AccelEmitter. They are data types for the Matrix A & B of the GEMMs.
Expand Down
5 changes: 4 additions & 1 deletion mlir/include/mlir/Dialect/Rock/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ def RockRegularizePass : Pass<"rock-regularize", "::mlir::func::FuncOp"> {

def RockGridwiseGemmToBlockwisePass : Pass<"rock-gridwise-gemm-to-blockwise", "::mlir::func::FuncOp"> {
let summary = "expand gridwise gemm into blockwise copy, blockwise gemm, and threadwise copy";
let dependentDialects = ["rock::RockDialect", "affine::AffineDialect", "gpu::GPUDialect", "vector::VectorDialect", "memref::MemRefDialect", "linalg::LinalgDialect", "scf::SCFDialect"];
let dependentDialects = ["rock::RockDialect", "affine::AffineDialect",
"gpu::GPUDialect", "vector::VectorDialect",
"memref::MemRefDialect", "linalg::LinalgDialect",
"scf::SCFDialect", "amdgpu::AMDGPUDialect"];
}

def RockLinalgAlignPass : Pass<"rock-linalg-align", "::mlir::func::FuncOp"> {
Expand Down
12 changes: 8 additions & 4 deletions mlir/lib/Dialect/Rock/IR/RockDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2482,13 +2482,17 @@ void BlockwiseLoadTileOp::getEffects(
loadType == GemmLoadTileType::DirectToLDSDoubleBuffer;
bool singleBuffer = loadType == GemmLoadTileType::Default ||
loadType == GemmLoadTileType::DirectToLDSDefault;
bool ldsReadOnly = loadType == GemmLoadTileType::LDSReadOnly;

effects.emplace_back(read, &getSourceMutable());
// LDSReadOnly does not read from global source.
if (!ldsReadOnly)
effects.emplace_back(read, &getSourceMutable());
if (loadType != GemmLoadTileType::BypassLDS) {
assert(getDestLDS() != nullptr);
Comment on lines +2487 to 2491
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BlockwiseLoadTileOp::getEffects doesn’t account for the new split-phase load types. For GemmLoadTileType::GlobalReadOnly the op does not write to destLDS, and for GemmLoadTileType::LDSWriteFromRegs the op should not read from source at all (it should read from destRegisters and write to destLDS). As written, MemoryEffects will incorrectly report global/LDS accesses, which can mislead scheduling and optimization passes that rely on effects. Please add explicit cases for GlobalReadOnly/LDSWriteFromRegs (and ensure LDSReadOnly remains LDS-read + regs-write only).

Copilot uses AI. Check for mistakes.
effects.emplace_back(write, &getDestLDSMutable()[0]);
// DoubleBuffer means we write to LDS and then, load from it
if (doubleBuffer)
// LDSReadOnly only reads from LDS, it does not write to it.
if (!ldsReadOnly)
effects.emplace_back(write, &getDestLDSMutable()[0]);
if (doubleBuffer || ldsReadOnly)
effects.emplace_back(read, &getDestLDSMutable()[0]);
}
if (!singleBuffer) {
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ void rock::buildKernelPipeline(OpPassManager &pm,
funcPm.addPass(rock::createRockThreadwiseGemmLoweringPass());
funcPm.addPass(rock::createRockAnalyzeMemoryUsePass());
funcPm.addPass(rock::createRockSugarToLoopsPass());
// Re-run the pipeline pass to remove back-to-back LDS barriers
// that may appear after SugarToLoops unrolls TransformingForOps.
funcPm.addPass(rock::createRockPipelinePass());
Comment on lines +243 to +245
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we add LDSBarriers in gridewiseToBlockwise conservatively thinking rock-pipeline will take care of it.
Is it possible to add barriers such that we don't need to run rock-pipeline pass again ?
Is it possible to enhance logic for backToBackBarriers ? Unrolling will create back to back if there is
case like this i think

LDSBarrier (1)
Scf.for {
LDSBarrier (2)
....
LDSBarrier (3)
}
LDSBarrier (4)

(1) this barrier, ay not be necessary if loop body starts with barrier and after exiting the loop there's a barrier

(2) For loop carried deps possibly

(3) Can be eliminated if there is a barrier at the exit of the loop

(4) Exit barrier

funcPm.addPass(rock::createRockCleanMathPass());
math::MathExtendToSupportedTypesOptions extendToLLVMTypesOptions;
extendToLLVMTypesOptions.extraTypeStrs = {"f16"};
Expand Down
145 changes: 88 additions & 57 deletions mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -928,8 +928,13 @@ struct BlockwiseReduceRewritePattern
}
} else {
if (rMethod == ReduceMethod::Sum) {
// Use -0.0 (negative zero) instead of +0.0. In IEEE 754, -0.0 is the
// true additive identity: fadd(-0.0, x) = x for ALL x (including -0.0
// and NaN). LLVM can fold `fadd -0.0, x → x`, eliminating the
// redundant `v_add_f32 v, 0, v` that +0.0 generates via
// llvm.vector.reduce.fadd.
Comment on lines +931 to +935
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this PR takes too long to merge, move these changes into a seperate PR and also create tests to make sure it doesn't generate v_add_f32 v, 0, v

Comment on lines +931 to +935
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new comment claims -0.0 is an additive identity “including NaN”, but IEEE-754 defines (-0.0) + NaN = NaN (same for +0.0). Consider adjusting the wording to avoid stating the identity property holds for NaNs; the optimization rationale about LLVM folding fadd -0.0, x -> x can still stand without that claim.

Suggested change
// Use -0.0 (negative zero) instead of +0.0. In IEEE 754, -0.0 is the
// true additive identity: fadd(-0.0, x) = x for ALL x (including -0.0
// and NaN). LLVM can fold `fadd -0.0, x → x`, eliminating the
// redundant `v_add_f32 v, 0, v` that +0.0 generates via
// llvm.vector.reduce.fadd.
// Use -0.0 (negative zero) instead of +0.0. LLVM can fold
// `fadd -0.0, x → x`, eliminating the redundant
// `v_add_f32 v, 0, v` that +0.0 generates via
// llvm.vector.reduce.fadd. (Note: IEEE 754 still propagates NaNs,
// i.e., x + NaN = NaN for any x.)

Copilot uses AI. Check for mistakes.
return createConstantFloatOp(rewriter, op.getLoc(), elementType,
elementType, 0.0);
elementType, -0.0f);
} else {
// Op verifier gurantees this.
assert(rMethod == ReduceMethod::Max);
Expand Down Expand Up @@ -964,7 +969,7 @@ struct BlockwiseReduceRewritePattern
kind = vector::CombiningKind::MAXNUMF;
}
}
input = vector::ReductionOp::create(builder, loc, kind, input);
return vector::ReductionOp::create(builder, loc, kind, input, acc);
}

if (rMethod == ReduceMethod::Sum) {
Expand Down Expand Up @@ -1410,67 +1415,93 @@ struct BlockwiseReduceRewritePattern
}
}

// This RAII scope would do the following :
// LDS[rtid] = reduce(LDS[rtid], LDS[rtid + offset])
// where offset is a power of 2.
// Initial it starts with power = ceil(|rtid|, power of 2) / 2
// Then keep on reducing the power.
// Branchless reduction: each thread reads all rTidDim partial
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also seems a like an independent change compared to scheduling VTile

// values from LDS and reduces locally in registers. This avoids
// creating conditional branches (scf.if) that split softmax into
// multiple basic blocks.
// Trade-off: every thread does rTidCount LDS reads (instead of
// log2(rTidCount) conditional reads in the tree reduction). For
// typical attention configs where rTidCount is small (e.g., 4),
// this is negligible overhead.
// TODO: We may have to use a heuristic to determine whether or not to
// use this depending on the size of rTidCount.
Comment on lines +1426 to +1427
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree this needs a heuristic. The branchless approach is O(N) LDS reads vs O(log N) in the tree, so for small rTidCount (2-4) it's clearly better, but for larger values (8, 16) the extra LDS reads may outweigh the branch elimination benefit.

Suggestion: benchmark both approaches for representative configs with rTidCount = 2, 4, 8, 16 on target architectures to find the empirical crossover point, then add a threshold and keep the old tree path as a fallback.

{
int64_t ceilPowerOf2 =
llvm::PowerOf2Ceil(threadViewShape[rTidDim]) / 2;
int64_t maxActiveReductionThreads = threadViewShape[rTidDim];
for (int64_t offset = ceilPowerOf2; offset >= 1;
offset = offset >> 1) {
Value offsetVal =
arith::ConstantIndexOp::create(rewriter, loc, offset);
Value rtidPlusOffsetVal =
arith::AddIOp::create(rewriter, loc, rtid, offsetVal);
Value maxActiveReductionThreadsVal = arith::ConstantIndexOp::create(
rewriter, loc, maxActiveReductionThreads);
maxActiveReductionThreads =
llvm::PowerOf2Ceil(maxActiveReductionThreads) >> 1;
Value isValid = arith::CmpIOp::create(
rewriter, loc, arith::CmpIPredicate::slt, rtidPlusOffsetVal,
maxActiveReductionThreadsVal);
scf::IfOp ifb = scf::IfOp::create(rewriter, loc, isValid,
/*withElseRegion=*/false);
int64_t rTidCount = threadViewShape[rTidDim];

// Accumulator for the full reduction.
auto accRegType = MemRefType::get({1}, elemType, AffineMap{},
privateMemoryAddressSpace);
Value accReg = GpuAllocOp::create(rewriter, loc, accRegType);
FillOp::create(rewriter, loc, accReg, initVal);
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initVal is used to initialize accReg in the branchless reduction path, but it is declared inside the preceding if (threadViewShape[rIterDim] > 1) block. As written, this won’t compile (and even conceptually, the branchless reduction should be able to run when rIterDim <= 1). Move the initVal definition outside the conditional (or recompute it in the branchless block), or remove the FillOp entirely since the i==0 iteration overwrites the accumulator.

Suggested change
FillOp::create(rewriter, loc, accReg, initVal);

Copilot uses AI. Check for mistakes.

// Read all rTidCount partial values from LDS and reduce.
// Every thread with the same nrtid computes the identical
// fully-reduced value.
for (int64_t i = 0; i < rTidCount; i++) {
Value iVal = arith::ConstantIndexOp::create(rewriter, loc, i);
SmallVector<Value, 3> readInits{nrtid, iVal, zeroConstantOp};
SmallVector<int64_t> bounds{1, 1, 1};
SmallVector<int64_t> strides{1, 1, 1};

TransformingForOp readLoop = TransformingForOp::create(
rewriter, loc, ArrayRef<ValueRange>{readInits},
ArrayRef<Attribute>{threadToLDSViewTrs},
ArrayRef<int64_t>(bounds), ArrayRef<int64_t>(strides),
/*forceUnroll=*/true, /*useIndexDiffs=*/true);
{
OpBuilder thenb = ifb.getThenBodyBuilder();
SmallVector<Value, 4> firstInits{nrtid, rtid, zeroConstantOp};
SmallVector<Value, 4> secondInits{nrtid, rtidPlusOffsetVal,
zeroConstantOp};
SmallVector<int64_t> bounds{1, 1, 1};
SmallVector<int64_t> strides{1, 1, 1};

TransformingForOp reductionLoop = TransformingForOp::create(
thenb, loc, ArrayRef<ValueRange>{firstInits, secondInits},
ArrayRef<Attribute>{threadToLDSViewTrs, threadToLDSViewTrs},
ArrayRef<int64_t>(bounds), ArrayRef<int64_t>(strides),
/*forceUnroll=*/true, /*useIndexDiffs=*/true);
{
PatternRewriter::InsertionGuard guard(thenb);
thenb.setInsertionPointToStart(reductionLoop.getBody());
Block::BlockArgListType firstLDSLoadCoords =
reductionLoop.getLowerCoords(/*domain=*/0);
Value firstLoadVal = InBoundsLoadOp::create(
thenb, loc, elemType, workspaceLDSBuffer,
firstLDSLoadCoords);
Block::BlockArgListType secondLDSLoadCoords =
reductionLoop.getLowerCoords(/*domain=*/1);
Value secondLoadVal = InBoundsLoadOp::create(
thenb, loc, elemType, workspaceLDSBuffer,
secondLDSLoadCoords);
Value reduced =
createReducingOp(op, firstLoadVal, secondLoadVal, thenb);
InBoundsStoreOp::create(thenb, loc, reduced, workspaceLDSBuffer,
firstLDSLoadCoords);
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(readLoop.getBody());
Block::BlockArgListType ldsCoords =
readLoop.getLowerCoords(/*domain=*/0);
Value ldVal = InBoundsLoadOp::create(
rewriter, loc, elemType, workspaceLDSBuffer, ldsCoords);
if (i == 0) {
// First iteration: store the loaded value directly to the
// accumulator. This avoids a redundant reduction with the
// identity element (e.g., `0.0 + x` for sum, `max(-inf, x)`
// for max).
InBoundsStoreOp::create(rewriter, loc, ldVal, accReg,
zeroConstantOp);
} else {
Value accVal = InBoundsLoadOp::create(rewriter, loc, elemType,
accReg, zeroConstantOp);
Value reduced = createReducingOp(op, ldVal, accVal, rewriter);
InBoundsStoreOp::create(rewriter, loc, reduced, accReg,
zeroConstantOp);
}
}
LDSBarrierOp::create(rewriter, loc);
}

// Write the fully reduced value back to LDS at [nrtid, 0].
// All threads with the same nrtid compute the same value,
// so concurrent writes to the same location are safe.
{
Value reducedVal = InBoundsLoadOp::create(rewriter, loc, elemType,
accReg, zeroConstantOp);
SmallVector<Value, 3> writeInits{nrtid, zeroConstantOp,
zeroConstantOp};
SmallVector<int64_t> writeBounds{1, 1, 1};
SmallVector<int64_t> writeStrides{1, 1, 1};

TransformingForOp writeLoop = TransformingForOp::create(
rewriter, loc, ArrayRef<ValueRange>{writeInits},
ArrayRef<Attribute>{threadToLDSViewTrs},
ArrayRef<int64_t>(writeBounds), ArrayRef<int64_t>(writeStrides),
/*forceUnroll=*/true, /*useIndexDiffs=*/true);
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(writeLoop.getBody());
Block::BlockArgListType ldsCoords =
writeLoop.getLowerCoords(/*domain=*/0);
InBoundsStoreOp::create(rewriter, loc, reducedVal,
workspaceLDSBuffer, ldsCoords);
}
}

LDSBarrierOp::create(rewriter, loc);
ArrayAttr reducedldsViewArrayAttr = createLDSWorkspaceView(
loc, rewriter, inputViewArrayAttr, axis, /*makeRDimZero-*/ true,
partialRegTensorShape[rDim]);
loc, rewriter, inputViewArrayAttr, axis,
/*makeRDimZero-*/ true, partialRegTensorShape[rDim]);
ThreadwiseReadIntoOp::create(rewriter, loc, workspaceLDSBuffer,
outputReg, reducedldsViewArrayAttr,
/*extraIndices=*/ValueRange{tid}, true,
Expand Down
Loading
Loading