Skip to content

Commit e5c5ac2

Browse files
committed
ENH: Add support for "isin"
1 parent 8f8047b commit e5c5ac2

4 files changed

Lines changed: 95 additions & 1 deletion

File tree

src/array_api_extra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
cov,
77
expand_dims,
88
isclose,
9+
isin,
910
nan_to_num,
1011
one_hot,
1112
pad,
@@ -39,6 +40,7 @@
3940
"default_dtype",
4041
"expand_dims",
4142
"isclose",
43+
"isin",
4244
"kron",
4345
"lazy_apply",
4446
"nan_to_num",

src/array_api_extra/_delegation.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,3 +836,62 @@ def argpartition(
836836
# kth is not small compared to x.size
837837

838838
return _funcs.argpartition(a, kth, axis=axis, xp=xp)
839+
840+
841+
def isin(
842+
a: Array,
843+
b: Array,
844+
/,
845+
*,
846+
assume_unique: bool = False,
847+
invert: bool = False,
848+
kind: str | None = None,
849+
xp: ModuleType | None = None,
850+
) -> Array:
851+
"""
852+
Determine whether each element in `a` is present in `b`.
853+
854+
Return a boolean array of the same shape as `a` that is True for elements
855+
that are in `b` and False otherwise.
856+
857+
Parameters
858+
----------
859+
a : array_like
860+
Input elements.
861+
b : array_like
862+
The elements against which to test each element of `a`.
863+
assume_unique : bool, optional
864+
If True, the input arrays are both assumed to be unique which can speed
865+
up the calculation. Default: False.
866+
invert : bool, optional
867+
If True, the values in the returned array are inverted. Default: False.
868+
kind : str | None, optional
869+
The algorithm or method to use. This will not affect the final result,
870+
but will affect the speed and memory use.
871+
For Numpy the options are {None, "sort", "table"}.
872+
For Jax the mapped parameter is instead `method` and the options are
873+
{"compare_all", "binary_search", "sort", and "auto" (default)}
874+
For Cupy, Dask, Torch and the default case this parameter is not present and
875+
thus ignored. Default: None.
876+
xp : array_namespace, optional
877+
The standard-compatible namespace for `a` and `b`. Default: infer.
878+
879+
Returns
880+
-------
881+
array
882+
An array having the same shape as that of `a` that is True for elements
883+
that are in `b` and False otherwise.
884+
"""
885+
if xp is None:
886+
xp = array_namespace(a, b)
887+
888+
if is_numpy_namespace(xp):
889+
return xp.isin(a, b, assume_unique=assume_unique, invert=invert, kind=kind)
890+
if is_jax_namespace(xp):
891+
if kind is None:
892+
kind = "auto"
893+
return xp.isin(a, b, assume_unique=assume_unique, invert=invert, method=kind)
894+
if is_cupy_namespace(xp) or is_torch_namespace(xp) or is_dask_namespace(xp):
895+
return xp.isin(a, b, assume_unique=assume_unique, invert=invert)
896+
897+
return _funcs.isin(a, b, assume_unique=assume_unique, invert=invert, xp=xp)

src/array_api_extra/_lib/_funcs.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,3 +801,25 @@ def argpartition( # numpydoc ignore=PR01,RT01
801801
) -> Array:
802802
"""See docstring in `array_api_extra._delegation.py`."""
803803
return xp.argsort(x, axis=axis, stable=False)
804+
805+
806+
def isin( # numpydoc ignore=PR01,RT01
807+
a: Array,
808+
b: Array,
809+
/,
810+
*,
811+
assume_unique: bool = False,
812+
invert: bool = False,
813+
xp: ModuleType | None = None,
814+
) -> Array:
815+
"""See docstring in `array_api_extra._delegation.py`."""
816+
if xp is None:
817+
xp = array_namespace(a, b)
818+
819+
original_a_shape = a.shape
820+
a = xp.reshape(a, (-1,))
821+
b = xp.reshape(b, (-1,))
822+
return xp.reshape(
823+
_helpers.in1d(a, b, assume_unique=assume_unique, invert=invert, xp=xp),
824+
original_a_shape,
825+
)

tests/test_funcs.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
default_dtype,
2323
expand_dims,
2424
isclose,
25+
isin,
2526
kron,
2627
nan_to_num,
2728
nunique,
@@ -888,7 +889,7 @@ def test_device(self, xp: ModuleType, device: Device, equal_nan: bool):
888889
b = xp.asarray([1e-9, 1e-4, xp.nan], device=device)
889890
res = isclose(a, b, equal_nan=equal_nan)
890891
assert get_device(res) == device
891-
892+
892893
def test_array_on_device_with_scalar(self, xp: ModuleType, device: Device):
893894
a = xp.asarray([0.01, 0.5, 0.8, 0.9, 1.00001], device=device)
894895
b = 1
@@ -1476,3 +1477,13 @@ def test_nd(self, xp: ModuleType, ndim: int):
14761477
@override
14771478
def test_input_validation(self, xp: ModuleType):
14781479
self._test_input_validation(xp)
1480+
1481+
1482+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no unique_inverse")
1483+
class TestIsIn:
1484+
def test_simple(self, xp: ModuleType):
1485+
a = xp.asarray([[0, 2], [4, 6]])
1486+
b = xp.asarray([1, 2, 3, 4])
1487+
expected = xp.asarray([[False, True], [True, False]])
1488+
res = isin(a, b)
1489+
xp_assert_equal(res, expected)

0 commit comments

Comments
 (0)