|
38 | 38 |
|
39 | 39 | import warnings |
40 | 40 |
|
| 41 | +import numpy |
| 42 | + |
41 | 43 | import dpnp |
42 | 44 | import dpnp.tensor as dpt |
43 | 45 | import dpnp.tensor._type_utils as dtu |
|
46 | 48 | from .exceptions import AxisError |
47 | 49 |
|
48 | 50 |
|
| 51 | +def _unwrap_index_element(x): |
| 52 | + """ |
| 53 | + Unwrap a single index element for the tensor indexing layer. |
| 54 | +
|
| 55 | + Converts dpnp arrays to usm_ndarray and array-like objects (range, list) |
| 56 | + to numpy arrays with intp dtype for NumPy-compatible advanced indexing. |
| 57 | +
|
| 58 | + """ |
| 59 | + |
| 60 | + if isinstance(x, dpt.usm_ndarray): |
| 61 | + return x |
| 62 | + if isinstance(x, dpnp_array): |
| 63 | + return x.get_array() |
| 64 | + if isinstance(x, range): |
| 65 | + return numpy.asarray(x, dtype=numpy.intp) |
| 66 | + if isinstance(x, list): |
| 67 | + # keep boolean lists as boolean |
| 68 | + arr = numpy.asarray(x) |
| 69 | + # cast empty lists (float64 in NumPy) to intp |
| 70 | + # for correct tensor indexing |
| 71 | + if arr.size == 0: |
| 72 | + arr = arr.astype(numpy.intp) |
| 73 | + return arr |
| 74 | + return x |
| 75 | + |
| 76 | + |
49 | 77 | def _get_unwrapped_index_key(key): |
50 | 78 | """ |
51 | 79 | Get an unwrapped index key. |
52 | 80 |
|
53 | 81 | Return a key where each nested instance of DPNP array is unwrapped into |
54 | | - USM ndarray for further processing in DPCTL advanced indexing functions. |
| 82 | + USM ndarray, and array-like objects (range, list) are converted to numpy |
| 83 | + arrays for further processing in advanced indexing functions. |
55 | 84 |
|
56 | 85 | """ |
57 | 86 |
|
58 | 87 | if isinstance(key, tuple): |
59 | | - if any(isinstance(x, dpnp_array) for x in key): |
60 | | - # create a new tuple from the input key with unwrapped DPNP arrays |
61 | | - return tuple( |
62 | | - x.get_array() if isinstance(x, dpnp_array) else x for x in key |
63 | | - ) |
64 | | - elif isinstance(key, dpnp_array): |
65 | | - return key.get_array() |
66 | | - return key |
| 88 | + return tuple(_unwrap_index_element(x) for x in key) |
| 89 | + return _unwrap_index_element(key) |
67 | 90 |
|
68 | 91 |
|
69 | 92 | # pylint: disable=too-many-public-methods |
|
0 commit comments