Skip to content

Commit 55d7169

Browse files
committed
MAINT: searchsorted: move delegation to _delegation.py
1 parent 699fc51 commit 55d7169

3 files changed

Lines changed: 76 additions & 61 deletions

File tree

src/array_api_extra/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
one_hot,
1313
pad,
1414
partition,
15+
searchsorted,
1516
setdiff1d,
1617
sinc,
1718
union1d,
@@ -23,7 +24,6 @@
2324
default_dtype,
2425
kron,
2526
nunique,
26-
searchsorted,
2727
)
2828
from ._lib._lazy import lazy_apply
2929

src/array_api_extra/_delegation.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"nan_to_num",
2828
"one_hot",
2929
"pad",
30+
"searchsorted",
3031
"sinc",
3132
]
3233

@@ -632,6 +633,76 @@ def pad(
632633
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)
633634

634635

636+
def searchsorted(
637+
x1: Array,
638+
x2: Array,
639+
/,
640+
*,
641+
side: Literal["left", "right"] = "left",
642+
xp: ModuleType | None = None,
643+
) -> Array:
644+
"""
645+
Find indices where elements should be inserted to maintain order.
646+
647+
Find the indices into a sorted array ``x1`` such that if the elements in ``x2``
648+
were inserted before the indices, the resulting array would remain sorted.
649+
650+
Parameters
651+
----------
652+
x1 : Array
653+
Input array. Should have a real-valued data type. Must be sorted in ascending
654+
order along the last axis.
655+
x2 : Array
656+
Array containing search values. Should have a real-valued data type. Must have
657+
the same shape as ``x1`` except along the last axis.
658+
side : {'left', 'right'}, optional
659+
Argument controlling which index is returned if an element of ``x2`` is equal to
660+
one or more elements of ``x1``: ``'left'`` returns the index of the first of
661+
these elements; ``'right'`` returns the next index after the last of these
662+
elements. Default: ``'left'``.
663+
xp : array_namespace, optional
664+
The standard-compatible namespace for the array arguments. Default: infer.
665+
666+
Returns
667+
-------
668+
Array: integer array
669+
An array of indices with the same shape as ``x2``.
670+
671+
Examples
672+
--------
673+
>>> import array_api_strict as xp
674+
>>> import array_api_extra as xpx
675+
>>> x = xp.asarray([11, 12, 13, 13, 14, 15])
676+
>>> xpx.searchsorted(x, xp.asarray([10, 11.5, 14.5, 16]), xp=xp)
677+
Array([0, 1, 5, 6], dtype=array_api_strict.int64)
678+
>>> xpx.searchsorted(x, xp.asarray(13), xp=xp)
679+
Array(2, dtype=array_api_strict.int64)
680+
>>> xpx.searchsorted(x, xp.asarray(13), side='right', xp=xp)
681+
Array(4, dtype=array_api_strict.int64)
682+
683+
`searchsorted` is vectorized along the last axis.
684+
685+
>>> x1 = xp.asarray([[1., 2., 3., 4.], [5., 6., 7., 8.]])
686+
>>> x2 = xp.asarray([[1.1, 3.3], [6.6, 8.8]])
687+
>>> xpx.searchsorted(x1, x2, xp=xp)
688+
Array([[1, 3],
689+
[2, 4]], dtype=array_api_strict.int64)
690+
"""
691+
if xp is None:
692+
xp = array_namespace(x1, x2)
693+
694+
xp_default_int = _funcs.default_dtype(xp, kind="integral")
695+
y_0d = xp.asarray(x2).ndim == 0
696+
x_1d = x1.ndim <= 1
697+
698+
if x_1d or is_torch_namespace(xp):
699+
x2 = xp.reshape(x2, ()) if (y_0d and x_1d) else x2
700+
out = xp.searchsorted(x1, x2, side=side)
701+
return xp.astype(out, xp_default_int, copy=False)
702+
703+
return _funcs.searchsorted(x1, x2, side=side, xp=xp)
704+
705+
635706
def setdiff1d(
636707
x1: Array | complex,
637708
x2: Array | complex,

src/array_api_extra/_lib/_funcs.py

Lines changed: 4 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
array_namespace,
1313
is_dask_namespace,
1414
is_jax_array,
15-
is_torch_namespace,
1615
)
1716
from ._utils._helpers import (
1817
asarrays,
@@ -677,65 +676,10 @@ def searchsorted(
677676
/,
678677
*,
679678
side: Literal["left", "right"] = "left",
680-
xp: ModuleType | None = None,
679+
xp: ModuleType,
681680
) -> Array:
682-
"""
683-
Find indices where elements should be inserted to maintain order.
684-
685-
Find the indices into a sorted array ``x1`` such that if the elements in ``x2``
686-
were inserted before the indices, the resulting array would remain sorted.
687-
688-
Parameters
689-
----------
690-
x1 : Array
691-
Input array. Should have a real-valued data type. Must be sorted in ascending
692-
order along the last axis.
693-
x2 : Array
694-
Array containing search values. Should have a real-valued data type. Must have
695-
the same shape as ``x1`` except along the last axis.
696-
side : {'left', 'right'}, optional
697-
Argument controlling which index is returned if an element of ``x2`` is equal to
698-
one or more elements of ``x1``: ``'left'`` returns the index of the first of
699-
these elements; ``'right'`` returns the next index after the last of these
700-
elements. Default: ``'left'``.
701-
xp : array_namespace, optional
702-
The standard-compatible namespace for the array arguments. Default: infer.
703-
704-
Returns
705-
-------
706-
Array: integer array
707-
An array of indices with the same shape as ``x2``.
708-
709-
Examples
710-
--------
711-
>>> import array_api_strict as xp
712-
>>> import array_api_extra as xpx
713-
>>> x = xp.asarray([11, 12, 13, 13, 14, 15])
714-
>>> xpx.searchsorted(x, xp.asarray([10, 11.5, 14.5, 16]), xp=xp)
715-
Array([0, 1, 5, 6], dtype=array_api_strict.int64)
716-
>>> xpx.searchsorted(x, xp.asarray(13), xp=xp)
717-
Array(2, dtype=array_api_strict.int64)
718-
>>> xpx.searchsorted(x, xp.asarray(13), side='right', xp=xp)
719-
Array(4, dtype=array_api_strict.int64)
720-
721-
`searchsorted` is vectorized along the last axis.
722-
723-
>>> x1 = xp.asarray([[1., 2., 3., 4.], [5., 6., 7., 8.]])
724-
>>> x2 = xp.asarray([[1.1, 3.3], [6.6, 8.8]])
725-
>>> xpx.searchsorted(x1, x2, xp=xp)
726-
Array([[1, 3],
727-
[2, 4]], dtype=array_api_strict.int64)
728-
"""
729-
xp = array_namespace(x1, x2) if xp is None else xp
730-
xp_default_int = xp.asarray(1).dtype
731-
y_0d = xp.asarray(x2).ndim == 0
732-
x_1d = x1.ndim <= 1
733-
734-
if x_1d or is_torch_namespace(xp):
735-
x2 = xp.reshape(x2, ()) if (y_0d and x_1d) else x2
736-
out = xp.searchsorted(x1, x2, side=side)
737-
return xp.astype(out, xp_default_int, copy=False)
738-
681+
# numpydoc ignore=PR01,RT01
682+
"""See docstring in `array_api_extra._delegation.py`."""
739683
a = xp.full(x2.shape, 0, device=_compat.device(x1))
740684

741685
if x1.shape[-1] == 0:
@@ -757,7 +701,7 @@ def searchsorted(
757701

758702
out = xp.where(compare(x2, xp.min(x1, axis=-1, keepdims=True)), 0, b)
759703
out = xp.where(xp.isnan(x2), x1.shape[-1], out) if side == "right" else out
760-
return xp.astype(out, xp_default_int, copy=False)
704+
return xp.astype(out, default_dtype(xp, kind="integral"), copy=False)
761705

762706

763707
def setdiff1d(

0 commit comments

Comments
 (0)