Skip to content

Commit 4c20ef1

Browse files
nefainlNefAI
andauthored
fix(MemoryFormatOpsPass): preserve input dim_order for clone/to_copy with no memory_format kwarg (#17611)
## Summary Fixes #16032 This PR fixes `MemoryFormatOpsPass` to correctly handle `torch.preserve_format` semantics for `clone()` and `_to_copy.default` operations. **Root cause:** When `clone()` or `_to_copy` is called without an explicit `memory_format` kwarg, the pass was defaulting to `torch.contiguous_format`, causing the output `dim_order` to be `[0,1,2,3]` (contiguous) even when the input was channels-last `[0,2,3,1]`. This caused runtime assertion failures: ``` Code=18 InvalidArgument: tensors_have_same_dim_order(self, out) ``` **Fix:** Change the default from `torch.contiguous_format` to `torch.preserve_format`, and derive `dim_order` from the input tensor's `dim_order()` when preserve_format is used. This is a minimal, focused fix following the guidance from @GregoryComer in the discussion on PR #17463. ## Changes - **`exir/passes/memory_format_ops_pass.py`** (+29/-5 lines): - Default `memory_format` to `torch.preserve_format` instead of `torch.contiguous_format` - When preserve_format, derive `dim_order` from `input_tensor.dim_order()` - Fallback to contiguous if no input tensor available (e.g., `empty()`) - **`exir/tests/test_passes.py`** (+130 lines): - `test_clone_no_kwarg_preserves_channels_last_dim_order`: Core repro case for #16032 - `test_clone_contiguous_format_kwarg_stays_contiguous`: Regression guard - `test_to_copy_no_kwarg_preserves_channels_last_dim_order`: Verifies `_to_copy.default` path ## Standalone Reproduction ```python import torch from torch.export import export from executorch.exir import to_edge, EdgeCompileConfig class ConvClone(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1) def forward(self, x): return self.conv(x).clone() model = ConvClone().to(memory_format=torch.channels_last) x = torch.randn(1, 3, 8, 8).to(memory_format=torch.channels_last) exported = export(model, (x,)) edge = to_edge(exported, compile_config=EdgeCompileConfig(_skip_dim_order=False)) # Before fix: clone node has dim_order=(0,1,2,3) - BUG # After fix: clone node has dim_order=(0,2,3,1) - CORRECT for node in edge.exported_program().graph_module.graph.nodes: if "_clone_dim_order" in str(node.target): print(f"clone dim_order: {tuple(node.meta['val'].dim_order())}") ``` ## Test Plan - [x] All 3 new tests pass - [x] Verified fix with standalone reproduction script - [x] No changes to existing tests required ## Related - Fixes #16032 - Supersedes #17463 (this is the minimal fix extracted from that PR per reviewer feedback) --------- Co-authored-by: NefAI <info@nefai.nl>
1 parent 584ef68 commit 4c20ef1

2 files changed

Lines changed: 181 additions & 6 deletions

File tree

exir/passes/memory_format_ops_pass.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import copy
88
import logging
9+
from typing import List, Optional
910

1011
import torch
1112
from executorch.exir.dialects.edge._ops import EdgeOpOverload
@@ -40,22 +41,41 @@ def call_operator(self, op, args, kwargs, meta):
4041
# new kwargs with dim_order, and no memory_format for the new op
4142
nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable
4243

43-
# get the target memory format for the EdgeOp
44-
mem_format = nkwargs.pop("memory_format", torch.contiguous_format)
44+
# Get the target memory format for the EdgeOp, defaulting to
45+
# preserve_format (clone() with no memory_format kwarg preserves
46+
# the input's layout instead of forcing contiguous).
47+
mem_format = nkwargs.pop("memory_format", torch.preserve_format)
4548

46-
# can always get the shape, assuming rank is specialized
49+
# Get input tensor and ndim
50+
input_tensor: Optional[torch.Tensor] = None
4751
if isinstance(args[0], ProxyValue) and args[0].is_tensor():
48-
ndim = args[0].to_tensor().dim()
52+
input_tensor = args[0].to_tensor()
53+
ndim = input_tensor.dim()
4954
elif isinstance(args[0], torch.Tensor):
50-
ndim = args[0].dim()
55+
input_tensor = args[0]
56+
ndim = input_tensor.dim()
5157
elif isinstance(args[0], torch.fx.immutable_collections.immutable_list):
5258
ndim = len(args[0])
5359
else:
5460
assert (
5561
0
5662
), f"Expecting a Tensor, a ProxyValue, or a Sequence, but got {type(args[0])}"
5763

58-
nkwargs["dim_order"] = get_dim_order(mem_format, ndim)
64+
# Derive dim_order based on memory format
65+
dim_order: List[int]
66+
if mem_format in (None, torch.preserve_format):
67+
# preserve_format: inherit dim_order from input tensor
68+
if input_tensor is not None:
69+
dim_order = [int(d) for d in input_tensor.dim_order()]
70+
else:
71+
# Fallback to contiguous if no single input tensor is available
72+
# (e.g. list inputs like torch.stack).
73+
dim_order = list(range(ndim))
74+
else:
75+
# Explicit memory format (contiguous_format, channels_last, etc.)
76+
dim_order = get_dim_order(mem_format, ndim)
77+
78+
nkwargs["dim_order"] = dim_order
5979
logger.debug(
6080
f"{op.__name__} = rank: {ndim}, memory_format: {mem_format}."
6181
f" {DimOrderOpsMap[op].__name__} = dim_order: {nkwargs['dim_order']}"

exir/tests/test_passes.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2502,6 +2502,161 @@ def test_convert_constant_dim_order_to_contiguous(self):
25022502
)
25032503

25042504

2505+
class TestMemoryFormatOpsPassPreserveFormat(unittest.TestCase):
2506+
"""
2507+
Tests for MemoryFormatOpsPass preserve_format semantics.
2508+
"""
2509+
2510+
def test_clone_no_kwarg_preserves_channels_last_dim_order(self) -> None:
2511+
"""
2512+
Verify that clone() on a channels-last input with no memory_format kwarg
2513+
produces a _clone_dim_order node with channels-last dim_order (0,2,3,1).
2514+
"""
2515+
2516+
class ConvClone(torch.nn.Module):
2517+
def __init__(self):
2518+
super().__init__()
2519+
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1)
2520+
2521+
def forward(self, x):
2522+
return self.conv(x).clone()
2523+
2524+
model = ConvClone().to(memory_format=torch.channels_last)
2525+
x = torch.randn(1, 3, 8, 8).to(memory_format=torch.channels_last)
2526+
2527+
# Run the model and verify that the output tensor preserves channels-last
2528+
# layout when no memory_format kwarg is provided.
2529+
with torch.no_grad():
2530+
y = model(x)
2531+
self.assertTrue(
2532+
y.is_contiguous(memory_format=torch.channels_last),
2533+
f"clone() without memory_format kwarg should preserve channels-last layout, got strides {y.stride()}",
2534+
)
2535+
2536+
ep = torch.export.export(model, (x,))
2537+
edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False))
2538+
2539+
# Find the _clone_dim_order node and check its dim_order
2540+
found_clone = False
2541+
for node in edge.exported_program().graph_module.graph.nodes:
2542+
if node.op == "call_function" and "_clone_dim_order" in str(node.target):
2543+
found_clone = True
2544+
spec = node.meta.get("val")
2545+
self.assertIsNotNone(spec, "Clone node should have meta['val']")
2546+
dim_order = tuple(spec.dim_order())
2547+
self.assertEqual(
2548+
dim_order,
2549+
(0, 2, 3, 1),
2550+
f"Clone should preserve channels-last dim_order, got {dim_order}",
2551+
)
2552+
break
2553+
2554+
self.assertTrue(found_clone, "Should find a _clone_dim_order node in the graph")
2555+
2556+
def test_clone_contiguous_format_kwarg_stays_contiguous(self) -> None:
2557+
"""
2558+
Regression guard: explicit contiguous_format should produce contiguous dim_order.
2559+
2560+
Note: When clone(memory_format=contiguous_format) is called on a channels-last
2561+
input, this is a layout-transforming operation. After export, this typically
2562+
lowers to _to_dim_order_copy (not _clone_dim_order) because it changes the
2563+
memory layout. We check for both node types to be robust.
2564+
"""
2565+
2566+
class CloneContiguousModel(torch.nn.Module):
2567+
def forward(self, x):
2568+
return x.clone(memory_format=torch.contiguous_format)
2569+
2570+
model = CloneContiguousModel()
2571+
x = torch.randn(1, 3, 8, 8).to(memory_format=torch.channels_last)
2572+
2573+
# Run the model and verify that the explicit contiguous_format kwarg
2574+
# produces a contiguous output layout (not channels-last).
2575+
with torch.no_grad():
2576+
y = model(x)
2577+
self.assertTrue(
2578+
y.is_contiguous(),
2579+
f"clone(memory_format=contiguous_format) should produce contiguous layout, got strides {y.stride()}",
2580+
)
2581+
self.assertFalse(
2582+
y.is_contiguous(memory_format=torch.channels_last),
2583+
"clone(memory_format=contiguous_format) should not preserve channels-last layout",
2584+
)
2585+
2586+
ep = torch.export.export(model, (x,))
2587+
edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False))
2588+
2589+
# Find the dim_order copy node and check its dim_order.
2590+
# This may be _to_dim_order_copy (layout transform) or _clone_dim_order.
2591+
found_copy = False
2592+
for node in edge.exported_program().graph_module.graph.nodes:
2593+
if node.op == "call_function" and (
2594+
"_clone_dim_order" in str(node.target)
2595+
or "_to_dim_order_copy" in str(node.target)
2596+
):
2597+
found_copy = True
2598+
spec = node.meta.get("val")
2599+
self.assertIsNotNone(spec, "Copy node should have meta['val']")
2600+
dim_order = tuple(spec.dim_order())
2601+
self.assertEqual(
2602+
dim_order,
2603+
(0, 1, 2, 3),
2604+
f"Explicit contiguous clone should have contiguous dim_order, got {dim_order}",
2605+
)
2606+
break
2607+
2608+
self.assertTrue(
2609+
found_copy, "Should find a _clone_dim_order or _to_dim_order_copy node"
2610+
)
2611+
2612+
def test_to_copy_no_kwarg_preserves_channels_last_dim_order(self) -> None:
2613+
"""
2614+
Verify that tensor.to(dtype=...) with no memory_format kwarg preserves
2615+
the input's dim_order (preserve_format semantics).
2616+
2617+
This tests the _to_copy.default path in MemoryFormatOpsPass.
2618+
"""
2619+
2620+
class ToCopyModel(torch.nn.Module):
2621+
def forward(self, x):
2622+
# .to(dtype=...) with no memory_format → preserve_format semantics
2623+
return x.to(dtype=torch.float32)
2624+
2625+
model = ToCopyModel()
2626+
x = torch.randn(1, 3, 8, 8, dtype=torch.float16).to(
2627+
memory_format=torch.channels_last
2628+
)
2629+
2630+
# Run the model and verify that tensor.to(dtype=...) with no memory_format
2631+
# kwarg preserves channels-last layout on the output tensor.
2632+
with torch.no_grad():
2633+
y = model(x)
2634+
self.assertTrue(
2635+
y.is_contiguous(memory_format=torch.channels_last),
2636+
f"to(dtype=...) without memory_format kwarg should preserve channels-last layout, got strides {y.stride()}",
2637+
)
2638+
2639+
ep = torch.export.export(model, (x,))
2640+
edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False))
2641+
2642+
# Find the _to_dim_order_copy node and verify it preserves channels-last
2643+
found_copy = False
2644+
for node in edge.exported_program().graph_module.graph.nodes:
2645+
if node.op == "call_function" and "_to_dim_order_copy" in str(node.target):
2646+
found_copy = True
2647+
spec = node.meta.get("val")
2648+
self.assertIsNotNone(spec, "Copy node should have meta['val']")
2649+
dim_order = tuple(spec.dim_order())
2650+
self.assertEqual(
2651+
dim_order,
2652+
(0, 2, 3, 1),
2653+
f"to(dtype=...) should preserve channels-last dim_order, got {dim_order}",
2654+
)
2655+
break
2656+
2657+
self.assertTrue(found_copy, "Should find a _to_dim_order_copy node")
2658+
2659+
25052660
class TestCSEPass(unittest.TestCase):
25062661
"""Tests for Common Subexpression Elimination pass."""
25072662

0 commit comments

Comments
 (0)