Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions exir/passes/memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import copy
import logging
from typing import List, Optional

import torch
from executorch.exir.dialects.edge._ops import EdgeOpOverload
Expand Down Expand Up @@ -40,22 +41,41 @@ def call_operator(self, op, args, kwargs, meta):
# new kwargs with dim_order, and no memory_format for the new op
nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable

# get the target memory format for the EdgeOp
mem_format = nkwargs.pop("memory_format", torch.contiguous_format)
# Get the target memory format for the EdgeOp, defaulting to
# preserve_format (clone() with no memory_format kwarg preserves
# the input's layout instead of forcing contiguous).
mem_format = nkwargs.pop("memory_format", torch.preserve_format)

# can always get the shape, assuming rank is specialized
# Get input tensor and ndim
input_tensor: Optional[torch.Tensor] = None
if isinstance(args[0], ProxyValue) and args[0].is_tensor():
ndim = args[0].to_tensor().dim()
input_tensor = args[0].to_tensor()
ndim = input_tensor.dim()
elif isinstance(args[0], torch.Tensor):
ndim = args[0].dim()
input_tensor = args[0]
ndim = input_tensor.dim()
elif isinstance(args[0], torch.fx.immutable_collections.immutable_list):
ndim = len(args[0])
else:
assert (
0
), f"Expecting a Tensor, a ProxyValue, or a Sequence, but got {type(args[0])}"

nkwargs["dim_order"] = get_dim_order(mem_format, ndim)
# Derive dim_order based on memory format
dim_order: List[int]
if mem_format in (None, torch.preserve_format):
# preserve_format: inherit dim_order from input tensor
if input_tensor is not None:
dim_order = [int(d) for d in input_tensor.dim_order()]
else:
# Fallback to contiguous if no single input tensor is available
# (e.g. list inputs like torch.stack).
dim_order = list(range(ndim))
else:
# Explicit memory format (contiguous_format, channels_last, etc.)
dim_order = get_dim_order(mem_format, ndim)

nkwargs["dim_order"] = dim_order
logger.debug(
f"{op.__name__} = rank: {ndim}, memory_format: {mem_format}."
f" {DimOrderOpsMap[op].__name__} = dim_order: {nkwargs['dim_order']}"
Expand Down
155 changes: 155 additions & 0 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2502,6 +2502,161 @@ def test_convert_constant_dim_order_to_contiguous(self):
)


class TestMemoryFormatOpsPassPreserveFormat(unittest.TestCase):
"""
Tests for MemoryFormatOpsPass preserve_format semantics.
"""

def test_clone_no_kwarg_preserves_channels_last_dim_order(self) -> None:
"""
Verify that clone() on a channels-last input with no memory_format kwarg
produces a _clone_dim_order node with channels-last dim_order (0,2,3,1).
"""

class ConvClone(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1)

def forward(self, x):
return self.conv(x).clone()

model = ConvClone().to(memory_format=torch.channels_last)
x = torch.randn(1, 3, 8, 8).to(memory_format=torch.channels_last)

# Run the model and verify that the output tensor preserves channels-last
# layout when no memory_format kwarg is provided.
with torch.no_grad():
y = model(x)
self.assertTrue(
y.is_contiguous(memory_format=torch.channels_last),
f"clone() without memory_format kwarg should preserve channels-last layout, got strides {y.stride()}",
)

ep = torch.export.export(model, (x,))
edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False))

# Find the _clone_dim_order node and check its dim_order
found_clone = False
for node in edge.exported_program().graph_module.graph.nodes:
if node.op == "call_function" and "_clone_dim_order" in str(node.target):
found_clone = True
spec = node.meta.get("val")
self.assertIsNotNone(spec, "Clone node should have meta['val']")
dim_order = tuple(spec.dim_order())
self.assertEqual(
dim_order,
(0, 2, 3, 1),
f"Clone should preserve channels-last dim_order, got {dim_order}",
)
break

self.assertTrue(found_clone, "Should find a _clone_dim_order node in the graph")

def test_clone_contiguous_format_kwarg_stays_contiguous(self) -> None:
"""
Regression guard: explicit contiguous_format should produce contiguous dim_order.

Note: When clone(memory_format=contiguous_format) is called on a channels-last
input, this is a layout-transforming operation. After export, this typically
lowers to _to_dim_order_copy (not _clone_dim_order) because it changes the
memory layout. We check for both node types to be robust.
"""

class CloneContiguousModel(torch.nn.Module):
def forward(self, x):
return x.clone(memory_format=torch.contiguous_format)

model = CloneContiguousModel()
x = torch.randn(1, 3, 8, 8).to(memory_format=torch.channels_last)

# Run the model and verify that the explicit contiguous_format kwarg
# produces a contiguous output layout (not channels-last).
with torch.no_grad():
y = model(x)
self.assertTrue(
y.is_contiguous(),
f"clone(memory_format=contiguous_format) should produce contiguous layout, got strides {y.stride()}",
)
self.assertFalse(
y.is_contiguous(memory_format=torch.channels_last),
"clone(memory_format=contiguous_format) should not preserve channels-last layout",
)

ep = torch.export.export(model, (x,))
edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False))

# Find the dim_order copy node and check its dim_order.
# This may be _to_dim_order_copy (layout transform) or _clone_dim_order.
found_copy = False
for node in edge.exported_program().graph_module.graph.nodes:
if node.op == "call_function" and (
"_clone_dim_order" in str(node.target)
or "_to_dim_order_copy" in str(node.target)
):
found_copy = True
spec = node.meta.get("val")
self.assertIsNotNone(spec, "Copy node should have meta['val']")
dim_order = tuple(spec.dim_order())
self.assertEqual(
dim_order,
(0, 1, 2, 3),
f"Explicit contiguous clone should have contiguous dim_order, got {dim_order}",
)
break

self.assertTrue(
found_copy, "Should find a _clone_dim_order or _to_dim_order_copy node"
)

def test_to_copy_no_kwarg_preserves_channels_last_dim_order(self) -> None:
"""
Verify that tensor.to(dtype=...) with no memory_format kwarg preserves
the input's dim_order (preserve_format semantics).

This tests the _to_copy.default path in MemoryFormatOpsPass.
"""

class ToCopyModel(torch.nn.Module):
def forward(self, x):
# .to(dtype=...) with no memory_format → preserve_format semantics
return x.to(dtype=torch.float32)

model = ToCopyModel()
x = torch.randn(1, 3, 8, 8, dtype=torch.float16).to(
memory_format=torch.channels_last
)

# Run the model and verify that tensor.to(dtype=...) with no memory_format
# kwarg preserves channels-last layout on the output tensor.
with torch.no_grad():
y = model(x)
self.assertTrue(
y.is_contiguous(memory_format=torch.channels_last),
f"to(dtype=...) without memory_format kwarg should preserve channels-last layout, got strides {y.stride()}",
)

ep = torch.export.export(model, (x,))
edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False))

# Find the _to_dim_order_copy node and verify it preserves channels-last
found_copy = False
for node in edge.exported_program().graph_module.graph.nodes:
if node.op == "call_function" and "_to_dim_order_copy" in str(node.target):
found_copy = True
spec = node.meta.get("val")
self.assertIsNotNone(spec, "Copy node should have meta['val']")
dim_order = tuple(spec.dim_order())
self.assertEqual(
dim_order,
(0, 2, 3, 1),
f"to(dtype=...) should preserve channels-last dim_order, got {dim_order}",
)
break

self.assertTrue(found_copy, "Should find a _to_dim_order_copy node")


class TestCSEPass(unittest.TestCase):
"""Tests for Common Subexpression Elimination pass."""

Expand Down
Loading