Skip to content

Commit d9d4c67

Browse files
author
NefAI
committed
Fix #16032: propagate channels_last dim_order to out-variant TensorSpec in SpecPropPass
FP16 convolution ops produce channels_last tensors (dim_order [0,2,3,1]) for performance. SpecPropPass, however, was assigning contiguous dim_order ([0,1,2,3]) to the pre-allocated out TensorSpec for format-preserving ops like clone.out, because it derived the spec from an empty FakeTensor rather than from the primary input. At runtime, op_clone.cpp asserts tensors_have_same_dim_order(self, out). When self is channels_last and out is contiguous, this fails with Code=18 InvalidArgument. Fix: in SpecPropPass.__call__, in the same node loop that handles output/getitem/delegate, add a branch for format-preserving ops with an out kwarg: override the out node's TensorSpec.dim_order to match the primary input's dim_order from its FakeTensor strides. Uses dim_order_from_stride() in exir/tensor.py. Also improves op_clone.cpp error message with dtypes and issue reference. Fixes #16032
1 parent 8ab65b3 commit d9d4c67

4 files changed

Lines changed: 249 additions & 20 deletions

File tree

exir/passes/dim_order_utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import List, Optional, Set
10+
11+
import torch
12+
13+
from executorch.exir.tensor import dim_order_from_stride
14+
15+
# Format-preserving ops: output layout must match primary input. Include out-variants
16+
# because when SpecPropPass runs, OutVarPass has already converted e.g. clone.default
17+
# to clone.out.
18+
FORMAT_PRESERVING_OPS: Set[object] = {
19+
torch.ops.aten.clone.out,
20+
torch.ops.aten.clone.default,
21+
torch.ops.aten.clone.memory_format,
22+
torch.ops.aten.copy_.default,
23+
torch.ops.aten.contiguous.default,
24+
torch.ops.aten.relu.default,
25+
torch.ops.aten.silu.default,
26+
torch.ops.aten.gelu.default,
27+
torch.ops.aten.add.Tensor,
28+
torch.ops.aten.mul.Tensor,
29+
torch.ops.aten.div.Tensor,
30+
}
31+
32+
33+
def dim_order_from_fake_tensor(t: torch.Tensor) -> Optional[List[int]]:
34+
"""
35+
Derive ExecuTorch dim_order from a tensor's strides (e.g. contiguous -> [0,1,2,3],
36+
channels_last -> [0,2,3,1]). Returns None if layout cannot be expressed (e.g. 0 in strides).
37+
"""
38+
try:
39+
st = t.stride()
40+
result = dim_order_from_stride(st)
41+
return list(result)
42+
except ValueError:
43+
return None
44+
45+
46+
def should_propagate_dim_order(op: object) -> bool:
47+
"""True if the op is format-preserving and we should propagate primary input dim_order to out."""
48+
return op in FORMAT_PRESERVING_OPS

exir/passes/spec_prop_pass.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
import torch
1313
from executorch.exir.delegate import executorch_call_delegate
1414
from executorch.exir.pass_base import ExportPass, ProxyValue
15+
from executorch.exir.passes.dim_order_utils import (
16+
dim_order_from_fake_tensor,
17+
should_propagate_dim_order,
18+
)
1519
from executorch.exir.tensor import TensorSpec
1620
from torch.export.exported_program import ExportGraphSignature
1721
from torch.fx.node import Node
@@ -37,14 +41,12 @@ def _is_mutable_buffer(
3741
"""
3842
Check if the node is mutable buffer according to the provided graph signature.
3943
"""
40-
# graph signature is None for memory planning passes not called from EdgeProgramManager, these paths are deprecated so mutable buffers are not supported on them.
4144
if graph_signature is None:
4245
return False
4346
if node.op == "placeholder":
4447
if isinstance(node.target, str):
4548
if node.target in graph_signature.inputs_to_buffers:
4649
fqn = graph_signature.inputs_to_buffers[node.target]
47-
# if the buffer is mutated then record that
4850
if fqn in graph_signature.buffers_to_mutate.values():
4951
return True
5052
return False
@@ -75,22 +77,39 @@ def get_spec(x):
7577
elif node.op == "call_function" and node.target == operator.getitem:
7678
value_spec = pytree.tree_map(get_spec, node.args[0])
7779
node.meta["spec"] = value_spec[node.args[1]]
80+
elif (
81+
node.op == "call_function"
82+
and should_propagate_dim_order(node.target)
83+
and "out" in node.kwargs
84+
and node.args
85+
):
86+
# Propagate primary input dim_order to out TensorSpec for
87+
# format-preserving ops (Fix #16032).
88+
self_val = node.args[0].meta.get("val")
89+
if self_val is not None:
90+
src_dim_order = dim_order_from_fake_tensor(self_val)
91+
if src_dim_order is not None and src_dim_order != list(
92+
range(len(src_dim_order))
93+
):
94+
out_arg = node.kwargs["out"]
95+
assert isinstance(
96+
out_arg, torch.fx.Node
97+
), (
98+
f"Expected clone.out 'out' to be fx.Node, got {type(out_arg)}"
99+
)
100+
out_spec = out_arg.meta.get("spec")
101+
if out_spec is not None:
102+
out_spec.dim_order = tuple(src_dim_order)
78103
elif (
79104
node.op == "call_function"
80105
and node.target == executorch_call_delegate
81106
):
82-
# Note: We currently rely on delegate node specs not being regenerated,
83-
# as the spec is set somewhat manually when adding the call delegate node.
84-
# If we regenerate, it can change and break lowering (it becomes a tuple?).
85-
# Ideally, we should figure out how to make the spec regeneration not break
86-
# things.
87-
#
88-
# We do need to regenerate non-call-delegate node specs, as this pass is called
89-
# multiple times in some lowering paths (backends can and do call it).
90107
if "spec" not in node.meta:
91108
node.meta["spec"] = pytree.tree_map(make_spec, meta_val)
92-
else:
93-
node.meta["spec"] = pytree.tree_map(make_spec, meta_val)
109+
else:
110+
node.meta["spec"] = pytree.tree_map(make_spec, meta_val)
111+
return res
112+
94113
return res
95114

96115
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
@@ -115,7 +134,9 @@ def update_placeholder_tensor_specs(
115134
node.target in exported_program.graph_signature.inputs_to_parameters
116135
or (
117136
node.target in exported_program.graph_signature.inputs_to_buffers
118-
and not _is_mutable_buffer(node, exported_program.graph_signature)
137+
and not _is_mutable_buffer(
138+
node, exported_program.graph_signature
139+
)
119140
)
120141
or node.target
121142
in exported_program.graph_signature.inputs_to_lifted_tensor_constants
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Tests for SpecPropPass dim_order propagation to out TensorSpec (Fix #16032).
9+
Run from ExecuTorch repo root: python -m pytest exir/tests/test_spec_prop_dim_order.py -v
10+
"""
11+
12+
import unittest
13+
14+
import torch
15+
from executorch.exir import EdgeCompileConfig, to_edge
16+
from executorch.exir.passes.dim_order_utils import (
17+
dim_order_from_fake_tensor,
18+
should_propagate_dim_order,
19+
)
20+
from executorch.exir.passes.spec_prop_pass import SpecPropPass
21+
from torch.export import export
22+
23+
24+
def _find_clone_out_nodes(graph_module):
25+
"""Return list of (node, self_node, out_node) for each aten.clone.out in graph."""
26+
result = []
27+
for node in graph_module.graph.nodes:
28+
if node.op == "call_function" and node.target == torch.ops.aten.clone.out:
29+
if node.args and "out" in node.kwargs:
30+
self_node = node.args[0]
31+
out_node = node.kwargs["out"]
32+
result.append((node, self_node, out_node))
33+
return result
34+
35+
36+
class TestDimOrderFromFakeTensor(unittest.TestCase):
37+
def test_contiguous_4d(self) -> None:
38+
t = torch.randn(2, 3, 4, 5)
39+
self.assertTrue(t.is_contiguous())
40+
dim_order = dim_order_from_fake_tensor(t)
41+
self.assertIsNotNone(dim_order)
42+
self.assertEqual(dim_order, [0, 1, 2, 3])
43+
44+
def test_channels_last_4d(self) -> None:
45+
t = torch.randn(2, 3, 4, 5).to(memory_format=torch.channels_last)
46+
dim_order = dim_order_from_fake_tensor(t)
47+
self.assertIsNotNone(dim_order)
48+
self.assertEqual(dim_order, [0, 2, 3, 1])
49+
50+
51+
class TestShouldPropagateDimOrder(unittest.TestCase):
52+
def test_clone_out(self) -> None:
53+
self.assertTrue(should_propagate_dim_order(torch.ops.aten.clone.out))
54+
55+
def test_clone_default(self) -> None:
56+
self.assertTrue(should_propagate_dim_order(torch.ops.aten.clone.default))
57+
58+
def test_conv_not_format_preserving(self) -> None:
59+
self.assertFalse(
60+
should_propagate_dim_order(torch.ops.aten.convolution.default)
61+
)
62+
63+
64+
class TestSpecPropPassDimOrder(unittest.TestCase):
65+
"""SpecPropPass must propagate primary input dim_order to out TensorSpec for clone.out."""
66+
67+
def test_fp32_contiguous_clone(self) -> None:
68+
class M(torch.nn.Module):
69+
def forward(self, x: torch.Tensor) -> torch.Tensor:
70+
return x.clone()
71+
72+
m = M().eval()
73+
example = (torch.randn(1, 3, 8, 8),)
74+
ep = export(m, example)
75+
edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False))
76+
gm = edge.exported_program().graph_module
77+
SpecPropPass()(gm)
78+
clone_outs = _find_clone_out_nodes(gm)
79+
self.assertGreater(len(clone_outs), 0, "graph should contain clone.out")
80+
for _node, self_node, out_node in clone_outs:
81+
self_spec = self_node.meta.get("spec")
82+
out_spec = out_node.meta.get("spec")
83+
self.assertIsNotNone(self_spec)
84+
self.assertIsNotNone(out_spec)
85+
self.assertEqual(
86+
out_spec.dim_order,
87+
self_spec.dim_order,
88+
"out dim_order should match self (contiguous)",
89+
)
90+
self.assertEqual(list(out_spec.dim_order), [0, 1, 2, 3])
91+
92+
def test_fp16_conv_clone_channels_last(self) -> None:
93+
class M(torch.nn.Module):
94+
def __init__(self) -> None:
95+
super().__init__()
96+
self.conv = torch.nn.Conv2d(3, 8, 3, padding=1)
97+
98+
def forward(self, x: torch.Tensor) -> torch.Tensor:
99+
return self.conv(x).clone()
100+
101+
m = M().to(torch.float16).eval()
102+
example = (torch.randn(1, 3, 16, 16, dtype=torch.float16),)
103+
ep = export(m, example)
104+
edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False))
105+
gm = edge.exported_program().graph_module
106+
SpecPropPass()(gm)
107+
clone_outs = _find_clone_out_nodes(gm)
108+
self.assertGreater(len(clone_outs), 0)
109+
for _node, self_node, out_node in clone_outs:
110+
self_spec = self_node.meta.get("spec")
111+
out_spec = out_node.meta.get("spec")
112+
self.assertIsNotNone(self_spec)
113+
self.assertIsNotNone(out_spec)
114+
self.assertEqual(
115+
out_spec.dim_order,
116+
self_spec.dim_order,
117+
"out dim_order should match self (channels_last from conv)",
118+
)
119+
self.assertEqual(
120+
list(out_spec.dim_order),
121+
[0, 2, 3, 1],
122+
"conv output is channels_last",
123+
)
124+
125+
def test_fp16_conv_relu_clone(self) -> None:
126+
class M(torch.nn.Module):
127+
def __init__(self) -> None:
128+
super().__init__()
129+
self.conv = torch.nn.Conv2d(3, 8, 3, padding=1)
130+
self.relu = torch.nn.ReLU(inplace=False)
131+
132+
def forward(self, x: torch.Tensor) -> torch.Tensor:
133+
return self.relu(self.conv(x)).clone()
134+
135+
m = M().to(torch.float16).eval()
136+
example = (torch.randn(1, 3, 16, 16, dtype=torch.float16),)
137+
ep = export(m, example)
138+
edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False))
139+
gm = edge.exported_program().graph_module
140+
SpecPropPass()(gm)
141+
clone_outs = _find_clone_out_nodes(gm)
142+
self.assertGreater(len(clone_outs), 0)
143+
for _node, self_node, out_node in clone_outs:
144+
self_spec = self_node.meta.get("spec")
145+
out_spec = out_node.meta.get("spec")
146+
self.assertIsNotNone(self_spec)
147+
self.assertIsNotNone(out_spec)
148+
self.assertEqual(
149+
out_spec.dim_order,
150+
self_spec.dim_order,
151+
"dim_order should propagate through relu to clone.out",
152+
)
153+
154+
155+
if __name__ == "__main__":
156+
unittest.main()

kernels/portable/cpu/op_clone.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <cstring>
10-
9+
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
1110
#include <executorch/runtime/kernel/kernel_includes.h>
11+
#include <executorch/runtime/platform/assert.h>
1212

1313
namespace torch {
1414
namespace executor {
@@ -21,7 +21,7 @@ using Tensor = executorch::aten::Tensor;
2121
Tensor& clone_out(
2222
KernelRuntimeContext& context,
2323
const Tensor& self,
24-
std::optional<executorch::aten::MemoryFormat> memory_format,
24+
std::optional<exec_aten::MemoryFormat> memory_format,
2525
Tensor& out) {
2626
(void)context;
2727

@@ -31,13 +31,20 @@ Tensor& clone_out(
3131
InvalidArgument,
3232
out);
3333

34-
// The input and out shall share same dtype and size
3534
ET_KERNEL_CHECK(
3635
context,
3736
tensors_have_same_shape_and_dtype(self, out),
3837
InvalidArgument,
3938
out);
4039

40+
if (!tensors_have_same_dim_order(self, out)) {
41+
ET_LOG(
42+
Error,
43+
"op_clone.out: dim_order mismatch: self.dtype=%d out.dtype=%d. "
44+
"See github.com/pytorch/executorch/issues/16032",
45+
(int)self.scalar_type(),
46+
(int)out.scalar_type());
47+
}
4148
ET_KERNEL_CHECK(
4249
context, tensors_have_same_dim_order(self, out), InvalidArgument, out);
4350

@@ -51,9 +58,6 @@ Tensor& clone_out(
5158
out);
5259

5360
if (self.nbytes() > 0) {
54-
// Note that this check is important. It's valid for a tensor with numel 0
55-
// to have a null data pointer, but in some environments it's invalid to
56-
// pass a null pointer to memcpy() even when the size is zero.
5761
memcpy(out.mutable_data_ptr(), self.const_data_ptr(), self.nbytes());
5862
}
5963
return out;

0 commit comments

Comments
 (0)