Skip to content

Commit 9ee6341

Browse files
committed
ENH: add minimal dlpack tests
1 parent cf7bc9f commit 9ee6341

1 file changed

Lines changed: 63 additions & 0 deletions

File tree

array_api_tests/test_dlpack.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from enum import Enum
2+
3+
from hypothesis import given, strategies as st
4+
from . import _array_module as xp
5+
from . import pytest_helpers as ph
6+
from . import hypothesis_helpers as hh
7+
8+
# dlpack Enum values,
9+
# https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack_device__.html
10+
11+
class DLPackDeviceEnum(Enum):
12+
CPU = 1
13+
CUDA = 2
14+
CPU_PINNED = 3
15+
OPENCL = 4
16+
VULKAN = 7
17+
METAL = 8
18+
VPI = 9
19+
ROCM = 10
20+
CUDA_MANAGED = 13
21+
ONE_API = 14
22+
23+
24+
@given(dtype=hh.all_dtypes, data=st.data())
25+
def test_dlpack_device(dtype, data):
26+
"""Test the array object __dlpack_device__ method."""
27+
# TODO: 1. generate inputs on non-default devices
28+
x = xp.empty(3, dtype=dtype)
29+
device_type, device_id = x.__dlpack_device__()
30+
31+
assert device_type in DLPackDeviceEnum
32+
assert isinstance(device_id, int)
33+
34+
35+
@given(
36+
x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_dims=1, max_side=2)),
37+
copy_kw=hh.kwargs(copy=st.booleans()),
38+
data=st.data()
39+
)
40+
def test_from_dlpack(x, copy_kw, data):
41+
devices = xp.__array_namespace_info__().devices()
42+
tgt_device_kw = data.draw(
43+
hh.kwargs(device=st.sampled_from(devices) | st.none())
44+
)
45+
# TODO: 1. test copy; 2. generate inputs on non-default devices
46+
tgt_device = tgt_device_kw['device'] if tgt_device_kw else None
47+
48+
49+
repro_snippet = ph.format_snippet(
50+
f"y = from_dlpack({x!r}, **tgt_device_kw, **copy_kw) with {tgt_device_kw=} and {copy_kw=}"
51+
)
52+
try:
53+
y = xp.from_dlpack(x, **tgt_device_kw, **copy_kw)
54+
55+
if tgt_device is None:
56+
assert y.device == x.device
57+
assert xp.all(y == x)
58+
else:
59+
assert y.device == tgt_device
60+
61+
except Exception as exc:
62+
ph.add_note(exc, repro_snippet)
63+
raise

0 commit comments

Comments
 (0)