Skip to content

Commit 17994ba

Browse files
petebucopybara-github
authored andcommitted
[mpmd] Add UniquifyAndMergeReturnsV2Pass: create-then-merge approach.
Alternative to UniquifyAndMergeReturnsPass that creates per-value identity fragments and immediately merges them using MergeRegionOps. Simpler than V1 since it reuses existing merge infrastructure. PiperOrigin-RevId: 907826215
1 parent ede550b commit 17994ba

5 files changed

Lines changed: 428 additions & 6 deletions

File tree

shardy/dialect/mpmd/transforms/common/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ cc_library(
4242
"rule_based_merge.cc",
4343
"scheduler_preprocess.cc",
4444
"split_bwd_fragments.cc",
45+
"uniquify_and_merge_returns.cc",
4546
"uniquify_function_inputs_outputs.cc",
4647
"unroll_for_loops.cc",
4748
],

shardy/dialect/mpmd/transforms/common/passes.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,25 @@ def UniquifyFunctionInputsOutputsPass :
461461
let dependentDialects = ["mlir::mpmd::MpmdDialect"];
462462
}
463463

464+
def UniquifyAndMergeReturnsPass :
465+
PassBase<"mpmd-uniquify-and-merge-returns", "DistributedFunctionPass"> {
466+
let summary = "Uniquifies return values by creating and immediately merging "
467+
"identity fragments.";
468+
let description = [{
469+
Like `mpmd-uniquify-function-inputs-outputs`, ensures each return operand is
470+
unique. Instead of creating new inferred fragments, this pass creates a tiny
471+
identity fragment per value that needs
472+
uniquification and immediately merges it into its producer (or a same-mesh
473+
fragment for block arguments) using MergeRegionOps.
474+
475+
The identity fragment is never persisted in the IR — it is created and merged
476+
in a single step, producing the same result as running uniquify followed by
477+
merge-inferred-fragments.
478+
}];
479+
480+
let dependentDialects = ["mlir::mpmd::MpmdDialect"];
481+
}
482+
464483
def SchedulingUnitVerifierPass :
465484
PassBase<"mpmd-scheduling-units-verifier", "DistributedFunctionPass"> {
466485
let summary = "Verifies if the program contains the required scheduling units.";
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
// RUN: mpmd_opt %s -mpmd-uniquify-and-merge-returns -split-input-file 2>&1 | FileCheck %s
2+
3+
!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4xf32>>
4+
!mesh_2_tensor = !mpmd.mesh_tensor<"m2", tensor<4xf32>>
5+
6+
// CHECK-LABEL: func @no_work_needed
7+
func.func @no_work_needed(%arg0: !mesh_1_tensor, %arg1: !mesh_2_tensor) -> (!mesh_1_tensor, !mesh_2_tensor) attributes {
8+
"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["x"=2]>>>
9+
} {
10+
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
11+
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m2", origin=["f2"]>
12+
// CHECK: return %[[F1]], %[[F2]]
13+
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg2: tensor<4xf32>) {
14+
%1 = stablehlo.add %arg2, %arg2 : tensor<4xf32>
15+
mpmd.return %1 : tensor<4xf32>
16+
} : (!mesh_1_tensor) -> !mesh_1_tensor
17+
%1 = mpmd.fragment<mesh="m2", origin=["f2"]> (%arg1) (%arg2: tensor<4xf32>) {
18+
%1 = stablehlo.add %arg2, %arg2 : tensor<4xf32>
19+
mpmd.return %1 : tensor<4xf32>
20+
} : (!mesh_2_tensor) -> !mesh_2_tensor
21+
return %0, %1 : !mesh_1_tensor, !mesh_2_tensor
22+
}
23+
24+
25+
// Test: single mesh, one return operand used multiple times.
26+
// The identity fragment for the extra copies merges into the producing fragment.
27+
// CHECK-LABEL: func @single_mesh_duplicate_return
28+
func.func @single_mesh_duplicate_return(%arg0: !mesh_1_tensor) -> (!mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor) attributes {
29+
"topology"=#mpmd.topology<<"m1": <["x"=2]>>>
30+
} {
31+
// f1 originally returns 1 result. After merge it returns 2 (original + copy).
32+
// CHECK: %[[F1:.*]]:2 = mpmd.fragment<mesh="m1", origin=["f1"]>
33+
// CHECK-SAME: (%arg0) (%arg1: tensor<4xf32>) {
34+
// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg1, %arg1 : tensor<4xf32>
35+
// CHECK-NEXT: mpmd.return %[[ADD]], %[[ADD]] : tensor<4xf32>, tensor<4xf32>
36+
// CHECK-NEXT: }
37+
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m1", origin=["f2"]>
38+
// CHECK: return %[[F2]], %[[F1]]#0, %[[F1]]#1
39+
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg1: tensor<4xf32>) {
40+
%1 = stablehlo.add %arg1, %arg1 : tensor<4xf32>
41+
mpmd.return %1 : tensor<4xf32>
42+
} : (!mesh_1_tensor) -> !mesh_1_tensor
43+
%1 = mpmd.fragment<mesh="m1", origin=["f2"]> (%0) (%arg1: tensor<4xf32>) {
44+
mpmd.return %arg1 : tensor<4xf32>
45+
} : (!mesh_1_tensor) -> !mesh_1_tensor
46+
return %1, %0, %0 : !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor
47+
}
48+
49+
// -----
50+
51+
!mesh_tensor = !mpmd.mesh_tensor<"m", tensor<4xui32>, sharding=<@mesh, [{"x"}]>>
52+
53+
// Test: block argument returned directly -> merged into an existing fragment
54+
// on the same mesh.
55+
// CHECK-LABEL: func @block_arg_passthrough
56+
func.func @block_arg_passthrough(%arg0: !mesh_tensor) -> (!mesh_tensor, !mesh_tensor, !mesh_tensor)
57+
attributes {"topology"=#mpmd.topology<<"m": <["x"=2]>>>}
58+
{
59+
// The block arg %arg0 is used at return[0] and return[2].
60+
// The identity fragment for %arg0 merges into fragment "f".
61+
// CHECK: %[[F:.*]]:3 = mpmd.fragment<mesh="m", origin=["f"]> (%arg0)
62+
// CHECK-SAME: (%arg1: tensor<4xui32>) {
63+
// CHECK-NEXT: mpmd.return %arg1, %arg1, %arg1 : tensor<4xui32>, tensor<4xui32>, tensor<4xui32>
64+
// CHECK-NEXT: }
65+
// CHECK: return %[[F]]#1, %[[F]]#0, %[[F]]#2
66+
%0 = mpmd.fragment<mesh="m", origin=["f"]>(%arg0) (%arg1: tensor<4xui32>) {
67+
mpmd.return %arg1 : tensor<4xui32>
68+
} : (!mesh_tensor) -> !mesh_tensor
69+
func.return %arg0, %0, %arg0 : !mesh_tensor, !mesh_tensor, !mesh_tensor
70+
}
71+
72+
// -----
73+
74+
!mesh_tensor = !mpmd.mesh_tensor<"m", tensor<4xui32>, sharding=<@mesh, [{"x"}]>>
75+
76+
// Test: identity function with no existing fragment -> must create a fallback.
77+
// CHECK-LABEL: func @identity_function
78+
func.func @identity_function(%arg0: !mesh_tensor) -> !mesh_tensor
79+
attributes {"topology"=#mpmd.topology<<"m": <["x"=2]>>>}
80+
{
81+
// CHECK-NEXT: %[[F:.*]] = mpmd.fragment<mesh="m", origin=[]> (%arg0) (%arg1: tensor<4xui32>) {
82+
// CHECK-NEXT: mpmd.return %arg1 : tensor<4xui32>
83+
// CHECK-NEXT: }
84+
// CHECK-NEXT: return %[[F]]
85+
func.return %arg0 : !mesh_tensor
86+
}
87+
88+
// -----
89+
90+
!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4xf32>>
91+
92+
// Regression test: multi-result fragment with multiple duplicated return values.
93+
// This exercises the case where MergeRegionOps erases the producing fragment.
94+
// Processing the first result merges and erases the original fragment; the
95+
// second result must be handled from the new merged fragment, not the stale one.
96+
// CHECK-LABEL: func @multi_result_duplicate_returns
97+
func.func @multi_result_duplicate_returns(%arg0: !mesh_1_tensor) -> (
98+
!mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor) attributes {
99+
"topology"=#mpmd.topology<<"m1": <["x"=2]>>>
100+
} {
101+
// f1 produces two results, each returned twice.
102+
// After uniquification, f1 returns 4 values: a, b, a_copy, b_copy.
103+
// CHECK: %[[F:.*]]:4 = mpmd.fragment<mesh="m1", origin=["f1"]>
104+
// CHECK-SAME: (%arg0) (%arg1: tensor<4xf32>) {
105+
// CHECK-NEXT: %[[A:.*]] = stablehlo.add %arg1, %arg1
106+
// CHECK-NEXT: %[[B:.*]] = stablehlo.multiply %arg1, %arg1
107+
// CHECK-NEXT: mpmd.return %[[A]], %[[B]], %[[A]], %[[B]]
108+
// CHECK-NEXT: }
109+
// CHECK: return %[[F]]#0, %[[F]]#2, %[[F]]#1, %[[F]]#3
110+
%0:2 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg1: tensor<4xf32>) {
111+
%a = stablehlo.add %arg1, %arg1 : tensor<4xf32>
112+
%b = stablehlo.multiply %arg1, %arg1 : tensor<4xf32>
113+
mpmd.return %a, %b : tensor<4xf32>, tensor<4xf32>
114+
} : (!mesh_1_tensor) -> (!mesh_1_tensor, !mesh_1_tensor)
115+
return %0#0, %0#0, %0#1, %0#1 : !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor
116+
}
117+
118+
// -----
119+
120+
!mesh_1_tensor_v2 = !mpmd.mesh_tensor<"m1", tensor<4xf32>>
121+
122+
// Regression test: chained fragments where both producers have duplicated
123+
// return values. F1 feeds into F2, and both F1's result and F2's result
124+
// appear twice in the return. The pass must uniquify all duplicates without
125+
// leaving any behind.
126+
// CHECK-LABEL: func @chained_fragments_duplicate_returns
127+
func.func @chained_fragments_duplicate_returns(
128+
%arg0: !mesh_1_tensor_v2) -> (
129+
!mesh_1_tensor_v2, !mesh_1_tensor_v2, !mesh_1_tensor_v2, !mesh_1_tensor_v2
130+
) attributes {
131+
"topology"=#mpmd.topology<<"m1": <["x"=2]>>>
132+
} {
133+
// CHECK: mpmd.fragment<mesh="m1"
134+
// CHECK: mpmd.fragment<mesh="m1"
135+
// CHECK: return
136+
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg1: tensor<4xf32>) {
137+
%a = stablehlo.add %arg1, %arg1 : tensor<4xf32>
138+
mpmd.return %a : tensor<4xf32>
139+
} : (!mesh_1_tensor_v2) -> !mesh_1_tensor_v2
140+
%1 = mpmd.fragment<mesh="m1", origin=["f2"]> (%0) (%arg1: tensor<4xf32>) {
141+
%b = stablehlo.multiply %arg1, %arg1 : tensor<4xf32>
142+
mpmd.return %b : tensor<4xf32>
143+
} : (!mesh_1_tensor_v2) -> !mesh_1_tensor_v2
144+
return %0, %0, %1, %1 : !mesh_1_tensor_v2, !mesh_1_tensor_v2, !mesh_1_tensor_v2, !mesh_1_tensor_v2
145+
}

0 commit comments

Comments
 (0)