Skip to content

Commit 25bb325

Browse files
NefAIcursoragent
andcommitted
fix(#16032): handle clone.default and edge ops, fix tests and spec init
- dim_order_utils: add edge clone ops, guard aten.clone.memory_format - spec_prop_pass: propagate dim_order for clone.default; ensure spec set when missing - test_spec_prop_dim_order: use pass result graph, find edge/aten clone, channels_last test Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent d9d4c67 commit 25bb325

3 files changed

Lines changed: 112 additions & 56 deletions

File tree

exir/passes/dim_order_utils.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,40 @@
1212

1313
from executorch.exir.tensor import dim_order_from_stride
1414

15+
try:
16+
from executorch.exir.dialects._ops import ops as exir_ops
17+
except ImportError:
18+
exir_ops = None # type: ignore[assignment]
19+
20+
21+
def _format_preserving_ops() -> Set[object]:
22+
"""Build set of format-preserving ops (aten and edge dialect)."""
23+
ops: Set[object] = {
24+
torch.ops.aten.clone.out,
25+
torch.ops.aten.clone.default,
26+
torch.ops.aten.copy_.default,
27+
torch.ops.aten.contiguous.default,
28+
torch.ops.aten.relu.default,
29+
torch.ops.aten.silu.default,
30+
torch.ops.aten.gelu.default,
31+
torch.ops.aten.add.Tensor,
32+
torch.ops.aten.mul.Tensor,
33+
torch.ops.aten.div.Tensor,
34+
}
35+
if hasattr(torch.ops.aten.clone, "memory_format"):
36+
ops.add(torch.ops.aten.clone.memory_format)
37+
if exir_ops is not None:
38+
ops.add(exir_ops.edge.aten.clone.default)
39+
ops.add(exir_ops.edge.dim_order_ops._clone_dim_order.default)
40+
if hasattr(exir_ops.edge.dim_order_ops._clone_dim_order, "out"):
41+
ops.add(exir_ops.edge.dim_order_ops._clone_dim_order.out)
42+
return ops
43+
44+
1545
# Format-preserving ops: output layout must match primary input. Include out-variants
1646
# because when SpecPropPass runs, OutVarPass has already converted e.g. clone.default
1747
# to clone.out.
18-
FORMAT_PRESERVING_OPS: Set[object] = {
19-
torch.ops.aten.clone.out,
20-
torch.ops.aten.clone.default,
21-
torch.ops.aten.clone.memory_format,
22-
torch.ops.aten.copy_.default,
23-
torch.ops.aten.contiguous.default,
24-
torch.ops.aten.relu.default,
25-
torch.ops.aten.silu.default,
26-
torch.ops.aten.gelu.default,
27-
torch.ops.aten.add.Tensor,
28-
torch.ops.aten.mul.Tensor,
29-
torch.ops.aten.div.Tensor,
30-
}
48+
FORMAT_PRESERVING_OPS: Set[object] = _format_preserving_ops()
3149

3250

3351
def dim_order_from_fake_tensor(t: torch.Tensor) -> Optional[List[int]]:

exir/passes/spec_prop_pass.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ def get_spec(x):
7272
if isinstance(module, torch.fx.GraphModule):
7373
for node in module.graph.nodes:
7474
meta_val = node.meta.get("val", None)
75+
# Ensure every node with val has a spec (base ExportPass may not set it).
76+
if "spec" not in node.meta and meta_val is not None:
77+
node.meta["spec"] = pytree.tree_map(make_spec, meta_val)
7578
if node.op == "output":
7679
node.meta["spec"] = pytree.tree_map(get_spec, node.args[0])
7780
elif node.op == "call_function" and node.target == operator.getitem:
@@ -80,26 +83,35 @@ def get_spec(x):
8083
elif (
8184
node.op == "call_function"
8285
and should_propagate_dim_order(node.target)
83-
and "out" in node.kwargs
8486
and node.args
8587
):
86-
# Propagate primary input dim_order to out TensorSpec for
87-
# format-preserving ops (Fix #16032).
88+
# Propagate primary input dim_order for format-preserving ops (Fix #16032).
89+
# Handles both clone.out (out= kwarg) and clone.default (single output).
8890
self_val = node.args[0].meta.get("val")
8991
if self_val is not None:
9092
src_dim_order = dim_order_from_fake_tensor(self_val)
91-
if src_dim_order is not None and src_dim_order != list(
92-
range(len(src_dim_order))
93-
):
93+
if "out" in node.kwargs:
9494
out_arg = node.kwargs["out"]
9595
assert isinstance(
9696
out_arg, torch.fx.Node
9797
), (
9898
f"Expected clone.out 'out' to be fx.Node, got {type(out_arg)}"
9999
)
100100
out_spec = out_arg.meta.get("spec")
101-
if out_spec is not None:
102-
out_spec.dim_order = tuple(src_dim_order)
101+
else:
102+
# clone.default: ensure node has spec (ExportPass may not set it)
103+
out_spec = node.meta.get("spec")
104+
if out_spec is None and meta_val is not None:
105+
node.meta["spec"] = pytree.tree_map(
106+
make_spec, meta_val
107+
)
108+
out_spec = node.meta["spec"]
109+
if (
110+
out_spec is not None
111+
and hasattr(out_spec, "dim_order")
112+
and src_dim_order is not None
113+
):
114+
out_spec.dim_order = tuple(src_dim_order)
103115
elif (
104116
node.op == "call_function"
105117
and node.target == executorch_call_delegate

exir/tests/test_spec_prop_dim_order.py

Lines changed: 61 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import torch
1515
from executorch.exir import EdgeCompileConfig, to_edge
16+
from executorch.exir.dialects._ops import ops as exir_ops
1617
from executorch.exir.passes.dim_order_utils import (
1718
dim_order_from_fake_tensor,
1819
should_propagate_dim_order,
@@ -21,15 +22,39 @@
2122
from torch.export import export
2223

2324

24-
def _find_clone_out_nodes(graph_module):
25-
"""Return list of (node, self_node, out_node) for each aten.clone.out in graph."""
25+
# Clone ops that may appear in the graph: aten (pre-OpReplacePass) or edge (after to_edge).
26+
_CLONE_OPS = (
27+
torch.ops.aten.clone.default,
28+
torch.ops.aten.clone.out,
29+
exir_ops.edge.aten.clone.default,
30+
exir_ops.edge.dim_order_ops._clone_dim_order.default,
31+
)
32+
if hasattr(exir_ops.edge.dim_order_ops._clone_dim_order, "out"):
33+
_CLONE_OPS = _CLONE_OPS + (exir_ops.edge.dim_order_ops._clone_dim_order.out,)
34+
35+
36+
def _find_clone_nodes(graph_module):
37+
"""
38+
Return list of (node, self_node, output_spec) for each clone in graph.
39+
to_edge uses edge ops (edge.aten.clone or edge.dim_order_ops._clone_dim_order).
40+
output_spec is node.meta['spec'] for single-output, or out_node.meta['spec'] for .out.
41+
"""
2642
result = []
2743
for node in graph_module.graph.nodes:
28-
if node.op == "call_function" and node.target == torch.ops.aten.clone.out:
29-
if node.args and "out" in node.kwargs:
30-
self_node = node.args[0]
31-
out_node = node.kwargs["out"]
32-
result.append((node, self_node, out_node))
44+
if node.op != "call_function":
45+
continue
46+
if node.target not in _CLONE_OPS:
47+
continue
48+
if not node.args:
49+
continue
50+
self_node = node.args[0]
51+
if "out" in node.kwargs:
52+
out_node = node.kwargs["out"]
53+
output_spec = out_node.meta.get("spec") if isinstance(out_node, torch.fx.Node) else None
54+
else:
55+
output_spec = node.meta.get("spec")
56+
if output_spec is not None:
57+
result.append((node, self_node, output_spec))
3358
return result
3459

3560

@@ -72,22 +97,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7297
m = M().eval()
7398
example = (torch.randn(1, 3, 8, 8),)
7499
ep = export(m, example)
75-
edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False))
100+
edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=True))
76101
gm = edge.exported_program().graph_module
77-
SpecPropPass()(gm)
78-
clone_outs = _find_clone_out_nodes(gm)
79-
self.assertGreater(len(clone_outs), 0, "graph should contain clone.out")
80-
for _node, self_node, out_node in clone_outs:
102+
pass_result = SpecPropPass()(gm)
103+
gm = pass_result.graph_module
104+
clone_nodes = _find_clone_nodes(gm)
105+
self.assertGreater(len(clone_nodes), 0, "graph should contain clone")
106+
for _node, self_node, output_spec in clone_nodes:
81107
self_spec = self_node.meta.get("spec")
82-
out_spec = out_node.meta.get("spec")
83108
self.assertIsNotNone(self_spec)
84-
self.assertIsNotNone(out_spec)
109+
self.assertIsNotNone(output_spec)
85110
self.assertEqual(
86-
out_spec.dim_order,
111+
output_spec.dim_order,
87112
self_spec.dim_order,
88113
"out dim_order should match self (contiguous)",
89114
)
90-
self.assertEqual(list(out_spec.dim_order), [0, 1, 2, 3])
115+
self.assertEqual(list(output_spec.dim_order), [0, 1, 2, 3])
91116

92117
def test_fp16_conv_clone_channels_last(self) -> None:
93118
class M(torch.nn.Module):
@@ -96,28 +121,29 @@ def __init__(self) -> None:
96121
self.conv = torch.nn.Conv2d(3, 8, 3, padding=1)
97122

98123
def forward(self, x: torch.Tensor) -> torch.Tensor:
99-
return self.conv(x).clone()
124+
# Explicit channels_last so the traced FakeTensor has channels_last strides.
125+
return self.conv(x).to(memory_format=torch.channels_last).clone()
100126

101127
m = M().to(torch.float16).eval()
102128
example = (torch.randn(1, 3, 16, 16, dtype=torch.float16),)
103129
ep = export(m, example)
104-
edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False))
130+
edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=True))
105131
gm = edge.exported_program().graph_module
106-
SpecPropPass()(gm)
107-
clone_outs = _find_clone_out_nodes(gm)
108-
self.assertGreater(len(clone_outs), 0)
109-
for _node, self_node, out_node in clone_outs:
132+
pass_result = SpecPropPass()(gm)
133+
gm = pass_result.graph_module
134+
clone_nodes = _find_clone_nodes(gm)
135+
self.assertGreater(len(clone_nodes), 0)
136+
for _node, self_node, output_spec in clone_nodes:
110137
self_spec = self_node.meta.get("spec")
111-
out_spec = out_node.meta.get("spec")
112138
self.assertIsNotNone(self_spec)
113-
self.assertIsNotNone(out_spec)
139+
self.assertIsNotNone(output_spec)
114140
self.assertEqual(
115-
out_spec.dim_order,
141+
output_spec.dim_order,
116142
self_spec.dim_order,
117143
"out dim_order should match self (channels_last from conv)",
118144
)
119145
self.assertEqual(
120-
list(out_spec.dim_order),
146+
list(output_spec.dim_order),
121147
[0, 2, 3, 1],
122148
"conv output is channels_last",
123149
)
@@ -135,20 +161,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
135161
m = M().to(torch.float16).eval()
136162
example = (torch.randn(1, 3, 16, 16, dtype=torch.float16),)
137163
ep = export(m, example)
138-
edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False))
164+
edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=True))
139165
gm = edge.exported_program().graph_module
140-
SpecPropPass()(gm)
141-
clone_outs = _find_clone_out_nodes(gm)
142-
self.assertGreater(len(clone_outs), 0)
143-
for _node, self_node, out_node in clone_outs:
166+
pass_result = SpecPropPass()(gm)
167+
gm = pass_result.graph_module
168+
clone_nodes = _find_clone_nodes(gm)
169+
self.assertGreater(len(clone_nodes), 0)
170+
for _node, self_node, output_spec in clone_nodes:
144171
self_spec = self_node.meta.get("spec")
145-
out_spec = out_node.meta.get("spec")
146172
self.assertIsNotNone(self_spec)
147-
self.assertIsNotNone(out_spec)
173+
self.assertIsNotNone(output_spec)
148174
self.assertEqual(
149-
out_spec.dim_order,
175+
output_spec.dim_order,
150176
self_spec.dim_order,
151-
"dim_order should propagate through relu to clone.out",
177+
"dim_order should propagate through relu to clone",
152178
)
153179

154180

0 commit comments

Comments
 (0)