Skip to content

Commit 1241dad

Browse files
committed
ENH: allow python scalars in the 2nd argument of searchsorted
The upcoming Array API revision 2025.12 will allow the second argument of searchsorted to be a python scalar, cf data-apis/array-api#982
1 parent 9858e36 commit 1241dad

2 files changed

Lines changed: 6 additions & 2 deletions

File tree

cupy/_sorting/search.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,10 @@ def _searchsorted(a, v, side, sorter, assume_increasing):
430430
raise NotImplementedError('Only int or ndarray are supported for a')
431431

432432
if not isinstance(v, cupy.ndarray):
433-
raise NotImplementedError('Only int or ndarray are supported for v')
433+
if not isinstance(v, int | float | complex):
434+
raise NotImplementedError(
435+
'Only python scalars or ndarrays are supported for v')
436+
v = cupy.asarray(v, dtype=a.dtype)
434437

435438
if a.ndim > 1:
436439
raise ValueError('object too deep for desired array')

tests/cupy_tests/sorting_tests/test_search.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,8 @@ def test_searchsorted(self, xp, dtype):
699699
x = testing.shaped_arange(self.shape, xp, dtype)
700700
bins = xp.array(self.bins)
701701
y = xp.searchsorted(bins, x, side=self.side)
702-
return y,
702+
y1 = xp.searchsorted(bins, 2, side=self.side) # python scalar for `v`
703+
return y, y1
703704

704705
@testing.for_all_dtypes(no_bool=True)
705706
@testing.numpy_cupy_array_equal()

0 commit comments

Comments
 (0)