@@ -49,6 +49,7 @@ namespace {
4949
5050// TODO: when stabilized, promote to a proper dialect attribute.
5151constexpr StringLiteral kSchedStageAttr = " sched.stage" ;
52+ constexpr StringLiteral kSchedRotateHeadAttr = " sched.rotate_head" ;
5253
5354// / Get the pipeline stage for an operation, defaulting to 0.
5455static int64_t getStage (Operation *op) {
@@ -560,6 +561,173 @@ static scf::ForOp emitKernel(scf::ForOp originalForOp,
560561 return kernelLoop;
561562}
562563
564+ // ===----------------------------------------------------------------------===//
565+ // Kernel rotation
566+ // ===----------------------------------------------------------------------===//
567+
568+ // / Rotate the kernel loop body so that the first op marked with
569+ // / `sched.rotate_head` appears first in the loop.
570+ // /
571+ // / The rotate_head attribute marks the first op of the "head" group.
572+ // / All ops from rotate_head to the yield are the head group; everything
573+ // / before is the rest group. Within the pipelined kernel body, cross-stage
574+ // / deps flow through iter_args, so the groups are SSA-independent.
575+ // /
576+ // / Returns the original loop unchanged if no rotate_head is found.
577+
578+ // / Find the rotation head and partition the loop body into rest (before)
579+ // / and rotate (from head onward) groups. Returns false if no rotation
580+ // / head is found or if either group is empty.
581+ static bool partitionAtRotateHead (Block *body, scf::YieldOp yield,
582+ SmallVectorImpl<Operation *> &restOps,
583+ SmallVectorImpl<Operation *> &rotateOps,
584+ DenseSet<Operation *> &rotateSet) {
585+ Operation *headOp = nullptr ;
586+ for (Operation &op : *body) {
587+ if (op.hasAttr (kSchedRotateHeadAttr )) {
588+ headOp = &op;
589+ break ;
590+ }
591+ }
592+ if (!headOp)
593+ return false ;
594+
595+ bool seenHead = false ;
596+ for (Operation &op : *body) {
597+ if (&op == yield)
598+ break ;
599+ if (&op == headOp)
600+ seenHead = true ;
601+ if (seenHead) {
602+ rotateOps.push_back (&op);
603+ rotateSet.insert (&op);
604+ } else {
605+ restOps.push_back (&op);
606+ }
607+ }
608+ return !restOps.empty () && !rotateOps.empty ();
609+ }
610+
611+ // / Build an IRMapping from an old loop's IV and region args to new values.
612+ static IRMapping buildLoopMapping (scf::ForOp oldLoop, Value newIV,
613+ ValueRange newRegionArgs) {
614+ IRMapping map;
615+ map.map (oldLoop.getInductionVar (), newIV);
616+ for (auto [oldArg, newArg] :
617+ llvm::zip (oldLoop.getRegionIterArgs (), newRegionArgs))
618+ map.map (oldArg, newArg);
619+ return map;
620+ }
621+
622+ // / Clone a list of ops using a mapping.
623+ static void cloneOpsWithMapping (OpBuilder &builder, ArrayRef<Operation *> ops,
624+ IRMapping &mapping) {
625+ for (Operation *op : ops)
626+ builder.clone (*op, mapping);
627+ }
628+
629+ // / Collect values defined in restOps and used by rotateOps (crossing values).
630+ static SmallVector<Value>
631+ findCrossingValues (ArrayRef<Operation *> rotateOps,
632+ const DenseSet<Operation *> &rotateSet, Block *body) {
633+ SmallVector<Value> result;
634+ DenseSet<Value> seen;
635+ for (Operation *op : rotateOps) {
636+ for (Value operand : op->getOperands ()) {
637+ auto *defOp = operand.getDefiningOp ();
638+ if (!defOp || defOp->getBlock () != body || rotateSet.contains (defOp))
639+ continue ;
640+ if (seen.insert (operand).second )
641+ result.push_back (operand);
642+ }
643+ }
644+ return result;
645+ }
646+
647+ static scf::ForOp rotateKernelBody (scf::ForOp kernelLoop, OpBuilder &builder) {
648+ Block *body = kernelLoop.getBody ();
649+ auto yield = cast<scf::YieldOp>(body->getTerminator ());
650+ Location loc = kernelLoop.getLoc ();
651+
652+ SmallVector<Operation *> restOps, rotateOps;
653+ DenseSet<Operation *> rotateSet;
654+ if (!partitionAtRotateHead (body, yield, restOps, rotateOps, rotateSet))
655+ return kernelLoop;
656+
657+ Value lb = kernelLoop.getLowerBound ();
658+ Value ub = kernelLoop.getUpperBound ();
659+ Value step = kernelLoop.getStep ();
660+ unsigned numOrig = kernelLoop.getInitArgs ().size ();
661+
662+ // Step 1: Peeled prologue: clone rest ops with IV = lb.
663+ builder.setInsertionPoint (kernelLoop);
664+ auto prologueMap = buildLoopMapping (kernelLoop, lb, kernelLoop.getInitArgs ());
665+ cloneOpsWithMapping (builder, restOps, prologueMap);
666+
667+ // Crossing values: rest-defined values consumed by rotate ops.
668+ auto crossingValues = findCrossingValues (rotateOps, rotateSet, body);
669+
670+ // New init args: original + crossing values from prologue.
671+ SmallVector<Value> newInits (kernelLoop.getInitArgs ());
672+ for (Value cv : crossingValues)
673+ newInits.push_back (prologueMap.lookupOrDefault (cv));
674+
675+ // Step 2: Rotated kernel: [lb, ub - step).
676+ Value ubMinusStep = arith::SubIOp::create (builder, loc, ub, step);
677+ auto newLoop =
678+ scf::ForOp::create (builder, loc, lb, ubMinusStep, step, newInits);
679+ builder.setInsertionPointToStart (newLoop.getBody ());
680+
681+ // Map old block args -> new, including crossing iter_args.
682+ auto rotMap =
683+ buildLoopMapping (kernelLoop, newLoop.getInductionVar (),
684+ newLoop.getRegionIterArgs ().take_front (numOrig));
685+ for (auto [cv, newArg] : llvm::zip (
686+ crossingValues, newLoop.getRegionIterArgs ().drop_front (numOrig)))
687+ rotMap.map (cv, newArg);
688+
689+ // Clone rotate ops first, then rest ops with shifted IV.
690+ cloneOpsWithMapping (builder, rotateOps, rotMap);
691+ Value kNext =
692+ arith::AddIOp::create (builder, loc, newLoop.getInductionVar (), step);
693+ IRMapping restMap (rotMap);
694+ restMap.map (kernelLoop.getInductionVar (), kNext );
695+ cloneOpsWithMapping (builder, restOps, restMap);
696+
697+ // Yield: pick from rotMap or restMap depending on which group defined it.
698+ SmallVector<Value> yieldVals;
699+ for (Value yv : yield.getOperands ()) {
700+ auto *defOp = yv.getDefiningOp ();
701+ yieldVals.push_back ((defOp && rotateSet.contains (defOp))
702+ ? rotMap.lookupOrDefault (yv)
703+ : restMap.lookupOrDefault (yv));
704+ }
705+ for (Value cv : crossingValues)
706+ yieldVals.push_back (restMap.lookupOrDefault (cv));
707+ scf::YieldOp::create (builder, loc, yieldVals);
708+
709+ // Step 3: Peeled epilogue: clone rotate ops using final loop results.
710+ builder.setInsertionPointAfter (newLoop);
711+ IRMapping epilogueMap;
712+ epilogueMap.map (kernelLoop.getInductionVar (), ubMinusStep);
713+ for (unsigned i = 0 ; i < numOrig; ++i)
714+ epilogueMap.map (kernelLoop.getRegionIterArgs ()[i], newLoop.getResult (i));
715+ for (auto [i, cv] : llvm::enumerate (crossingValues))
716+ epilogueMap.map (cv, newLoop.getResult (numOrig + i));
717+ cloneOpsWithMapping (builder, rotateOps, epilogueMap);
718+
719+ // Replace uses: rotate-defined results come from epilogue, rest from loop.
720+ for (unsigned i = 0 ; i < numOrig; ++i) {
721+ auto *defOp = yield.getOperand (i).getDefiningOp ();
722+ Value replacement = (defOp && rotateSet.contains (defOp))
723+ ? epilogueMap.lookupOrDefault (yield.getOperand (i))
724+ : newLoop.getResult (i);
725+ kernelLoop.getResult (i).replaceAllUsesWith (replacement);
726+ }
727+ kernelLoop.erase ();
728+ return newLoop;
729+ }
730+
563731// ===----------------------------------------------------------------------===//
564732// Epilogue
565733// ===----------------------------------------------------------------------===//
@@ -726,6 +894,13 @@ void SCFPipelineAsterSchedPass::runOnOperation() {
726894
727895 originalForOp.erase ();
728896
897+ // Step 5: (optional): Rotate the kernel body so the op marked
898+ // with sched.rotate_head fires first. Only triggers if the
899+ // attribute is present in the kernel body.
900+ if (rotateKernel) {
901+ kernelLoop = rotateKernelBody (kernelLoop, builder);
902+ }
903+
729904 if (lcmUnroll) {
730905 int64_t factor = computeStageLCM (info) *
731906 std::max<int64_t >(unrollFactorMultiplier, 1 );
0 commit comments