|
34 | 34 | setdiff1d, |
35 | 35 | sinc, |
36 | 36 | union1d, |
| 37 | + unravel_index, |
37 | 38 | ) |
38 | 39 | from array_api_extra import ( |
39 | 40 | searchsorted as xpx_searchsorted, |
@@ -1981,3 +1982,51 @@ def test_2d(self, xp: ModuleType): |
1981 | 1982 | def test_device(self, xp: ModuleType, device: Device): |
1982 | 1983 | a = xp.asarray([1 + 1j], device=device) |
1983 | 1984 | assert get_device(angle(a)) == device |
| 1985 | + |
| 1986 | + |
| 1987 | +class TestUnravelIndex: |
| 1988 | + def test_simple(self, xp: ModuleType): |
| 1989 | + ind = xp.asarray([22, 41, 37]) |
| 1990 | + shape = (7, 6) |
| 1991 | + expected = (xp.asarray([3, 6, 6]), xp.asarray([4, 5, 1])) |
| 1992 | + res = unravel_index(ind, shape) |
| 1993 | + for res_arr, exp_arr in zip(res, expected, strict=True): |
| 1994 | + assert_equal(res_arr, exp_arr) |
| 1995 | + |
| 1996 | + ind = xp.asarray([0, 1, 2, 3, 4, 5]) |
| 1997 | + shape = (3, 2) |
| 1998 | + expected = ( |
| 1999 | + xp.asarray([0, 0, 1, 1, 2, 2]), |
| 2000 | + xp.asarray([0, 1, 0, 1, 0, 1]), |
| 2001 | + ) |
| 2002 | + res = unravel_index(ind, shape) |
| 2003 | + for res_arr, exp_arr in zip(res, expected, strict=True): |
| 2004 | + assert_equal(res_arr, exp_arr) |
| 2005 | + |
| 2006 | + def test_indices_scalar(self, xp: ModuleType): |
| 2007 | + ind = xp.asarray(1621) |
| 2008 | + shape = (6, 7, 8, 9) |
| 2009 | + expected = (xp.asarray(3), xp.asarray(1), xp.asarray(4), xp.asarray(1)) |
| 2010 | + res = unravel_index(ind, shape) |
| 2011 | + # a tuple of integers is expected |
| 2012 | + assert res == expected |
| 2013 | + |
| 2014 | + def test_indices_2d(self, xp: ModuleType): |
| 2015 | + ind = xp.asarray([[1234], [5678]]) |
| 2016 | + shape = (10, 10, 10, 10) |
| 2017 | + expected = ( |
| 2018 | + xp.asarray([[1], [5]]), |
| 2019 | + xp.asarray([[2], [6]]), |
| 2020 | + xp.asarray([[3], [7]]), |
| 2021 | + xp.asarray([[4], [8]]), |
| 2022 | + ) |
| 2023 | + res = unravel_index(ind, shape) |
| 2024 | + for res_arr, exp_arr in zip(res, expected, strict=True): |
| 2025 | + assert_equal(res_arr, exp_arr) |
| 2026 | + |
| 2027 | + def test_device(self, xp: ModuleType, device: Device): |
| 2028 | + ind = xp.asarray([4, 1], device=device) |
| 2029 | + shape = (3, 2) |
| 2030 | + res = unravel_index(ind, shape) |
| 2031 | + for res_arr in res: |
| 2032 | + assert get_device(res_arr) == device |
0 commit comments