Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
from .decompose_var_pass import DecomposeVarPass # noqa
from .decompose_where_scalar_other_pass import DecomposeWhereScalarOtherPass # noqa
from .decorate_fp32_to_int32_casting_pass import DecorateFp32toInt32CastingPass # noqa
from .ensure_unique_output_nodes_pass import EnsureUniqueOutputNodesPass # noqa
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
FoldAndAnnotateQParamsPass,
QuantizeClampArgumentsPass,
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
DecomposeVarPass,
DecomposeWhereScalarOtherPass,
DecorateFp32toInt32CastingPass,
EnsureUniqueOutputNodesPass,
FoldAndAnnotateQParamsPass,
FuseBatchNorm2dPass,
FuseConsecutiveConcatShapesPass,
Expand Down Expand Up @@ -544,6 +545,7 @@ def _tosa_pipeline(
FuseEqualPlaceholdersPass(exported_program),
FuseConsecutiveConcatShapesPass(),
ToTosaMemoryFormatPass(exported_program),
EnsureUniqueOutputNodesPass(),
RemoveNoopPass(),
InsertRescalePass(),
InsertDataLayoutCastsPass(),
Expand Down
82 changes: 82 additions & 0 deletions backends/arm/_passes/ensure_unique_output_nodes_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections import Counter
from typing import Any, Set, Type

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class EnsureUniqueOutputNodesPass(ArmPass):
"""Ensure each graph output leaf references a unique producer node.

If the same node appears multiple times in the output structure, insert a
``tosa.IDENTITY`` node for each occurrence and replace the repeated output
entries with those identity nodes.

"""
Comment on lines +16 to +23
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass inherits ArmPass.__init__(tfa_pass=False, *args, **kwargs), so accidental construction as EnsureUniqueOutputNodesPass(exported_program) (as done by some test harness paths) will bind the exported program object to tfa_pass and treat the pass as a transform-for-annotation pass. Define an explicit __init__ that accepts an optional/ignored exported_program positional parameter and always forwards tfa_pass=False to ArmPass to avoid this silent behavior change.

Copilot uses AI. Check for mistakes.

_passes_required_after: Set[Type[ExportPass]] = set()

@staticmethod
def _collect_output_nodes(
output_value: Any, counts: Counter[torch.fx.Node]
) -> None:
if isinstance(output_value, torch.fx.Node):
counts[output_value] += 1
return
if isinstance(output_value, (list, tuple)):
for value in output_value:
EnsureUniqueOutputNodesPass._collect_output_nodes(value, counts)

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
output_node = graph.output_node()
output_value = output_node.args[0]

counts: Counter[torch.fx.Node] = Counter()
self._collect_output_nodes(output_value, counts)
repeated_nodes = {node for node, count in counts.items() if count > 1}
if not repeated_nodes:
return PassResult(graph_module, False)

modified = False

def _replace_repeated_outputs(value: Any) -> Any:
nonlocal modified
if isinstance(value, torch.fx.Node):
if value not in repeated_nodes:
return value
with graph.inserting_before(output_node):
identity_node = create_node(
graph,
exir_ops.backend.tosa.IDENTITY.default,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for output placeholders I guess to not mess up the signature?

args=(value,),
from_node=value,
)
modified = True
return identity_node

if isinstance(value, tuple):
return tuple(_replace_repeated_outputs(v) for v in value)

if isinstance(value, list):
return [_replace_repeated_outputs(v) for v in value]

return value

new_output_value = _replace_repeated_outputs(output_value)
if modified:
output_node.args = (new_output_value,)
graph.eliminate_dead_code()
graph.lint()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, modified)
10 changes: 5 additions & 5 deletions backends/arm/_passes/to_tosa_memory_format_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ def insert_output_transpose(node, graph_module):
"""Convert a producer's output to channels-last by appending a backend
`TRANSPOSE` node and rewiring its users.
"""

rank = len(get_first_fake_tensor(node).size())
spatial_rank = node.meta["tosa_spatial_rank"]
mem_format = ToTosaMemoryFormatPass._channels_last_order(rank, spatial_rank)
Expand Down Expand Up @@ -383,17 +382,18 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
if output_dim_orders is None:
raise RuntimeError(f"{output_dim_orders=} is not supported.")

transposed_output_inputs: set[torch.fx.Node] = set()
for output_node_input, output_dim_order in zip(
outputs, output_dim_orders, strict=True
):
if output_dim_order in (
NCHW_ORDER,
NNCHW_ORDER,
NNNCHW_ORDER,
if (
output_dim_order in (NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER)
and output_node_input not in transposed_output_inputs
):
self.insert_input_transpose(
output_node, output_node_input, graph_module
)
transposed_output_inputs.add(output_node_input)

def remove_dim_order_kwargs(
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
op_tosa_custom,
op_tosa_depthwise_conv2d,
op_tosa_gather,
op_tosa_identity,
op_tosa_matmul,
op_tosa_max_pool2d,
op_tosa_pad,
Expand Down
62 changes: 62 additions & 0 deletions backends/arm/operators/op_tosa_identity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List

import torch
import tosa_serializer as ts

from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_same_dtype,
validate_valid_dtype,
)
from executorch.backends.arm.tosa.mapping import TosaArg


@register_node_visitor
class IdentityVisitor(NodeVisitor):
"""Lower the TOSA IDENTITY op."""

target = "tosa.IDENTITY.default"

def define_node(
self,
node: torch.fx.Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
validate_num_inputs(self.target, inputs, 1)
validate_same_dtype(self.target, [inputs[0], output], ts)
validate_valid_dtype(
self.target,
[inputs[0], output],
[
ts.DType.BOOL,
ts.DType.INT8,
ts.DType.INT16,
ts.DType.INT32,
ts.DType.FP16,
ts.DType.FP32,
ts.DType.BF16,
],
self.tosa_spec,
)
Comment on lines +36 to +51
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate_valid_dtype doesn’t use tosa_spec, so passing a hard-coded list here effectively enables INT16/BF16 for all specs. This can allow lowering an IDENTITY on dtypes that the active TOSA profile/extensions don’t support. Build supported_dtypes dynamically from self.tosa_spec (similar to TransposeVisitor) and only include INT/FP/BF16 types when the spec indicates support.

Copilot uses AI. Check for mistakes.

attr = ts.TosaSerializerAttribute()
attr.IdentityAttribute()
self._serialize_operator(
node,
tosa_graph,
ts.Op.IDENTITY,
[inputs[0].name],
[output.name],
attr,
)
25 changes: 25 additions & 0 deletions backends/arm/test/misc/test_tosa_dialect_identity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import executorch.backends.arm.tosa.dialect # noqa: F401
import torch
from executorch.backends.arm.tosa.specification import (
TosaLoweringContext,
TosaSpecification,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch._subclasses.fake_tensor import FakeTensorMode


def test_identity_tosa_FP() -> None:
sample_input = torch.randn((1, 2, 3, 4), dtype=torch.float32)

with TosaLoweringContext(
TosaSpecification.create_from_string("TOSA-1.0+FP")
), FakeTensorMode() as mode:
output = exir_ops.backend.tosa.IDENTITY.default(mode.from_tensor(sample_input))

assert output.dtype == sample_input.dtype
assert tuple(output.shape) == tuple(sample_input.shape)
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm._passes import EnsureUniqueOutputNodesPass
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
from executorch.backends.test.harness.stages import StageType
from executorch.exir.dialects._ops import ops as exir_ops


class DuplicateOutputModule(torch.nn.Module):
def forward(self, x: torch.Tensor):
y = x + 1.0
return y, y


class UniqueOutputModule(torch.nn.Module):
def forward(self, x: torch.Tensor):
y = x + 1.0
z = x + 2.0
return y, z


def test_ensure_unique_output_nodes_no_target_inserts_identity_per_repeated_output() -> (
None
):
pipeline = PassPipeline[tuple[torch.Tensor]](
DuplicateOutputModule(),
(torch.rand(2, 2),),
quantize=False,
pass_list=[EnsureUniqueOutputNodesPass],
ops_after_pass={
"executorch_exir_dialects_backend__ops_tosa_IDENTITY_default": 2,
},
)
pipeline.pop_stage("run_method_and_compare_outputs")
pipeline.run()

graph_module = (
pipeline.tester.get_artifact(StageType.RUN_PASSES)
.exported_program()
.graph_module
)
output_node = graph_module.graph.output_node()
outputs = list(output_node.args[0])

assert outputs[0] is not outputs[1]
assert outputs[0].target == exir_ops.backend.tosa.IDENTITY.default
assert outputs[1].target == exir_ops.backend.tosa.IDENTITY.default
assert outputs[0].args[0] is outputs[1].args[0]


def test_ensure_unique_output_nodes_no_target_keeps_unique_outputs_unchanged() -> None:
pipeline = PassPipeline[tuple[torch.Tensor]](
UniqueOutputModule(),
(torch.rand(2, 2),),
quantize=False,
pass_list=[EnsureUniqueOutputNodesPass],
ops_not_after_pass=[
"executorch_exir_dialects_backend__ops_tosa_IDENTITY_default",
],
)
pipeline.pop_stage("run_method_and_compare_outputs")
pipeline.run()
59 changes: 59 additions & 0 deletions backends/arm/test/passes/test_to_tosa_memory_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@
import torch
from executorch.backends.arm._passes import (
AnnotateOutputDimOrderPass,
EnsureUniqueOutputNodesPass,
FuseEqualPlaceholdersPass,
ToTosaMemoryFormatPass,
)

from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
PassPipeline,
TosaPipelineFP,
TosaPipelineINT,
)
from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass
from executorch.exir.dialects._ops import ops as exir_ops

input_t = Tuple[torch.Tensor] # Input x

Expand Down Expand Up @@ -177,6 +181,26 @@ def get_inputs(self) -> input_t:
return (torch.rand(4, 4, 4, 4),)


class DuplicateConstantOutputs(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer("grid0", torch.zeros(1, 32, 32, 2))
self.register_buffer("grid1", torch.zeros(1, 32, 32, 2))

def forward(self, x: torch.Tensor):
return self.grid0, self.grid1, x


class DuplicateConstantOutputsWithAdd(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer("grid0", torch.zeros(1, 32, 32, 2))
self.register_buffer("grid1", torch.zeros(1, 32, 32, 2))

def forward(self, x: torch.Tensor):
return self.grid0, self.grid1, x + x


modules: Dict[str, ModuleMetadata] = {
"no_nhwc": NoNHWC(),
"parallel_clusters": ParallelClusters(),
Expand Down Expand Up @@ -209,3 +233,38 @@ def test_to_tosa_memory_format_tosa_INT_functional(module: ModuleMetadata) -> No
module_nn = cast(torch.nn.Module, module)
pipeline = TosaPipelineINT[input_t](module_nn, module.get_inputs(), [])
pipeline.run()


def test_to_tosa_memory_format_no_target_preserves_duplicate_output_slots() -> None:
pipeline = PassPipeline[input_t](
DuplicateConstantOutputs(),
(torch.rand(1, 2, 32, 32),),
quantize=False,
pass_list=[RemoveGetItemPass, AnnotateOutputDimOrderPass],
passes_with_exported_program=[
FuseEqualPlaceholdersPass,
ToTosaMemoryFormatPass,
EnsureUniqueOutputNodesPass,
],
Comment on lines +244 to +248
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EnsureUniqueOutputNodesPass is included in passes_with_exported_program, whose harness wrapper instantiates passes as PassClass(ep) unconditionally. Since EnsureUniqueOutputNodesPass doesn't take an exported_program parameter, this positional arg will get bound to ArmPass.__init__'s tfa_pass and can silently change pass behavior. Either move this pass out of passes_with_exported_program (keeping execution order correct), or update EnsureUniqueOutputNodesPass.__init__ to accept an exported_program positional argument and ignore it so instantiation is safe.

Copilot uses AI. Check for mistakes.
)
pipeline.pop_stage("run_method_and_compare_outputs")
pipeline.run()

graph_module = pipeline.tester.get_artifact().exported_program().graph_module
output_node = graph_module.graph.output_node()
outputs = list(output_node.args[0])

assert outputs[0] is not outputs[1]
assert outputs[0].target == exir_ops.backend.tosa.IDENTITY.default
assert outputs[1].target == exir_ops.backend.tosa.IDENTITY.default
assert outputs[0].args[0] is outputs[1].args[0]


def test_to_tosa_memory_format_tosa_FP_duplicate_output_identity() -> None:
pipeline = TosaPipelineFP[input_t](
DuplicateConstantOutputsWithAdd(),
(torch.rand(1, 2, 32, 32),),
[],
[],
)
pipeline.run()
1 change: 1 addition & 0 deletions backends/arm/tosa/dialect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
custom,
depthwise_conv2d,
gather,
identity,
matmul,
max_pool2d,
pad,
Expand Down
Loading
Loading