diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 574756dd002..51dd5fa91c1 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -159,6 +159,7 @@ from executorch.exir import ExportedProgram from executorch.exir.pass_base import ExportPass from executorch.exir.pass_manager import PassManager +from torch._export.utils import _get_shape_env_from_gm from torch.fx import GraphModule from torch.fx.passes.infra.pass_base import PassResult from torch.nn.modules import Module @@ -335,10 +336,12 @@ def add_passes(self, passes: Sequence[ExportPass | None]): if p is not None: self.add_pass(p) + def _tosa_context(self, graph_module: GraphModule) -> TosaLoweringContext: + shape_env = _get_shape_env_from_gm(graph_module) + return TosaLoweringContext(self.tosa_spec, shape_env) + def _transform(self, graph_module: GraphModule): - shape_env = graph_module.shape_env - with TosaLoweringContext(self.tosa_spec, shape_env): - return self(graph_module).graph_module + return self(graph_module).graph_module def add_pass(self, pipeline_pass): if type(pipeline_pass) in self._skip_pass_types: @@ -532,87 +535,89 @@ def transform_to_backend_pipeline( f"No pass pipeline found for TOSA specification: {self.tosa_spec}" ) - return self._tosa_pipeline(exported_program, graph_module) + with self._tosa_context(graph_module): + return self._tosa_pipeline(exported_program, graph_module) def transform_for_annotation_pipeline(self, graph_module: GraphModule): - # Preprocessing passes - self.add_pass(RemoveGraphAssertsPass(tfa_pass=True)) - self.add_pass(ConstantFoldingPass()) - - # Transformation passes (pre scalar -> tensor) - self.add_passes( - [ - DecomposeIndexCopyPass(tfa_pass=True), - DecomposeSelectScatterPass(tfa_pass=True), - DecomposeSliceScatterPass(tfa_pass=True), - ConvertInt64ConstOpsToInt32Pass(tfa_pass=True), - ConvertInt64OutputOpsToInt32Pass(tfa_pass=True), - InsertInt32CastsAfterInt64PlaceholdersPass(tfa_pass=True), - DecomposeEmbeddingPass(tfa_pass=True), - DecomposeScaledDotProductAttentionPass(tfa_pass=True), - DecomposeRoundPass(tfa_pass=True), - DecomposeLogitPass(tfa_pass=True), - PromoteBoolOperandsPass(tfa_pass=True), - DecomposeSignPass(tfa_pass=True), - DecomposeTrilPass(tfa_pass=True), - DecomposeAddmmPass(tfa_pass=True), - DecomposeRemainderPass(tfa_pass=True), - DecomposeFloorDividePass(tfa_pass=True), - DecomposeDivTensorModePass(tfa_pass=True), - DecomposeWhereScalarOtherPass(tfa_pass=True), - DecomposeEinsumPass(tfa_pass=True), - RewriteInplaceArithmeticPass(tfa_pass=True), - DecomposeAddSubAlphaPass(tfa_pass=True), - DecomposeLeakyReLUPass(tfa_pass=True), - DecomposeGroupNormPass(tfa_pass=True), - DecomposeLayerNormPass(tfa_pass=True), - DecomposeVarPass(tfa_pass=True), - DecomposeMeanDimPass(graph_module, self.tosa_spec, tfa_pass=True), - ] - ) + with self._tosa_context(graph_module): + # Preprocessing passes + self.add_pass(RemoveGraphAssertsPass(tfa_pass=True)) + self.add_pass(ConstantFoldingPass()) + + # Transformation passes (pre scalar -> tensor) + self.add_passes( + [ + DecomposeIndexCopyPass(tfa_pass=True), + DecomposeSelectScatterPass(tfa_pass=True), + DecomposeSliceScatterPass(tfa_pass=True), + ConvertInt64ConstOpsToInt32Pass(tfa_pass=True), + ConvertInt64OutputOpsToInt32Pass(tfa_pass=True), + InsertInt32CastsAfterInt64PlaceholdersPass(tfa_pass=True), + DecomposeEmbeddingPass(tfa_pass=True), + DecomposeScaledDotProductAttentionPass(tfa_pass=True), + DecomposeRoundPass(tfa_pass=True), + DecomposeLogitPass(tfa_pass=True), + PromoteBoolOperandsPass(tfa_pass=True), + DecomposeSignPass(tfa_pass=True), + DecomposeTrilPass(tfa_pass=True), + DecomposeAddmmPass(tfa_pass=True), + DecomposeRemainderPass(tfa_pass=True), + DecomposeFloorDividePass(tfa_pass=True), + DecomposeDivTensorModePass(tfa_pass=True), + DecomposeWhereScalarOtherPass(tfa_pass=True), + DecomposeEinsumPass(tfa_pass=True), + RewriteInplaceArithmeticPass(tfa_pass=True), + DecomposeAddSubAlphaPass(tfa_pass=True), + DecomposeLeakyReLUPass(tfa_pass=True), + DecomposeGroupNormPass(tfa_pass=True), + DecomposeLayerNormPass(tfa_pass=True), + DecomposeVarPass(tfa_pass=True), + DecomposeMeanDimPass(graph_module, self.tosa_spec, tfa_pass=True), + ] + ) - # Scalars -> tensors - self.add_passes( - [ - ReplaceScalarWithTensorByProfilePass(tfa_pass=True), - ScalarsToAttributePass(tfa_pass=True), - ControlFlowConstInlinePass(tfa_pass=True), - ] - ) + # Scalars -> tensors + self.add_passes( + [ + ReplaceScalarWithTensorByProfilePass(tfa_pass=True), + ScalarsToAttributePass(tfa_pass=True), + ControlFlowConstInlinePass(tfa_pass=True), + ] + ) - # Transformation passes (post scalar removal) - self.add_passes( - [ - NormalizeWhileInitialArgsPass(use_exir_clone=False, tfa_pass=True), - DecomposeGruPass(tfa_pass=True), - DecomposeLstmPass(tfa_pass=True), - DecomposeRnnPass(tfa_pass=True), - DecomposeNotEqualPass(tfa_pass=True), - DecomposeCosineSimilarityPass(tfa_pass=True), - DecomposeGluPass(tfa_pass=True), - DecomposeDivPass(tfa_pass=True), - DecomposeLinalgVectorNormPass(tfa_pass=True), - DecomposeSqrtPass(tfa_pass=True), - DecomposeAdaptiveAvgPool2dPass(tfa_pass=True), - DecomposeAvgPool2dPass(tfa_pass=True), - DecomposeSoftmaxPass( - tfa_pass=True, - ), - ConvertMinMaxPass(tfa_pass=True), - AccumulateIndexPutPass(tfa_pass=True), - DecomposeMatmulPass(tfa_pass=True), - ] - ) + # Transformation passes (post scalar removal) + self.add_passes( + [ + NormalizeWhileInitialArgsPass(use_exir_clone=False, tfa_pass=True), + DecomposeGruPass(tfa_pass=True), + DecomposeLstmPass(tfa_pass=True), + DecomposeRnnPass(tfa_pass=True), + DecomposeNotEqualPass(tfa_pass=True), + DecomposeCosineSimilarityPass(tfa_pass=True), + DecomposeGluPass(tfa_pass=True), + DecomposeDivPass(tfa_pass=True), + DecomposeLinalgVectorNormPass(tfa_pass=True), + DecomposeSqrtPass(tfa_pass=True), + DecomposeAdaptiveAvgPool2dPass(tfa_pass=True), + DecomposeAvgPool2dPass(tfa_pass=True), + DecomposeSoftmaxPass( + tfa_pass=True, + ), + ConvertMinMaxPass(tfa_pass=True), + AccumulateIndexPutPass(tfa_pass=True), + DecomposeMatmulPass(tfa_pass=True), + ] + ) - # Postprocessing passes - self.add_passes( - [ - ReplaceInfAndLimitValuesPass(tfa_pass=True), - DecomposeMaskedFillPass(tfa_pass=True), - ] - ) + # Postprocessing passes + self.add_passes( + [ + ReplaceInfAndLimitValuesPass(tfa_pass=True), + DecomposeMaskedFillPass(tfa_pass=True), + ] + ) - return self._transform(graph_module) + return self._transform(graph_module) def __call__(self, module: Module) -> PassResult: try: diff --git a/backends/test/suite/conftest.py b/backends/test/suite/conftest.py index 188f54e4e42..c1604c11016 100644 --- a/backends/test/suite/conftest.py +++ b/backends/test/suite/conftest.py @@ -1,3 +1,8 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import os from typing import Any @@ -9,19 +14,29 @@ from executorch.backends.test.suite.runner import run_test +FLOW_TEST_CASE_TIMEOUTS = { + "backends/test/suite/models/": 1200, + "backends/test/suite/operators/": 120, +} + + def pytest_collection_modifyitems(config, items): for item in items: callspec = getattr(item, "callspec", None) - if callspec is None: - continue - flow = callspec.params.get("test_runner") - if not isinstance(flow, TestFlow): - continue - test_name = item.originalname or item.name - if flow.should_skip_test(test_name): - item.add_marker( - pytest.mark.skip(reason=f"Skipped by {flow.name} skip_patterns") - ) + if callspec is not None: + flow = callspec.params.get("test_runner") + if isinstance(flow, TestFlow): + test_name = item.originalname or item.name + if flow.should_skip_test(test_name): + item.add_marker( + pytest.mark.skip(reason=f"Skipped by {flow.name} skip_patterns") + ) + + item_path = str(getattr(item, "path", "")) + for suite_prefix, timeout_s in FLOW_TEST_CASE_TIMEOUTS.items(): + if suite_prefix in item_path: + item.add_marker(pytest.mark.timeout(timeout_s)) + break def pytest_configure(config):