Skip to content

Commit 8139bdf

Browse files
committed
up
1 parent b5dbfb7 commit 8139bdf

2 files changed

Lines changed: 175 additions & 2 deletions

File tree

exir/backend/test/test_backends.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,6 +1446,85 @@ def inputs(self):
14461446
torch.allclose(model_outputs[0], ref_output, atol=1e-03, rtol=1e-03)
14471447
)
14481448

1449+
def test_arrange_graph_outputs_reorders_mutations_before_user_outputs(self):
1450+
"""
1451+
Directly test that arrange_graph_outputs correctly reorders a
1452+
submodule's output tuple so that BUFFER_MUTATION outputs come before
1453+
USER_OUTPUT outputs, and that getitem indices in the parent graph are
1454+
remapped accordingly.
1455+
"""
1456+
from executorch.exir.lowered_backend_module import arrange_graph_outputs
1457+
from torch.export.exported_program import OutputKind, OutputSpec, TensorArgument
1458+
1459+
# Build a submodule graph with 3 outputs in order:
1460+
# [user_out_0, buffer_mut_1, user_out_2]
1461+
# The expected reordering is:
1462+
# [buffer_mut_1, user_out_0, user_out_2]
1463+
sub_graph = torch.fx.Graph()
1464+
x = sub_graph.placeholder("x")
1465+
buf = sub_graph.placeholder("buf")
1466+
add_node = sub_graph.call_function(torch.ops.aten.add.Tensor, (x, x))
1467+
mul_node = sub_graph.call_function(torch.ops.aten.mul.Tensor, (buf, x))
1468+
sub_node = sub_graph.call_function(torch.ops.aten.sub.Tensor, (x, x))
1469+
# Output order: user, mutation, user
1470+
sub_graph.output((add_node, mul_node, sub_node))
1471+
sub_gm = torch.fx.GraphModule({}, sub_graph)
1472+
1473+
output_specs = [
1474+
OutputSpec(
1475+
kind=OutputKind.USER_OUTPUT,
1476+
arg=TensorArgument(name="add"),
1477+
target=None,
1478+
),
1479+
OutputSpec(
1480+
kind=OutputKind.BUFFER_MUTATION,
1481+
arg=TensorArgument(name="mul"),
1482+
target="buf",
1483+
),
1484+
OutputSpec(
1485+
kind=OutputKind.USER_OUTPUT,
1486+
arg=TensorArgument(name="sub"),
1487+
target=None,
1488+
),
1489+
]
1490+
1491+
# Build a parent graph with a call_module node and getitem users
1492+
parent_graph = torch.fx.Graph()
1493+
px = parent_graph.placeholder("x")
1494+
call_mod = parent_graph.call_module("sub_mod", (px,))
1495+
gi0 = parent_graph.call_function(operator.getitem, (call_mod, 0))
1496+
gi1 = parent_graph.call_function(operator.getitem, (call_mod, 1))
1497+
gi2 = parent_graph.call_function(operator.getitem, (call_mod, 2))
1498+
parent_graph.output((gi0, gi1, gi2))
1499+
1500+
# Run arrange_graph_outputs
1501+
arrange_graph_outputs(sub_gm, output_specs, call_mod)
1502+
1503+
# Verify output_specs are reordered: mutation first
1504+
self.assertEqual(output_specs[0].kind, OutputKind.BUFFER_MUTATION)
1505+
self.assertEqual(output_specs[1].kind, OutputKind.USER_OUTPUT)
1506+
self.assertEqual(output_specs[2].kind, OutputKind.USER_OUTPUT)
1507+
self.assertEqual(output_specs[0].target, "buf")
1508+
1509+
# Verify the submodule graph output tuple is reordered
1510+
output_node = None
1511+
for node in sub_gm.graph.nodes:
1512+
if node.op == "output":
1513+
output_node = node
1514+
break
1515+
reordered = list(output_node.args[0])
1516+
self.assertIs(reordered[0], mul_node) # buffer mutation first
1517+
self.assertIs(reordered[1], add_node) # then user outputs
1518+
self.assertIs(reordered[2], sub_node)
1519+
1520+
# Verify getitem indices were remapped:
1521+
# old 0 (user) -> new 1
1522+
# old 1 (mutation) -> new 0
1523+
# old 2 (user) -> new 2 (unchanged)
1524+
self.assertEqual(gi0.args[1], 1)
1525+
self.assertEqual(gi1.args[1], 0)
1526+
self.assertEqual(gi2.args[1], 2)
1527+
14491528
def test_prohibited_nested_backends(self):
14501529
class MyBackend(BackendDetails):
14511530
@staticmethod

exir/lowered_backend_module.py

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,97 @@ def arrange_graph_placeholders(
449449
return gm
450450

451451

452+
def arrange_graph_outputs(
453+
gm: torch.fx.GraphModule,
454+
output_specs: List[OutputSpec],
455+
call_module_node: torch.fx.Node,
456+
) -> torch.fx.GraphModule:
457+
"""
458+
Reorders the output tuple of the graph so that buffer mutation outputs come
459+
before user outputs, matching the ordering that ExportedProgram's verifier
460+
expects: [buffer_mutations..., user_outputs...].
461+
462+
The partitioner may produce a submodule whose output tuple has buffer
463+
mutations and user outputs interleaved in arbitrary order. The verifier
464+
determines which outputs are mutations by position (first N outputs where
465+
N = number of mutation specs), so a misordered tuple causes a
466+
SpecViolationError.
467+
468+
This function builds a permutation from the output_specs (which
469+
_get_new_signature already classified correctly) and rewrites the graph's
470+
output node to match. It also remaps getitem indices on the parent
471+
graph's call_module_node so the parent continues to extract the correct
472+
outputs.
473+
474+
Args:
475+
gm: The graph module whose output ordering may need adjustment.
476+
output_specs: The output specs built by _get_new_signature, with
477+
correct kind annotations but potentially mismatched ordering
478+
relative to the graph's output tuple.
479+
call_module_node: The call_module node in the parent graph whose
480+
getitem users need index remapping.
481+
482+
Returns:
483+
The graph module with reordered outputs (modified in-place).
484+
"""
485+
# Find the output node
486+
output_node = None
487+
for node in gm.graph.nodes:
488+
if node.op == "output":
489+
output_node = node
490+
break
491+
492+
if output_node is None or not output_node.args[0]:
493+
return gm
494+
495+
old_outputs = list(output_node.args[0])
496+
497+
if len(old_outputs) != len(output_specs):
498+
raise RuntimeError(
499+
f"Mismatch between graph outputs ({len(old_outputs)}) and "
500+
f"output_specs ({len(output_specs)}). This indicates a bug in "
501+
"_get_new_signature."
502+
)
503+
504+
# Separate indices by kind: mutations first, then user outputs
505+
mutation_indices = []
506+
user_output_indices = []
507+
for i, spec in enumerate(output_specs):
508+
if spec.kind in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION):
509+
mutation_indices.append(i)
510+
else:
511+
user_output_indices.append(i)
512+
513+
new_order = mutation_indices + user_output_indices
514+
515+
# Check if already in correct order
516+
if new_order == list(range(len(old_outputs))):
517+
return gm
518+
519+
# Build reverse mapping: old_index -> new_index
520+
old_to_new = {old_idx: new_idx for new_idx, old_idx in enumerate(new_order)}
521+
522+
# Reorder the output tuple in the submodule graph
523+
new_outputs = [old_outputs[i] for i in new_order]
524+
output_node.args = (tuple(new_outputs),)
525+
526+
# Reorder the output_specs to match (in-place)
527+
reordered_specs = [output_specs[i] for i in new_order]
528+
output_specs.clear()
529+
output_specs.extend(reordered_specs)
530+
531+
# Remap getitem indices in the parent graph
532+
for user in list(call_module_node.users.keys()):
533+
if user.op == "call_function" and user.target == operator.getitem:
534+
old_idx = user.args[1]
535+
if isinstance(old_idx, int) and old_idx in old_to_new:
536+
user.args = (user.args[0], old_to_new[old_idx])
537+
538+
gm.graph.lint()
539+
540+
return gm
541+
542+
452543
# TODO Don't regenerate new signature manually.
453544
def _get_new_signature( # noqa: C901
454545
original_program: ExportedProgram,
@@ -706,8 +797,6 @@ def create_exported_program_from_submodule(
706797
# Arrange the submodule's placeholders in order
707798
submodule = arrange_graph_placeholders(submodule, owning_program, tag)
708799

709-
# TODO: we probably need to arrange the outputs wrt buffer mutations.
710-
711800
# Get updated graph signature
712801
(
713802
subgraph_signature,
@@ -719,6 +808,11 @@ def create_exported_program_from_submodule(
719808
owning_program, submodule, call_module_node, tag, is_submodule
720809
)
721810

811+
# Reorder outputs: buffer mutations first, then user outputs.
812+
# The verifier expects this ordering but _get_new_signature produces
813+
# output_specs in graph order which may interleave the two kinds.
814+
arrange_graph_outputs(submodule, subgraph_signature.output_specs, call_module_node)
815+
722816
in_spec = pytree.tree_flatten((tuple(subgraph_signature.user_inputs), {}))[1]
723817
out_spec = pytree.tree_flatten(subgraph_signature.user_outputs)[1]
724818

0 commit comments

Comments
 (0)