@@ -647,9 +647,9 @@ def searchsorted(
647647 Find the indices into a sorted array ``x1`` such that if the elements in ``x2``
648648 were inserted before the indices, the resulting array would remain sorted.
649649
650- The behavior of this function is similar to that of the homonymous function in the
651- array API standard, but it relaxes the requirement that `x1` must be
652- one-dimensional. The function is vectorized, treating slices along the last axis
650+ The behavior of this function is similar to that of `array_api.searchsorted`,
651+ but it relaxes the requirement that `x1` must be one-dimensional.
652+ This function is vectorized, treating slices along the last axis
653653 as elements and preceding axes as batch (or "loop") dimensions.
654654
655655 Parameters
@@ -701,11 +701,11 @@ def searchsorted(
701701 raise ValueError (message )
702702
703703 xp_default_int = _funcs .default_dtype (xp , kind = "integral" )
704- y_0d = xp . asarray ( x2 ) .ndim == 0
705- x_1d = x1 .ndim <= 1
704+ x2_0d = x2 .ndim == 0
705+ x1_1d = x1 .ndim <= 1
706706
707- if x_1d or is_torch_namespace (xp ):
708- x2 = xp .reshape (x2 , ()) if (y_0d and x_1d ) else x2
707+ if x1_1d or is_torch_namespace (xp ):
708+ x2 = xp .reshape (x2 , ()) if (x2_0d and x1_1d ) else x2
709709 out = xp .searchsorted (x1 , x2 , side = side )
710710 return xp .astype (out , xp_default_int , copy = False )
711711
0 commit comments