Skip to content

Commit 1c69817

Browse files
committed
up
1 parent e0e10cc commit 1c69817

2 files changed

Lines changed: 176 additions & 2 deletions

File tree

exir/backend/test/test_lowered_backend_module.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import operator
78
import unittest
89

910
import executorch.exir.tests.models as models
@@ -205,3 +206,82 @@ def forward(self, *args):
205206

206207
program = nested_lowered_model.program()
207208
self.validate_lowered_module_program(program)
209+
210+
def test_arrange_graph_outputs_reorders_mutations_before_user_outputs(self):
211+
"""
212+
Directly test that arrange_graph_outputs correctly reorders a
213+
submodule's output tuple so that BUFFER_MUTATION outputs come before
214+
USER_OUTPUT outputs, and that getitem indices in the parent graph are
215+
remapped accordingly.
216+
"""
217+
from executorch.exir.lowered_backend_module import arrange_graph_outputs
218+
from torch.export.exported_program import OutputKind, OutputSpec, TensorArgument
219+
220+
# Build a submodule graph with 3 outputs in order:
221+
# [user_out_0, buffer_mut_1, user_out_2]
222+
# The expected reordering is:
223+
# [buffer_mut_1, user_out_0, user_out_2]
224+
sub_graph = torch.fx.Graph()
225+
x = sub_graph.placeholder("x")
226+
buf = sub_graph.placeholder("buf")
227+
add_node = sub_graph.call_function(torch.ops.aten.add.Tensor, (x, x))
228+
mul_node = sub_graph.call_function(torch.ops.aten.mul.Tensor, (buf, x))
229+
sub_node = sub_graph.call_function(torch.ops.aten.sub.Tensor, (x, x))
230+
# Output order: user, mutation, user
231+
sub_graph.output((add_node, mul_node, sub_node))
232+
sub_gm = torch.fx.GraphModule({}, sub_graph)
233+
234+
output_specs = [
235+
OutputSpec(
236+
kind=OutputKind.USER_OUTPUT,
237+
arg=TensorArgument(name="add"),
238+
target=None,
239+
),
240+
OutputSpec(
241+
kind=OutputKind.BUFFER_MUTATION,
242+
arg=TensorArgument(name="mul"),
243+
target="buf",
244+
),
245+
OutputSpec(
246+
kind=OutputKind.USER_OUTPUT,
247+
arg=TensorArgument(name="sub"),
248+
target=None,
249+
),
250+
]
251+
252+
# Build a parent graph with a call_module node and getitem users
253+
parent_graph = torch.fx.Graph()
254+
px = parent_graph.placeholder("x")
255+
call_mod = parent_graph.call_module("sub_mod", (px,))
256+
gi0 = parent_graph.call_function(operator.getitem, (call_mod, 0))
257+
gi1 = parent_graph.call_function(operator.getitem, (call_mod, 1))
258+
gi2 = parent_graph.call_function(operator.getitem, (call_mod, 2))
259+
parent_graph.output((gi0, gi1, gi2))
260+
261+
# Run arrange_graph_outputs
262+
arrange_graph_outputs(sub_gm, output_specs, call_mod)
263+
264+
# Verify output_specs are reordered: mutation first
265+
self.assertEqual(output_specs[0].kind, OutputKind.BUFFER_MUTATION)
266+
self.assertEqual(output_specs[1].kind, OutputKind.USER_OUTPUT)
267+
self.assertEqual(output_specs[2].kind, OutputKind.USER_OUTPUT)
268+
self.assertEqual(output_specs[0].target, "buf")
269+
270+
# Verify the submodule graph output tuple is reordered
271+
output_node = None
272+
for node in sub_gm.graph.nodes:
273+
if node.op == "output":
274+
output_node = node
275+
break
276+
reordered = list(output_node.args[0])
277+
self.assertIs(reordered[0], mul_node) # buffer mutation first
278+
self.assertIs(reordered[1], add_node) # then user outputs
279+
self.assertIs(reordered[2], sub_node)
280+
281+
# Verify getitem indices were remapped:
282+
# old 0 (user) -> new 1
283+
# old 1 (mutation) -> new 0
284+
# old 2 (user) -> new 2 (unchanged)
285+
self.assertEqual(gi0.args[1], 1)
286+
self.assertEqual(gi1.args[1], 0)
287+
self.assertEqual(gi2.args[1], 2)

exir/lowered_backend_module.py

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,97 @@ def arrange_graph_placeholders(
447447
return gm
448448

449449

450+
def arrange_graph_outputs(
451+
gm: torch.fx.GraphModule,
452+
output_specs: List[OutputSpec],
453+
call_module_node: torch.fx.Node,
454+
) -> torch.fx.GraphModule:
455+
"""
456+
Reorders the output tuple of the graph so that buffer mutation outputs come
457+
before user outputs, matching the ordering that ExportedProgram's verifier
458+
expects: [buffer_mutations..., user_outputs...].
459+
460+
The partitioner may produce a submodule whose output tuple has buffer
461+
mutations and user outputs interleaved in arbitrary order. The verifier
462+
determines which outputs are mutations by position (first N outputs where
463+
N = number of mutation specs), so a misordered tuple causes a
464+
SpecViolationError.
465+
466+
This function builds a permutation from the output_specs (which
467+
_get_new_signature already classified correctly) and rewrites the graph's
468+
output node to match. It also remaps getitem indices on the parent
469+
graph's call_module_node so the parent continues to extract the correct
470+
outputs.
471+
472+
Args:
473+
gm: The graph module whose output ordering may need adjustment.
474+
output_specs: The output specs built by _get_new_signature, with
475+
correct kind annotations but potentially mismatched ordering
476+
relative to the graph's output tuple.
477+
call_module_node: The call_module node in the parent graph whose
478+
getitem users need index remapping.
479+
480+
Returns:
481+
The graph module with reordered outputs (modified in-place).
482+
"""
483+
# Find the output node
484+
output_node = None
485+
for node in gm.graph.nodes:
486+
if node.op == "output":
487+
output_node = node
488+
break
489+
490+
if output_node is None or not output_node.args[0]:
491+
return gm
492+
493+
old_outputs = list(output_node.args[0])
494+
495+
if len(old_outputs) != len(output_specs):
496+
raise RuntimeError(
497+
f"Mismatch between graph outputs ({len(old_outputs)}) and "
498+
f"output_specs ({len(output_specs)}). This indicates a bug in "
499+
"_get_new_signature."
500+
)
501+
502+
# Separate indices by kind: mutations first, then user outputs
503+
mutation_indices = []
504+
user_output_indices = []
505+
for i, spec in enumerate(output_specs):
506+
if spec.kind in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION):
507+
mutation_indices.append(i)
508+
else:
509+
user_output_indices.append(i)
510+
511+
new_order = mutation_indices + user_output_indices
512+
513+
# Check if already in correct order
514+
if new_order == list(range(len(old_outputs))):
515+
return gm
516+
517+
# Build reverse mapping: old_index -> new_index
518+
old_to_new = {old_idx: new_idx for new_idx, old_idx in enumerate(new_order)}
519+
520+
# Reorder the output tuple in the submodule graph
521+
new_outputs = [old_outputs[i] for i in new_order]
522+
output_node.args = (tuple(new_outputs),)
523+
524+
# Reorder the output_specs to match (in-place)
525+
reordered_specs = [output_specs[i] for i in new_order]
526+
output_specs.clear()
527+
output_specs.extend(reordered_specs)
528+
529+
# Remap getitem indices in the parent graph
530+
for user in list(call_module_node.users.keys()):
531+
if user.op == "call_function" and user.target == operator.getitem:
532+
old_idx = user.args[1]
533+
if isinstance(old_idx, int) and old_idx in old_to_new:
534+
user.args = (user.args[0], old_to_new[old_idx])
535+
536+
gm.graph.lint()
537+
538+
return gm
539+
540+
450541
# TODO Don't regenerate new signature manually.
451542
def _get_new_signature( # noqa: C901
452543
original_program: ExportedProgram,
@@ -704,8 +795,6 @@ def create_exported_program_from_submodule(
704795
# Arrange the submodule's placeholders in order
705796
submodule = arrange_graph_placeholders(submodule, owning_program, tag)
706797

707-
# TODO: we probably need to arrange the outputs wrt buffer mutations.
708-
709798
# Get updated graph signature
710799
(
711800
subgraph_signature,
@@ -717,6 +806,11 @@ def create_exported_program_from_submodule(
717806
owning_program, submodule, call_module_node, tag, is_submodule
718807
)
719808

809+
# Reorder outputs: buffer mutations first, then user outputs.
810+
# The verifier expects this ordering but _get_new_signature produces
811+
# output_specs in graph order which may interleave the two kinds.
812+
arrange_graph_outputs(submodule, subgraph_signature.output_specs, call_module_node)
813+
720814
in_spec = pytree.tree_flatten((tuple(subgraph_signature.user_inputs), {}))[1]
721815
out_spec = pytree.tree_flatten(subgraph_signature.user_outputs)[1]
722816

0 commit comments

Comments
 (0)