Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions spec/draft/API_specification/set_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Objects in API
:toctree: generated
:template: method.rst

isin
unique_all
unique_counts
unique_inverse
Expand Down
47 changes: 45 additions & 2 deletions src/array_api_stubs/_draft/set_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,50 @@
__all__ = ["unique_all", "unique_counts", "unique_inverse", "unique_values"]
__all__ = ["isin", "unique_all", "unique_counts", "unique_inverse", "unique_values"]


from ._types import Tuple, array
from ._types import Tuple, Union, array


def isin(
x1: Union[array, int, float, complex, bool],
x2: Union[array, int, float, complex, bool],
Comment thread
kgryte marked this conversation as resolved.
Outdated
/,
*,
invert: bool = False,
) -> array:
"""
Tests whether each element in ``x1`` is in ``x2``.
Comment thread
kgryte marked this conversation as resolved.
Outdated

Parameters
----------
x1: Union[array, int, float, complex, bool]
first input array. **May** have any data type.
Comment thread
kgryte marked this conversation as resolved.
Outdated
x2: Union[array, int, float, complex, bool]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there maybe be a constraint on the rank of x2?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We intentionally did not impose a constraint and it is not clear whether there is a conceptual reason to do so, as this API is a vectorized API for finding whether a needle (a value) is in a haystack (an array) regardless of the dimensionality of the haystack.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW (at least some) other array libraries happily accept arbitrarily shaped x1 and x2:

In [1]: import numpy as np

In [2]: a = np.arange(3*4*5).reshape(3, 4, 5)

In [3]: b = np.arange(11)

In [4]: np.isin(a, b).shape
Out[4]: (3, 4, 5)

In [5]: np.isin(b, a).shape
Out[5]: (11,)

Here's a hypothesis test data-apis/array-api-tests#407 which does not restrict the shapes, and which seems to pass on numpy,cupy, jax and torch locally.

second input array. **May** have any data type.
invert: bool
boolean indicating whether to invert the test criterion. If ``True``, the function **must** test whether each element in ``x1`` is *not* in ``x2``. If ``False``, the function **must** test whether each element in ``x1`` is in ``x2``. Default: ``False``.
Comment thread
kgryte marked this conversation as resolved.

Returns
-------
out: array
an array containing element-wise test results. The returned array **must** have the same shape as ``x1`` and **must** have a boolean data type.
Comment thread
kgryte marked this conversation as resolved.
Outdated

Notes
-----

- At least one of ``x1`` or ``x2`` **must** be an array.

- If an element in ``x1`` is in ``x2``, the corresponding element in the output array **must** be ``True``; otherwise, the corresponding element in the output array **must** be ``False``.

- Testing whether an element in ``x1`` corresponds to an element in ``x2`` **must** be determined based on value equality (see :func:`~array_api.equal`). For input arrays having floating-point data types, value-based equality implies the following behavior. When ``invert`` is ``False``,

- As ``nan`` values compare as ``False``, if an element in ``x1`` is ``nan``, the corresponding element in the returned array **must** be ``False``.
- As complex floating-point values having at least one ``nan`` component compare as ``False``, if an element in ``x1`` is a complex floating-point value having one or more ``nan`` components, the corresponding element in the returned array **must** be ``False``.
- As ``-0`` and ``+0`` compare as ``True``, if an element in ``x1`` is ``±0`` and ``x2`` contains at least one element which is ``±0``, the corresponding element in the returned array **must** be ``True``.

When ``invert`` is ``True``, the returned array **must** contain the same results as if the operation is implemented as ``logical_not(isin(x1, x2))``.

- Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is unspecified and thus implementation-defined.
"""


def unique_all(x: array, /) -> Tuple[array, array, array, array]:
Expand Down