Skip to content

Commit 613a3d1

Browse files
committed
Use requires_module mark for numpy version checks in mutation tests
Replace inline skipif version check with requires_module(np, "2.1") from the shared test helpers, consistent with other test files. Made-with: Cursor
1 parent 1fda19c commit 613a3d1

File tree

1 file changed

+4
-13
lines changed

1 file changed

+4
-13
lines changed

cuda_core/tests/graph/test_graphdef_mutation.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,14 @@
33

44
"""Tests for mutating a graph definition (edge changes, node removal)."""
55

6-
import pytest
6+
import numpy as np
77
from helpers.collection_interface_testers import assert_mutable_set_interface
88
from helpers.graph_kernels import compile_parallel_kernels
9+
from helpers.marks import requires_module
910

1011
from cuda.core import Device, LaunchConfig, LegacyPinnedMemoryResource
1112
from cuda.core._graph._graph_def import GraphDef, KernelNode, MemsetNode
1213

13-
try:
14-
import numpy as np
15-
16-
_has_numpy_2_1 = tuple(int(i) for i in np.__version__.split(".")[:2]) >= (2, 1)
17-
except ImportError:
18-
np = None
19-
_has_numpy_2_1 = False
20-
21-
_need_numpy_2_1 = pytest.mark.skipif(not _has_numpy_2_1, reason="need numpy 2.1+")
22-
2314

2415
class YRig:
2516
"""Test rigging for graph mutation tests. Constructs a Y-shaped graph with
@@ -143,7 +134,7 @@ def close(self):
143134
self._buf.close()
144135

145136

146-
@_need_numpy_2_1
137+
@requires_module(np, "2.1")
147138
class TestMutateYRig:
148139
"""Tests that mutate the Y-shaped graph built by YRig."""
149140

@@ -269,7 +260,7 @@ def test_adjacency_set_property_setter(init_cuda):
269260
assert hub.pred == set()
270261

271262

272-
@_need_numpy_2_1
263+
@requires_module(np, "2.1")
273264
def test_convert_linear_to_fan_in(init_cuda):
274265
"""Chain four computations sequentially, then rewire so all pairs run in
275266
parallel feeding into a reduce node.

0 commit comments

Comments
 (0)