1010from . import hypothesis_helpers as hh
1111from . import pytest_helpers as ph
1212from . import shape_helpers as sh
13- from . import xps
1413from .typing import DataType
1514from . import api_version
1615
17- # TODO: test with complex dtypes
18- def non_complex_dtypes ():
19- return xps .boolean_dtypes () | hh .real_dtypes
20-
2116
2217def float32 (n : Union [int , float ]) -> float :
2318 return struct .unpack ("!f" , struct .pack ("!f" , float (n )))[0 ]
2419
2520
26- def _float_match_complex (complex_dtype ):
27- if complex_dtype == xp .complex64 :
28- return xp .float32
29- elif complex_dtype == xp .complex128 :
30- return xp .float64
21+ def _get_ranges (dtype ):
22+ """Ranges of dtype if integer, else ranges of the matching real float."""
23+ if dh .is_int_dtype (dtype ):
24+ _real_dtype = dtype
3125 else :
32- return dh .default_float
26+ _real_dtype = dh .real_dtype_for (dtype )
27+ m , M = dh .dtype_ranges [_real_dtype ]
28+ return m , M
3329
3430
3531@given (
@@ -39,7 +35,6 @@ def _float_match_complex(complex_dtype):
3935 data = st .data (),
4036)
4137def test_astype (x_dtype , dtype , kw , data ):
42- _complex_dtypes = (xp .complex64 , xp .complex128 )
4338
4439 if xp .bool in (x_dtype , dtype ):
4540 elements_strat = hh .from_dtype (x_dtype )
@@ -52,15 +47,9 @@ def test_astype(x_dtype, dtype, kw, data):
5247 else :
5348 cast = float
5449
55- real_dtype = x_dtype
56- if x_dtype in _complex_dtypes :
57- real_dtype = _float_match_complex (x_dtype )
58- m1 , M1 = dh .dtype_ranges [real_dtype ]
59-
60- real_dtype = dtype
61- if dtype in _complex_dtypes :
62- real_dtype = _float_match_complex (x_dtype )
63- m2 , M2 = dh .dtype_ranges [real_dtype ]
50+ # generate values in range for both src and target dtypes
51+ m1 , M1 = _get_ranges (x_dtype )
52+ m2 , M2 = _get_ranges (dtype )
6453
6554 min_value = cast (max (m1 , m2 ))
6655 max_value = cast (min (M1 , M2 ))
@@ -79,7 +68,7 @@ def test_astype(x_dtype, dtype, kw, data):
7968 # according to the spec, "Casting a complex floating-point array to a real-valued
8069 # data type should not be permitted."
8170 # https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html#astype
82- assume (not ((x_dtype in _complex_dtypes ) and (dtype not in _complex_dtypes )))
71+ assume (not ((x_dtype in dh . complex_dtypes ) and (dtype not in dh . complex_dtypes )))
8372
8473 repro_snippet = ph .format_snippet (f"xp.astype({ x !r} , { dtype !r} , **kw) with { kw = } " )
8574 try :
0 commit comments