1414from .typing import DataType
1515from . import api_version
1616
17- # TODO: test with complex dtypes
18- def non_complex_dtypes ():
19- return xps .boolean_dtypes () | hh .real_dtypes
20-
2117
2218def float32 (n : Union [int , float ]) -> float :
2319 return struct .unpack ("!f" , struct .pack ("!f" , float (n )))[0 ]
2420
2521
26- def _float_match_complex (complex_dtype ):
27- if complex_dtype == xp .complex64 :
28- return xp .float32
29- elif complex_dtype == xp .complex128 :
30- return xp .float64
22+ def _get_ranges (dtype ):
23+ """Ranges of dtype if integer, else ranges of the matching real float."""
24+ if dh .is_int_dtype (dtype ):
25+ _real_dtype = dtype
3126 else :
32- return dh .default_float
27+ _real_dtype = dh .real_dtype_for (dtype )
28+ m , M = dh .dtype_ranges [_real_dtype ]
29+ return m , M
3330
3431
3532@given (
@@ -39,7 +36,6 @@ def _float_match_complex(complex_dtype):
3936 data = st .data (),
4037)
4138def test_astype (x_dtype , dtype , kw , data ):
42- _complex_dtypes = (xp .complex64 , xp .complex128 )
4339
4440 if xp .bool in (x_dtype , dtype ):
4541 elements_strat = hh .from_dtype (x_dtype )
@@ -52,15 +48,9 @@ def test_astype(x_dtype, dtype, kw, data):
5248 else :
5349 cast = float
5450
55- real_dtype = x_dtype
56- if x_dtype in _complex_dtypes :
57- real_dtype = _float_match_complex (x_dtype )
58- m1 , M1 = dh .dtype_ranges [real_dtype ]
59-
60- real_dtype = dtype
61- if dtype in _complex_dtypes :
62- real_dtype = _float_match_complex (x_dtype )
63- m2 , M2 = dh .dtype_ranges [real_dtype ]
51+ # generate values in range for both src and target dtypes
52+ m1 , M1 = _get_ranges (x_dtype )
53+ m2 , M2 = _get_ranges (dtype )
6454
6555 min_value = cast (max (m1 , m2 ))
6656 max_value = cast (min (M1 , M2 ))
@@ -79,7 +69,7 @@ def test_astype(x_dtype, dtype, kw, data):
7969 # according to the spec, "Casting a complex floating-point array to a real-valued
8070 # data type should not be permitted."
8171 # https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html#astype
82- assume (not ((x_dtype in _complex_dtypes ) and (dtype not in _complex_dtypes )))
72+ assume (not ((x_dtype in dh . complex_dtypes ) and (dtype not in dh . complex_dtypes )))
8373
8474 repro_snippet = ph .format_snippet (f"xp.astype({ x !r} , { dtype !r} , **kw) with { kw = } " )
8575 try :
0 commit comments