-
Notifications
You must be signed in to change notification settings - Fork 609
Expand file tree
/
Copy patharray_api.py
More file actions
121 lines (99 loc) · 3.99 KB
/
array_api.py
File metadata and controls
121 lines (99 loc) · 3.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Utilities for the array API."""
from typing import (
Any,
)
import array_api_compat
import numpy as np
from packaging.version import (
Version,
)
# Type alias for array_api compatible arrays
Array = np.ndarray | Any # Any to support JAX, PyTorch, etc. arrays
# array api adds take_along_axis in https://github.com/data-apis/array-api/pull/816
# but it hasn't been released yet
# below is a pure Python implementation of take_along_axis
# https://github.com/data-apis/array-api/issues/177#issuecomment-2093630595
def xp_swapaxes(a: Array, axis1: int, axis2: int) -> Array:
xp = array_api_compat.array_namespace(a)
axes = list(range(a.ndim))
axes[axis1], axes[axis2] = axes[axis2], axes[axis1]
a = xp.permute_dims(a, axes)
return a
def xp_take_along_axis(arr: Array, indices: Array, axis: int) -> Array:
xp = array_api_compat.array_namespace(arr)
if Version(xp.__array_api_version__) >= Version("2024.12"):
# see: https://github.com/data-apis/array-api-strict/blob/d086c619a58f35c38240592ef994aa19ca7beebc/array_api_strict/_indexing_functions.py#L30-L39
return xp.take_along_axis(arr, indices, axis=axis)
arr = xp_swapaxes(arr, axis, -1)
indices = xp_swapaxes(indices, axis, -1)
m = arr.shape[-1]
n = indices.shape[-1]
shape = list(arr.shape)
shape.pop(-1)
shape = (*shape, n)
arr = xp.reshape(arr, (-1,))
if n != 0:
indices = xp.reshape(indices, (-1, n))
else:
indices = xp.reshape(indices, (0, 0))
offset = (xp.arange(indices.shape[0], dtype=indices.dtype) * m)[:, xp.newaxis]
indices = xp.reshape(offset + indices, (-1,))
out = xp.take(arr, indices)
out = xp.reshape(out, shape)
return xp_swapaxes(out, axis, -1)
def xp_scatter_sum(input: Array, dim: int, index: Array, src: Array) -> Array:
"""Reduces all values from the src tensor to the indices specified in the index tensor."""
if array_api_compat.is_jax_array(input):
from deepmd.jax.common import (
scatter_sum,
)
return scatter_sum(
input,
dim,
index,
src,
)
elif array_api_compat.is_torch_array(input):
# PyTorch: use scatter_add (non-mutating version)
import torch
return torch.scatter_add(input, dim, index, src)
else:
raise NotImplementedError("Only JAX and PyTorch arrays are supported.")
def xp_add_at(x: Array, indices: Array, values: Array) -> Array:
"""Adds values to the specified indices of x in place or returns new x (for JAX)."""
xp = array_api_compat.array_namespace(x, indices, values)
if array_api_compat.is_numpy_array(x):
# NumPy: supports np.add.at (in-place)
xp.add.at(x, indices, values)
return x
elif array_api_compat.is_jax_array(x):
# JAX: functional update, not in-place
return x.at[indices].add(values)
elif array_api_compat.is_torch_array(x):
# PyTorch: use index_add (non-mutating version)
import torch
return torch.index_add(x, 0, indices, values)
else:
# Fallback for array_api_strict: use basic indexing only
# may need a more efficient way to do this
n = indices.shape[0]
for i in range(n):
idx = int(indices[i])
x[idx, ...] = x[idx, ...] + values[i, ...]
return x
def xp_bincount(x: Array, weights: Array | None = None, minlength: int = 0) -> Array:
"""Counts the number of occurrences of each value in x."""
xp = array_api_compat.array_namespace(x)
if (
array_api_compat.is_numpy_array(x)
or array_api_compat.is_jax_array(x)
or array_api_compat.is_torch_array(x)
):
result = xp.bincount(x, weights=weights, minlength=minlength)
else:
if weights is None:
weights = xp.ones_like(x)
result = xp.zeros((max(minlength, int(xp.max(x)) + 1),), dtype=weights.dtype)
result = xp_add_at(result, x, weights)
return result