Skip to content

Commit 7ef1220

Browse files
mdhaberlucascolley
andauthored
Apply suggestions from code review
Co-authored-by: Lucas Colley <lucas.colley8@gmail.com>
1 parent 8947b8c commit 7ef1220

2 files changed

Lines changed: 9 additions & 9 deletions

File tree

src/array_api_extra/_delegation.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -647,9 +647,9 @@ def searchsorted(
647647
Find the indices into a sorted array ``x1`` such that if the elements in ``x2``
648648
were inserted before the indices, the resulting array would remain sorted.
649649
650-
The behavior of this function is similar to that of the homonymous function in the
651-
array API standard, but it relaxes the requirement that `x1` must be
652-
one-dimensional. The function is vectorized, treating slices along the last axis
650+
The behavior of this function is similar to that of `array_api.searchsorted`,
651+
but it relaxes the requirement that `x1` must be one-dimensional.
652+
This function is vectorized, treating slices along the last axis
653653
as elements and preceding axes as batch (or "loop") dimensions.
654654
655655
Parameters
@@ -701,11 +701,11 @@ def searchsorted(
701701
raise ValueError(message)
702702

703703
xp_default_int = _funcs.default_dtype(xp, kind="integral")
704-
y_0d = xp.asarray(x2).ndim == 0
705-
x_1d = x1.ndim <= 1
704+
x2_0d = x2.ndim == 0
705+
x1_1d = x1.ndim <= 1
706706

707-
if x_1d or is_torch_namespace(xp):
708-
x2 = xp.reshape(x2, ()) if (y_0d and x_1d) else x2
707+
if x1_1d or is_torch_namespace(xp):
708+
x2 = xp.reshape(x2, ()) if (x2_0d and x1_1d) else x2
709709
out = xp.searchsorted(x1, x2, side=side)
710710
return xp.astype(out, xp_default_int, copy=False)
711711

tests/test_funcs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1789,7 +1789,7 @@ def xp_searchsorted(
17891789
side: Literal["left", "right"],
17901790
xp: ModuleType,
17911791
) -> Array:
1792-
return xp.searchsorted(xp.asarray(a), xp.asarray(v), side=side)
1792+
return xp.searchsorted(a, v, side=side)
17931793

17941794

17951795
@pytest.mark.skip_xp_backend(Backend.DASK, reason="no take_along_axis")
@@ -1845,7 +1845,7 @@ def test_nd(
18451845
x[mask] = np.inf
18461846
x = np.sort(x, axis=-1) # type:ignore[assignment]
18471847
x, y = np.asarray(x, dtype=np.float64), np.asarray(y, dtype=np.float64)
1848-
xp_default_int = xp.asarray(1).dtype
1848+
xp_default_int = default_dtype(xp, kind="integral")
18491849
if x.size == 0 and x.ndim > 0 and x.shape[-1] != 0:
18501850
ref = xp.empty((*x.shape[:-1], y.shape[-1]), dtype=xp_default_int)
18511851
else:

0 commit comments

Comments
 (0)