Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 83 additions & 78 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
from executorch.exir import ExportedProgram
from executorch.exir.pass_base import ExportPass
from executorch.exir.pass_manager import PassManager
from torch._export.utils import _get_shape_env_from_gm
from torch.fx import GraphModule
from torch.fx.passes.infra.pass_base import PassResult
from torch.nn.modules import Module
Expand Down Expand Up @@ -335,10 +336,12 @@ def add_passes(self, passes: Sequence[ExportPass | None]):
if p is not None:
self.add_pass(p)

def _tosa_context(self, graph_module: GraphModule) -> TosaLoweringContext:
shape_env = _get_shape_env_from_gm(graph_module)
return TosaLoweringContext(self.tosa_spec, shape_env)

def _transform(self, graph_module: GraphModule):
shape_env = graph_module.shape_env
with TosaLoweringContext(self.tosa_spec, shape_env):
return self(graph_module).graph_module
return self(graph_module).graph_module

def add_pass(self, pipeline_pass):
if type(pipeline_pass) in self._skip_pass_types:
Expand Down Expand Up @@ -532,87 +535,89 @@ def transform_to_backend_pipeline(
f"No pass pipeline found for TOSA specification: {self.tosa_spec}"
)

return self._tosa_pipeline(exported_program, graph_module)
with self._tosa_context(graph_module):
return self._tosa_pipeline(exported_program, graph_module)

def transform_for_annotation_pipeline(self, graph_module: GraphModule):
# Preprocessing passes
self.add_pass(RemoveGraphAssertsPass(tfa_pass=True))
self.add_pass(ConstantFoldingPass())

# Transformation passes (pre scalar -> tensor)
self.add_passes(
[
DecomposeIndexCopyPass(tfa_pass=True),
DecomposeSelectScatterPass(tfa_pass=True),
DecomposeSliceScatterPass(tfa_pass=True),
ConvertInt64ConstOpsToInt32Pass(tfa_pass=True),
ConvertInt64OutputOpsToInt32Pass(tfa_pass=True),
InsertInt32CastsAfterInt64PlaceholdersPass(tfa_pass=True),
DecomposeEmbeddingPass(tfa_pass=True),
DecomposeScaledDotProductAttentionPass(tfa_pass=True),
DecomposeRoundPass(tfa_pass=True),
DecomposeLogitPass(tfa_pass=True),
PromoteBoolOperandsPass(tfa_pass=True),
DecomposeSignPass(tfa_pass=True),
DecomposeTrilPass(tfa_pass=True),
DecomposeAddmmPass(tfa_pass=True),
DecomposeRemainderPass(tfa_pass=True),
DecomposeFloorDividePass(tfa_pass=True),
DecomposeDivTensorModePass(tfa_pass=True),
DecomposeWhereScalarOtherPass(tfa_pass=True),
DecomposeEinsumPass(tfa_pass=True),
RewriteInplaceArithmeticPass(tfa_pass=True),
DecomposeAddSubAlphaPass(tfa_pass=True),
DecomposeLeakyReLUPass(tfa_pass=True),
DecomposeGroupNormPass(tfa_pass=True),
DecomposeLayerNormPass(tfa_pass=True),
DecomposeVarPass(tfa_pass=True),
DecomposeMeanDimPass(graph_module, self.tosa_spec, tfa_pass=True),
]
)
with self._tosa_context(graph_module):
# Preprocessing passes
self.add_pass(RemoveGraphAssertsPass(tfa_pass=True))
self.add_pass(ConstantFoldingPass())

# Transformation passes (pre scalar -> tensor)
self.add_passes(
[
DecomposeIndexCopyPass(tfa_pass=True),
DecomposeSelectScatterPass(tfa_pass=True),
DecomposeSliceScatterPass(tfa_pass=True),
ConvertInt64ConstOpsToInt32Pass(tfa_pass=True),
ConvertInt64OutputOpsToInt32Pass(tfa_pass=True),
InsertInt32CastsAfterInt64PlaceholdersPass(tfa_pass=True),
DecomposeEmbeddingPass(tfa_pass=True),
DecomposeScaledDotProductAttentionPass(tfa_pass=True),
DecomposeRoundPass(tfa_pass=True),
DecomposeLogitPass(tfa_pass=True),
PromoteBoolOperandsPass(tfa_pass=True),
DecomposeSignPass(tfa_pass=True),
DecomposeTrilPass(tfa_pass=True),
DecomposeAddmmPass(tfa_pass=True),
DecomposeRemainderPass(tfa_pass=True),
DecomposeFloorDividePass(tfa_pass=True),
DecomposeDivTensorModePass(tfa_pass=True),
DecomposeWhereScalarOtherPass(tfa_pass=True),
DecomposeEinsumPass(tfa_pass=True),
RewriteInplaceArithmeticPass(tfa_pass=True),
DecomposeAddSubAlphaPass(tfa_pass=True),
DecomposeLeakyReLUPass(tfa_pass=True),
DecomposeGroupNormPass(tfa_pass=True),
DecomposeLayerNormPass(tfa_pass=True),
DecomposeVarPass(tfa_pass=True),
DecomposeMeanDimPass(graph_module, self.tosa_spec, tfa_pass=True),
]
)

# Scalars -> tensors
self.add_passes(
[
ReplaceScalarWithTensorByProfilePass(tfa_pass=True),
ScalarsToAttributePass(tfa_pass=True),
ControlFlowConstInlinePass(tfa_pass=True),
]
)
# Scalars -> tensors
self.add_passes(
[
ReplaceScalarWithTensorByProfilePass(tfa_pass=True),
ScalarsToAttributePass(tfa_pass=True),
ControlFlowConstInlinePass(tfa_pass=True),
]
)

# Transformation passes (post scalar removal)
self.add_passes(
[
NormalizeWhileInitialArgsPass(use_exir_clone=False, tfa_pass=True),
DecomposeGruPass(tfa_pass=True),
DecomposeLstmPass(tfa_pass=True),
DecomposeRnnPass(tfa_pass=True),
DecomposeNotEqualPass(tfa_pass=True),
DecomposeCosineSimilarityPass(tfa_pass=True),
DecomposeGluPass(tfa_pass=True),
DecomposeDivPass(tfa_pass=True),
DecomposeLinalgVectorNormPass(tfa_pass=True),
DecomposeSqrtPass(tfa_pass=True),
DecomposeAdaptiveAvgPool2dPass(tfa_pass=True),
DecomposeAvgPool2dPass(tfa_pass=True),
DecomposeSoftmaxPass(
tfa_pass=True,
),
ConvertMinMaxPass(tfa_pass=True),
AccumulateIndexPutPass(tfa_pass=True),
DecomposeMatmulPass(tfa_pass=True),
]
)
# Transformation passes (post scalar removal)
self.add_passes(
[
NormalizeWhileInitialArgsPass(use_exir_clone=False, tfa_pass=True),
DecomposeGruPass(tfa_pass=True),
DecomposeLstmPass(tfa_pass=True),
DecomposeRnnPass(tfa_pass=True),
DecomposeNotEqualPass(tfa_pass=True),
DecomposeCosineSimilarityPass(tfa_pass=True),
DecomposeGluPass(tfa_pass=True),
DecomposeDivPass(tfa_pass=True),
DecomposeLinalgVectorNormPass(tfa_pass=True),
DecomposeSqrtPass(tfa_pass=True),
DecomposeAdaptiveAvgPool2dPass(tfa_pass=True),
DecomposeAvgPool2dPass(tfa_pass=True),
DecomposeSoftmaxPass(
tfa_pass=True,
),
ConvertMinMaxPass(tfa_pass=True),
AccumulateIndexPutPass(tfa_pass=True),
DecomposeMatmulPass(tfa_pass=True),
]
)

# Postprocessing passes
self.add_passes(
[
ReplaceInfAndLimitValuesPass(tfa_pass=True),
DecomposeMaskedFillPass(tfa_pass=True),
]
)
# Postprocessing passes
self.add_passes(
[
ReplaceInfAndLimitValuesPass(tfa_pass=True),
DecomposeMaskedFillPass(tfa_pass=True),
]
)

return self._transform(graph_module)
return self._transform(graph_module)

def __call__(self, module: Module) -> PassResult:
try:
Expand Down
35 changes: 25 additions & 10 deletions backends/test/suite/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
from typing import Any

Expand All @@ -9,19 +14,29 @@
from executorch.backends.test.suite.runner import run_test


FLOW_TEST_CASE_TIMEOUTS = {
"backends/test/suite/models/": 1200,
"backends/test/suite/operators/": 120,
}


def pytest_collection_modifyitems(config, items):
for item in items:
callspec = getattr(item, "callspec", None)
if callspec is None:
continue
flow = callspec.params.get("test_runner")
if not isinstance(flow, TestFlow):
continue
test_name = item.originalname or item.name
if flow.should_skip_test(test_name):
item.add_marker(
pytest.mark.skip(reason=f"Skipped by {flow.name} skip_patterns")
)
if callspec is not None:
flow = callspec.params.get("test_runner")
if isinstance(flow, TestFlow):
test_name = item.originalname or item.name
if flow.should_skip_test(test_name):
item.add_marker(
pytest.mark.skip(reason=f"Skipped by {flow.name} skip_patterns")
)

item_path = str(getattr(item, "path", ""))
for suite_prefix, timeout_s in FLOW_TEST_CASE_TIMEOUTS.items():
if suite_prefix in item_path:
item.add_marker(pytest.mark.timeout(timeout_s))
break


def pytest_configure(config):
Expand Down
Loading