Skip to content

Commit 883a654

Browse files
committed
Add brroadcasting support for nan, posinf, and neginf kwargs
1 parent 1f7f4d9 commit 883a654

File tree

1 file changed

+70
-37
lines changed

1 file changed

+70
-37
lines changed

dpnp/dpnp_iface_mathematical.py

Lines changed: 70 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3646,20 +3646,24 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
36463646
an array does not require a copy.
36473647
36483648
Default: ``True``.
3649-
nan : {int, float, bool}, optional
3650-
Value to be used to fill ``NaN`` values.
3649+
nan : {scalar, array_like}, optional
3650+
Values to be used to fill ``NaN`` values. If no values are passed then
3651+
``NaN`` values will be replaced with ``0.0``.
3652+
Expected to have a real-valued data type for the values.
36513653
36523654
Default: ``0.0``.
3653-
posinf : {int, float, bool, None}, optional
3654-
Value to be used to fill positive infinity values. If no value is
3655+
posinf : {None, scalar, array_like}, optional
3656+
Values to be used to fill positive infinity values. If no values are
36553657
passed then positive infinity values will be replaced with a very
36563658
large number.
3659+
Expected to have a real-valued data type for the values.
36573660
36583661
Default: ``None``.
3659-
neginf : {int, float, bool, None} optional
3660-
Value to be used to fill negative infinity values. If no value is
3662+
neginf : {None, scalar, array_like}, optional
3663+
Values to be used to fill negative infinity values. If no values are
36613664
passed then negative infinity values will be replaced with a very
36623665
small (or negative) number.
3666+
Expected to have a real-valued data type for the values.
36633667
36643668
Default: ``None``.
36653669
@@ -3687,13 +3691,22 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
36873691
array(-1.79769313e+308)
36883692
>>> np.nan_to_num(np.array(np.nan))
36893693
array(0.)
3694+
36903695
>>> x = np.array([np.inf, -np.inf, np.nan, -128, 128])
36913696
>>> np.nan_to_num(x)
36923697
array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000,
36933698
-1.28000000e+002, 1.28000000e+002])
36943699
>>> np.nan_to_num(x, nan=-9999, posinf=33333333, neginf=33333333)
36953700
array([ 3.3333333e+07, 3.3333333e+07, -9.9990000e+03, -1.2800000e+02,
36963701
1.2800000e+02])
3702+
3703+
>>> nan = np.array([11, 12, -9999, 13, 14])
3704+
>>> posinf = np.array([33333333, 11, 12, 13, 14])
3705+
>>> neginf = np.array([11, 33333333, 12, 13, 14])
3706+
>>> np.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)
3707+
array([ 3.3333333e+07, 3.3333333e+07, -9.9990000e+03, -1.2800000e+02,
3708+
1.2800000e+02])
3709+
36973710
>>> y = np.array([complex(np.inf, np.nan), np.nan, complex(np.nan, np.inf)])
36983711
>>> np.nan_to_num(y)
36993712
array([1.79769313e+308 +0.00000000e+000j, # may vary
@@ -3706,33 +3719,27 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
37063719

37073720
dpnp.check_supported_arrays_type(x)
37083721

3709-
# Python boolean is a subtype of an integer
3710-
# so additional check for bool is not needed.
3711-
if not isinstance(nan, (int, float)):
3712-
raise TypeError(
3713-
"nan must be a scalar of an integer, float, bool, "
3714-
f"but got {type(nan)}"
3715-
)
3716-
x_type = x.dtype.type
3722+
def _check_nan_inf(val, name):
3723+
# Python boolean is a subtype of an integer
3724+
if not isinstance(val, (int, float)):
3725+
val = dpnp.asarray(
3726+
val, sycl_queue=x.sycl_queue, usm_type=x.usm_type
3727+
)
3728+
if dpnp.issubdtype(val.dtype, dpnp.complexfloating):
3729+
raise TypeError(f"{name} must not be of a complex type")
3730+
return val
37173731

3718-
if not issubclass(x_type, dpnp.inexact):
3732+
x_type = x.dtype.type
3733+
if not dpnp.issubdtype(x_type, dpnp.inexact):
37193734
return dpnp.copy(x) if copy else dpnp.get_result_array(x)
37203735

37213736
max_f, min_f = _get_max_min(x.real.dtype)
3737+
3738+
nan = _check_nan_inf(nan, "nan")
37223739
if posinf is not None:
3723-
if not isinstance(posinf, (int, float)):
3724-
raise TypeError(
3725-
"posinf must be a scalar of an integer, float, bool, "
3726-
f"or be None, but got {type(posinf)}"
3727-
)
3728-
max_f = posinf
3740+
max_f = _check_nan_inf(posinf, "posinf")
37293741
if neginf is not None:
3730-
if not isinstance(neginf, (int, float)):
3731-
raise TypeError(
3732-
"neginf must be a scalar of an integer, float, bool, "
3733-
f"or be None, but got {type(neginf)}"
3734-
)
3735-
min_f = neginf
3742+
min_f = _check_nan_inf(neginf, "neginf")
37363743

37373744
if copy:
37383745
out = dpnp.empty_like(x)
@@ -3741,19 +3748,45 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
37413748
raise ValueError("copy is required for read-only array `x`")
37423749
out = x
37433750

3744-
x_ary = dpnp.get_usm_ndarray(x)
3745-
out_ary = dpnp.get_usm_ndarray(out)
3746-
3747-
q = x.sycl_queue
3748-
_manager = dpu.SequentialOrderManager[q]
3751+
# handle a special case when nan and infs are all scalars
3752+
if all(dpnp.isscalar(el) for el in (nan, max_f, min_f)):
3753+
x_ary = dpnp.get_usm_ndarray(x)
3754+
out_ary = dpnp.get_usm_ndarray(out)
3755+
3756+
q = x.sycl_queue
3757+
_manager = dpu.SequentialOrderManager[q]
3758+
3759+
h_ev, comp_ev = ufi._nan_to_num(
3760+
x_ary,
3761+
nan,
3762+
max_f,
3763+
min_f,
3764+
out_ary,
3765+
q,
3766+
depends=_manager.submitted_events,
3767+
)
37493768

3750-
h_ev, comp_ev = ufi._nan_to_num(
3751-
x_ary, nan, max_f, min_f, out_ary, q, depends=_manager.submitted_events
3752-
)
3769+
_manager.add_event_pair(h_ev, comp_ev)
37533770

3754-
_manager.add_event_pair(h_ev, comp_ev)
3771+
return dpnp.get_result_array(out)
37553772

3756-
return dpnp.get_result_array(out)
3773+
# handle a common case with broadcasting of input nan and infs
3774+
if dpnp.issubdtype(x_type, dpnp.complexfloating):
3775+
parts = (x.real, x.imag)
3776+
parts_out = (out.real, out.imag)
3777+
else:
3778+
parts = (x,)
3779+
parts_out = (out,)
3780+
3781+
for part, part_out in zip(parts, parts_out):
3782+
nan_mask = dpnp.isnan(part)
3783+
posinf_mask = dpnp.isposinf(part)
3784+
neginf_mask = dpnp.isneginf(part)
3785+
3786+
part = dpnp.where(nan_mask, nan, part, out=part_out)
3787+
part = dpnp.where(posinf_mask, max_f, part, out=part_out)
3788+
part = dpnp.where(neginf_mask, min_f, part, out=part_out)
3789+
return out
37573790

37583791

37593792
_NEGATIVE_DOCSTRING = """

0 commit comments

Comments
 (0)