Skip to content

Commit 27e9066

Browse files
authored
fix: add init_cuda fixture to tests requiring CUDA context (#1522)
Four tests in test_utils.py relied on CuPy implicitly creating a CUDA context but failed when pytest-randomly ordered them after tests using the init_cuda fixture, which pops the context on cleanup.
1 parent 5dd4ac9 commit 27e9066

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

cuda_core/tests/test_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def _get_ptr(array):
345345
for view_as in ["dlpack", "cai"]
346346
],
347347
)
348-
def test_view_sliced_external(shape, slices, stride_order, view_as):
348+
def test_view_sliced_external(init_cuda, shape, slices, stride_order, view_as):
349349
if view_as == "dlpack":
350350
if np is None:
351351
pytest.skip("NumPy is not installed")
@@ -380,7 +380,7 @@ def test_view_sliced_external(shape, slices, stride_order, view_as):
380380
("stride_order", "view_as"),
381381
[(stride_order, view_as) for stride_order in ["C", "F"] for view_as in ["dlpack", "cai"]],
382382
)
383-
def test_view_sliced_external_negative_offset(stride_order, view_as):
383+
def test_view_sliced_external_negative_offset(init_cuda, stride_order, view_as):
384384
shape = (5,)
385385
if view_as == "dlpack":
386386
if np is None:
@@ -422,7 +422,7 @@ def test_view_sliced_external_negative_offset(stride_order, view_as):
422422
)
423423
@pytest.mark.parametrize("shape", [(0,), (0, 0), (0, 0, 0)])
424424
@pytest.mark.parametrize("dtype", [np.int64, np.uint8, np.float64])
425-
def test_view_zero_size_array(api, shape, dtype):
425+
def test_view_zero_size_array(init_cuda, api, shape, dtype):
426426
cp = pytest.importorskip("cupy")
427427

428428
x = cp.empty(shape, dtype=dtype)
@@ -446,7 +446,7 @@ def test_from_buffer_with_non_power_of_two_itemsize():
446446
assert view.dtype == dtype
447447

448448

449-
def test_struct_array():
449+
def test_struct_array(init_cuda):
450450
cp = pytest.importorskip("cupy")
451451

452452
x = np.array([(1.0, 2), (2.0, 3)], dtype=[("array1", np.float64), ("array2", np.int64)])

0 commit comments

Comments
 (0)