@@ -449,6 +449,97 @@ def arrange_graph_placeholders(
449449 return gm
450450
451451
452+ def arrange_graph_outputs (
453+ gm : torch .fx .GraphModule ,
454+ output_specs : List [OutputSpec ],
455+ call_module_node : torch .fx .Node ,
456+ ) -> torch .fx .GraphModule :
457+ """
458+ Reorders the output tuple of the graph so that buffer mutation outputs come
459+ before user outputs, matching the ordering that ExportedProgram's verifier
460+ expects: [buffer_mutations..., user_outputs...].
461+
462+ The partitioner may produce a submodule whose output tuple has buffer
463+ mutations and user outputs interleaved in arbitrary order. The verifier
464+ determines which outputs are mutations by position (first N outputs where
465+ N = number of mutation specs), so a misordered tuple causes a
466+ SpecViolationError.
467+
468+ This function builds a permutation from the output_specs (which
469+ _get_new_signature already classified correctly) and rewrites the graph's
470+ output node to match. It also remaps getitem indices on the parent
471+ graph's call_module_node so the parent continues to extract the correct
472+ outputs.
473+
474+ Args:
475+ gm: The graph module whose output ordering may need adjustment.
476+ output_specs: The output specs built by _get_new_signature, with
477+ correct kind annotations but potentially mismatched ordering
478+ relative to the graph's output tuple.
479+ call_module_node: The call_module node in the parent graph whose
480+ getitem users need index remapping.
481+
482+ Returns:
483+ The graph module with reordered outputs (modified in-place).
484+ """
485+ # Find the output node
486+ output_node = None
487+ for node in gm .graph .nodes :
488+ if node .op == "output" :
489+ output_node = node
490+ break
491+
492+ if output_node is None or not output_node .args [0 ]:
493+ return gm
494+
495+ old_outputs = list (output_node .args [0 ])
496+
497+ if len (old_outputs ) != len (output_specs ):
498+ raise RuntimeError (
499+ f"Mismatch between graph outputs ({ len (old_outputs )} ) and "
500+ f"output_specs ({ len (output_specs )} ). This indicates a bug in "
501+ "_get_new_signature."
502+ )
503+
504+ # Separate indices by kind: mutations first, then user outputs
505+ mutation_indices = []
506+ user_output_indices = []
507+ for i , spec in enumerate (output_specs ):
508+ if spec .kind in (OutputKind .BUFFER_MUTATION , OutputKind .USER_INPUT_MUTATION ):
509+ mutation_indices .append (i )
510+ else :
511+ user_output_indices .append (i )
512+
513+ new_order = mutation_indices + user_output_indices
514+
515+ # Check if already in correct order
516+ if new_order == list (range (len (old_outputs ))):
517+ return gm
518+
519+ # Build reverse mapping: old_index -> new_index
520+ old_to_new = {old_idx : new_idx for new_idx , old_idx in enumerate (new_order )}
521+
522+ # Reorder the output tuple in the submodule graph
523+ new_outputs = [old_outputs [i ] for i in new_order ]
524+ output_node .args = (tuple (new_outputs ),)
525+
526+ # Reorder the output_specs to match (in-place)
527+ reordered_specs = [output_specs [i ] for i in new_order ]
528+ output_specs .clear ()
529+ output_specs .extend (reordered_specs )
530+
531+ # Remap getitem indices in the parent graph
532+ for user in list (call_module_node .users .keys ()):
533+ if user .op == "call_function" and user .target == operator .getitem :
534+ old_idx = user .args [1 ]
535+ if isinstance (old_idx , int ) and old_idx in old_to_new :
536+ user .args = (user .args [0 ], old_to_new [old_idx ])
537+
538+ gm .graph .lint ()
539+
540+ return gm
541+
542+
452543# TODO Don't regenerate new signature manually.
453544def _get_new_signature ( # noqa: C901
454545 original_program : ExportedProgram ,
@@ -706,8 +797,6 @@ def create_exported_program_from_submodule(
706797 # Arrange the submodule's placeholders in order
707798 submodule = arrange_graph_placeholders (submodule , owning_program , tag )
708799
709- # TODO: we probably need to arrange the outputs wrt buffer mutations.
710-
711800 # Get updated graph signature
712801 (
713802 subgraph_signature ,
@@ -719,6 +808,11 @@ def create_exported_program_from_submodule(
719808 owning_program , submodule , call_module_node , tag , is_submodule
720809 )
721810
811+ # Reorder outputs: buffer mutations first, then user outputs.
812+ # The verifier expects this ordering but _get_new_signature produces
813+ # output_specs in graph order which may interleave the two kinds.
814+ arrange_graph_outputs (submodule , subgraph_signature .output_specs , call_module_node )
815+
722816 in_spec = pytree .tree_flatten ((tuple (subgraph_signature .user_inputs ), {}))[1 ]
723817 out_spec = pytree .tree_flatten (subgraph_signature .user_outputs )[1 ]
724818
0 commit comments