Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 40 additions & 9 deletions backends/arm/_passes/fuse_constant_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Set, Type
from collections.abc import Mapping
from typing import Sequence, Set, Type

import torch._export.utils
import torch.fx
Expand All @@ -18,6 +19,7 @@
from executorch.backends.arm._passes.fuse_equal_placeholders_pass import (
FuseEqualPlaceholdersPass,
)
from executorch.backends.arm.tosa.dialect.shape import meta_has_shape_mark
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
from executorch.backends.transforms.utils import (
create_constant_placeholder,
Expand Down Expand Up @@ -53,6 +55,36 @@ def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.exported_program = exported_program

@staticmethod
def _is_tosa_dialect_op(target) -> bool:
target_str = str(target)
return (
"executorch.exir.dialects.backend._ops.tosa." in target_str
or "<EdgeOpOverload: tosa." in target_str
)

@staticmethod
def _arg_contains_symbolic_shape(arg) -> bool:
if isinstance(arg, torch.fx.Node):
if meta_has_shape_mark(arg.meta):
return True
return FuseConstantArgsPass._arg_contains_symbolic_shape(
arg.meta.get("val")
)
if isinstance(arg, torch.SymInt):
return True
if isinstance(arg, Mapping):
return any(
FuseConstantArgsPass._arg_contains_symbolic_shape(k)
or FuseConstantArgsPass._arg_contains_symbolic_shape(v)
for k, v in arg.items()
)
if isinstance(arg, Sequence) and not isinstance(arg, (str, bytes)):
return any(
FuseConstantArgsPass._arg_contains_symbolic_shape(v) for v in arg
)
return False

def _propagate_special_dtype(self, from_nodes, to_node, data):
"""Propagate special dtype meta if it exists."""
special_dtypes = set()
Expand Down Expand Up @@ -142,13 +174,13 @@ def call(self, graph_module):
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
if node.target in [
exir_ops.backend.tosa.MATMUL.default,
exir_ops.backend.tosa.RESCALE.default,
exir_ops.backend.tosa.RESIZE.default,
exir_ops.backend.tosa.TABLE.default,
exir_ops.backend.tosa.TRANSPOSE.default,
]:
# Don't fuse TOSA dialect ops as they do not have eager forward functions.
# Also don't fuse ops whose explicit args/kwargs include symbolic shape values.
if (
self._is_tosa_dialect_op(node.target)
or self._arg_contains_symbolic_shape(node.args)
or self._arg_contains_symbolic_shape(node.kwargs)
):
continue

input_nodes = node.all_input_nodes
Expand All @@ -164,7 +196,6 @@ def call(self, graph_module):
)
if not all(input_nodes_constant):
continue

try:
did_fuse = self._fuse_nodes(node)
if did_fuse:
Expand Down
77 changes: 76 additions & 1 deletion backends/arm/test/passes/test_fuse_constant_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import operator
from typing import cast, ClassVar, Dict, Protocol, Tuple

import executorch.backends.arm.tosa.dialect # noqa: F401
import torch
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
ComputeConstantOpsAOTPass,
Expand All @@ -15,8 +16,15 @@
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
from executorch.backends.arm.tosa.specification import (
TosaLoweringContext,
TosaSpecification,
)
from executorch.backends.test.harness.stages import StageType
from executorch.backends.test.program_builder import ProgramBuilder
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export.graph_signature import InputKind

input_t = Tuple[torch.Tensor] # Input x
input_t2 = Tuple[torch.Tensor, torch.Tensor]
Expand Down Expand Up @@ -270,3 +278,70 @@ def test_fuse_constant_args_tosa_INT_cat_uses_top_level_arg_qparams() -> None:
for node in pass_result.graph_module.graph.nodes
if node.op == "placeholder"
] == ["aten_cat_default_fused_const"]


def test_fuse_constant_args_identifies_tosa_dialect_targets() -> None:
class FakeTosaTarget:
def __str__(self) -> str:
return "executorch.exir.dialects.backend._ops.tosa.MAX_POOL2D.default"

assert FuseConstantArgsPass._is_tosa_dialect_op(FakeTosaTarget())
assert FuseConstantArgsPass._is_tosa_dialect_op(
exir_ops.backend.tosa.GATHER.default
)
assert not FuseConstantArgsPass._is_tosa_dialect_op(torch.ops.aten.add.Tensor)


def test_fuse_constant_args_identifies_symbolic_shape_args() -> None:
graph = torch.fx.Graph()
shape_node = graph.placeholder("shape")
shape_node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE

assert FuseConstantArgsPass._arg_contains_symbolic_shape((shape_node, [1, 2]))
assert not FuseConstantArgsPass._arg_contains_symbolic_shape(
([1, 2], {"pad": (0, 0)})
)


def test_fuse_constant_args_skips_backend_tosa_gather(caplog) -> None:
with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")):
builder = ProgramBuilder()
values = builder.placeholder(
"values",
torch.randn(1, 4, 3),
input_kind=InputKind.CONSTANT_TENSOR,
)
indices = builder.placeholder(
"indices",
torch.tensor([[0, 2]], dtype=torch.int32),
input_kind=InputKind.CONSTANT_TENSOR,
)
gather = builder.call_operator(
exir_ops.backend.tosa.GATHER.default,
(values, indices),
)
builder.output([gather])

exported_program = builder.get_program()
graph_module = exported_program.graph_module

with caplog.at_level("WARNING"):
FuseConstantArgsPass(exported_program)(graph_module)

warning_messages = [
record.getMessage()
for record in caplog.records
if record.name == "executorch.backends.arm._passes.fuse_constant_ops_pass"
]
assert not any(
"Failed to fuse constant op" in message and "GATHER" in message
for message in warning_messages
)
assert (
sum(
node.op == "call_function"
and node.target == exir_ops.backend.tosa.GATHER.default
for node in graph_module.graph.nodes
)
== 1
)
Loading