Skip to content

Commit 1bde1a4

Browse files
committed
Rework requires() mark: rename to requires_module, use importorskip
Rename the mark to requires_module and reimplement it as a thin wrapper around pytest.importorskip, forwarding *args/**kwargs directly. Version arguments are now strings (matching importorskip's minversion parameter) rather than integer tuples. Update all call sites accordingly. Made-with: Cursor
1 parent c5bc597 commit 1bde1a4

7 files changed

Lines changed: 44 additions & 38 deletions

File tree

cuda_core/tests/graph/test_advanced.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
import numpy as np
77
import pytest
88
from helpers.graph_kernels import compile_common_kernels, compile_conditional_kernels
9-
from helpers.marks import requires
9+
from helpers.marks import requires_module
1010

1111
from cuda.core import Device, LaunchConfig, LegacyPinnedMemoryResource, launch
1212

1313

14-
@requires(np, 2, 1)
14+
@requires_module(np, "2.1")
1515
def test_graph_child_graph(init_cuda):
1616
mod = compile_common_kernels()
1717
add_one = mod.get_kernel("add_one")
@@ -64,7 +64,7 @@ def test_graph_child_graph(init_cuda):
6464
b.close()
6565

6666

67-
@requires(np, 2, 1)
67+
@requires_module(np, "2.1")
6868
def test_graph_update(init_cuda):
6969
mod = compile_conditional_kernels(int)
7070
add_one = mod.get_kernel("add_one")

cuda_core/tests/graph/test_basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
import pytest
88
from helpers.graph_kernels import compile_common_kernels
9-
from helpers.marks import requires
9+
from helpers.marks import requires_module
1010

1111
from cuda.core import Device, GraphBuilder, LaunchConfig, LegacyPinnedMemoryResource, launch
1212

@@ -117,7 +117,7 @@ def test_graph_is_join_required(init_cuda):
117117
gb.end_building().complete()
118118

119119

120-
@requires(np, 2, 1)
120+
@requires_module(np, "2.1")
121121
def test_graph_repeat_capture(init_cuda):
122122
mod = compile_common_kernels()
123123
add_one = mod.get_kernel("add_one")

cuda_core/tests/graph/test_conditional.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88
import numpy as np
99
import pytest
1010
from helpers.graph_kernels import compile_conditional_kernels
11-
from helpers.marks import requires
11+
from helpers.marks import requires_module
1212

1313
from cuda.core import Device, GraphBuilder, LaunchConfig, LegacyPinnedMemoryResource, launch
1414

1515

1616
@pytest.mark.parametrize(
1717
"condition_value", [True, False, ctypes.c_bool(True), ctypes.c_bool(False), np.bool_(True), np.bool_(False), 1, 0]
1818
)
19-
@requires(np, 2, 1)
19+
@requires_module(np, "2.1")
2020
def test_graph_conditional_if(init_cuda, condition_value):
2121
mod = compile_conditional_kernels(type(condition_value))
2222
add_one = mod.get_kernel("add_one")
@@ -80,7 +80,7 @@ def test_graph_conditional_if(init_cuda, condition_value):
8080
@pytest.mark.parametrize(
8181
"condition_value", [True, False, ctypes.c_bool(True), ctypes.c_bool(False), np.bool_(True), np.bool_(False), 1, 0]
8282
)
83-
@requires(np, 2, 1)
83+
@requires_module(np, "2.1")
8484
def test_graph_conditional_if_else(init_cuda, condition_value):
8585
mod = compile_conditional_kernels(type(condition_value))
8686
add_one = mod.get_kernel("add_one")
@@ -152,7 +152,7 @@ def test_graph_conditional_if_else(init_cuda, condition_value):
152152

153153

154154
@pytest.mark.parametrize("condition_value", [0, 1, 2, 3])
155-
@requires(np, 2, 1)
155+
@requires_module(np, "2.1")
156156
def test_graph_conditional_switch(init_cuda, condition_value):
157157
mod = compile_conditional_kernels(type(condition_value))
158158
add_one = mod.get_kernel("add_one")
@@ -243,7 +243,7 @@ def test_graph_conditional_switch(init_cuda, condition_value):
243243

244244

245245
@pytest.mark.parametrize("condition_value", [True, False, 1, 0])
246-
@requires(np, 2, 1)
246+
@requires_module(np, "2.1")
247247
def test_graph_conditional_while(init_cuda, condition_value):
248248
mod = compile_conditional_kernels(type(condition_value))
249249
add_one = mod.get_kernel("add_one")

cuda_core/tests/graph/test_device_launch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import numpy as np
1414
import pytest
15-
from helpers.marks import requires
15+
from helpers.marks import requires_module
1616

1717
from cuda.core import (
1818
Device,
@@ -83,7 +83,7 @@ def _compile_device_launcher_kernel():
8383
Device().compute_capability.major < 9,
8484
reason="Device-side graph launch requires Hopper (sm_90+) architecture",
8585
)
86-
@requires(np, 2, 1)
86+
@requires_module(np, "2.1")
8787
def test_device_launch_basic(init_cuda):
8888
"""Test basic device-side graph launch functionality.
8989
@@ -135,7 +135,7 @@ def test_device_launch_basic(init_cuda):
135135
Device().compute_capability.major < 9,
136136
reason="Device-side graph launch requires Hopper (sm_90+) architecture",
137137
)
138-
@requires(np, 2, 1)
138+
@requires_module(np, "2.1")
139139
def test_device_launch_multiple(init_cuda):
140140
"""Test that device-side graph launch can be executed multiple times.
141141

cuda_core/tests/helpers/marks.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,43 @@
33

44
"""Reusable pytest marks for cuda_core tests."""
55

6-
import importlib
7-
import types
6+
import inspect
87

98
import pytest
109

1110

12-
def requires(module, *version):
13-
"""Skip the test if a module is missing or older than the given version.
11+
def requires_module(module, *args, **kwargs):
12+
"""Skip the test if a module is missing or older than required.
13+
14+
Thin wrapper around :func:`pytest.importorskip`. The first argument
15+
may be a module object or a string; all remaining positional and
16+
keyword arguments (``minversion``, ``reason``, ``exc_type``) are
17+
forwarded.
18+
19+
Prefer this over ``pytest.importorskip`` when:
20+
21+
- You need finer granularity than module scope or a test body; this
22+
mark can decorate classes, individual tests, or ``pytest.param`` entries.
23+
- You want to skip before fixtures run, avoiding setup costs.
24+
- The module is already imported and you want to pass it directly.
1425
1526
Usage::
1627
17-
@requires(np, 2, 1)
28+
@requires_module("numpy", "2.1")
1829
def test_foo(): ...
1930
2031
21-
@requires("scipy", 1, 12)
32+
@requires_module(np, minversion="2.1")
2233
def test_bar(): ...
2334
"""
24-
if isinstance(module, str):
25-
name = module
26-
try:
27-
module = importlib.import_module(name)
28-
except ImportError:
29-
return pytest.mark.skip(reason=f"{name} is not installed")
30-
elif isinstance(module, types.ModuleType):
31-
name = module.__name__
32-
else:
35+
if inspect.ismodule(module):
36+
module = module.__name__
37+
elif not isinstance(module, str):
3338
raise TypeError(f"expected module or string, got {type(module).__name__}")
3439

35-
n = len(version)
36-
parts = module.__version__.split(".")[:n]
37-
installed = tuple(int(p) for p in parts)
38-
ver_str = ".".join(str(v) for v in version)
39-
return pytest.mark.skipif(installed < version, reason=f"need {name} {ver_str}+")
40+
try:
41+
pytest.importorskip(module, *args, **kwargs)
42+
except pytest.skip.Skipped as exc:
43+
return pytest.mark.skipif(True, reason=str(exc))
44+
else:
45+
return pytest.mark.skipif(False, reason="")

cuda_core/tests/test_launcher.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import ctypes
55

66
import helpers
7-
from helpers.marks import requires
7+
from helpers.marks import requires_module
88
from helpers.misc import StreamWrapper
99

1010
try:
@@ -191,7 +191,7 @@ def test_launch_invalid_values(init_cuda):
191191

192192

193193
@pytest.mark.parametrize("python_type, cpp_type, init_value", PARAMS)
194-
@requires(np, 2, 1)
194+
@requires_module(np, "2.1")
195195
def test_launch_scalar_argument(python_type, cpp_type, init_value):
196196
dev = Device()
197197
dev.set_current()
@@ -290,7 +290,7 @@ def test_cooperative_launch():
290290
"device_memory_resource", # kludgy, but can go away after #726 is resolved
291291
pytest.param(
292292
LegacyPinnedMemoryResource,
293-
marks=requires(np, 2, 2, 5),
293+
marks=requires_module(np, "2.2.5"),
294294
),
295295
],
296296
)

cuda_core/tests/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
ml_dtypes = None
2727
import numpy as np
2828
import pytest
29-
from helpers.marks import requires
29+
from helpers.marks import requires_module
3030

3131
from cuda.core import Device
3232
from cuda.core._dlpack import DLDeviceType
@@ -86,7 +86,7 @@ def convert_strides_to_counts(strides, itemsize):
8686
# readonly is fixed recently (numpy/numpy#26501)
8787
pytest.param(
8888
np.frombuffer(b""),
89-
marks=requires(np, 2, 1),
89+
marks=requires_module(np, "2.1"),
9090
),
9191
),
9292
)

0 commit comments

Comments
 (0)