Skip to content

Commit 1077877

Browse files
committed
Add more testing when input array has any floating or complex dtype
1 parent a1cfdf1 commit 1077877

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

dpnp/dpnp_iface_mathematical.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3719,14 +3719,12 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
37193719

37203720
dpnp.check_supported_arrays_type(x)
37213721

3722-
def _check_nan_inf(val, name):
3722+
def _check_nan_inf(val, val_dt):
37233723
# Python boolean is a subtype of an integer
37243724
if not isinstance(val, (int, float)):
37253725
val = dpnp.asarray(
3726-
val, sycl_queue=x.sycl_queue, usm_type=x.usm_type
3726+
val, dtype=val_dt, sycl_queue=x.sycl_queue, usm_type=x.usm_type
37273727
)
3728-
if dpnp.issubdtype(val.dtype, dpnp.complexfloating):
3729-
raise TypeError(f"{name} must not be of a complex type")
37303728
return val
37313729

37323730
x_type = x.dtype.type
@@ -3735,11 +3733,18 @@ def _check_nan_inf(val, name):
37353733

37363734
max_f, min_f = _get_max_min(x.real.dtype)
37373735

3738-
nan = _check_nan_inf(nan, "nan")
3736+
# get dtype of nan and infs values if casting required
3737+
is_complex = dpnp.issubdtype(x_type, dpnp.complexfloating)
3738+
if is_complex:
3739+
val_dt = x.real.dtype
3740+
else:
3741+
val_dt = x.dtype
3742+
3743+
nan = _check_nan_inf(nan, val_dt)
37393744
if posinf is not None:
3740-
max_f = _check_nan_inf(posinf, "posinf")
3745+
max_f = _check_nan_inf(posinf, val_dt)
37413746
if neginf is not None:
3742-
min_f = _check_nan_inf(neginf, "neginf")
3747+
min_f = _check_nan_inf(neginf, val_dt)
37433748

37443749
if copy:
37453750
out = dpnp.empty_like(x)
@@ -3771,7 +3776,7 @@ def _check_nan_inf(val, name):
37713776
return dpnp.get_result_array(out)
37723777

37733778
# handle a common case with broadcasting of input nan and infs
3774-
if dpnp.issubdtype(x_type, dpnp.complexfloating):
3779+
if is_complex:
37753780
parts = (x.real, x.imag)
37763781
parts_out = (out.real, out.imag)
37773782
else:

dpnp/tests/test_mathematical.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,10 +1480,11 @@ def test_boolean_array(self):
14801480
expected = numpy.nan_to_num(a)
14811481
assert_allclose(result, expected)
14821482

1483+
@pytest.mark.parametrize("dt", get_float_complex_dtypes())
14831484
@pytest.mark.parametrize("kw_name", ["nan", "posinf", "neginf"])
1484-
@pytest.mark.parametrize("val", [[1, 2, -1, -2, 7], (7,), numpy.array(1)])
1485-
def test_nan_infs_array_like(self, kw_name, val):
1486-
a = numpy.array([0, 1, dpnp.nan, dpnp.inf, -dpnp.inf])
1485+
@pytest.mark.parametrize("val", [[1, 2, -1, -2, 7], (7.0,), numpy.array(1)])
1486+
def test_nan_infs_array_like(self, dt, kw_name, val):
1487+
a = numpy.array([0, 1, dpnp.nan, dpnp.inf, -dpnp.inf], dtype=dt)
14871488
ia = dpnp.array(a)
14881489

14891490
result = dpnp.nan_to_num(ia, **{kw_name: val})
@@ -1494,7 +1495,7 @@ def test_nan_infs_array_like(self, kw_name, val):
14941495
@pytest.mark.parametrize("kw_name", ["nan", "posinf", "neginf"])
14951496
def test_nan_infs_complex_dtype(self, xp, kw_name):
14961497
ia = xp.array([0, 1, xp.nan, xp.inf, -xp.inf])
1497-
with pytest.raises((TypeError, ValueError), match="complex.*type"):
1498+
with pytest.raises(TypeError, match="complex"):
14981499
xp.nan_to_num(ia, **{kw_name: 1j})
14991500

15001501
def test_numpy_input_array(self):

0 commit comments

Comments
 (0)