@@ -3646,20 +3646,24 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
36463646 an array does not require a copy.
36473647
36483648 Default: ``True``.
3649- nan : {int, float, bool}, optional
3650- Value to be used to fill ``NaN`` values.
3649+ nan : {scalar, array_like}, optional
3650+ Values to be used to fill ``NaN`` values. If no values are passed then
3651+ ``NaN`` values will be replaced with ``0.0``.
3652+ Expected to have a real-valued data type for the values.
36513653
36523654 Default: ``0.0``.
3653- posinf : {int, float, bool, None }, optional
3654- Value to be used to fill positive infinity values. If no value is
3655+ posinf : {None, scalar, array_like }, optional
3656+ Values to be used to fill positive infinity values. If no values are
36553657 passed then positive infinity values will be replaced with a very
36563658 large number.
3659+ Expected to have a real-valued data type for the values.
36573660
36583661 Default: ``None``.
3659- neginf : {int, float, bool, None} optional
3660- Value to be used to fill negative infinity values. If no value is
3662+ neginf : {None, scalar, array_like}, optional
3663+ Values to be used to fill negative infinity values. If no values are
36613664 passed then negative infinity values will be replaced with a very
36623665 small (or negative) number.
3666+ Expected to have a real-valued data type for the values.
36633667
36643668 Default: ``None``.
36653669
@@ -3687,13 +3691,22 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
36873691 array(-1.79769313e+308)
36883692 >>> np.nan_to_num(np.array(np.nan))
36893693 array(0.)
3694+
36903695 >>> x = np.array([np.inf, -np.inf, np.nan, -128, 128])
36913696 >>> np.nan_to_num(x)
36923697 array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000,
36933698 -1.28000000e+002, 1.28000000e+002])
36943699 >>> np.nan_to_num(x, nan=-9999, posinf=33333333, neginf=33333333)
36953700 array([ 3.3333333e+07, 3.3333333e+07, -9.9990000e+03, -1.2800000e+02,
36963701 1.2800000e+02])
3702+
3703+ >>> nan = np.array([11, 12, -9999, 13, 14])
3704+ >>> posinf = np.array([33333333, 11, 12, 13, 14])
3705+ >>> neginf = np.array([11, 33333333, 12, 13, 14])
3706+ >>> np.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)
3707+ array([ 3.3333333e+07, 3.3333333e+07, -9.9990000e+03, -1.2800000e+02,
3708+ 1.2800000e+02])
3709+
36973710 >>> y = np.array([complex(np.inf, np.nan), np.nan, complex(np.nan, np.inf)])
36983711 >>> np.nan_to_num(y)
36993712 array([1.79769313e+308 +0.00000000e+000j, # may vary
@@ -3706,33 +3719,27 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
37063719
37073720 dpnp .check_supported_arrays_type (x )
37083721
3709- # Python boolean is a subtype of an integer
3710- # so additional check for bool is not needed.
3711- if not isinstance (nan , (int , float )):
3712- raise TypeError (
3713- "nan must be a scalar of an integer, float, bool, "
3714- f"but got { type (nan )} "
3715- )
3716- x_type = x .dtype .type
3722+ def _check_nan_inf (val , name ):
3723+ # Python boolean is a subtype of an integer
3724+ if not isinstance (val , (int , float )):
3725+ val = dpnp .asarray (
3726+ val , sycl_queue = x .sycl_queue , usm_type = x .usm_type
3727+ )
3728+ if dpnp .issubdtype (val .dtype , dpnp .complexfloating ):
3729+ raise TypeError (f"{ name } must not be of a complex type" )
3730+ return val
37173731
3718- if not issubclass (x_type , dpnp .inexact ):
3732+ x_type = x .dtype .type
3733+ if not dpnp .issubdtype (x_type , dpnp .inexact ):
37193734 return dpnp .copy (x ) if copy else dpnp .get_result_array (x )
37203735
37213736 max_f , min_f = _get_max_min (x .real .dtype )
3737+
3738+ nan = _check_nan_inf (nan , "nan" )
37223739 if posinf is not None :
3723- if not isinstance (posinf , (int , float )):
3724- raise TypeError (
3725- "posinf must be a scalar of an integer, float, bool, "
3726- f"or be None, but got { type (posinf )} "
3727- )
3728- max_f = posinf
3740+ max_f = _check_nan_inf (posinf , "posinf" )
37293741 if neginf is not None :
3730- if not isinstance (neginf , (int , float )):
3731- raise TypeError (
3732- "neginf must be a scalar of an integer, float, bool, "
3733- f"or be None, but got { type (neginf )} "
3734- )
3735- min_f = neginf
3742+ min_f = _check_nan_inf (neginf , "neginf" )
37363743
37373744 if copy :
37383745 out = dpnp .empty_like (x )
@@ -3741,19 +3748,45 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
37413748 raise ValueError ("copy is required for read-only array `x`" )
37423749 out = x
37433750
3744- x_ary = dpnp .get_usm_ndarray (x )
3745- out_ary = dpnp .get_usm_ndarray (out )
3746-
3747- q = x .sycl_queue
3748- _manager = dpu .SequentialOrderManager [q ]
3751+ # handle a special case when nan and infs are all scalars
3752+ if all (dpnp .isscalar (el ) for el in (nan , max_f , min_f )):
3753+ x_ary = dpnp .get_usm_ndarray (x )
3754+ out_ary = dpnp .get_usm_ndarray (out )
3755+
3756+ q = x .sycl_queue
3757+ _manager = dpu .SequentialOrderManager [q ]
3758+
3759+ h_ev , comp_ev = ufi ._nan_to_num (
3760+ x_ary ,
3761+ nan ,
3762+ max_f ,
3763+ min_f ,
3764+ out_ary ,
3765+ q ,
3766+ depends = _manager .submitted_events ,
3767+ )
37493768
3750- h_ev , comp_ev = ufi ._nan_to_num (
3751- x_ary , nan , max_f , min_f , out_ary , q , depends = _manager .submitted_events
3752- )
3769+ _manager .add_event_pair (h_ev , comp_ev )
37533770
3754- _manager . add_event_pair ( h_ev , comp_ev )
3771+ return dpnp . get_result_array ( out )
37553772
3756- return dpnp .get_result_array (out )
3773+ # handle a common case with broadcasting of input nan and infs
3774+ if dpnp .issubdtype (x_type , dpnp .complexfloating ):
3775+ parts = (x .real , x .imag )
3776+ parts_out = (out .real , out .imag )
3777+ else :
3778+ parts = (x ,)
3779+ parts_out = (out ,)
3780+
3781+ for part , part_out in zip (parts , parts_out ):
3782+ nan_mask = dpnp .isnan (part )
3783+ posinf_mask = dpnp .isposinf (part )
3784+ neginf_mask = dpnp .isneginf (part )
3785+
3786+ part = dpnp .where (nan_mask , nan , part , out = part_out )
3787+ part = dpnp .where (posinf_mask , max_f , part , out = part_out )
3788+ part = dpnp .where (neginf_mask , min_f , part , out = part_out )
3789+ return out
37573790
37583791
37593792_NEGATIVE_DOCSTRING = """
0 commit comments