Skip to content

Commit 8bc9071

Browse files
committed
ENH: add unravel_index
1 parent 904e2e8 commit 8bc9071

5 files changed

Lines changed: 116 additions & 0 deletions

File tree

docs/api-assorted.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,5 @@
2727
setdiff1d
2828
sinc
2929
union1d
30+
unravel_index
3031
```

src/array_api_extra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
setdiff1d,
2020
sinc,
2121
union1d,
22+
unravel_index,
2223
)
2324
from ._lib._at import at
2425
from ._lib._funcs import (
@@ -58,4 +59,5 @@
5859
"sinc",
5960
"testing",
6061
"union1d",
62+
"unravel_index",
6163
]

src/array_api_extra/_delegation.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"pad",
3232
"searchsorted",
3333
"sinc",
34+
"unravel_index",
3435
]
3536

3637

@@ -1307,3 +1308,56 @@ def union1d(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
13071308
return xp.union1d(a, b)
13081309

13091310
return _funcs.union1d(a, b, xp=xp)
1311+
1312+
1313+
def unravel_index(
1314+
ind: Array,
1315+
shape: tuple[int, ...],
1316+
/,
1317+
*,
1318+
xp: ModuleType | None = None,
1319+
) -> tuple[Array, ...]:
1320+
"""
1321+
Convert a flat index or array of flat indices into a tuple of coordinate arrays.
1322+
1323+
Parameters
1324+
----------
1325+
ind : array
1326+
An integer array whose elements are indices into the flattened version
1327+
of an array of dimensions `shape`.
1328+
1329+
shape : tuple of ints
1330+
The shape to use for unraveling `indices`.
1331+
1332+
xp : array_namespace, optional
1333+
The standard-compatible namespace for `x`. Default: infer.
1334+
1335+
Returns
1336+
-------
1337+
tuple of array
1338+
A tuple of unraveled indices. Each array in the tuple has the same shape
1339+
as the `indices` array.
1340+
1341+
Examples
1342+
--------
1343+
>>> import array_api_extra as xpx
1344+
>>> import array_api_strict as xp
1345+
>>> xpx.unravel_index(xp.asarray([1, 2, 3, 4, 5]), (4, 3))
1346+
(
1347+
Array([0, 0, 1, 1, 1], dtype=array_api_strict.int64),
1348+
Array([1, 2, 0, 1, 2], dtype=array_api_strict.int64),
1349+
)
1350+
"""
1351+
if xp is None:
1352+
xp = array_namespace(ind)
1353+
1354+
if (
1355+
is_numpy_namespace(xp)
1356+
or is_cupy_namespace(xp)
1357+
or is_dask_namespace(xp)
1358+
or is_jax_namespace(xp)
1359+
or is_torch_namespace(xp)
1360+
):
1361+
return xp.unravel_index(ind, shape)
1362+
1363+
return _funcs.unravel_index(ind, shape)

src/array_api_extra/_lib/_funcs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,3 +757,13 @@ def angle(z: Array, /, *, deg: bool = False, xp: ModuleType | None = None) -> Ar
757757
if deg:
758758
a = a * 180 / xp.pi
759759
return a
760+
761+
762+
def unravel_index(ind: Array, shape: tuple[int, ...], /) -> tuple[Array, ...]:
763+
# numpydoc ignore=PR01,RT01
764+
"""See docstring in `array_api_extra._delegation.py`."""
765+
coords: list[Array] = []
766+
for dim in reversed(shape):
767+
coords.append(ind % dim)
768+
ind = ind // dim
769+
return tuple(reversed(coords))

tests/test_funcs.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
setdiff1d,
3535
sinc,
3636
union1d,
37+
unravel_index,
3738
)
3839
from array_api_extra import (
3940
searchsorted as xpx_searchsorted,
@@ -1981,3 +1982,51 @@ def test_2d(self, xp: ModuleType):
19811982
def test_device(self, xp: ModuleType, device: Device):
19821983
a = xp.asarray([1 + 1j], device=device)
19831984
assert get_device(angle(a)) == device
1985+
1986+
1987+
class TestUnravelIndex:
1988+
def test_simple(self, xp: ModuleType):
1989+
ind = xp.asarray([22, 41, 37])
1990+
shape = (7, 6)
1991+
expected = (xp.asarray([3, 6, 6]), xp.asarray([4, 5, 1]))
1992+
res = unravel_index(ind, shape)
1993+
for res_arr, exp_arr in zip(res, expected, strict=True):
1994+
assert_equal(res_arr, exp_arr)
1995+
1996+
ind = xp.asarray([0, 1, 2, 3, 4, 5])
1997+
shape = (3, 2)
1998+
expected = (
1999+
xp.asarray([0, 0, 1, 1, 2, 2]),
2000+
xp.asarray([0, 1, 0, 1, 0, 1]),
2001+
)
2002+
res = unravel_index(ind, shape)
2003+
for res_arr, exp_arr in zip(res, expected, strict=True):
2004+
assert_equal(res_arr, exp_arr)
2005+
2006+
def test_indices_scalar(self, xp: ModuleType):
2007+
ind = xp.asarray(1621)
2008+
shape = (6, 7, 8, 9)
2009+
expected = (xp.asarray(3), xp.asarray(1), xp.asarray(4), xp.asarray(1))
2010+
res = unravel_index(ind, shape)
2011+
# a tuple of integers is expected
2012+
assert res == expected
2013+
2014+
def test_indices_2d(self, xp: ModuleType):
2015+
ind = xp.asarray([[1234], [5678]])
2016+
shape = (10, 10, 10, 10)
2017+
expected = (
2018+
xp.asarray([[1], [5]]),
2019+
xp.asarray([[2], [6]]),
2020+
xp.asarray([[3], [7]]),
2021+
xp.asarray([[4], [8]]),
2022+
)
2023+
res = unravel_index(ind, shape)
2024+
for res_arr, exp_arr in zip(res, expected, strict=True):
2025+
assert_equal(res_arr, exp_arr)
2026+
2027+
def test_device(self, xp: ModuleType, device: Device):
2028+
ind = xp.asarray([4, 1], device=device)
2029+
shape = (3, 2)
2030+
res = unravel_index(ind, shape)
2031+
for res_arr in res:
2032+
assert get_device(res_arr) == device

0 commit comments

Comments
 (0)