Skip to content

Commit bcba790

Browse files
authored
Decouple transforms test from Arm/TOSA test utils (pytorch#20495)
test_to_contiguous_channels_last_pass.py imported backends.arm.test.common solely for the parametrize helper, which transitively pulled in the Arm backend and tosa-tools (tosa_serializer, tosa) just to collect a backend-agnostic transforms test. Add a neutral parametrize helper under backends/transforms/test/common.py and import it from there instead. This also fixes the test's xfail=xfails kwarg (the helper's parameter is xfails), which was crashing collection. Authored with Claude.
1 parent 46c784f commit bcba790

2 files changed

Lines changed: 71 additions & 2 deletions

File tree

backends/transforms/test/common.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, Callable, ParamSpec, TypeVar
8+
9+
import pytest
10+
11+
xfail_type = str | tuple[str, type[Exception]]
12+
_P = ParamSpec("_P")
13+
_R = TypeVar("_R")
14+
Decorator = Callable[[Callable[_P, _R]], Callable[_P, _R]]
15+
16+
17+
def parametrize(
18+
arg_name: str,
19+
test_data: dict[str, Any],
20+
xfails: dict[str, xfail_type] | None = None,
21+
skips: dict[str, str] | None = None,
22+
strict: bool = True,
23+
flakies: dict[str, int] | None = None,
24+
) -> Decorator:
25+
"""Backend-neutral version of pytest.mark.parametrize with some syntactic
26+
sugar and added xfail functionality.
27+
28+
- test_data is expected as a dict of (id, test_data) pairs
29+
- allows specifying a dict of (id, failure_reason) pairs to mark specific
30+
tests as xfail. failure_reason can be str or tuple[str, type[Exception]].
31+
Strings set the reason for failure, the exception type sets the expected
32+
error.
33+
"""
34+
xfails = xfails or {}
35+
skips = skips or {}
36+
flakies = flakies or {}
37+
38+
def decorator_func(func: Callable[_P, _R]) -> Callable[_P, _R]:
39+
pytest_testsuite = []
40+
for id, test_parameters in test_data.items():
41+
if id in flakies:
42+
marker = (pytest.mark.flaky(reruns=flakies[id]),)
43+
elif id in skips:
44+
# fail markers do not work with 'buck' based ci, so use skip instead
45+
marker = (pytest.mark.skip(reason=skips[id]),)
46+
elif id in xfails:
47+
xfail_info = xfails[id]
48+
reason = ""
49+
raises = None
50+
if isinstance(xfail_info, str):
51+
reason = xfail_info
52+
elif isinstance(xfail_info, tuple):
53+
reason, raises = xfail_info
54+
else:
55+
raise RuntimeError(
56+
"xfail info needs to be str, or tuple[str, type[Exception]]"
57+
)
58+
marker = (
59+
pytest.mark.xfail(reason=reason, raises=raises, strict=strict),
60+
)
61+
else:
62+
marker = ()
63+
64+
pytest_param = pytest.param(test_parameters, id=id, marks=marker)
65+
pytest_testsuite.append(pytest_param)
66+
decorator = pytest.mark.parametrize(arg_name, pytest_testsuite)
67+
return decorator(func)
68+
69+
return decorator_func

backends/transforms/test/test_to_contiguous_channels_last_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import pytest
1010
import torch
11-
from executorch.backends.arm.test import common
11+
from executorch.backends.transforms.test import common
1212
from executorch.exir import to_edge_transform_and_lower
1313
from executorch.exir.dialects._ops import ops as exir_ops
1414
from executorch.exir.pass_base import ExportPass
@@ -695,6 +695,6 @@ def test_permute_view_counts(case: PermuteCountTestCase) -> None:
695695
@pytest.mark.skip(
696696
reason="Proof of concept - currently no permute-view passes implemented."
697697
)
698-
@common.parametrize("case", cases_channels_last, xfail=xfails)
698+
@common.parametrize("case", cases_channels_last, xfails=xfails)
699699
def test_permute_view_counts_channels_last(case: PermuteCountTestCase) -> None:
700700
run_test(case)

0 commit comments

Comments
 (0)