forked from data-apis/array-api-strict
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_set_functions.py
More file actions
133 lines (106 loc) · 4.1 KB
/
Copy path_set_functions.py
File metadata and controls
133 lines (106 loc) · 4.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from typing import NamedTuple
import numpy as np
from ._array_object import Array
from ._flags import requires_data_dependent_shapes, requires_api_version
from ._helpers import _maybe_normalize_py_scalars
from ._dtypes import _result_type
# Note: np.unique() is split into four functions in the array API:
# unique_all, unique_counts, unique_inverse, and unique_values (this is done
# to remove polymorphic return types).
# Note: The various unique() functions are supposed to return multiple NaNs.
# This does not match the NumPy behavior, however, this is currently left as a
# TODO in this implementation as this behavior may be reverted in np.unique().
# See https://github.com/numpy/numpy/issues/20326.
# Note: The functions here return a namedtuple (np.unique() returns a normal
# tuple).
class UniqueAllResult(NamedTuple):
values: Array
indices: Array
inverse_indices: Array
counts: Array
class UniqueCountsResult(NamedTuple):
values: Array
counts: Array
class UniqueInverseResult(NamedTuple):
values: Array
inverse_indices: Array
@requires_data_dependent_shapes
def unique_all(x: Array, /) -> UniqueAllResult:
"""
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
See its docstring for more information.
"""
values, indices, inverse_indices, counts = np.unique(
x._array,
return_counts=True,
return_index=True,
return_inverse=True,
equal_nan=False,
)
# np.unique() flattens inverse indices, but they need to share x's shape
# See https://github.com/numpy/numpy/issues/20638
inverse_indices = inverse_indices.reshape(x.shape)
return UniqueAllResult(
Array._new(values, device=x.device),
Array._new(indices, device=x.device),
Array._new(inverse_indices, device=x.device),
Array._new(counts, device=x.device),
)
@requires_data_dependent_shapes
def unique_counts(x: Array, /) -> UniqueCountsResult:
res = np.unique(
x._array,
return_counts=True,
return_index=False,
return_inverse=False,
equal_nan=False,
)
return UniqueCountsResult(*[Array._new(i, device=x.device) for i in res])
@requires_data_dependent_shapes
def unique_inverse(x: Array, /) -> UniqueInverseResult:
"""
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
See its docstring for more information.
"""
values, inverse_indices = np.unique(
x._array,
return_counts=False,
return_index=False,
return_inverse=True,
equal_nan=False,
)
# np.unique() flattens inverse indices, but they need to share x's shape
# See https://github.com/numpy/numpy/issues/20638
inverse_indices = inverse_indices.reshape(x.shape)
return UniqueInverseResult(Array._new(values, device=x.device),
Array._new(inverse_indices, device=x.device))
@requires_data_dependent_shapes
def unique_values(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
See its docstring for more information.
"""
res = np.unique(
x._array,
return_counts=False,
return_index=False,
return_inverse=False,
equal_nan=False,
)
return Array._new(res, device=x.device)
@requires_api_version('2025.12')
def isin(x1: Array | int, x2: Array | int, /, *, invert: bool = False) -> Array:
"""
Array API compatible wrapper for :py:func:`np.isin <numpy.isin>`.
See its docstring for more information.
"""
# implementation here is from _elementwise_functions.py::_binary_ufunc_proto
x1, x2 = _maybe_normalize_py_scalars(x1, x2, "integer", "isin")
if x1.device != x2.device:
raise ValueError(
f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined."
)
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
# x1, x2 = Array._normalize_two_args(x1, x2) # no need to change 0D -> 1D here
return Array._new(np.isin(x1._array, x2._array), device=x1.device)