11import math
22import warnings
3+ from collections .abc import Callable
34from types import ModuleType
45from typing import Any , Literal , cast
56
2930 one_hot ,
3031 pad ,
3132 partition ,
32- searchsorted ,
3333 setdiff1d ,
3434 sinc ,
3535 union1d ,
3636)
37+ from array_api_extra import (
38+ searchsorted as xpx_searchsorted ,
39+ )
3740from array_api_extra ._lib ._backends import NUMPY_VERSION , Backend
41+ from array_api_extra ._lib ._funcs import searchsorted as _funcs_searchsorted
3842from array_api_extra ._lib ._testing import xfail , xp_assert_close , xp_assert_equal
3943from array_api_extra ._lib ._utils ._compat import (
4044 array_namespace ,
5862lazy_xp_function (pad )
5963# FIXME calls in1d which calls xp.unique_values without size
6064lazy_xp_function (setdiff1d , jax_jit = False )
61- lazy_xp_function (searchsorted )
65+ lazy_xp_function (xpx_searchsorted )
66+ lazy_xp_function (_funcs_searchsorted )
6267lazy_xp_function (sinc )
6368
6469
@@ -1772,7 +1777,7 @@ class TestSearchsorted:
17721777 def test_input_validation (self , xp : ModuleType ):
17731778 message = "`side` must be either 'left' or 'right'."
17741779 with pytest .raises (ValueError , match = message ):
1775- _ = searchsorted (xp .asarray ([1 , 2 ]), xp .asarray ([1 , 2 ]), side = "center" ) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
1780+ _ = xpx_searchsorted (xp .asarray ([1 , 2 ]), xp .asarray ([1 , 2 ]), side = "center" ) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
17761781
17771782 @pytest .mark .parametrize ("side" , ["left" , "right" ])
17781783 @pytest .mark .parametrize ("ties" , [False , True ])
@@ -1781,6 +1786,7 @@ def test_input_validation(self, xp: ModuleType):
17811786 )
17821787 @pytest .mark .parametrize ("nans_x" , [False , True ])
17831788 @pytest .mark .parametrize ("infs_x" , [False , True ])
1789+ @pytest .mark .parametrize ("searchsorted" , [xpx_searchsorted , _funcs_searchsorted ])
17841790 def test_nd (
17851791 self ,
17861792 side : Literal ["left" , "right" ],
@@ -1789,9 +1795,16 @@ def test_nd(
17891795 nans_x : bool ,
17901796 infs_x : bool ,
17911797 xp : ModuleType ,
1798+ searchsorted : Callable [..., Array ],
17921799 ):
1793- if nans_x and is_torch_namespace (xp ):
1800+ if nans_x and is_torch_namespace (xp ) and searchsorted == xpx_searchsorted :
17941801 pytest .skip ("torch sorts NaNs differently" )
1802+ if isinstance (shape , tuple ) and searchsorted == _funcs_searchsorted :
1803+ message = (
1804+ "Redundant; `xpx_searchsorted` delegates to "
1805+ "`_funcs_searchsorted` for multidimensional input."
1806+ )
1807+ pytest .skip (message )
17951808 rng = np .random .default_rng (945298725498274853 )
17961809 x = rng .integers (5 , size = shape ) if ties else rng .random (shape )
17971810 # float32 is to accommodate JAX - nextafter with `float64` is too small?
0 commit comments