Skip to content

Commit d4fa140

Browse files
NefAIcursoragent
andcommitted
fix: use explicit dim_order kwarg for _clone_dim_order output spec in SpecPropPass
For _clone_dim_order and _to_dim_order_copy, output layout is determined by the op's dim_order argument (e.g. from clone(memory_format=channels_last)), not by the input tensor. SpecPropPass was propagating from the input's FakeTensor, so contiguous input produced a contiguous output spec while the kernel received dim_order=[0,2,3,1], causing runtime InvalidArgument. - dim_order_utils: add _is_dim_order_op_with_explicit_arg and get_explicit_output_dim_order(node) to read dim_order from kwargs - spec_prop_pass: set output spec dim_order from explicit kwargs when present; otherwise keep propagating from primary input (format-preserving) Fixes test_op_clone_dim_order_propagation and related memory format tests. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent eb5a799 commit d4fa140

2 files changed

Lines changed: 62 additions & 24 deletions

File tree

exir/passes/dim_order_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,22 @@
1818
exir_ops = None # type: ignore[assignment]
1919

2020

21+
def _is_dim_order_op_with_explicit_arg(op: object) -> bool:
22+
"""True if the op takes an explicit dim_order kwarg (e.g. _clone_dim_order, _to_dim_order_copy)."""
23+
if exir_ops is None:
24+
return False
25+
return op in (
26+
exir_ops.edge.dim_order_ops._clone_dim_order.default,
27+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
28+
) or (
29+
hasattr(exir_ops.edge.dim_order_ops._clone_dim_order, "out")
30+
and op == exir_ops.edge.dim_order_ops._clone_dim_order.out
31+
) or (
32+
hasattr(exir_ops.edge.dim_order_ops._to_dim_order_copy, "out")
33+
and op == exir_ops.edge.dim_order_ops._to_dim_order_copy.out
34+
)
35+
36+
2137
def _format_preserving_ops() -> Set[object]:
2238
"""Build set of format-preserving ops (aten and edge dialect)."""
2339
ops: Set[object] = {
@@ -61,6 +77,25 @@ def dim_order_from_fake_tensor(t: torch.Tensor) -> Optional[List[int]]:
6177
return None
6278

6379

80+
def get_explicit_output_dim_order(
81+
node: "torch.fx.Node",
82+
) -> Optional[List[int]]:
83+
"""
84+
If the node is a dim_order op (_clone_dim_order, _to_dim_order_copy) with
85+
an explicit dim_order in kwargs, return it. Otherwise return None so the
86+
caller can propagate from the primary input (format-preserving).
87+
"""
88+
if not _is_dim_order_op_with_explicit_arg(node.target):
89+
return None
90+
dim_order_val = node.kwargs.get("dim_order") if node.kwargs else None
91+
if dim_order_val is None:
92+
return None
93+
if isinstance(dim_order_val, (list, tuple)) and len(dim_order_val) > 0:
94+
if all(isinstance(i, int) for i in dim_order_val):
95+
return list(dim_order_val)
96+
return None
97+
98+
6499
def should_propagate_dim_order(op: object) -> bool:
65100
"""True if the op is format-preserving and we should propagate primary input dim_order to out."""
66101
return op in FORMAT_PRESERVING_OPS

exir/passes/spec_prop_pass.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from executorch.exir.pass_base import ExportPass, ProxyValue
1515
from executorch.exir.passes.dim_order_utils import (
1616
dim_order_from_fake_tensor,
17+
get_explicit_output_dim_order,
1718
should_propagate_dim_order,
1819
)
1920
from executorch.exir.tensor import TensorSpec
@@ -85,33 +86,35 @@ def get_spec(x):
8586
and should_propagate_dim_order(node.target)
8687
and node.args
8788
):
88-
# Propagate primary input dim_order for format-preserving ops (Fix #16032).
89-
# Handles both clone.out (out= kwarg) and clone.default (single output).
90-
self_val = node.args[0].meta.get("val")
91-
if self_val is not None:
92-
src_dim_order = dim_order_from_fake_tensor(self_val)
93-
if "out" in node.kwargs:
94-
out_arg = node.kwargs["out"]
95-
assert isinstance(
96-
out_arg, torch.fx.Node
97-
), (
98-
f"Expected clone.out 'out' to be fx.Node, got {type(out_arg)}"
89+
# Output dim_order: use explicit kwargs for dim_order ops
90+
# (_clone_dim_order, _to_dim_order_copy), else propagate from input (Fix #16032).
91+
if "out" in node.kwargs:
92+
out_arg = node.kwargs["out"]
93+
assert isinstance(
94+
out_arg, torch.fx.Node
95+
), (
96+
f"Expected clone.out 'out' to be fx.Node, got {type(out_arg)}"
97+
)
98+
out_spec = out_arg.meta.get("spec")
99+
else:
100+
out_spec = node.meta.get("spec")
101+
if out_spec is None and meta_val is not None:
102+
node.meta["spec"] = pytree.tree_map(
103+
make_spec, meta_val
99104
)
100-
out_spec = out_arg.meta.get("spec")
105+
out_spec = node.meta["spec"]
106+
if out_spec is not None and hasattr(out_spec, "dim_order"):
107+
explicit_dim_order = get_explicit_output_dim_order(node)
108+
if explicit_dim_order is not None:
109+
out_spec.dim_order = tuple(explicit_dim_order)
101110
else:
102-
# clone.default: ensure node has spec (ExportPass may not set it)
103-
out_spec = node.meta.get("spec")
104-
if out_spec is None and meta_val is not None:
105-
node.meta["spec"] = pytree.tree_map(
106-
make_spec, meta_val
111+
self_val = node.args[0].meta.get("val")
112+
if self_val is not None:
113+
src_dim_order = dim_order_from_fake_tensor(
114+
self_val
107115
)
108-
out_spec = node.meta["spec"]
109-
if (
110-
out_spec is not None
111-
and hasattr(out_spec, "dim_order")
112-
and src_dim_order is not None
113-
):
114-
out_spec.dim_order = tuple(src_dim_order)
116+
if src_dim_order is not None:
117+
out_spec.dim_order = tuple(src_dim_order)
115118
elif (
116119
node.op == "call_function"
117120
and node.target == executorch_call_delegate

0 commit comments

Comments
 (0)