Skip to content

Commit 51ade21

Browse files
committed
initial work with a small test; fails for sparse backend as it has not argsort;
1 parent ca20f03 commit 51ade21

4 files changed

Lines changed: 154 additions & 1 deletion

File tree

src/array_api_extra/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._delegation import isclose, nan_to_num, one_hot, pad
3+
from ._delegation import (
4+
argpartition,
5+
isclose,
6+
nan_to_num,
7+
one_hot,
8+
pad,
9+
partition,
10+
)
411
from ._lib._at import at
512
from ._lib._funcs import (
613
apply_where,
@@ -23,6 +30,7 @@
2330
__all__ = [
2431
"__version__",
2532
"apply_where",
33+
"argpartition",
2634
"at",
2735
"atleast_nd",
2836
"broadcast_shapes",
@@ -37,6 +45,7 @@
3745
"nunique",
3846
"one_hot",
3947
"pad",
48+
"partition",
4049
"setdiff1d",
4150
"sinc",
4251
]

src/array_api_extra/_delegation.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,110 @@ def pad(
326326
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
327327

328328
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)
329+
330+
331+
def partition(
332+
a: Array,
333+
kth: int,
334+
*,
335+
xp: ModuleType | None = None,
336+
) -> Array:
337+
"""
338+
Return a partitioned copy of an array.
339+
340+
Parameters
341+
----------
342+
a : 1-dimensional array
343+
Input array.
344+
kth : int
345+
Element index to partition by.
346+
xp : array_namespace, optional
347+
The standard-compatible namespace for `x`. Default: infer.
348+
349+
Returns
350+
-------
351+
partitioned_array
352+
Array of the same type and shape as a.
353+
"""
354+
# Validate inputs.
355+
if xp is None:
356+
xp = array_namespace(a)
357+
if a.ndim != 1:
358+
msg = "only 1-dimensional arrays are currently supported"
359+
raise NotImplementedError(msg)
360+
361+
# Delegate where possible.
362+
if is_numpy_namespace(xp) or is_cupy_namespace(xp):
363+
return xp.partition(a, kth)
364+
if is_jax_namespace(xp):
365+
from jax import numpy
366+
367+
return numpy.partition(a, kth)
368+
369+
# Use top-k when possible:
370+
if is_torch_namespace(xp):
371+
from torch import topk
372+
373+
a_left, indices_left = topk(a, kth, largest=False, sorted=False)
374+
mask_right = xp.ones(a.shape, dtype=bool)
375+
mask_right[indices_left] = False
376+
return xp.concat((a_left, a[mask_right]))
377+
# Note: dask topk/argtopk sort the return values, so it's
378+
# not much more efficient than sorting everything when
379+
# kth is not small compared to x.size
380+
381+
return _funcs.partition(a, kth, xp=xp)
382+
383+
384+
def argpartition(
385+
a: Array,
386+
kth: int,
387+
*,
388+
xp: ModuleType | None = None,
389+
) -> Array:
390+
"""
391+
Perform an indirect partition along the given axis.
392+
393+
Parameters
394+
----------
395+
a : 1-dimensional array
396+
Input array.
397+
kth : int
398+
Element index to partition by.
399+
xp : array_namespace, optional
400+
The standard-compatible namespace for `x`. Default: infer.
401+
402+
Returns
403+
-------
404+
index_array
405+
Array of indices that partition `a` along the specified axis.
406+
"""
407+
# Validate inputs.
408+
if xp is None:
409+
xp = array_namespace(a)
410+
if a.ndim != 1:
411+
msg = "only 1-dimensional arrays are currently supported"
412+
raise NotImplementedError(msg)
413+
414+
# Delegate where possible.
415+
if is_numpy_namespace(xp) or is_cupy_namespace(xp):
416+
return xp.argpartition(a, kth)
417+
if is_jax_namespace(xp):
418+
from jax import numpy
419+
420+
return numpy.argpartition(a, kth)
421+
422+
# Use top-k when possible:
423+
if is_torch_namespace(xp):
424+
from torch import topk
425+
426+
_, indices = topk(a, kth, largest=False, sorted=False)
427+
mask = xp.ones(a.shape, dtype=bool)
428+
mask[indices] = False
429+
indices_above = xp.arange(a.shape[0])[mask]
430+
return xp.concat((indices, indices_above))
431+
# Note: dask topk/argtopk sort the return values, so it's
432+
# not much more efficient than sorting everything when
433+
# kth is not small compared to x.size
434+
435+
return _funcs.argpartition(a, kth, xp=xp)

src/array_api_extra/_lib/_funcs.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,3 +1029,23 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
10291029
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)),
10301030
)
10311031
return xp.sin(y) / y
1032+
1033+
1034+
def partition( # numpydoc ignore=PR01,RT01
1035+
x: Array,
1036+
kth: int, # noqa: ARG001
1037+
*,
1038+
xp: ModuleType,
1039+
) -> Array:
1040+
"""See docstring in `array_api_extra._delegation.py`."""
1041+
return xp.sort(x, stable=False)
1042+
1043+
1044+
def argpartition( # numpydoc ignore=PR01,RT01
1045+
x: Array,
1046+
kth: int, # noqa: ARG001
1047+
*,
1048+
xp: ModuleType,
1049+
) -> Array:
1050+
"""See docstring in `array_api_extra._delegation.py`."""
1051+
return xp.argsort(x, stable=False)

tests/test_funcs.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from array_api_extra import (
1414
apply_where,
15+
argpartition,
1516
at,
1617
atleast_nd,
1718
broadcast_shapes,
@@ -25,6 +26,7 @@
2526
nunique,
2627
one_hot,
2728
pad,
29+
partition,
2830
setdiff1d,
2931
sinc,
3032
)
@@ -1298,3 +1300,18 @@ def test_device(self, xp: ModuleType, device: Device):
12981300

12991301
def test_xp(self, xp: ModuleType):
13001302
xp_assert_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0))
1303+
1304+
1305+
class TestPartition:
1306+
def test_basic(self, xp: ModuleType):
1307+
# Using 0-dimensional array
1308+
rng = np.random.default_rng(2847)
1309+
1310+
for _ in range(100):
1311+
n = rng.integers(1, 1000)
1312+
x = xp.asarray(rng.random(size=n))
1313+
k = int(rng.integers(1, n - 1))
1314+
y = partition(x, k)
1315+
assert xp.max(y[:k]) <= xp.min(y[k:])
1316+
y = x[argpartition(x, k)]
1317+
assert xp.max(y[:k]) <= xp.min(y[k:])

0 commit comments

Comments
 (0)