Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions exir/passes/memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import copy
import logging
from typing import List, Optional

import torch
from executorch.exir.dialects.edge._ops import EdgeOpOverload
Expand Down Expand Up @@ -41,21 +42,39 @@
nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable

# get the target memory format for the EdgeOp
mem_format = nkwargs.pop("memory_format", torch.contiguous_format)
# Default to preserve_format: clone() with no memory_format kwarg should
Comment thread
nefainl marked this conversation as resolved.
Outdated
# preserve the input's layout, not force contiguous. Issue #16032.
mem_format = nkwargs.pop("memory_format", torch.preserve_format)

# can always get the shape, assuming rank is specialized
# Get input tensor and ndim
input_tensor: Optional[torch.Tensor] = None
if isinstance(args[0], ProxyValue) and args[0].is_tensor():
ndim = args[0].to_tensor().dim()
input_tensor = args[0].to_tensor()
ndim = input_tensor.dim()
elif isinstance(args[0], torch.Tensor):
ndim = args[0].dim()
input_tensor = args[0]
ndim = input_tensor.dim()
elif isinstance(args[0], torch.fx.immutable_collections.immutable_list):
ndim = len(args[0])
else:
assert (
0
), f"Expecting a Tensor, a ProxyValue, or a Sequence, but got {type(args[0])}"

nkwargs["dim_order"] = get_dim_order(mem_format, ndim)
# Derive dim_order based on memory format
dim_order: List[int]
if mem_format in (None, torch.preserve_format):
# preserve_format: inherit dim_order from input tensor
if input_tensor is not None:
dim_order = list(int(d) for d in input_tensor.dim_order())

Check warning on line 69 in exir/passes/memory_format_ops_pass.py

View workflow job for this annotation

GitHub Actions / lintrunner

FLAKE8 C400

Unnecessary generator - rewrite as a list comprehension. See https://pypi.org/project/flake8-comprehensions/#rules.
else:
# Fallback to contiguous if no input tensor available
dim_order = list(range(ndim))
else:
# Explicit memory format (contiguous_format, channels_last, etc.)
dim_order = get_dim_order(mem_format, ndim)

nkwargs["dim_order"] = dim_order
logger.debug(
f"{op.__name__} = rank: {ndim}, memory_format: {mem_format}."
f" {DimOrderOpsMap[op].__name__} = dim_order: {nkwargs['dim_order']}"
Expand Down
130 changes: 130 additions & 0 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2499,3 +2499,133 @@ def test_convert_constant_dim_order_to_contiguous(self):
modified_const.is_contiguous(),
f"Constant should be contiguous after pass, got strides {modified_const.stride()}",
)


class TestMemoryFormatOpsPassPreserveFormat(unittest.TestCase):
"""
Tests for MemoryFormatOpsPass preserve_format semantics.

Issue #16032: clone() with no memory_format kwarg should preserve the input's
Comment thread
nefainl marked this conversation as resolved.
Outdated
dim_order, not default to contiguous. This caused runtime assertion failures
when cloning channels-last tensors.
"""

def test_clone_no_kwarg_preserves_channels_last_dim_order(self) -> None:
"""
Verify that clone() on a channels-last input with no memory_format kwarg
produces a _clone_dim_order node with channels-last dim_order (0,2,3,1).

This is the core reproduction case for issue #16032.
Comment thread
nefainl marked this conversation as resolved.
Outdated
"""

class ConvClone(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1)

def forward(self, x):
return self.conv(x).clone()

model = ConvClone().to(memory_format=torch.channels_last)
x = torch.randn(1, 3, 8, 8).to(memory_format=torch.channels_last)

ep = torch.export.export(model, (x,))
edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False))

# Find the _clone_dim_order node and check its dim_order
found_clone = False
for node in edge.exported_program().graph_module.graph.nodes:
if node.op == "call_function" and "_clone_dim_order" in str(node.target):
found_clone = True
spec = node.meta.get("val")
self.assertIsNotNone(spec, "Clone node should have meta['val']")
dim_order = tuple(spec.dim_order())
self.assertEqual(
dim_order,
(0, 2, 3, 1),
f"Clone should preserve channels-last dim_order, got {dim_order}",
)
break

self.assertTrue(found_clone, "Should find a _clone_dim_order node in the graph")

def test_clone_contiguous_format_kwarg_stays_contiguous(self) -> None:
"""
Regression guard: explicit contiguous_format should produce contiguous dim_order.

Note: When clone(memory_format=contiguous_format) is called on a channels-last
input, this is a layout-transforming operation. After export, this typically
lowers to _to_dim_order_copy (not _clone_dim_order) because it changes the
memory layout. We check for both node types to be robust.
"""

class CloneContiguousModel(torch.nn.Module):
def forward(self, x):
return x.clone(memory_format=torch.contiguous_format)

model = CloneContiguousModel()
x = torch.randn(1, 3, 8, 8).to(memory_format=torch.channels_last)

ep = torch.export.export(model, (x,))
edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False))

# Find the dim_order copy node and check its dim_order.
# This may be _to_dim_order_copy (layout transform) or _clone_dim_order.
found_copy = False
for node in edge.exported_program().graph_module.graph.nodes:
if node.op == "call_function" and (
"_clone_dim_order" in str(node.target)
or "_to_dim_order_copy" in str(node.target)
):
found_copy = True
spec = node.meta.get("val")
self.assertIsNotNone(spec, "Copy node should have meta['val']")
dim_order = tuple(spec.dim_order())
self.assertEqual(
dim_order,
(0, 1, 2, 3),
f"Explicit contiguous clone should have contiguous dim_order, got {dim_order}",
)
break

self.assertTrue(
found_copy, "Should find a _clone_dim_order or _to_dim_order_copy node"
)

def test_to_copy_no_kwarg_preserves_channels_last_dim_order(self) -> None:
"""
Verify that tensor.to(dtype=...) with no memory_format kwarg preserves
the input's dim_order (preserve_format semantics).

This tests the _to_copy.default path in MemoryFormatOpsPass.
"""

class ToCopyModel(torch.nn.Module):
def forward(self, x):
# .to(dtype=...) with no memory_format → preserve_format semantics
return x.to(dtype=torch.float32)

model = ToCopyModel()
x = torch.randn(1, 3, 8, 8, dtype=torch.float16).to(
memory_format=torch.channels_last
)

ep = torch.export.export(model, (x,))
edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False))

# Find the _to_dim_order_copy node and verify it preserves channels-last
found_copy = False
for node in edge.exported_program().graph_module.graph.nodes:
if node.op == "call_function" and "_to_dim_order_copy" in str(node.target):
found_copy = True
spec = node.meta.get("val")
self.assertIsNotNone(spec, "Copy node should have meta['val']")
dim_order = tuple(spec.dim_order())
self.assertEqual(
dim_order,
(0, 2, 3, 1),
f"to(dtype=...) should preserve channels-last dim_order, got {dim_order}",
)
break

self.assertTrue(found_copy, "Should find a _to_dim_order_copy node")
Loading