From 990608e960fa658d5cd81a9444ab66e066582781 Mon Sep 17 00:00:00 2001 From: NefAI Date: Sat, 21 Feb 2026 10:16:51 +0100 Subject: [PATCH 1/4] fix(MemoryFormatOpsPass): preserve input dim_order for clone/to_copy with no memory_format kwarg Issue #16032: clone() and _to_copy operations with no explicit memory_format kwarg were defaulting to contiguous dim_order, causing runtime assertion failures when cloning channels-last tensors. Changes: - Default memory_format to torch.preserve_format instead of torch.contiguous_format - When preserve_format, derive dim_order from input tensor's dim_order() - Simplify type annotation: dim_order is always assigned, no Optional needed Tests: - test_clone_no_kwarg_preserves_channels_last_dim_order: core repro case - test_clone_contiguous_format_kwarg_stays_contiguous: regression guard - test_to_copy_no_kwarg_preserves_channels_last_dim_order: _to_copy path --- exir/passes/memory_format_ops_pass.py | 29 +++++- exir/tests/test_passes.py | 130 ++++++++++++++++++++++++++ 2 files changed, 154 insertions(+), 5 deletions(-) diff --git a/exir/passes/memory_format_ops_pass.py b/exir/passes/memory_format_ops_pass.py index 421f30960b6..9eff4b3e1df 100644 --- a/exir/passes/memory_format_ops_pass.py +++ b/exir/passes/memory_format_ops_pass.py @@ -6,6 +6,7 @@ import copy import logging +from typing import List, Optional import torch from executorch.exir.dialects.edge._ops import EdgeOpOverload @@ -41,13 +42,18 @@ def call_operator(self, op, args, kwargs, meta): nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable # get the target memory format for the EdgeOp - mem_format = nkwargs.pop("memory_format", torch.contiguous_format) + # Default to preserve_format: clone() with no memory_format kwarg should + # preserve the input's layout, not force contiguous. Issue #16032. + mem_format = nkwargs.pop("memory_format", torch.preserve_format) - # can always get the shape, assuming rank is specialized + # Get input tensor and ndim + input_tensor: Optional[torch.Tensor] = None if isinstance(args[0], ProxyValue) and args[0].is_tensor(): - ndim = args[0].to_tensor().dim() + input_tensor = args[0].to_tensor() + ndim = input_tensor.dim() elif isinstance(args[0], torch.Tensor): - ndim = args[0].dim() + input_tensor = args[0] + ndim = input_tensor.dim() elif isinstance(args[0], torch.fx.immutable_collections.immutable_list): ndim = len(args[0]) else: @@ -55,7 +61,20 @@ def call_operator(self, op, args, kwargs, meta): 0 ), f"Expecting a Tensor, a ProxyValue, or a Sequence, but got {type(args[0])}" - nkwargs["dim_order"] = get_dim_order(mem_format, ndim) + # Derive dim_order based on memory format + dim_order: List[int] + if mem_format in (None, torch.preserve_format): + # preserve_format: inherit dim_order from input tensor + if input_tensor is not None: + dim_order = list(int(d) for d in input_tensor.dim_order()) + else: + # Fallback to contiguous if no input tensor available + dim_order = list(range(ndim)) + else: + # Explicit memory format (contiguous_format, channels_last, etc.) + dim_order = get_dim_order(mem_format, ndim) + + nkwargs["dim_order"] = dim_order logger.debug( f"{op.__name__} = rank: {ndim}, memory_format: {mem_format}." f" {DimOrderOpsMap[op].__name__} = dim_order: {nkwargs['dim_order']}" diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 452f9694a8d..c1d442d0429 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -2499,3 +2499,133 @@ def test_convert_constant_dim_order_to_contiguous(self): modified_const.is_contiguous(), f"Constant should be contiguous after pass, got strides {modified_const.stride()}", ) + + +class TestMemoryFormatOpsPassPreserveFormat(unittest.TestCase): + """ + Tests for MemoryFormatOpsPass preserve_format semantics. + + Issue #16032: clone() with no memory_format kwarg should preserve the input's + dim_order, not default to contiguous. This caused runtime assertion failures + when cloning channels-last tensors. + """ + + def test_clone_no_kwarg_preserves_channels_last_dim_order(self) -> None: + """ + Verify that clone() on a channels-last input with no memory_format kwarg + produces a _clone_dim_order node with channels-last dim_order (0,2,3,1). + + This is the core reproduction case for issue #16032. + """ + + class ConvClone(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1) + + def forward(self, x): + return self.conv(x).clone() + + model = ConvClone().to(memory_format=torch.channels_last) + x = torch.randn(1, 3, 8, 8).to(memory_format=torch.channels_last) + + ep = torch.export.export(model, (x,)) + edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False)) + + # Find the _clone_dim_order node and check its dim_order + found_clone = False + for node in edge.exported_program().graph_module.graph.nodes: + if node.op == "call_function" and "_clone_dim_order" in str(node.target): + found_clone = True + spec = node.meta.get("val") + self.assertIsNotNone(spec, "Clone node should have meta['val']") + dim_order = tuple(spec.dim_order()) + self.assertEqual( + dim_order, + (0, 2, 3, 1), + f"Clone should preserve channels-last dim_order, got {dim_order}", + ) + break + + self.assertTrue(found_clone, "Should find a _clone_dim_order node in the graph") + + def test_clone_contiguous_format_kwarg_stays_contiguous(self) -> None: + """ + Regression guard: explicit contiguous_format should produce contiguous dim_order. + + Note: When clone(memory_format=contiguous_format) is called on a channels-last + input, this is a layout-transforming operation. After export, this typically + lowers to _to_dim_order_copy (not _clone_dim_order) because it changes the + memory layout. We check for both node types to be robust. + """ + + class CloneContiguousModel(torch.nn.Module): + def forward(self, x): + return x.clone(memory_format=torch.contiguous_format) + + model = CloneContiguousModel() + x = torch.randn(1, 3, 8, 8).to(memory_format=torch.channels_last) + + ep = torch.export.export(model, (x,)) + edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False)) + + # Find the dim_order copy node and check its dim_order. + # This may be _to_dim_order_copy (layout transform) or _clone_dim_order. + found_copy = False + for node in edge.exported_program().graph_module.graph.nodes: + if node.op == "call_function" and ( + "_clone_dim_order" in str(node.target) + or "_to_dim_order_copy" in str(node.target) + ): + found_copy = True + spec = node.meta.get("val") + self.assertIsNotNone(spec, "Copy node should have meta['val']") + dim_order = tuple(spec.dim_order()) + self.assertEqual( + dim_order, + (0, 1, 2, 3), + f"Explicit contiguous clone should have contiguous dim_order, got {dim_order}", + ) + break + + self.assertTrue( + found_copy, "Should find a _clone_dim_order or _to_dim_order_copy node" + ) + + def test_to_copy_no_kwarg_preserves_channels_last_dim_order(self) -> None: + """ + Verify that tensor.to(dtype=...) with no memory_format kwarg preserves + the input's dim_order (preserve_format semantics). + + This tests the _to_copy.default path in MemoryFormatOpsPass. + """ + + class ToCopyModel(torch.nn.Module): + def forward(self, x): + # .to(dtype=...) with no memory_format → preserve_format semantics + return x.to(dtype=torch.float32) + + model = ToCopyModel() + x = torch.randn(1, 3, 8, 8, dtype=torch.float16).to( + memory_format=torch.channels_last + ) + + ep = torch.export.export(model, (x,)) + edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False)) + + # Find the _to_dim_order_copy node and verify it preserves channels-last + found_copy = False + for node in edge.exported_program().graph_module.graph.nodes: + if node.op == "call_function" and "_to_dim_order_copy" in str(node.target): + found_copy = True + spec = node.meta.get("val") + self.assertIsNotNone(spec, "Copy node should have meta['val']") + dim_order = tuple(spec.dim_order()) + self.assertEqual( + dim_order, + (0, 2, 3, 1), + f"to(dtype=...) should preserve channels-last dim_order, got {dim_order}", + ) + break + + self.assertTrue(found_copy, "Should find a _to_dim_order_copy node") From 45ab83a6b3929353a660fc4b709f784907b251d7 Mon Sep 17 00:00:00 2001 From: NefAI Date: Fri, 27 Feb 2026 21:15:15 +0100 Subject: [PATCH 2/4] fix: update MemoryFormatOpsPass comments and tests Clarify preserve_format behavior and extend MemoryFormatOpsPass tests to run the models and assert that output tensors have the expected memory format / dim order. --- exir/passes/memory_format_ops_pass.py | 6 ++--- exir/tests/test_passes.py | 37 ++++++++++++++++++++++----- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/exir/passes/memory_format_ops_pass.py b/exir/passes/memory_format_ops_pass.py index 9eff4b3e1df..f04a9a69261 100644 --- a/exir/passes/memory_format_ops_pass.py +++ b/exir/passes/memory_format_ops_pass.py @@ -41,9 +41,9 @@ def call_operator(self, op, args, kwargs, meta): # new kwargs with dim_order, and no memory_format for the new op nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable - # get the target memory format for the EdgeOp - # Default to preserve_format: clone() with no memory_format kwarg should - # preserve the input's layout, not force contiguous. Issue #16032. + # Get the target memory format for the EdgeOp, defaulting to + # preserve_format (clone() with no memory_format kwarg preserves + # the input's layout instead of forcing contiguous). mem_format = nkwargs.pop("memory_format", torch.preserve_format) # Get input tensor and ndim diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index d47fb3137fd..6e85c4e5173 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -2504,18 +2504,12 @@ def test_convert_constant_dim_order_to_contiguous(self): class TestMemoryFormatOpsPassPreserveFormat(unittest.TestCase): """ Tests for MemoryFormatOpsPass preserve_format semantics. - - Issue #16032: clone() with no memory_format kwarg should preserve the input's - dim_order, not default to contiguous. This caused runtime assertion failures - when cloning channels-last tensors. """ def test_clone_no_kwarg_preserves_channels_last_dim_order(self) -> None: """ Verify that clone() on a channels-last input with no memory_format kwarg produces a _clone_dim_order node with channels-last dim_order (0,2,3,1). - - This is the core reproduction case for issue #16032. """ class ConvClone(torch.nn.Module): @@ -2529,6 +2523,15 @@ def forward(self, x): model = ConvClone().to(memory_format=torch.channels_last) x = torch.randn(1, 3, 8, 8).to(memory_format=torch.channels_last) + # Run the model and verify that the output tensor preserves channels-last + # layout when no memory_format kwarg is provided. + with torch.no_grad(): + y = model(x) + self.assertTrue( + y.is_contiguous(memory_format=torch.channels_last), + f"clone() without memory_format kwarg should preserve channels-last layout, got strides {y.stride()}", + ) + ep = torch.export.export(model, (x,)) edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False)) @@ -2566,6 +2569,19 @@ def forward(self, x): model = CloneContiguousModel() x = torch.randn(1, 3, 8, 8).to(memory_format=torch.channels_last) + # Run the model and verify that the explicit contiguous_format kwarg + # produces a contiguous output layout (not channels-last). + with torch.no_grad(): + y = model(x) + self.assertTrue( + y.is_contiguous(), + f"clone(memory_format=contiguous_format) should produce contiguous layout, got strides {y.stride()}", + ) + self.assertFalse( + y.is_contiguous(memory_format=torch.channels_last), + "clone(memory_format=contiguous_format) should not preserve channels-last layout", + ) + ep = torch.export.export(model, (x,)) edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False)) @@ -2610,6 +2626,15 @@ def forward(self, x): memory_format=torch.channels_last ) + # Run the model and verify that tensor.to(dtype=...) with no memory_format + # kwarg preserves channels-last layout on the output tensor. + with torch.no_grad(): + y = model(x) + self.assertTrue( + y.is_contiguous(memory_format=torch.channels_last), + f"to(dtype=...) without memory_format kwarg should preserve channels-last layout, got strides {y.stride()}", + ) + ep = torch.export.export(model, (x,)) edge = to_edge(ep, compile_config=EdgeCompileConfig(_skip_dim_order=False)) From c1814d64ec5cb67bce6dc47ab112bc844270e702 Mon Sep 17 00:00:00 2001 From: NefAI Date: Fri, 27 Feb 2026 21:34:57 +0100 Subject: [PATCH 3/4] fix: address lint and style for MemoryFormatOpsPass Replace the generator-based dim_order construction with a list comprehension to satisfy FLAKE8 C400 and add the missing blank line before the new test class to align with PEP 8 spacing. --- exir/passes/memory_format_ops_pass.py | 2 +- exir/tests/test_passes.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/exir/passes/memory_format_ops_pass.py b/exir/passes/memory_format_ops_pass.py index f04a9a69261..4cb1087359c 100644 --- a/exir/passes/memory_format_ops_pass.py +++ b/exir/passes/memory_format_ops_pass.py @@ -66,7 +66,7 @@ def call_operator(self, op, args, kwargs, meta): if mem_format in (None, torch.preserve_format): # preserve_format: inherit dim_order from input tensor if input_tensor is not None: - dim_order = list(int(d) for d in input_tensor.dim_order()) + dim_order = [int(d) for d in input_tensor.dim_order()] else: # Fallback to contiguous if no input tensor available dim_order = list(range(ndim)) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 6e85c4e5173..3bf526858dc 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -2501,6 +2501,7 @@ def test_convert_constant_dim_order_to_contiguous(self): f"Constant should be contiguous after pass, got strides {modified_const.stride()}", ) + class TestMemoryFormatOpsPassPreserveFormat(unittest.TestCase): """ Tests for MemoryFormatOpsPass preserve_format semantics. From 0ebd613f6bcd5da051426882adaf44bc6d612099 Mon Sep 17 00:00:00 2001 From: NefAI Date: Fri, 27 Feb 2026 21:46:16 +0100 Subject: [PATCH 4/4] docs: clarify fallback comment and PEP 8 spacing before TestCSEPass --- exir/passes/memory_format_ops_pass.py | 3 ++- exir/tests/test_passes.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/exir/passes/memory_format_ops_pass.py b/exir/passes/memory_format_ops_pass.py index 4cb1087359c..13468dfd8d8 100644 --- a/exir/passes/memory_format_ops_pass.py +++ b/exir/passes/memory_format_ops_pass.py @@ -68,7 +68,8 @@ def call_operator(self, op, args, kwargs, meta): if input_tensor is not None: dim_order = [int(d) for d in input_tensor.dim_order()] else: - # Fallback to contiguous if no input tensor available + # Fallback to contiguous if no single input tensor is available + # (e.g. list inputs like torch.stack). dim_order = list(range(ndim)) else: # Explicit memory format (contiguous_format, channels_last, etc.) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 3bf526858dc..f683384f8f9 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -2656,6 +2656,7 @@ def forward(self, x): self.assertTrue(found_copy, "Should find a _to_dim_order_copy node") + class TestCSEPass(unittest.TestCase): """Tests for Common Subexpression Elimination pass."""