Skip to content

Commit d6e5db6

Browse files
committed
ENH: cupy: make isin accept int scalars
1 parent b52cb84 commit d6e5db6

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

array_api_compat/cupy/_aliases.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def searchsorted(
156156
x2: Array | int | float,
157157
/,
158158
*,
159-
side: Literal['left', 'right'] = 'left',
159+
side: Literal['left', 'right'] = 'lef
160160
sorter: Array | None = None
161161
) -> Array:
162162
if not isinstance(x2, cp.ndarray):
@@ -167,6 +167,15 @@ def searchsorted(
167167
return cp.searchsorted(x1, x2, side, sorter)
168168

169169

170+
# CuPy isin does not accept scalars
171+
def isin(x1: Array | int, x2: Array | int, /, *, invert: bool = False, **kwds) -> Array:
172+
if isinstance(x1, int):
173+
x1 = cp.asarray(x1)
174+
if isinstance(x2, int):
175+
x2 = cp.asarray(x2)
176+
return xp.isin(x1, x2)
177+
178+
170179
# These functions are completely new here. If the library already has them
171180
# (i.e., numpy 2.0), use the library version instead of our wrapper.
172181
if hasattr(cp, 'vecdot'):
@@ -191,7 +200,7 @@ def searchsorted(
191200
'bool', 'concat', 'count_nonzero', 'pow', 'sign',
192201
'ceil', 'floor', 'trunc', 'take_along_axis',
193202
'broadcast_arrays', 'meshgrid',
194-
'searchsorted',
203+
'searchsorted', 'isin',
195204
]
196205

197206

0 commit comments

Comments
 (0)