Skip to content

Commit 6fecefe

Browse files
Move take_along_axis() to dpctl_ext/tensor and reuse it in dpnp
1 parent f63f2f0 commit 6fecefe

File tree

4 files changed

+133
-1
lines changed

4 files changed

+133
-1
lines changed

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
put,
4747
put_along_axis,
4848
take,
49+
take_along_axis,
4950
)
5051
from dpctl_ext.tensor._manipulation_functions import (
5152
roll,
@@ -66,6 +67,7 @@
6667
"reshape",
6768
"roll",
6869
"take",
70+
"take_along_axis",
6971
"to_numpy",
7072
"tril",
7173
"triu",

dpctl_ext/tensor/_copy_utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,58 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
384384
return
385385

386386

387+
def _take_multi_index(ary, inds, p, mode=0):
388+
if not isinstance(ary, dpt.usm_ndarray):
389+
raise TypeError(
390+
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
391+
)
392+
ary_nd = ary.ndim
393+
p = normalize_axis_index(operator.index(p), ary_nd)
394+
mode = operator.index(mode)
395+
if mode not in [0, 1]:
396+
raise ValueError(
397+
"Invalid value for mode keyword, only 0 or 1 is supported"
398+
)
399+
if not isinstance(inds, (list, tuple)):
400+
inds = (inds,)
401+
402+
exec_q, res_usm_type = _get_indices_queue_usm_type(
403+
inds, ary.sycl_queue, ary.usm_type
404+
)
405+
if exec_q is None:
406+
raise dpctl.utils.ExecutionPlacementError(
407+
"Can not automatically determine where to allocate the "
408+
"result or performance execution. "
409+
"Use `usm_ndarray.to_device` method to migrate data to "
410+
"be associated with the same queue."
411+
)
412+
413+
inds = _prepare_indices_arrays(inds, exec_q, res_usm_type)
414+
415+
ind0 = inds[0]
416+
ary_sh = ary.shape
417+
p_end = p + len(inds)
418+
if 0 in ary_sh[p:p_end] and ind0.size != 0:
419+
raise IndexError("cannot take non-empty indices from an empty axis")
420+
res_shape = ary_sh[:p] + ind0.shape + ary_sh[p_end:]
421+
res = dpt.empty(
422+
res_shape, dtype=ary.dtype, usm_type=res_usm_type, sycl_queue=exec_q
423+
)
424+
_manager = dpctl.utils.SequentialOrderManager[exec_q]
425+
dep_ev = _manager.submitted_events
426+
hev, take_ev = ti._take(
427+
src=ary,
428+
ind=inds,
429+
dst=res,
430+
axis_start=p,
431+
mode=mode,
432+
sycl_queue=exec_q,
433+
depends=dep_ev,
434+
)
435+
_manager.add_event_pair(hev, take_ev)
436+
return res
437+
438+
387439
def from_numpy(np_ary, /, *, device=None, usm_type="device", sycl_queue=None):
388440
"""
389441
from_numpy(arg, device=None, usm_type="device", sycl_queue=None)

dpctl_ext/tensor/_indexing_functions.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
_extract_impl,
4242
_nonzero_impl,
4343
_put_multi_index,
44+
_take_multi_index,
4445
)
4546
from ._numpy_helper import normalize_axis_index
4647

@@ -561,3 +562,80 @@ def take(x, indices, /, *, axis=None, out=None, mode="wrap"):
561562
out = orig_out
562563

563564
return out
565+
566+
567+
def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"):
568+
"""
569+
Returns elements from an array at the one-dimensional indices specified
570+
by ``indices`` along a provided ``axis``.
571+
572+
Args:
573+
x (usm_ndarray):
574+
input array. Must be compatible with ``indices``, except for the
575+
axis (dimension) specified by ``axis``.
576+
indices (usm_ndarray):
577+
array indices. Must have the same rank (i.e., number of dimensions)
578+
as ``x``.
579+
axis: int
580+
axis along which to select values. If ``axis`` is negative, the
581+
function determines the axis along which to select values by
582+
counting from the last dimension. Default: ``-1``.
583+
mode (str, optional):
584+
How out-of-bounds indices will be handled. Possible values
585+
are:
586+
587+
- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
588+
negative indices.
589+
- ``"clip"``: clips indices to (``0 <= i < n``).
590+
591+
Default: ``"wrap"``.
592+
593+
Returns:
594+
usm_ndarray:
595+
an array having the same data type as ``x``. The returned array has
596+
the same rank (i.e., number of dimensions) as ``x`` and a shape
597+
determined according to broadcasting rules, except for the axis
598+
(dimension) specified by ``axis`` whose size must equal the size
599+
of the corresponding axis (dimension) in ``indices``.
600+
601+
Note:
602+
Treatment of the out-of-bound indices in ``indices`` array is controlled
603+
by the value of ``mode`` keyword.
604+
"""
605+
if not isinstance(x, dpt.usm_ndarray):
606+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
607+
if not isinstance(indices, dpt.usm_ndarray):
608+
raise TypeError(
609+
f"Expected dpctl.tensor.usm_ndarray, got {type(indices)}"
610+
)
611+
x_nd = x.ndim
612+
if x_nd != indices.ndim:
613+
raise ValueError(
614+
"Number of dimensions in the first and the second "
615+
"argument arrays must be equal"
616+
)
617+
pp = normalize_axis_index(operator.index(axis), x_nd)
618+
out_usm_type = dpctl.utils.get_coerced_usm_type(
619+
(x.usm_type, indices.usm_type)
620+
)
621+
exec_q = dpctl.utils.get_execution_queue((x.sycl_queue, indices.sycl_queue))
622+
if exec_q is None:
623+
raise dpctl.utils.ExecutionPlacementError(
624+
"Execution placement can not be unambiguously inferred "
625+
"from input arguments. "
626+
)
627+
mode_i = _get_indexing_mode(mode)
628+
indexes_dt = (
629+
dpt.uint64
630+
if indices.dtype == dpt.uint64
631+
else ti.default_device_index_type(exec_q.sycl_device)
632+
)
633+
_ind = tuple(
634+
(
635+
indices
636+
if i == pp
637+
else _range(x.shape[i], i, x_nd, exec_q, out_usm_type, indexes_dt)
638+
)
639+
for i in range(x_nd)
640+
)
641+
return _take_multi_index(x, _ind, 0, mode=mode_i)

dpnp/dpnp_iface_indexing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2295,7 +2295,7 @@ def take_along_axis(a, indices, axis=-1, mode="wrap"):
22952295
usm_a = dpnp.get_usm_ndarray(a)
22962296
usm_ind = dpnp.get_usm_ndarray(indices)
22972297

2298-
usm_res = dpt.take_along_axis(usm_a, usm_ind, axis=axis, mode=mode)
2298+
usm_res = dpt_ext.take_along_axis(usm_a, usm_ind, axis=axis, mode=mode)
22992299
return dpnp_array._create_from_usm_ndarray(usm_res)
23002300

23012301

0 commit comments

Comments
 (0)