Skip to content

Commit f40dce3

Browse files
petebucopybara-github
authored andcommitted
[mpmd] Merge inferred fragments inline in UniquifyFunctionInputsOutputsPass.
Instead of running a separate MergeInferredFragmentsPass after uniquify, merge each newly created inferred fragment into an existing same-mesh fragment directly within the pass. Removes the separate pass from the export pipeline. PiperOrigin-RevId: 908184894
1 parent b6cf4a1 commit f40dce3

5 files changed

Lines changed: 152 additions & 27 deletions

File tree

shardy/dialect/mpmd/transforms/common/passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,10 @@ def UniquifyFunctionInputsOutputsPass :
450450
Similarly, if a function returns a block argument, this pass creates an
451451
identity fragment for that block argument, guaranteeing that values are
452452
passed by value to the function, not by reference.
453+
454+
Additionally, when not using transfers, the pass will attempt to merge
455+
each newly created inferred fragment into an existing same-mesh fragment
456+
to reduce the total number of fragments.
453457
}];
454458

455459
let options = [

shardy/dialect/mpmd/transforms/common/test/uniquify_function_inputs_outputs_with_reshard.mlir

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,9 @@ func.func @no_work_needed(%arg0: !mesh_1_tensor, %arg1: !mesh_2_tensor) -> (!mes
2626
func.func @single_mesh_one_return_operand(%arg0: !mesh_1_tensor) -> (!mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor) attributes {
2727
"topology"=#mpmd.topology<<"m1": <["x"=2]>>>
2828
} {
29-
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
30-
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m1", origin=["f2"]>
31-
// CHECK: %[[UF:.*]]:2 = mpmd.fragment<mesh="m1", origin=[]> (%[[F1]]) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xf32>) {
32-
// CHECK: mpmd.return %arg1, %arg1 : tensor<4xf32>, tensor<4xf32>
33-
// CHECK: %[[F2]], %[[UF]]#0, %[[UF]]#1
29+
// CHECK-NEXT: %[[F1:.*]]:3 = mpmd.fragment<mesh="m1", origin=["f1"]>
30+
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m1", origin=["f2"]> (%[[F1]]#0)
31+
// CHECK: return %[[F2]], %[[F1]]#1, %[[F1]]#2
3432
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg1: tensor<4xf32>) {
3533
%1 = stablehlo.add %arg1, %arg1 : tensor<4xf32>
3634
mpmd.return %1 : tensor<4xf32>
@@ -48,11 +46,8 @@ func.func @needs_fragment_for_m1_with_many_values(%arg0: !mesh_1_tensor, %arg1:
4846
} {
4947
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
5048
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m2", origin=["f2"]>
51-
// CHECK: %[[F3:.*]] = mpmd.fragment<mesh="m1", origin=["f3"]>
52-
// CHECK: %[[UF:.*]]:5 = mpmd.fragment<mesh="m1", origin=[]> (%[[F1]], %[[F3]]) {mpmd.inferred_by = ["uniquify"]} (%[[A1:.*]]: tensor<4xf32>, %[[A2:.*]]: tensor<4xf32>)
53-
// CHECK-NEXT: mpmd.return %[[A1]], %[[A1]], %[[A2]], %[[A2]], %[[A2]]
54-
// CHECK-NEXT: }
55-
// CHECK-NEXT: return %[[F2]], %[[UF]]#0, %[[UF]]#2, %[[UF]]#1, %[[UF]]#3, %[[UF]]#4
49+
// CHECK: %[[F3:.*]]:5 = mpmd.fragment<mesh="m1", origin=["f3"]> (%[[F1]], %arg0)
50+
// CHECK: return %[[F2]], %[[F3]]#0, %[[F3]]#2, %[[F3]]#1, %[[F3]]#3, %[[F3]]#4
5651
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg2: tensor<4xf32>) {
5752
mpmd.return %arg2 : tensor<4xf32>
5853
} : (!mesh_1_tensor) -> !mesh_1_tensor
@@ -70,9 +65,10 @@ func.func @needs_fragment_for_m1_and_m2(%arg0: !mesh_1_tensor, %arg1: !mesh_2_te
7065
) -> (!mesh_1_tensor, !mesh_2_tensor, !mesh_2_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor) attributes {
7166
"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["x"=2]>>>
7267
} {
73-
// CHECK: %[[UF1:.*]]:4 = mpmd.fragment<mesh="m1", origin=[]> ({{.*}}) {mpmd.inferred_by = ["uniquify"]}
74-
// CHECK: %[[UF2:.*]]:2 = mpmd.fragment<mesh="m2", origin=[]> ({{.*}}) {mpmd.inferred_by = ["uniquify"]}
75-
// CHECK: return %[[UF1]]#0, %[[UF2]]#0, %[[UF2]]#1, %[[UF1]]#2, %[[UF1]]#1, %[[UF1]]#3
68+
// CHECK: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
69+
// CHECK: %[[F2:.*]]:2 = mpmd.fragment<mesh="m2", origin=["f2"]>
70+
// CHECK: %[[F3:.*]]:4 = mpmd.fragment<mesh="m1", origin=["f3"]> (%[[F1]], %arg0)
71+
// CHECK: return %[[F3]]#0, %[[F2]]#0, %[[F2]]#1, %[[F3]]#2, %[[F3]]#1, %[[F3]]#3
7672
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg2: tensor<4xf32>) {
7773
mpmd.return %arg2 : tensor<4xf32>
7874
} : (!mesh_1_tensor) -> !mesh_1_tensor
@@ -95,11 +91,9 @@ module {
9591
func.func @single_mesh_one_return_operand_with_global_view(%arg0: !dist_mesh_tensor) -> (!dist_mesh_tensor, !dist_mesh_tensor, !dist_mesh_tensor) attributes {
9692
"topology"=#mpmd.topology<<"m1": <["x"=2]>>>
9793
} {
98-
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
99-
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m1", origin=["f2"]>
100-
// CHECK: %[[UF:.*]]:2 = mpmd.fragment<mesh="m1", origin=[]> (%[[F1]]) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xf32>) {
101-
// CHECK: mpmd.return %arg1, %arg1 : tensor<4xf32>, tensor<4xf32>
102-
// CHECK: %[[F2]], %[[UF]]#0, %[[UF]]#1
94+
// CHECK-NEXT: %[[F1:.*]]:3 = mpmd.fragment<mesh="m1", origin=["f1"]>
95+
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m1", origin=["f2"]> (%[[F1]]#0)
96+
// CHECK: return %[[F2]], %[[F1]]#1, %[[F1]]#2
10397
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg1: tensor<4xf32>) {
10498
%1 = stablehlo.add %arg1, %arg1 : tensor<4xf32>
10599
mpmd.return %1 : tensor<4xf32>

shardy/dialect/mpmd/transforms/common/uniquify_function_inputs_outputs.cc

Lines changed: 102 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,21 @@ limitations under the License.
1818

1919
#include "llvm/ADT/DenseSet.h"
2020
#include "llvm/ADT/MapVector.h"
21+
#include "llvm/ADT/STLExtras.h"
2122
#include "llvm/ADT/SmallVector.h"
2223
#include "mlir/Dialect/Func/IR/FuncOps.h"
2324
#include "mlir/IR/Builders.h"
2425
#include "mlir/IR/BuiltinAttributes.h"
2526
#include "mlir/IR/MLIRContext.h"
27+
#include "mlir/IR/PatternMatch.h"
2628
#include "mlir/IR/Types.h"
2729
#include "mlir/IR/Value.h"
2830
#include "mlir/Support/LLVM.h"
2931
#include "shardy/common/logging.h"
3032
#include "shardy/dialect/mpmd/ir/dialect.h"
3133
#include "shardy/dialect/mpmd/ir/utils.h"
3234
#include "shardy/dialect/mpmd/transforms/common/passes.h" // IWYU pragma: keep
35+
#include "shardy/dialect/mpmd/transforms/common/utils.h"
3336
#include "shardy/dialect/sdy/ir/dialect.h"
3437

3538
namespace mlir::mpmd {
@@ -41,17 +44,111 @@ namespace {
4144

4245
using ValueToReturnIndices = llvm::MapVector<Value, SmallVector<int64_t>>;
4346

47+
bool CanMoveAfter(Operation* op_to_move, Operation* target_op) {
48+
if (op_to_move->getBlock() != target_op->getBlock()) return false;
49+
if (!op_to_move->isBeforeInBlock(target_op)) return false;
50+
51+
Operation* current = op_to_move->getNextNode();
52+
while (current) {
53+
for (Value result : op_to_move->getResults()) {
54+
if (llvm::is_contained(current->getOperands(), result)) {
55+
return false;
56+
}
57+
}
58+
for (Value operand : op_to_move->getOperands()) {
59+
if (operand.getDefiningOp() == current) {
60+
return false;
61+
}
62+
}
63+
64+
if (current == target_op) break;
65+
current = current->getNextNode();
66+
}
67+
return true;
68+
}
69+
70+
// Tries to merge the newly created inferred fragment into an existing
71+
// same-mesh fragment in the block.
72+
void MergeInferredFragmentWithExisting(FragmentOp fragment_op,
73+
StringRef mesh_name,
74+
Operation* return_op,
75+
OpBuilder& builder) {
76+
// Try to find an existing same-mesh fragment to merge the newly created
77+
// inferred fragment into. We track two things:
78+
// - latest_operand_producer: the latest op that produces any operand of
79+
// the inferred fragment (needed for positioning constraints).
80+
// - merge_target: the latest same-mesh fragment we can merge into.
81+
FragmentOp merge_target = nullptr;
82+
Operation* latest_operand_producer = nullptr;
83+
84+
// Updates merge_target if `op` is a same-mesh fragment that appears later
85+
// in the block than the current candidate.
86+
auto updateMergeTarget = [&](Operation* op) {
87+
auto frag = dyn_cast<FragmentOp>(op);
88+
if (frag && frag.getMeshName() == mesh_name &&
89+
(!merge_target || merge_target->isBeforeInBlock(frag))) {
90+
merge_target = frag;
91+
}
92+
};
93+
94+
// First, scan the operand producers for a merge candidate.
95+
for (Value v : fragment_op.getOperands()) {
96+
Operation* op = v.getDefiningOp();
97+
if (!op) continue;
98+
if (!latest_operand_producer ||
99+
latest_operand_producer->isBeforeInBlock(op)) {
100+
latest_operand_producer = op;
101+
}
102+
updateMergeTarget(op);
103+
}
104+
105+
// If no producer fragment on the same mesh was found among the operands,
106+
// look for any fragment on the same mesh in the block (sideways merge).
107+
// Only do this when there are actual producer ops (not just block arguments).
108+
if (!merge_target && latest_operand_producer) {
109+
for (Operation& op : *return_op->getBlock()) {
110+
if (&op == return_op || &op == fragment_op) continue;
111+
updateMergeTarget(&op);
112+
}
113+
}
114+
115+
// Give up if no merge candidate was found, or if the candidate can't be
116+
// moved after the latest operand producer (which would break dominance).
117+
if (!merge_target || (merge_target != latest_operand_producer &&
118+
!CanMoveAfter(merge_target, latest_operand_producer))) {
119+
return;
120+
}
121+
122+
// Position the merge target after the latest operand producer so all
123+
// operands of the inferred fragment are available, then merge.
124+
if (merge_target != latest_operand_producer) {
125+
merge_target->moveAfter(latest_operand_producer);
126+
}
127+
128+
fragment_op->moveAfter(merge_target);
129+
IRRewriter rewriter(builder.getContext());
130+
FragmentOp merged_fragment = MergeRegionOps(
131+
merge_target, fragment_op, rewriter,
132+
/*num_static_args=*/0, /*replace_producer_use_in_consumer_block=*/
133+
[](OpOperand&, Value) {
134+
SDY_CHECK(false) << "Fragment ops shouldn't have free variables";
135+
},
136+
GetFragmentOriginUnion(merge_target, fragment_op, rewriter),
137+
merge_target.getMeshNameAttr(),
138+
/*stage_id=*/merge_target.getStageIdAttr());
139+
SetInferredByAttr(merged_fragment, "uniquify", builder);
140+
}
141+
44142
void CreateReturnFragmentForMesh(StringRef mesh_name, Operation* return_op,
45143
ValueToReturnIndices& value_to_return_indices,
46144
OpBuilder& builder) {
47145
// We remove any entries that require no work, in order to avoid too many
48146
// checks.
49147
value_to_return_indices.remove_if([](const auto& it) {
50148
if (it.second.size() == 1) {
51-
Value v = it.first;
52-
return !isa<BlockArgument>(v);
149+
return !isa<BlockArgument>(it.first);
53150
}
54-
return it.second.empty();
151+
return false;
55152
});
56153

57154
if (value_to_return_indices.empty()) {
@@ -98,6 +195,8 @@ void CreateReturnFragmentForMesh(StringRef mesh_name, Operation* return_op,
98195
}
99196
auto block_builder = OpBuilder::atBlockEnd(&fragment_block);
100197
ReturnOp::create(block_builder, loc, returned_values);
198+
199+
MergeInferredFragmentWithExisting(fragment_op, mesh_name, return_op, builder);
101200
}
102201

103202
// Replaces the return values of the function with transfer ops.

shardy/dialect/mpmd/transforms/export/export_pipeline.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,6 @@ void addExportPipeline(OpPassManager& pm, const ExportOptions& options) {
8686
// identity fragments, which would be canonicalized away.
8787
pm.addNestedPass<FuncOp>(createUniquifyFunctionInputsOutputsPass());
8888

89-
// The fragments created by the pass above maybe slowdown compilation (more
90-
// fragments to compile) and may cause performance regressions. Thus, we merge
91-
// them with other fragments.
92-
pm.addNestedPass<FuncOp>(createMergeInferredFragmentsPass());
93-
9489
// Mark each fragment with the inputs and outputs which are offloaded to host
9590
// memory.
9691
pm.addNestedPass<FuncOp>(createMarkOffloadedInputOutputPass());

shardy/dialect/mpmd/transforms/export/test/export_pipeline.mlir

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mpmd_opt %s -mpmd-export-pipeline 2>&1 | FileCheck %s
1+
// RUN: mpmd_opt %s -mpmd-export-pipeline -split-input-file 2>&1 | FileCheck %s
22

33
!mesh_1_tensor_4_8_f32 = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>>
44

@@ -17,3 +17,36 @@ func.func @main(%arg0: !mesh_1_tensor_4_8_f32 {tf.aliasing_output = 0: i32}, %ar
1717
} : (!mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32) -> (!mesh_1_tensor_4_8_f32)
1818
func.return %0 : !mesh_1_tensor_4_8_f32
1919
}
20+
21+
// -----
22+
23+
!mesh_1_tensor_4_8_f32 = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>>
24+
!mesh_2_tensor_4_8_f32 = !mpmd.mesh_tensor<"m2", tensor<4x8xf32>>
25+
26+
// This test verifies that an explicit fragment and an inferred fragment
27+
// (created by the UniquifyFunctionInputsOutputsPass for the duplicated return
28+
// of the transfer result) are merged sideways. Without sideways merge, the
29+
// transfer result would produce a separate inferred fragment call on m1.
30+
// The function-level returns remain unique SSA values (%[[RES]]#0, #1, #2),
31+
// preserving the invariant established by the uniquify pass, even though the
32+
// fragment body internally returns the same value in multiple positions.
33+
// CHECK-LABEL: func.func @test_sideways_merge
34+
func.func @test_sideways_merge(%arg0: !mesh_1_tensor_4_8_f32, %arg1: !mesh_2_tensor_4_8_f32)
35+
-> (!mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32) attributes {
36+
"topology"=#mpmd.topology<
37+
<"m1": <["x"=2]>>,
38+
<"m2": <["x"=2]>>
39+
>} {
40+
// CHECK: %[[RES:.*]]:3 = mpmd.fragment_call<mesh="m1", origin=["f1"]> @[[CALLEE_M1:.*]]
41+
// CHECK-NOT: mpmd.fragment_call<mesh="m1"
42+
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2
43+
44+
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg2: tensor<4x8xf32>) {
45+
%4 = stablehlo.add %arg2, %arg2 : tensor<4x8xf32>
46+
mpmd.return %4 : tensor<4x8xf32>
47+
} : (!mesh_1_tensor_4_8_f32) -> !mesh_1_tensor_4_8_f32
48+
49+
%1 = mpmd.transfer %arg1 : (!mesh_2_tensor_4_8_f32) -> !mesh_1_tensor_4_8_f32
50+
51+
func.return %0, %1, %1 : !mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32
52+
}

0 commit comments

Comments
 (0)