1212 array_namespace ,
1313 is_dask_namespace ,
1414 is_jax_array ,
15- is_torch_namespace ,
1615)
1716from ._utils ._helpers import (
1817 asarrays ,
@@ -677,65 +676,10 @@ def searchsorted(
677676 / ,
678677 * ,
679678 side : Literal ["left" , "right" ] = "left" ,
680- xp : ModuleType | None = None ,
679+ xp : ModuleType ,
681680) -> Array :
682- """
683- Find indices where elements should be inserted to maintain order.
684-
685- Find the indices into a sorted array ``x1`` such that if the elements in ``x2``
686- were inserted before the indices, the resulting array would remain sorted.
687-
688- Parameters
689- ----------
690- x1 : Array
691- Input array. Should have a real-valued data type. Must be sorted in ascending
692- order along the last axis.
693- x2 : Array
694- Array containing search values. Should have a real-valued data type. Must have
695- the same shape as ``x1`` except along the last axis.
696- side : {'left', 'right'}, optional
697- Argument controlling which index is returned if an element of ``x2`` is equal to
698- one or more elements of ``x1``: ``'left'`` returns the index of the first of
699- these elements; ``'right'`` returns the next index after the last of these
700- elements. Default: ``'left'``.
701- xp : array_namespace, optional
702- The standard-compatible namespace for the array arguments. Default: infer.
703-
704- Returns
705- -------
706- Array: integer array
707- An array of indices with the same shape as ``x2``.
708-
709- Examples
710- --------
711- >>> import array_api_strict as xp
712- >>> import array_api_extra as xpx
713- >>> x = xp.asarray([11, 12, 13, 13, 14, 15])
714- >>> xpx.searchsorted(x, xp.asarray([10, 11.5, 14.5, 16]), xp=xp)
715- Array([0, 1, 5, 6], dtype=array_api_strict.int64)
716- >>> xpx.searchsorted(x, xp.asarray(13), xp=xp)
717- Array(2, dtype=array_api_strict.int64)
718- >>> xpx.searchsorted(x, xp.asarray(13), side='right', xp=xp)
719- Array(4, dtype=array_api_strict.int64)
720-
721- `searchsorted` is vectorized along the last axis.
722-
723- >>> x1 = xp.asarray([[1., 2., 3., 4.], [5., 6., 7., 8.]])
724- >>> x2 = xp.asarray([[1.1, 3.3], [6.6, 8.8]])
725- >>> xpx.searchsorted(x1, x2, xp=xp)
726- Array([[1, 3],
727- [2, 4]], dtype=array_api_strict.int64)
728- """
729- xp = array_namespace (x1 , x2 ) if xp is None else xp
730- xp_default_int = xp .asarray (1 ).dtype
731- y_0d = xp .asarray (x2 ).ndim == 0
732- x_1d = x1 .ndim <= 1
733-
734- if x_1d or is_torch_namespace (xp ):
735- x2 = xp .reshape (x2 , ()) if (y_0d and x_1d ) else x2
736- out = xp .searchsorted (x1 , x2 , side = side )
737- return xp .astype (out , xp_default_int , copy = False )
738-
681+ # numpydoc ignore=PR01,RT01
682+ """See docstring in `array_api_extra._delegation.py`."""
739683 a = xp .full (x2 .shape , 0 , device = _compat .device (x1 ))
740684
741685 if x1 .shape [- 1 ] == 0 :
@@ -757,7 +701,7 @@ def searchsorted(
757701
758702 out = xp .where (compare (x2 , xp .min (x1 , axis = - 1 , keepdims = True )), 0 , b )
759703 out = xp .where (xp .isnan (x2 ), x1 .shape [- 1 ], out ) if side == "right" else out
760- return xp .astype (out , xp_default_int , copy = False )
704+ return xp .astype (out , default_dtype ( xp , kind = "integral" ) , copy = False )
761705
762706
763707def setdiff1d (
0 commit comments