Skip to content

Commit eba9034

Browse files
committed
MAINT: clean up test_astype
1 parent e0f3e37 commit eba9034

File tree

1 file changed

+11
-21
lines changed

1 file changed

+11
-21
lines changed

array_api_tests/test_data_type_functions.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,19 @@
1414
from .typing import DataType
1515
from . import api_version
1616

17-
# TODO: test with complex dtypes
18-
def non_complex_dtypes():
19-
return xps.boolean_dtypes() | hh.real_dtypes
20-
2117

2218
def float32(n: Union[int, float]) -> float:
2319
return struct.unpack("!f", struct.pack("!f", float(n)))[0]
2420

2521

26-
def _float_match_complex(complex_dtype):
27-
if complex_dtype == xp.complex64:
28-
return xp.float32
29-
elif complex_dtype == xp.complex128:
30-
return xp.float64
22+
def _get_ranges(dtype):
23+
"""Ranges of dtype if integer, else ranges of the matching real float."""
24+
if dh.is_int_dtype(dtype):
25+
_real_dtype = dtype
3126
else:
32-
return dh.default_float
27+
_real_dtype = dh.real_dtype_for(dtype)
28+
m, M = dh.dtype_ranges[_real_dtype]
29+
return m, M
3330

3431

3532
@given(
@@ -39,7 +36,6 @@ def _float_match_complex(complex_dtype):
3936
data=st.data(),
4037
)
4138
def test_astype(x_dtype, dtype, kw, data):
42-
_complex_dtypes = (xp.complex64, xp.complex128)
4339

4440
if xp.bool in (x_dtype, dtype):
4541
elements_strat = hh.from_dtype(x_dtype)
@@ -52,15 +48,9 @@ def test_astype(x_dtype, dtype, kw, data):
5248
else:
5349
cast = float
5450

55-
real_dtype = x_dtype
56-
if x_dtype in _complex_dtypes:
57-
real_dtype = _float_match_complex(x_dtype)
58-
m1, M1 = dh.dtype_ranges[real_dtype]
59-
60-
real_dtype = dtype
61-
if dtype in _complex_dtypes:
62-
real_dtype = _float_match_complex(x_dtype)
63-
m2, M2 = dh.dtype_ranges[real_dtype]
51+
# generate values in range for both src and target dtypes
52+
m1, M1 = _get_ranges(x_dtype)
53+
m2, M2 = _get_ranges(dtype)
6454

6555
min_value = cast(max(m1, m2))
6656
max_value = cast(min(M1, M2))
@@ -79,7 +69,7 @@ def test_astype(x_dtype, dtype, kw, data):
7969
# according to the spec, "Casting a complex floating-point array to a real-valued
8070
# data type should not be permitted."
8171
# https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html#astype
82-
assume(not ((x_dtype in _complex_dtypes) and (dtype not in _complex_dtypes)))
72+
assume(not ((x_dtype in dh.complex_dtypes) and (dtype not in dh.complex_dtypes)))
8373

8474
repro_snippet = ph.format_snippet(f"xp.astype({x!r}, {dtype!r}, **kw) with {kw = }")
8575
try:

0 commit comments

Comments
 (0)