Skip to content

Commit eef10e4

Browse files
Add rotation support to scf.pipeliner and retire dedicated pass
The dedicated pass creates pass dependencies at a distance though attributes, which is brittle. Signed-off-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
1 parent 3c2b466 commit eef10e4

6 files changed

Lines changed: 248 additions & 328 deletions

File tree

include/aster/Transforms/Passes.td

Lines changed: 2 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ def SCFPipelineAsterSched : Pass<"aster-scf-pipeline"> {
232232
"Multiplier for the LCM unroll factor (effective = lcm * multiplier)">,
233233
Option<"epiloguePeeling", "epilogue-peeling", "bool", /*default=*/"false",
234234
"Fully unroll the cleanup epilogue loop after LCM unrolling">,
235+
Option<"rotateKernel", "rotate-kernel", "bool", /*default=*/"true",
236+
"Peel the kernel loop at the first sched.rotate_head">,
235237
];
236238
let dependentDialects = [
237239
"mlir::affine::AffineDialect",
@@ -241,52 +243,6 @@ def SCFPipelineAsterSched : Pass<"aster-scf-pipeline"> {
241243
];
242244
}
243245

244-
//===----------------------------------------------------------------------===//
245-
// SCFRotate
246-
//===----------------------------------------------------------------------===//
247-
248-
def SCFRotate : Pass<"aster-scf-rotate"> {
249-
let summary = "Rotate loop body around the most dominating insertion point";
250-
let description = [{
251-
This pass moves operations marked with sched.rotate_head -- together
252-
with their transitive in-block dependencies -- to the front of scf.for
253-
loop bodies, then strips the attribute.
254-
255-
The "most dominating insertion point" is the earliest transitive
256-
dependency of any rotate_head op. After rotation, that dependency
257-
becomes the first op in the loop body, followed by the rest of the
258-
dependency subgraph and the rotate_head ops, all preserving their
259-
original relative order.
260-
261-
Designed to run after aster-scf-pipeline. The input IR places
262-
sched.rotate_head on ops that should execute first in the loop body
263-
(e.g., MFMA to hide latency).
264-
265-
Example (deps already iter_args):
266-
Before: for k iter_args(C, ds_carry) {
267-
alloc; load; ds_write; barrier; ds_read;
268-
MFMA(C, ds_carry) {sched.rotate_head};
269-
}
270-
After: for k iter_args(C, ds_carry) {
271-
MFMA(C, ds_carry);
272-
alloc; load; ds_write; barrier; ds_read;
273-
}
274-
275-
Example (in-block dep pulled along):
276-
Before: for k iter_args(C) {
277-
other_work;
278-
ds_read -> val;
279-
MFMA(C, val) {sched.rotate_head};
280-
}
281-
After: for k iter_args(C) {
282-
ds_read -> val;
283-
MFMA(C, val);
284-
other_work;
285-
}
286-
}];
287-
let dependentDialects = [];
288-
}
289-
290246
//===----------------------------------------------------------------------===//
291247
// SimplifyAllocaIterArgs
292248
//===----------------------------------------------------------------------===//

lib/Transforms/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ add_mlir_library(AsterTransforms
1515
OptimizeArith.cpp
1616
ReplaceConstantGPUDims.cpp
1717
SCFPipelineAsterSched.cpp
18-
SCFRotate.cpp
1918
SimplifyAllocaIterArgs.cpp
2019
ToIntArith.cpp
2120
Utils.cpp

lib/Transforms/SCFPipelineAsterSched.cpp

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ namespace {
4949

5050
// TODO: when stabilized, promote to a proper dialect attribute.
5151
constexpr StringLiteral kSchedStageAttr = "sched.stage";
52+
constexpr StringLiteral kSchedRotateHeadAttr = "sched.rotate_head";
5253

5354
/// Get the pipeline stage for an operation, defaulting to 0.
5455
static int64_t getStage(Operation *op) {
@@ -560,6 +561,173 @@ static scf::ForOp emitKernel(scf::ForOp originalForOp,
560561
return kernelLoop;
561562
}
562563

564+
//===----------------------------------------------------------------------===//
565+
// Kernel rotation
566+
//===----------------------------------------------------------------------===//
567+
568+
/// Rotate the kernel loop body so that the first op marked with
569+
/// `sched.rotate_head` appears first in the loop.
570+
///
571+
/// The rotate_head attribute marks the first op of the "head" group.
572+
/// All ops from rotate_head to the yield are the head group; everything
573+
/// before is the rest group. Within the pipelined kernel body, cross-stage
574+
/// deps flow through iter_args, so the groups are SSA-independent.
575+
///
576+
/// Returns the original loop unchanged if no rotate_head is found.
577+
578+
/// Find the rotation head and partition the loop body into rest (before)
579+
/// and rotate (from head onward) groups. Returns false if no rotation
580+
/// head is found or if either group is empty.
581+
static bool partitionAtRotateHead(Block *body, scf::YieldOp yield,
582+
SmallVectorImpl<Operation *> &restOps,
583+
SmallVectorImpl<Operation *> &rotateOps,
584+
DenseSet<Operation *> &rotateSet) {
585+
Operation *headOp = nullptr;
586+
for (Operation &op : *body) {
587+
if (op.hasAttr(kSchedRotateHeadAttr)) {
588+
headOp = &op;
589+
break;
590+
}
591+
}
592+
if (!headOp)
593+
return false;
594+
595+
bool seenHead = false;
596+
for (Operation &op : *body) {
597+
if (&op == yield)
598+
break;
599+
if (&op == headOp)
600+
seenHead = true;
601+
if (seenHead) {
602+
rotateOps.push_back(&op);
603+
rotateSet.insert(&op);
604+
} else {
605+
restOps.push_back(&op);
606+
}
607+
}
608+
return !restOps.empty() && !rotateOps.empty();
609+
}
610+
611+
/// Build an IRMapping from an old loop's IV and region args to new values.
612+
static IRMapping buildLoopMapping(scf::ForOp oldLoop, Value newIV,
613+
ValueRange newRegionArgs) {
614+
IRMapping map;
615+
map.map(oldLoop.getInductionVar(), newIV);
616+
for (auto [oldArg, newArg] :
617+
llvm::zip(oldLoop.getRegionIterArgs(), newRegionArgs))
618+
map.map(oldArg, newArg);
619+
return map;
620+
}
621+
622+
/// Clone a list of ops using a mapping.
623+
static void cloneOpsWithMapping(OpBuilder &builder, ArrayRef<Operation *> ops,
624+
IRMapping &mapping) {
625+
for (Operation *op : ops)
626+
builder.clone(*op, mapping);
627+
}
628+
629+
/// Collect values defined in restOps and used by rotateOps (crossing values).
630+
static SmallVector<Value>
631+
findCrossingValues(ArrayRef<Operation *> rotateOps,
632+
const DenseSet<Operation *> &rotateSet, Block *body) {
633+
SmallVector<Value> result;
634+
DenseSet<Value> seen;
635+
for (Operation *op : rotateOps) {
636+
for (Value operand : op->getOperands()) {
637+
auto *defOp = operand.getDefiningOp();
638+
if (!defOp || defOp->getBlock() != body || rotateSet.contains(defOp))
639+
continue;
640+
if (seen.insert(operand).second)
641+
result.push_back(operand);
642+
}
643+
}
644+
return result;
645+
}
646+
647+
static scf::ForOp rotateKernelBody(scf::ForOp kernelLoop, OpBuilder &builder) {
648+
Block *body = kernelLoop.getBody();
649+
auto yield = cast<scf::YieldOp>(body->getTerminator());
650+
Location loc = kernelLoop.getLoc();
651+
652+
SmallVector<Operation *> restOps, rotateOps;
653+
DenseSet<Operation *> rotateSet;
654+
if (!partitionAtRotateHead(body, yield, restOps, rotateOps, rotateSet))
655+
return kernelLoop;
656+
657+
Value lb = kernelLoop.getLowerBound();
658+
Value ub = kernelLoop.getUpperBound();
659+
Value step = kernelLoop.getStep();
660+
unsigned numOrig = kernelLoop.getInitArgs().size();
661+
662+
// Step 1: Peeled prologue: clone rest ops with IV = lb.
663+
builder.setInsertionPoint(kernelLoop);
664+
auto prologueMap = buildLoopMapping(kernelLoop, lb, kernelLoop.getInitArgs());
665+
cloneOpsWithMapping(builder, restOps, prologueMap);
666+
667+
// Crossing values: rest-defined values consumed by rotate ops.
668+
auto crossingValues = findCrossingValues(rotateOps, rotateSet, body);
669+
670+
// New init args: original + crossing values from prologue.
671+
SmallVector<Value> newInits(kernelLoop.getInitArgs());
672+
for (Value cv : crossingValues)
673+
newInits.push_back(prologueMap.lookupOrDefault(cv));
674+
675+
// Step 2: Rotated kernel: [lb, ub - step).
676+
Value ubMinusStep = arith::SubIOp::create(builder, loc, ub, step);
677+
auto newLoop =
678+
scf::ForOp::create(builder, loc, lb, ubMinusStep, step, newInits);
679+
builder.setInsertionPointToStart(newLoop.getBody());
680+
681+
// Map old block args -> new, including crossing iter_args.
682+
auto rotMap =
683+
buildLoopMapping(kernelLoop, newLoop.getInductionVar(),
684+
newLoop.getRegionIterArgs().take_front(numOrig));
685+
for (auto [cv, newArg] : llvm::zip(
686+
crossingValues, newLoop.getRegionIterArgs().drop_front(numOrig)))
687+
rotMap.map(cv, newArg);
688+
689+
// Clone rotate ops first, then rest ops with shifted IV.
690+
cloneOpsWithMapping(builder, rotateOps, rotMap);
691+
Value kNext =
692+
arith::AddIOp::create(builder, loc, newLoop.getInductionVar(), step);
693+
IRMapping restMap(rotMap);
694+
restMap.map(kernelLoop.getInductionVar(), kNext);
695+
cloneOpsWithMapping(builder, restOps, restMap);
696+
697+
// Yield: pick from rotMap or restMap depending on which group defined it.
698+
SmallVector<Value> yieldVals;
699+
for (Value yv : yield.getOperands()) {
700+
auto *defOp = yv.getDefiningOp();
701+
yieldVals.push_back((defOp && rotateSet.contains(defOp))
702+
? rotMap.lookupOrDefault(yv)
703+
: restMap.lookupOrDefault(yv));
704+
}
705+
for (Value cv : crossingValues)
706+
yieldVals.push_back(restMap.lookupOrDefault(cv));
707+
scf::YieldOp::create(builder, loc, yieldVals);
708+
709+
// Step 3: Peeled epilogue: clone rotate ops using final loop results.
710+
builder.setInsertionPointAfter(newLoop);
711+
IRMapping epilogueMap;
712+
epilogueMap.map(kernelLoop.getInductionVar(), ubMinusStep);
713+
for (unsigned i = 0; i < numOrig; ++i)
714+
epilogueMap.map(kernelLoop.getRegionIterArgs()[i], newLoop.getResult(i));
715+
for (auto [i, cv] : llvm::enumerate(crossingValues))
716+
epilogueMap.map(cv, newLoop.getResult(numOrig + i));
717+
cloneOpsWithMapping(builder, rotateOps, epilogueMap);
718+
719+
// Replace uses: rotate-defined results come from epilogue, rest from loop.
720+
for (unsigned i = 0; i < numOrig; ++i) {
721+
auto *defOp = yield.getOperand(i).getDefiningOp();
722+
Value replacement = (defOp && rotateSet.contains(defOp))
723+
? epilogueMap.lookupOrDefault(yield.getOperand(i))
724+
: newLoop.getResult(i);
725+
kernelLoop.getResult(i).replaceAllUsesWith(replacement);
726+
}
727+
kernelLoop.erase();
728+
return newLoop;
729+
}
730+
563731
//===----------------------------------------------------------------------===//
564732
// Epilogue
565733
//===----------------------------------------------------------------------===//
@@ -726,6 +894,13 @@ void SCFPipelineAsterSchedPass::runOnOperation() {
726894

727895
originalForOp.erase();
728896

897+
// Step 5: (optional): Rotate the kernel body so the op marked
898+
// with sched.rotate_head fires first. Only triggers if the
899+
// attribute is present in the kernel body.
900+
if (rotateKernel) {
901+
kernelLoop = rotateKernelBody(kernelLoop, builder);
902+
}
903+
729904
if (lcmUnroll) {
730905
int64_t factor = computeStageLCM(info) *
731906
std::max<int64_t>(unrollFactorMultiplier, 1);

lib/Transforms/SCFRotate.cpp

Lines changed: 0 additions & 107 deletions
This file was deleted.

0 commit comments

Comments
 (0)