Skip to content

Commit 0d6432d

Browse files
authored
Merge pull request #428 from ev-br/cleanup_test_astype
MAINT: clean up test_astype
2 parents e0f3e37 + c85e01e commit 0d6432d

File tree

1 file changed

+11
-22
lines changed

1 file changed

+11
-22
lines changed

array_api_tests/test_data_type_functions.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,22 @@
1010
from . import hypothesis_helpers as hh
1111
from . import pytest_helpers as ph
1212
from . import shape_helpers as sh
13-
from . import xps
1413
from .typing import DataType
1514
from . import api_version
1615

17-
# TODO: test with complex dtypes
18-
def non_complex_dtypes():
19-
return xps.boolean_dtypes() | hh.real_dtypes
20-
2116

2217
def float32(n: Union[int, float]) -> float:
2318
return struct.unpack("!f", struct.pack("!f", float(n)))[0]
2419

2520

26-
def _float_match_complex(complex_dtype):
27-
if complex_dtype == xp.complex64:
28-
return xp.float32
29-
elif complex_dtype == xp.complex128:
30-
return xp.float64
21+
def _get_ranges(dtype):
22+
"""Ranges of dtype if integer, else ranges of the matching real float."""
23+
if dh.is_int_dtype(dtype):
24+
_real_dtype = dtype
3125
else:
32-
return dh.default_float
26+
_real_dtype = dh.real_dtype_for(dtype)
27+
m, M = dh.dtype_ranges[_real_dtype]
28+
return m, M
3329

3430

3531
@given(
@@ -39,7 +35,6 @@ def _float_match_complex(complex_dtype):
3935
data=st.data(),
4036
)
4137
def test_astype(x_dtype, dtype, kw, data):
42-
_complex_dtypes = (xp.complex64, xp.complex128)
4338

4439
if xp.bool in (x_dtype, dtype):
4540
elements_strat = hh.from_dtype(x_dtype)
@@ -52,15 +47,9 @@ def test_astype(x_dtype, dtype, kw, data):
5247
else:
5348
cast = float
5449

55-
real_dtype = x_dtype
56-
if x_dtype in _complex_dtypes:
57-
real_dtype = _float_match_complex(x_dtype)
58-
m1, M1 = dh.dtype_ranges[real_dtype]
59-
60-
real_dtype = dtype
61-
if dtype in _complex_dtypes:
62-
real_dtype = _float_match_complex(x_dtype)
63-
m2, M2 = dh.dtype_ranges[real_dtype]
50+
# generate values in range for both src and target dtypes
51+
m1, M1 = _get_ranges(x_dtype)
52+
m2, M2 = _get_ranges(dtype)
6453

6554
min_value = cast(max(m1, m2))
6655
max_value = cast(min(M1, M2))
@@ -79,7 +68,7 @@ def test_astype(x_dtype, dtype, kw, data):
7968
# according to the spec, "Casting a complex floating-point array to a real-valued
8069
# data type should not be permitted."
8170
# https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html#astype
82-
assume(not ((x_dtype in _complex_dtypes) and (dtype not in _complex_dtypes)))
71+
assume(not ((x_dtype in dh.complex_dtypes) and (dtype not in dh.complex_dtypes)))
8372

8473
repro_snippet = ph.format_snippet(f"xp.astype({x!r}, {dtype!r}, **kw) with {kw = }")
8574
try:

0 commit comments

Comments
 (0)