Skip to content

Commit 7e6f10b

Browse files
committed
ENH Add union1d
1 parent ebe9a5b commit 7e6f10b

4 files changed

Lines changed: 78 additions & 0 deletions

File tree

src/array_api_extra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
pad,
1313
partition,
1414
sinc,
15+
union1d,
1516
)
1617
from ._lib._at import at
1718
from ._lib._funcs import (
@@ -50,4 +51,5 @@
5051
"partition",
5152
"setdiff1d",
5253
"sinc",
54+
"union1d",
5355
]

src/array_api_extra/_delegation.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,3 +895,37 @@ def isin(
895895
return xp.isin(a, b, assume_unique=assume_unique, invert=invert)
896896

897897
return _funcs.isin(a, b, assume_unique=assume_unique, invert=invert, xp=xp)
898+
899+
900+
def union1d(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
901+
"""
902+
Find the union of two arrays.
903+
904+
Return the unique, sorted array of values that are in either of the two
905+
input arrays.
906+
907+
Parameters
908+
----------
909+
a, b : Array
910+
Input arrays. They are flattened internally if they are not already 1D.
911+
912+
xp : array_namespace, optional
913+
The standard-compatible namespace for `a` and `b`. Default: infer.
914+
915+
Returns
916+
-------
917+
Array
918+
Unique, sorted union of the input arrays.
919+
"""
920+
if xp is None:
921+
xp = array_namespace(a, b)
922+
923+
if (
924+
is_numpy_namespace(xp)
925+
or is_cupy_namespace(xp)
926+
or is_dask_namespace(xp)
927+
or is_jax_namespace(xp)
928+
):
929+
return xp.union1d(a, b)
930+
931+
return _funcs.union1d(a, b, xp=xp)

src/array_api_extra/_lib/_funcs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,3 +820,11 @@ def isin( # numpydoc ignore=PR01,RT01
820820
_helpers.in1d(a, b, assume_unique=assume_unique, invert=invert, xp=xp),
821821
original_a_shape,
822822
)
823+
824+
825+
def union1d(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
826+
# numpydoc ignore=PR01,RT01
827+
"""See docstring in `array_api_extra._delegation.py`."""
828+
a = xp.reshape(a, (-1,))
829+
b = xp.reshape(b, (-1,))
830+
return xp.unique_values(xp.concat([a, b]))

tests/test_funcs.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
partition,
3232
setdiff1d,
3333
sinc,
34+
union1d,
3435
)
3536
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
3637
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
@@ -1529,3 +1530,36 @@ def test_kind(self, xp: ModuleType, library: Backend):
15291530
expected = xp.asarray([False, True, False, True])
15301531
res = isin(a, b, kind="sort")
15311532
xp_assert_equal(res, expected)
1533+
1534+
1535+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="unique_values returns arrays")
1536+
@pytest.mark.skip_xp_backend(
1537+
Backend.ARRAY_API_STRICTEST,
1538+
reason="data_dependent_shapes flag for unique_values is disabled",
1539+
)
1540+
class TestUnion1d:
1541+
def test_simple(self, xp: ModuleType):
1542+
a = xp.asarray([-1, 1, 0])
1543+
b = xp.asarray([2, -2, 0])
1544+
expected = xp.asarray([-2, -1, 0, 1, 2])
1545+
res = union1d(a, b)
1546+
xp_assert_equal(res, expected)
1547+
1548+
def test_2d(self, xp: ModuleType):
1549+
a = xp.asarray([[-1, 1, 0], [1, 2, 0]])
1550+
b = xp.asarray([[1, 0, 1], [-2, -1, 0]])
1551+
expected = xp.asarray([-2, -1, 0, 1, 2])
1552+
res = union1d(a, b)
1553+
xp_assert_equal(res, expected)
1554+
1555+
def test_3d(self, xp: ModuleType):
1556+
a = xp.asarray([[[-1, 0], [1, 2]], [[-1, 0], [1, 2]]])
1557+
b = xp.asarray([[[0, 1], [-1, 2]], [[1, -2], [0, 2]]])
1558+
expected = xp.asarray([-2, -1, 0, 1, 2])
1559+
res = union1d(a, b)
1560+
xp_assert_equal(res, expected)
1561+
1562+
def test_device(self, xp: ModuleType, device: Device):
1563+
a = xp.asarray([-1, 1, 0])
1564+
b = xp.asarray([2, -2, 0])
1565+
assert get_device(union1d(a, b)) == device

0 commit comments

Comments
 (0)