-
Notifications
You must be signed in to change notification settings - Fork 609
Expand file tree
/
Copy patharray_api.py
More file actions
138 lines (111 loc) · 4.14 KB
/
array_api.py
File metadata and controls
138 lines (111 loc) · 4.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Utilities for the array API."""
from typing import (
Any,
Callable,
Optional,
Union,
)
import array_api_compat
import numpy as np
from packaging.version import (
Version,
)
# Type alias for array_api compatible arrays
Array = Union[np.ndarray, Any] # Any to support JAX, PyTorch, etc. arrays
def support_array_api(version: str) -> Callable:
"""Mark a function as supporting the specific version of the array API.
Parameters
----------
version : str
The version of the array API
Returns
-------
Callable
The decorated function
Examples
--------
>>> @support_array_api(version="2022.12")
... def f(x):
... pass
"""
def set_version(func: Callable) -> Callable:
func.array_api_version = version
return func
return set_version
# array api adds take_along_axis in https://github.com/data-apis/array-api/pull/816
# but it hasn't been released yet
# below is a pure Python implementation of take_along_axis
# https://github.com/data-apis/array-api/issues/177#issuecomment-2093630595
def xp_swapaxes(a: Array, axis1: int, axis2: int) -> Array:
xp = array_api_compat.array_namespace(a)
axes = list(range(a.ndim))
axes[axis1], axes[axis2] = axes[axis2], axes[axis1]
a = xp.permute_dims(a, axes)
return a
def xp_take_along_axis(arr: Array, indices: Array, axis: int) -> Array:
xp = array_api_compat.array_namespace(arr)
if Version(xp.__array_api_version__) >= Version("2024.12"):
# see: https://github.com/data-apis/array-api-strict/blob/d086c619a58f35c38240592ef994aa19ca7beebc/array_api_strict/_indexing_functions.py#L30-L39
return xp.take_along_axis(arr, indices, axis=axis)
arr = xp_swapaxes(arr, axis, -1)
indices = xp_swapaxes(indices, axis, -1)
m = arr.shape[-1]
n = indices.shape[-1]
shape = list(arr.shape)
shape.pop(-1)
shape = (*shape, n)
arr = xp.reshape(arr, (-1,))
if n != 0:
indices = xp.reshape(indices, (-1, n))
else:
indices = xp.reshape(indices, (0, 0))
offset = (xp.arange(indices.shape[0], dtype=indices.dtype) * m)[:, xp.newaxis]
indices = xp.reshape(offset + indices, (-1,))
out = xp.take(arr, indices)
out = xp.reshape(out, shape)
return xp_swapaxes(out, axis, -1)
def xp_scatter_sum(input: Array, dim: int, index: Array, src: Array) -> Array:
"""Reduces all values from the src tensor to the indices specified in the index tensor."""
# jax only
if array_api_compat.is_jax_array(input):
from deepmd.jax.common import (
scatter_sum,
)
return scatter_sum(
input,
dim,
index,
src,
)
else:
raise NotImplementedError("Only JAX arrays are supported.")
def xp_add_at(x: Array, indices: Array, values: Array) -> Array:
"""Adds values to the specified indices of x in place or returns new x (for JAX)."""
xp = array_api_compat.array_namespace(x, indices, values)
if array_api_compat.is_numpy_array(x):
# NumPy: supports np.add.at (in-place)
xp.add.at(x, indices, values)
return x
elif array_api_compat.is_jax_array(x):
# JAX: functional update, not in-place
return x.at[indices].add(values)
else:
# Fallback for array_api_strict: use basic indexing only
# may need a more efficient way to do this
n = indices.shape[0]
for i in range(n):
idx = int(indices[i])
x[idx, ...] = x[idx, ...] + values[i, ...]
return x
def xp_bincount(x: Array, weights: Optional[Array] = None, minlength: int = 0) -> Array:
"""Counts the number of occurrences of each value in x."""
xp = array_api_compat.array_namespace(x)
if array_api_compat.is_numpy_array(x) or array_api_compat.is_jax_array(x):
result = xp.bincount(x, weights=weights, minlength=minlength)
else:
if weights is None:
weights = xp.ones_like(x)
result = xp.zeros((max(minlength, int(xp.max(x)) + 1),), dtype=weights.dtype)
result = xp_add_at(result, x, weights)
return result