@@ -18,18 +18,21 @@ limitations under the License.
1818
1919#include " llvm/ADT/DenseSet.h"
2020#include " llvm/ADT/MapVector.h"
21+ #include " llvm/ADT/STLExtras.h"
2122#include " llvm/ADT/SmallVector.h"
2223#include " mlir/Dialect/Func/IR/FuncOps.h"
2324#include " mlir/IR/Builders.h"
2425#include " mlir/IR/BuiltinAttributes.h"
2526#include " mlir/IR/MLIRContext.h"
27+ #include " mlir/IR/PatternMatch.h"
2628#include " mlir/IR/Types.h"
2729#include " mlir/IR/Value.h"
2830#include " mlir/Support/LLVM.h"
2931#include " shardy/common/logging.h"
3032#include " shardy/dialect/mpmd/ir/dialect.h"
3133#include " shardy/dialect/mpmd/ir/utils.h"
3234#include " shardy/dialect/mpmd/transforms/common/passes.h" // IWYU pragma: keep
35+ #include " shardy/dialect/mpmd/transforms/common/utils.h"
3336#include " shardy/dialect/sdy/ir/dialect.h"
3437
3538namespace mlir ::mpmd {
@@ -41,17 +44,111 @@ namespace {
4144
4245using ValueToReturnIndices = llvm::MapVector<Value, SmallVector<int64_t >>;
4346
47+ bool CanMoveAfter (Operation* op_to_move, Operation* target_op) {
48+ if (op_to_move->getBlock () != target_op->getBlock ()) return false ;
49+ if (!op_to_move->isBeforeInBlock (target_op)) return false ;
50+
51+ Operation* current = op_to_move->getNextNode ();
52+ while (current) {
53+ for (Value result : op_to_move->getResults ()) {
54+ if (llvm::is_contained (current->getOperands (), result)) {
55+ return false ;
56+ }
57+ }
58+ for (Value operand : op_to_move->getOperands ()) {
59+ if (operand.getDefiningOp () == current) {
60+ return false ;
61+ }
62+ }
63+
64+ if (current == target_op) break ;
65+ current = current->getNextNode ();
66+ }
67+ return true ;
68+ }
69+
70+ // Tries to merge the newly created inferred fragment into an existing
71+ // same-mesh fragment in the block.
72+ void MergeInferredFragmentWithExisting (FragmentOp fragment_op,
73+ StringRef mesh_name,
74+ Operation* return_op,
75+ OpBuilder& builder) {
76+ // Try to find an existing same-mesh fragment to merge the newly created
77+ // inferred fragment into. We track two things:
78+ // - latest_operand_producer: the latest op that produces any operand of
79+ // the inferred fragment (needed for positioning constraints).
80+ // - merge_target: the latest same-mesh fragment we can merge into.
81+ FragmentOp merge_target = nullptr ;
82+ Operation* latest_operand_producer = nullptr ;
83+
84+ // Updates merge_target if `op` is a same-mesh fragment that appears later
85+ // in the block than the current candidate.
86+ auto updateMergeTarget = [&](Operation* op) {
87+ auto frag = dyn_cast<FragmentOp>(op);
88+ if (frag && frag.getMeshName () == mesh_name &&
89+ (!merge_target || merge_target->isBeforeInBlock (frag))) {
90+ merge_target = frag;
91+ }
92+ };
93+
94+ // First, scan the operand producers for a merge candidate.
95+ for (Value v : fragment_op.getOperands ()) {
96+ Operation* op = v.getDefiningOp ();
97+ if (!op) continue ;
98+ if (!latest_operand_producer ||
99+ latest_operand_producer->isBeforeInBlock (op)) {
100+ latest_operand_producer = op;
101+ }
102+ updateMergeTarget (op);
103+ }
104+
105+ // If no producer fragment on the same mesh was found among the operands,
106+ // look for any fragment on the same mesh in the block (sideways merge).
107+ // Only do this when there are actual producer ops (not just block arguments).
108+ if (!merge_target && latest_operand_producer) {
109+ for (Operation& op : *return_op->getBlock ()) {
110+ if (&op == return_op || &op == fragment_op) continue ;
111+ updateMergeTarget (&op);
112+ }
113+ }
114+
115+ // Give up if no merge candidate was found, or if the candidate can't be
116+ // moved after the latest operand producer (which would break dominance).
117+ if (!merge_target || (merge_target != latest_operand_producer &&
118+ !CanMoveAfter (merge_target, latest_operand_producer))) {
119+ return ;
120+ }
121+
122+ // Position the merge target after the latest operand producer so all
123+ // operands of the inferred fragment are available, then merge.
124+ if (merge_target != latest_operand_producer) {
125+ merge_target->moveAfter (latest_operand_producer);
126+ }
127+
128+ fragment_op->moveAfter (merge_target);
129+ IRRewriter rewriter (builder.getContext ());
130+ FragmentOp merged_fragment = MergeRegionOps (
131+ merge_target, fragment_op, rewriter,
132+ /* num_static_args=*/ 0 , /* replace_producer_use_in_consumer_block=*/
133+ [](OpOperand&, Value) {
134+ SDY_CHECK (false ) << " Fragment ops shouldn't have free variables" ;
135+ },
136+ GetFragmentOriginUnion (merge_target, fragment_op, rewriter),
137+ merge_target.getMeshNameAttr (),
138+ /* stage_id=*/ merge_target.getStageIdAttr ());
139+ SetInferredByAttr (merged_fragment, " uniquify" , builder);
140+ }
141+
44142void CreateReturnFragmentForMesh (StringRef mesh_name, Operation* return_op,
45143 ValueToReturnIndices& value_to_return_indices,
46144 OpBuilder& builder) {
47145 // We remove any entries that require no work, in order to avoid too many
48146 // checks.
49147 value_to_return_indices.remove_if ([](const auto & it) {
50148 if (it.second .size () == 1 ) {
51- Value v = it.first ;
52- return !isa<BlockArgument>(v);
149+ return !isa<BlockArgument>(it.first );
53150 }
54- return it. second . empty () ;
151+ return false ;
55152 });
56153
57154 if (value_to_return_indices.empty ()) {
@@ -98,6 +195,8 @@ void CreateReturnFragmentForMesh(StringRef mesh_name, Operation* return_op,
98195 }
99196 auto block_builder = OpBuilder::atBlockEnd (&fragment_block);
100197 ReturnOp::create (block_builder, loc, returned_values);
198+
199+ MergeInferredFragmentWithExisting (fragment_op, mesh_name, return_op, builder);
101200}
102201
103202// Replaces the return values of the function with transfer ops.
0 commit comments