1313
1414import torch
1515from executorch .exir import EdgeCompileConfig , to_edge
16+ from executorch .exir .dialects ._ops import ops as exir_ops
1617from executorch .exir .passes .dim_order_utils import (
1718 dim_order_from_fake_tensor ,
1819 should_propagate_dim_order ,
2122from torch .export import export
2223
2324
24- def _find_clone_out_nodes (graph_module ):
25- """Return list of (node, self_node, out_node) for each aten.clone.out in graph."""
25+ # Clone ops that may appear in the graph: aten (pre-OpReplacePass) or edge (after to_edge).
26+ _CLONE_OPS = (
27+ torch .ops .aten .clone .default ,
28+ torch .ops .aten .clone .out ,
29+ exir_ops .edge .aten .clone .default ,
30+ exir_ops .edge .dim_order_ops ._clone_dim_order .default ,
31+ )
32+ if hasattr (exir_ops .edge .dim_order_ops ._clone_dim_order , "out" ):
33+ _CLONE_OPS = _CLONE_OPS + (exir_ops .edge .dim_order_ops ._clone_dim_order .out ,)
34+
35+
36+ def _find_clone_nodes (graph_module ):
37+ """
38+ Return list of (node, self_node, output_spec) for each clone in graph.
39+ to_edge uses edge ops (edge.aten.clone or edge.dim_order_ops._clone_dim_order).
40+ output_spec is node.meta['spec'] for single-output, or out_node.meta['spec'] for .out.
41+ """
2642 result = []
2743 for node in graph_module .graph .nodes :
28- if node .op == "call_function" and node .target == torch .ops .aten .clone .out :
29- if node .args and "out" in node .kwargs :
30- self_node = node .args [0 ]
31- out_node = node .kwargs ["out" ]
32- result .append ((node , self_node , out_node ))
44+ if node .op != "call_function" :
45+ continue
46+ if node .target not in _CLONE_OPS :
47+ continue
48+ if not node .args :
49+ continue
50+ self_node = node .args [0 ]
51+ if "out" in node .kwargs :
52+ out_node = node .kwargs ["out" ]
53+ output_spec = out_node .meta .get ("spec" ) if isinstance (out_node , torch .fx .Node ) else None
54+ else :
55+ output_spec = node .meta .get ("spec" )
56+ if output_spec is not None :
57+ result .append ((node , self_node , output_spec ))
3358 return result
3459
3560
@@ -72,22 +97,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7297 m = M ().eval ()
7398 example = (torch .randn (1 , 3 , 8 , 8 ),)
7499 ep = export (m , example )
75- edge = to_edge (ep , compile_config = EdgeCompileConfig (_skip_dim_order = False ))
100+ edge = to_edge (ep , compile_config = EdgeCompileConfig (_skip_dim_order = True ))
76101 gm = edge .exported_program ().graph_module
77- SpecPropPass ()(gm )
78- clone_outs = _find_clone_out_nodes (gm )
79- self .assertGreater (len (clone_outs ), 0 , "graph should contain clone.out" )
80- for _node , self_node , out_node in clone_outs :
102+ pass_result = SpecPropPass ()(gm )
103+ gm = pass_result .graph_module
104+ clone_nodes = _find_clone_nodes (gm )
105+ self .assertGreater (len (clone_nodes ), 0 , "graph should contain clone" )
106+ for _node , self_node , output_spec in clone_nodes :
81107 self_spec = self_node .meta .get ("spec" )
82- out_spec = out_node .meta .get ("spec" )
83108 self .assertIsNotNone (self_spec )
84- self .assertIsNotNone (out_spec )
109+ self .assertIsNotNone (output_spec )
85110 self .assertEqual (
86- out_spec .dim_order ,
111+ output_spec .dim_order ,
87112 self_spec .dim_order ,
88113 "out dim_order should match self (contiguous)" ,
89114 )
90- self .assertEqual (list (out_spec .dim_order ), [0 , 1 , 2 , 3 ])
115+ self .assertEqual (list (output_spec .dim_order ), [0 , 1 , 2 , 3 ])
91116
92117 def test_fp16_conv_clone_channels_last (self ) -> None :
93118 class M (torch .nn .Module ):
@@ -96,28 +121,29 @@ def __init__(self) -> None:
96121 self .conv = torch .nn .Conv2d (3 , 8 , 3 , padding = 1 )
97122
98123 def forward (self , x : torch .Tensor ) -> torch .Tensor :
99- return self .conv (x ).clone ()
124+ # Explicit channels_last so the traced FakeTensor has channels_last strides.
125+ return self .conv (x ).to (memory_format = torch .channels_last ).clone ()
100126
101127 m = M ().to (torch .float16 ).eval ()
102128 example = (torch .randn (1 , 3 , 16 , 16 , dtype = torch .float16 ),)
103129 ep = export (m , example )
104- edge = to_edge (ep , compile_config = EdgeCompileConfig (_skip_dim_order = False ))
130+ edge = to_edge (ep , compile_config = EdgeCompileConfig (_skip_dim_order = True ))
105131 gm = edge .exported_program ().graph_module
106- SpecPropPass ()(gm )
107- clone_outs = _find_clone_out_nodes (gm )
108- self .assertGreater (len (clone_outs ), 0 )
109- for _node , self_node , out_node in clone_outs :
132+ pass_result = SpecPropPass ()(gm )
133+ gm = pass_result .graph_module
134+ clone_nodes = _find_clone_nodes (gm )
135+ self .assertGreater (len (clone_nodes ), 0 )
136+ for _node , self_node , output_spec in clone_nodes :
110137 self_spec = self_node .meta .get ("spec" )
111- out_spec = out_node .meta .get ("spec" )
112138 self .assertIsNotNone (self_spec )
113- self .assertIsNotNone (out_spec )
139+ self .assertIsNotNone (output_spec )
114140 self .assertEqual (
115- out_spec .dim_order ,
141+ output_spec .dim_order ,
116142 self_spec .dim_order ,
117143 "out dim_order should match self (channels_last from conv)" ,
118144 )
119145 self .assertEqual (
120- list (out_spec .dim_order ),
146+ list (output_spec .dim_order ),
121147 [0 , 2 , 3 , 1 ],
122148 "conv output is channels_last" ,
123149 )
@@ -135,20 +161,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
135161 m = M ().to (torch .float16 ).eval ()
136162 example = (torch .randn (1 , 3 , 16 , 16 , dtype = torch .float16 ),)
137163 ep = export (m , example )
138- edge = to_edge (ep , compile_config = EdgeCompileConfig (_skip_dim_order = False ))
164+ edge = to_edge (ep , compile_config = EdgeCompileConfig (_skip_dim_order = True ))
139165 gm = edge .exported_program ().graph_module
140- SpecPropPass ()(gm )
141- clone_outs = _find_clone_out_nodes (gm )
142- self .assertGreater (len (clone_outs ), 0 )
143- for _node , self_node , out_node in clone_outs :
166+ pass_result = SpecPropPass ()(gm )
167+ gm = pass_result .graph_module
168+ clone_nodes = _find_clone_nodes (gm )
169+ self .assertGreater (len (clone_nodes ), 0 )
170+ for _node , self_node , output_spec in clone_nodes :
144171 self_spec = self_node .meta .get ("spec" )
145- out_spec = out_node .meta .get ("spec" )
146172 self .assertIsNotNone (self_spec )
147- self .assertIsNotNone (out_spec )
173+ self .assertIsNotNone (output_spec )
148174 self .assertEqual (
149- out_spec .dim_order ,
175+ output_spec .dim_order ,
150176 self_spec .dim_order ,
151- "dim_order should propagate through relu to clone.out " ,
177+ "dim_order should propagate through relu to clone" ,
152178 )
153179
154180
0 commit comments