Skip to content

Commit 3870f10

Browse files
committed
TST: searchsorted: parameterize over xpx_searchsorted and _funcs_searchsorted
1 parent 48199cf commit 3870f10

1 file changed

Lines changed: 17 additions & 4 deletions

File tree

tests/test_funcs.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import math
22
import warnings
3+
from collections.abc import Callable
34
from types import ModuleType
45
from typing import Any, Literal, cast
56

@@ -29,12 +30,15 @@
2930
one_hot,
3031
pad,
3132
partition,
32-
searchsorted,
3333
setdiff1d,
3434
sinc,
3535
union1d,
3636
)
37+
from array_api_extra import (
38+
searchsorted as xpx_searchsorted,
39+
)
3740
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
41+
from array_api_extra._lib._funcs import searchsorted as _funcs_searchsorted
3842
from array_api_extra._lib._testing import xfail, xp_assert_close, xp_assert_equal
3943
from array_api_extra._lib._utils._compat import (
4044
array_namespace,
@@ -58,7 +62,8 @@
5862
lazy_xp_function(pad)
5963
# FIXME calls in1d which calls xp.unique_values without size
6064
lazy_xp_function(setdiff1d, jax_jit=False)
61-
lazy_xp_function(searchsorted)
65+
lazy_xp_function(xpx_searchsorted)
66+
lazy_xp_function(_funcs_searchsorted)
6267
lazy_xp_function(sinc)
6368

6469

@@ -1772,7 +1777,7 @@ class TestSearchsorted:
17721777
def test_input_validation(self, xp: ModuleType):
17731778
message = "`side` must be either 'left' or 'right'."
17741779
with pytest.raises(ValueError, match=message):
1775-
_ = searchsorted(xp.asarray([1, 2]), xp.asarray([1, 2]), side="center") # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
1780+
_ = xpx_searchsorted(xp.asarray([1, 2]), xp.asarray([1, 2]), side="center") # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
17761781

17771782
@pytest.mark.parametrize("side", ["left", "right"])
17781783
@pytest.mark.parametrize("ties", [False, True])
@@ -1781,6 +1786,7 @@ def test_input_validation(self, xp: ModuleType):
17811786
)
17821787
@pytest.mark.parametrize("nans_x", [False, True])
17831788
@pytest.mark.parametrize("infs_x", [False, True])
1789+
@pytest.mark.parametrize("searchsorted", [xpx_searchsorted, _funcs_searchsorted])
17841790
def test_nd(
17851791
self,
17861792
side: Literal["left", "right"],
@@ -1789,9 +1795,16 @@ def test_nd(
17891795
nans_x: bool,
17901796
infs_x: bool,
17911797
xp: ModuleType,
1798+
searchsorted: Callable[..., Array],
17921799
):
1793-
if nans_x and is_torch_namespace(xp):
1800+
if nans_x and is_torch_namespace(xp) and searchsorted == xpx_searchsorted:
17941801
pytest.skip("torch sorts NaNs differently")
1802+
if isinstance(shape, tuple) and searchsorted == _funcs_searchsorted:
1803+
message = (
1804+
"Redundant; `xpx_searchsorted` delegates to "
1805+
"`_funcs_searchsorted` for multidimensional input."
1806+
)
1807+
pytest.skip(message)
17951808
rng = np.random.default_rng(945298725498274853)
17961809
x = rng.integers(5, size=shape) if ties else rng.random(shape)
17971810
# float32 is to accommodate JAX - nextafter with `float64` is too small?

0 commit comments

Comments
 (0)