Skip to content

Commit 1dd288f

Browse files
author
zhangyue
committed
test(conftest): joint (device, implementation_index) parametrize
Replaces the per-test `@pytest.mark.parametrize("implementation_index", ...)` + runtime `if impl not in active_indices: skip` pattern with a single hook in `conftest.pytest_generate_tests` that emits only the (device, impl) pairs actually active on each device. Rationale: kernel dispatch is per-device, so cross-device union (previous `all_active_implementation_indices` helper) polluted the matrix with impls that the selected device can't run — runtime-skipped noise. Joint generation keeps the matrix to its semantic cell: "this device has this impl, so run it". - `tests/conftest.py`: when both `device` and `implementation_index` are in fixturenames, emit pairs via `op_cls.active_implementation_indices(dev)`; fall back to a skipped placeholder (`id="skip"`) when no device has an active impl, avoiding `[NOTSET-...]` test IDs. - `tests/{test_add,test_gemm,test_rms_norm,test_swiglu}.py`: drop the hardcoded `implementation_index` parametrize decorator and the runtime `active_indices` guard — conftest now handles both. - `tests/utils.py`: remove the `all_active_implementation_indices` helper (superseded by per-device generation in conftest). Same test outcome on Ascend CI (1935 passed / 1686 skipped) but the remaining skips are now either semantically mandatory (uint dtypes unsupported by `torch_npu`, Gemm impl=2 SFINAE-only workaround, op missing ascend impl on op-simple pending PR #66) rather than mechanism artifacts.
1 parent 9d7cb0e commit 1dd288f

6 files changed

Lines changed: 62 additions & 76 deletions

File tree

tests/conftest.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,65 @@ def pytest_generate_tests(metafunc):
191191
else:
192192
devices = ()
193193

194-
metafunc.parametrize("device", devices or available)
194+
devices = devices or available
195+
196+
# Joint `(device, implementation_index)` parametrize: generate only
197+
# pairs where the op has an active implementation on that device.
198+
# Avoids cross-device pollution — an impl active on `cpu` but not on
199+
# `npu` no longer appears as a runtime skip in the npu column.
200+
if (
201+
"implementation_index" in metafunc.fixturenames
202+
and "implementation_index" not in already_parametrized
203+
):
204+
op_cls = _op_class_from_module(metafunc.module)
205+
206+
if op_cls is not None and hasattr(op_cls, "active_implementation_indices"):
207+
pairs = [
208+
(dev, idx)
209+
for dev in devices
210+
for idx in op_cls.active_implementation_indices(dev)
211+
]
212+
213+
if not pairs:
214+
# Emit one skipped placeholder so test IDs read
215+
# `[skip-dtype0-...]` instead of `[NOTSET-...]`.
216+
pairs = [
217+
pytest.param(
218+
devices[0] if devices else "cpu",
219+
0,
220+
marks=pytest.mark.skip(
221+
reason=(
222+
f"{op_cls.__name__} has no active "
223+
"implementation on any available device"
224+
)
225+
),
226+
id="skip",
227+
)
228+
]
229+
230+
metafunc.parametrize("device, implementation_index", pairs)
231+
232+
return
233+
234+
metafunc.parametrize("device", devices)
235+
236+
237+
def _op_class_from_module(module):
238+
"""Derive the `infini.ops.<Op>` class from a `tests/test_<snake>.py` module."""
239+
module_name = module.__name__.rsplit(".", 1)[-1]
240+
241+
if not module_name.startswith("test_"):
242+
return None
243+
244+
op_snake = module_name[len("test_") :]
245+
op_pascal = "".join(part.capitalize() for part in op_snake.split("_"))
246+
247+
try:
248+
import infini.ops as _ops
249+
except ImportError:
250+
return None
251+
252+
return getattr(_ops, op_pascal, None)
195253

196254

197255
@pytest.hookimpl(tryfirst=True)

tests/test_add.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from tests.utils import (
66
Payload,
7-
all_active_implementation_indices,
87
empty_strided,
98
get_stream,
109
randint_strided,
@@ -36,9 +35,6 @@
3635
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)),
3736
),
3837
)
39-
@pytest.mark.parametrize(
40-
"implementation_index", all_active_implementation_indices(infini.ops.Add)
41-
)
4238
@pytest.mark.parametrize(
4339
("dtype", "rtol", "atol"),
4440
(
@@ -64,11 +60,6 @@ def test_add(
6460
"The `torch.musa` test cloning path does not support `uint16`, `uint32`, or `uint64`."
6561
)
6662

67-
active_indices = infini.ops.Add.active_implementation_indices(device)
68-
69-
if implementation_index not in active_indices:
70-
pytest.skip(f"implementation `{implementation_index}` not active on `{device}`")
71-
7263
if implementation_index == 1 and dtype in _UINT_DTYPES:
7364
pytest.skip("ATen `add` does not support unsigned integer types")
7465

tests/test_gemm.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,7 @@
22
import pytest
33
import torch
44

5-
from tests.utils import (
6-
Payload,
7-
all_active_implementation_indices,
8-
get_stream,
9-
randn_strided,
10-
)
5+
from tests.utils import Payload, get_stream, randn_strided
116

127

138
@pytest.mark.auto_act_and_assert
@@ -25,9 +20,6 @@
2520
@pytest.mark.parametrize("beta", (-1, -0.5, 0, 0.5, 1))
2621
@pytest.mark.parametrize("trans_a", (False, True))
2722
@pytest.mark.parametrize("trans_b", (False, True))
28-
@pytest.mark.parametrize(
29-
"implementation_index", all_active_implementation_indices(infini.ops.Gemm)
30-
)
3123
@pytest.mark.parametrize(
3224
("dtype", "rtol", "atol"),
3325
(
@@ -61,11 +53,6 @@ def test_gemm(
6153
if device == "mlu" and dtype == torch.bfloat16:
6254
pytest.skip("`bfloat16` is not supported by `cnnlBatchMatMulEx`")
6355

64-
active_indices = infini.ops.Gemm.active_implementation_indices(device)
65-
66-
if implementation_index not in active_indices:
67-
pytest.skip(f"implementation `{implementation_index}` not active on `{device}`")
68-
6956
if implementation_index == 1 and dtype in (torch.float16, torch.bfloat16):
7057
pytest.skip("cuBLASLt half-precision exceeds current tolerances")
7158

tests/test_rms_norm.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,7 @@
22
import pytest
33
import torch
44

5-
from tests.utils import (
6-
Payload,
7-
all_active_implementation_indices,
8-
empty_strided,
9-
get_stream,
10-
randn_strided,
11-
)
5+
from tests.utils import Payload, empty_strided, get_stream, randn_strided
126

137

148
@pytest.mark.auto_act_and_assert
@@ -24,9 +18,6 @@
2418
),
2519
)
2620
@pytest.mark.parametrize("eps", (1e-6, 1e-5))
27-
@pytest.mark.parametrize(
28-
"implementation_index", all_active_implementation_indices(infini.ops.RmsNorm)
29-
)
3021
@pytest.mark.parametrize(
3122
("dtype", "rtol", "atol"),
3223
(
@@ -48,11 +39,6 @@ def test_rms_norm(
4839
rtol,
4940
atol,
5041
):
51-
active_indices = infini.ops.RmsNorm.active_implementation_indices(device)
52-
53-
if implementation_index not in active_indices:
54-
pytest.skip(f"implementation `{implementation_index}` not active on `{device}`")
55-
5642
input = randn_strided(input_shape, input_strides, dtype=dtype, device=device)
5743
weight = randn_strided(weight_shape, weight_strides, dtype=dtype, device=device)
5844
out = empty_strided(input_shape, out_strides, dtype=dtype, device=device)

tests/test_swiglu.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,7 @@
22
import pytest
33
import torch
44

5-
from tests.utils import (
6-
Payload,
7-
all_active_implementation_indices,
8-
empty_strided,
9-
get_stream,
10-
rand_strided,
11-
)
5+
from tests.utils import Payload, empty_strided, get_stream, rand_strided
126

137

148
@pytest.mark.auto_act_and_assert
@@ -25,9 +19,6 @@
2519
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)),
2620
),
2721
)
28-
@pytest.mark.parametrize(
29-
"implementation_index", all_active_implementation_indices(infini.ops.Swiglu)
30-
)
3122
@pytest.mark.parametrize(
3223
("dtype", "rtol", "atol"),
3324
(
@@ -47,11 +38,6 @@ def test_swiglu(
4738
rtol,
4839
atol,
4940
):
50-
active_indices = infini.ops.Swiglu.active_implementation_indices(device)
51-
52-
if implementation_index not in active_indices:
53-
pytest.skip(f"implementation `{implementation_index}` not active on `{device}`")
54-
5541
input = rand_strided(shape, input_strides, dtype=dtype, device=device)
5642
gate = rand_strided(shape, gate_strides, dtype=dtype, device=device)
5743
out = empty_strided(shape, out_strides, dtype=dtype, device=device)

tests/utils.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -122,28 +122,6 @@ def get_stream(device):
122122
return getattr(stream, attr, 0)
123123

124124

125-
def all_active_implementation_indices(op_cls):
126-
"""Union of `op_cls.active_implementation_indices(device)` across every
127-
locally-available torch device type.
128-
129-
Use as the `@pytest.mark.parametrize("implementation_index", ...)` value so
130-
the test matrix grows automatically when a new backend implementation is
131-
added. Per-device filtering (skipping indices not active on the currently
132-
selected device) stays the test body's responsibility — see the `skip`
133-
pattern in `test_gemm.py`.
134-
135-
Limited to `get_available_devices()` to avoid `DispatchFunc::std::abort`
136-
for device types outside the build's `ActiveDevices` set (e.g., querying
137-
`"cuda"` on an Ascend-only build).
138-
"""
139-
indices = set()
140-
141-
for device in get_available_devices():
142-
indices.update(op_cls.active_implementation_indices(device))
143-
144-
return tuple(sorted(indices))
145-
146-
147125
def clone_strided(input):
148126
output = empty_strided(
149127
input.size(), input.stride(), dtype=input.dtype, device=input.device

0 commit comments

Comments
 (0)