Skip to content

Commit ac27407

Browse files
committed
ENH: cupy: make isin accept int scalars
1 parent b52cb84 commit ac27407

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

array_api_compat/cupy/_aliases.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,15 @@ def searchsorted(
167167
return cp.searchsorted(x1, x2, side, sorter)
168168

169169

170+
# CuPy isin does not accept scalars
171+
def isin(x1: Array | int, x2: Array | int, /, *, invert: bool = False, **kwds) -> Array:
172+
if isinstance(x1, int):
173+
x1 = cp.asarray(x1)
174+
if isinstance(x2, int):
175+
x2 = cp.asarray(x2)
176+
return cp.isin(x1, x2)
177+
178+
170179
# These functions are completely new here. If the library already has them
171180
# (i.e., numpy 2.0), use the library version instead of our wrapper.
172181
if hasattr(cp, 'vecdot'):
@@ -191,7 +200,7 @@ def searchsorted(
191200
'bool', 'concat', 'count_nonzero', 'pow', 'sign',
192201
'ceil', 'floor', 'trunc', 'take_along_axis',
193202
'broadcast_arrays', 'meshgrid',
194-
'searchsorted',
203+
'searchsorted', 'isin',
195204
]
196205

197206

0 commit comments

Comments
 (0)