@@ -34,6 +34,31 @@ def __init__(self, tfa_pass=False, *args, **kwargs):
3434 def _skip_pass (self , input_tensors : list ) -> bool :
3535 return False
3636
37+ def _get_input_tensors (self , node : torch .fx .Node ) -> list :
38+ input_tensors = []
39+ for arg in node .args :
40+ if hasattr (arg , "meta" ):
41+ input_tensors .append (arg .meta ["val" ]) # type: ignore[union-attr]
42+ elif isinstance (arg , int ):
43+ input_tensors .append (arg )
44+ return input_tensors
45+
46+ def _get_placeholder_map (
47+ self ,
48+ node : torch .fx .Node ,
49+ decomposed_module : torch .fx .GraphModule ,
50+ ) -> dict [str , torch .fx .Node ]:
51+ # Keep decomposed_module in the hook signature so subclasses can inspect
52+ # traced placeholder structure when the mapping is not one-to-one.
53+ name_to_input_tensor_map = {}
54+ for i , arg in enumerate (node .args ):
55+ name_to_input_tensor_map [f"arg{ i } _1" ] = arg
56+ return name_to_input_tensor_map # type: ignore[return-value]
57+
58+ def _get_output_node (self , output_node : torch .fx .Node ) -> torch .fx .Node :
59+ """Return the traced value node for graphs that emit output(node)."""
60+ return output_node .args [0 ] # type: ignore[return-value]
61+
3762 def call (self , graph_module : torch .fx .GraphModule ) -> PassResult : # noqa: C901
3863 modified = False
3964 for node in graph_module .graph .nodes :
@@ -44,13 +69,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
4469 ):
4570 continue
4671
47- input_tensors = []
48- for arg in node .args :
49- if hasattr (arg , "meta" ):
50- input_tensors .append (arg .meta ["val" ])
51-
52- elif isinstance (arg , int ):
53- input_tensors .append (arg )
72+ input_tensors = self ._get_input_tensors (node )
5473
5574 if self ._skip_pass (input_tensors ):
5675 continue
@@ -70,22 +89,26 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
7089 )(* input_tensors )
7190
7291 with graph_module .graph .inserting_before (node ):
73- name_to_input_tensor_map = {}
74- for i , arg in enumerate ( node . args ):
75- name_to_input_tensor_map [ f"arg { i } _1" ] = arg
92+ name_to_input_tensor_map = self . _get_placeholder_map (
93+ node , decomposed_module
94+ )
7695
7796 decomposed_node_to_subgraph_node = {}
7897 last_decomposed_node = None
7998 # Create a mapping from input nodes in decomposed module to original nodes.
8099 # In decomposed module, there are only input tensors for placeholder op.
81100 for decomposed_node in decomposed_module .graph .nodes :
82101 if decomposed_node .op == "placeholder" :
102+ # Some ops, such as einsum, trace extra placeholders that do
103+ # not map back to original graph tensor inputs.
104+ if decomposed_node .name not in name_to_input_tensor_map :
105+ continue
83106 decomposed_node_to_subgraph_node [decomposed_node ] = (
84107 name_to_input_tensor_map [decomposed_node .name ]
85108 )
86109
87110 if decomposed_node .op == "output" :
88- last_decomposed_node = decomposed_node . args [ 0 ]
111+ last_decomposed_node = self . _get_output_node ( decomposed_node )
89112
90113 # Copy node from decompose graph module
91114 for decomposed_node in decomposed_module .graph .nodes :
0 commit comments