@@ -447,6 +447,97 @@ def arrange_graph_placeholders(
447447 return gm
448448
449449
450+ def arrange_graph_outputs (
451+ gm : torch .fx .GraphModule ,
452+ output_specs : List [OutputSpec ],
453+ call_module_node : torch .fx .Node ,
454+ ) -> torch .fx .GraphModule :
455+ """
456+ Reorders the output tuple of the graph so that buffer mutation outputs come
457+ before user outputs, matching the ordering that ExportedProgram's verifier
458+ expects: [buffer_mutations..., user_outputs...].
459+
460+ The partitioner may produce a submodule whose output tuple has buffer
461+ mutations and user outputs interleaved in arbitrary order. The verifier
462+ determines which outputs are mutations by position (first N outputs where
463+ N = number of mutation specs), so a misordered tuple causes a
464+ SpecViolationError.
465+
466+ This function builds a permutation from the output_specs (which
467+ _get_new_signature already classified correctly) and rewrites the graph's
468+ output node to match. It also remaps getitem indices on the parent
469+ graph's call_module_node so the parent continues to extract the correct
470+ outputs.
471+
472+ Args:
473+ gm: The graph module whose output ordering may need adjustment.
474+ output_specs: The output specs built by _get_new_signature, with
475+ correct kind annotations but potentially mismatched ordering
476+ relative to the graph's output tuple.
477+ call_module_node: The call_module node in the parent graph whose
478+ getitem users need index remapping.
479+
480+ Returns:
481+ The graph module with reordered outputs (modified in-place).
482+ """
483+ # Find the output node
484+ output_node = None
485+ for node in gm .graph .nodes :
486+ if node .op == "output" :
487+ output_node = node
488+ break
489+
490+ if output_node is None or not output_node .args [0 ]:
491+ return gm
492+
493+ old_outputs = list (output_node .args [0 ])
494+
495+ if len (old_outputs ) != len (output_specs ):
496+ raise RuntimeError (
497+ f"Mismatch between graph outputs ({ len (old_outputs )} ) and "
498+ f"output_specs ({ len (output_specs )} ). This indicates a bug in "
499+ "_get_new_signature."
500+ )
501+
502+ # Separate indices by kind: mutations first, then user outputs
503+ mutation_indices = []
504+ user_output_indices = []
505+ for i , spec in enumerate (output_specs ):
506+ if spec .kind in (OutputKind .BUFFER_MUTATION , OutputKind .USER_INPUT_MUTATION ):
507+ mutation_indices .append (i )
508+ else :
509+ user_output_indices .append (i )
510+
511+ new_order = mutation_indices + user_output_indices
512+
513+ # Check if already in correct order
514+ if new_order == list (range (len (old_outputs ))):
515+ return gm
516+
517+ # Build reverse mapping: old_index -> new_index
518+ old_to_new = {old_idx : new_idx for new_idx , old_idx in enumerate (new_order )}
519+
520+ # Reorder the output tuple in the submodule graph
521+ new_outputs = [old_outputs [i ] for i in new_order ]
522+ output_node .args = (tuple (new_outputs ),)
523+
524+ # Reorder the output_specs to match (in-place)
525+ reordered_specs = [output_specs [i ] for i in new_order ]
526+ output_specs .clear ()
527+ output_specs .extend (reordered_specs )
528+
529+ # Remap getitem indices in the parent graph
530+ for user in list (call_module_node .users .keys ()):
531+ if user .op == "call_function" and user .target == operator .getitem :
532+ old_idx = user .args [1 ]
533+ if isinstance (old_idx , int ) and old_idx in old_to_new :
534+ user .args = (user .args [0 ], old_to_new [old_idx ])
535+
536+ gm .graph .lint ()
537+
538+ return gm
539+
540+
450541# TODO Don't regenerate new signature manually.
451542def _get_new_signature ( # noqa: C901
452543 original_program : ExportedProgram ,
@@ -704,8 +795,6 @@ def create_exported_program_from_submodule(
704795 # Arrange the submodule's placeholders in order
705796 submodule = arrange_graph_placeholders (submodule , owning_program , tag )
706797
707- # TODO: we probably need to arrange the outputs wrt buffer mutations.
708-
709798 # Get updated graph signature
710799 (
711800 subgraph_signature ,
@@ -717,6 +806,11 @@ def create_exported_program_from_submodule(
717806 owning_program , submodule , call_module_node , tag , is_submodule
718807 )
719808
809+ # Reorder outputs: buffer mutations first, then user outputs.
810+ # The verifier expects this ordering but _get_new_signature produces
811+ # output_specs in graph order which may interleave the two kinds.
812+ arrange_graph_outputs (submodule , subgraph_signature .output_specs , call_module_node )
813+
720814 in_spec = pytree .tree_flatten ((tuple (subgraph_signature .user_inputs ), {}))[1 ]
721815 out_spec = pytree .tree_flatten (subgraph_signature .user_outputs )[1 ]
722816
0 commit comments