Skip to content

Commit 11f4d3f

Browse files
authored
Merge pull request data-apis#374 from ev-br/searchsorted_scalars_cupy
ENH: cupy: add a workaround for cp.searchorted 2nd argument
2 parents 8b22efb + 661b531 commit 11f4d3f

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

array_api_compat/cupy/_aliases.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,24 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Arra
149149
return tuple(cp.meshgrid(*arrays, indexing=indexing))
150150

151151

152+
# Match https://github.com/cupy/cupy/pull/9512/ until cupy v14 is the minimum
153+
# supported version
154+
def searchsorted(
155+
x1: Array,
156+
x2: Array | int | float,
157+
/,
158+
*,
159+
side: Literal['left', 'right'] = 'left',
160+
sorter: Array | None = None
161+
) -> Array:
162+
if not isinstance(x2, cp.ndarray):
163+
if not isinstance(x2, int | float | complex):
164+
raise NotImplementedError(
165+
'Only python scalars or ndarrays are supported for x2')
166+
x2 = cp.asarray(x2)
167+
return cp.searchsorted(x1, x2, side, sorter)
168+
169+
152170
# These functions are completely new here. If the library already has them
153171
# (i.e., numpy 2.0), use the library version instead of our wrapper.
154172
if hasattr(cp, 'vecdot'):
@@ -172,7 +190,9 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Arra
172190
'bitwise_invert', 'bitwise_right_shift',
173191
'bool', 'concat', 'count_nonzero', 'pow', 'sign',
174192
'ceil', 'floor', 'trunc', 'take_along_axis',
175-
'broadcast_arrays', 'meshgrid']
193+
'broadcast_arrays', 'meshgrid',
194+
'searchsorted',
195+
]
176196

177197

178198
def __dir__() -> list[str]:

0 commit comments

Comments
 (0)