Skip to content

Commit f63f2f0

Browse files
Move put_along_axis to dpctl_ext/tensor and reuse it in dpnp
1 parent 7feb4ee commit f63f2f0

File tree

4 files changed

+243
-1
lines changed

4 files changed

+243
-1
lines changed

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
nonzero,
4545
place,
4646
put,
47+
put_along_axis,
4748
take,
4849
)
4950
from dpctl_ext.tensor._manipulation_functions import (
@@ -61,6 +62,7 @@
6162
"nonzero",
6263
"place",
6364
"put",
65+
"put_along_axis",
6466
"reshape",
6567
"roll",
6668
"take",

dpctl_ext/tensor/_copy_utils.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import builtins
3030
import operator
31+
from numbers import Integral
3132

3233
import dpctl
3334
import dpctl.memory as dpm
@@ -40,6 +41,7 @@
4041

4142
# TODO: revert to `import dpctl.tensor...`
4243
# when dpnp fully migrates dpctl/tensor
44+
import dpctl_ext.tensor as dpt_ext
4345
import dpctl_ext.tensor._tensor_impl as ti
4446

4547
from ._numpy_helper import normalize_axis_index
@@ -200,6 +202,42 @@ def _extract_impl(ary, ary_mask, axis=0):
200202
return dst
201203

202204

205+
def _get_indices_queue_usm_type(inds, queue, usm_type):
206+
"""
207+
Utility for validating indices are NumPy ndarray or usm_ndarray of integral
208+
dtype or Python integers. At least one must be an array.
209+
210+
For each array, the queue and usm type are appended to `queue_list` and
211+
`usm_type_list`, respectively.
212+
"""
213+
queues = [queue]
214+
usm_types = [usm_type]
215+
any_array = False
216+
for ind in inds:
217+
if isinstance(ind, (np.ndarray, dpt.usm_ndarray)):
218+
any_array = True
219+
if ind.dtype.kind not in "ui":
220+
raise IndexError(
221+
"arrays used as indices must be of integer (or boolean) "
222+
"type"
223+
)
224+
if isinstance(ind, dpt.usm_ndarray):
225+
queues.append(ind.sycl_queue)
226+
usm_types.append(ind.usm_type)
227+
elif not isinstance(ind, Integral):
228+
raise TypeError(
229+
"all elements of `ind` expected to be usm_ndarrays, "
230+
f"NumPy arrays, or integers, found {type(ind)}"
231+
)
232+
if not any_array:
233+
raise TypeError(
234+
"at least one element of `inds` expected to be an array"
235+
)
236+
usm_type = dpctl.utils.get_coerced_usm_type(usm_types)
237+
q = dpctl.utils.get_execution_queue(queues)
238+
return q, usm_type
239+
240+
203241
def _nonzero_impl(ary):
204242
if not isinstance(ary, dpt.usm_ndarray):
205243
raise TypeError(
@@ -231,6 +269,121 @@ def _nonzero_impl(ary):
231269
return res
232270

233271

272+
def _prepare_indices_arrays(inds, q, usm_type):
273+
"""
274+
Utility taking a mix of usm_ndarray and possibly Python int scalar indices,
275+
a queue (assumed to be common to arrays in inds), and a usm type.
276+
277+
Python scalar integers are promoted to arrays on the provided queue and
278+
with the provided usm type. All arrays are then promoted to a common
279+
integral type (if possible) before being broadcast to a common shape.
280+
"""
281+
# scalar integers -> arrays
282+
inds = tuple(
283+
map(
284+
lambda ind: (
285+
ind
286+
if isinstance(ind, dpt.usm_ndarray)
287+
else dpt.asarray(ind, usm_type=usm_type, sycl_queue=q)
288+
),
289+
inds,
290+
)
291+
)
292+
293+
# promote to a common integral type if possible
294+
ind_dt = dpt.result_type(*inds)
295+
if ind_dt.kind not in "ui":
296+
raise ValueError(
297+
"cannot safely promote indices to an integer data type"
298+
)
299+
inds = tuple(
300+
map(
301+
lambda ind: (
302+
ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt)
303+
),
304+
inds,
305+
)
306+
)
307+
308+
# broadcast
309+
inds = dpt.broadcast_arrays(*inds)
310+
311+
return inds
312+
313+
314+
def _put_multi_index(ary, inds, p, vals, mode=0):
315+
if not isinstance(ary, dpt.usm_ndarray):
316+
raise TypeError(
317+
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
318+
)
319+
ary_nd = ary.ndim
320+
p = normalize_axis_index(operator.index(p), ary_nd)
321+
mode = operator.index(mode)
322+
if mode not in [0, 1]:
323+
raise ValueError(
324+
"Invalid value for mode keyword, only 0 or 1 is supported"
325+
)
326+
if not isinstance(inds, (list, tuple)):
327+
inds = (inds,)
328+
329+
exec_q, coerced_usm_type = _get_indices_queue_usm_type(
330+
inds, ary.sycl_queue, ary.usm_type
331+
)
332+
333+
if exec_q is not None:
334+
if not isinstance(vals, dpt.usm_ndarray):
335+
vals = dpt.asarray(
336+
vals,
337+
dtype=ary.dtype,
338+
usm_type=coerced_usm_type,
339+
sycl_queue=exec_q,
340+
)
341+
else:
342+
exec_q = dpctl.utils.get_execution_queue((exec_q, vals.sycl_queue))
343+
coerced_usm_type = dpctl.utils.get_coerced_usm_type(
344+
(
345+
coerced_usm_type,
346+
vals.usm_type,
347+
)
348+
)
349+
if exec_q is None:
350+
raise dpctl.utils.ExecutionPlacementError(
351+
"Can not automatically determine where to allocate the "
352+
"result or performance execution. "
353+
"Use `usm_ndarray.to_device` method to migrate data to "
354+
"be associated with the same queue."
355+
)
356+
357+
inds = _prepare_indices_arrays(inds, exec_q, coerced_usm_type)
358+
359+
ind0 = inds[0]
360+
ary_sh = ary.shape
361+
p_end = p + len(inds)
362+
if 0 in ary_sh[p:p_end] and ind0.size != 0:
363+
raise IndexError(
364+
"cannot put into non-empty indices along an empty axis"
365+
)
366+
expected_vals_shape = ary_sh[:p] + ind0.shape + ary_sh[p_end:]
367+
if vals.dtype == ary.dtype:
368+
rhs = vals
369+
else:
370+
rhs = dpt_ext.astype(vals, ary.dtype)
371+
rhs = dpt.broadcast_to(rhs, expected_vals_shape)
372+
_manager = dpctl.utils.SequentialOrderManager[exec_q]
373+
dep_ev = _manager.submitted_events
374+
hev, put_ev = ti._put(
375+
dst=ary,
376+
ind=inds,
377+
val=rhs,
378+
axis_start=p,
379+
mode=mode,
380+
sycl_queue=exec_q,
381+
depends=dep_ev,
382+
)
383+
_manager.add_event_pair(hev, put_ev)
384+
return
385+
386+
234387
def from_numpy(np_ary, /, *, device=None, usm_type="device", sycl_queue=None):
235388
"""
236389
from_numpy(arg, device=None, usm_type="device", sycl_queue=None)

dpctl_ext/tensor/_indexing_functions.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from ._copy_utils import (
4141
_extract_impl,
4242
_nonzero_impl,
43+
_put_multi_index,
4344
)
4445
from ._numpy_helper import normalize_axis_index
4546

@@ -54,6 +55,12 @@ def _get_indexing_mode(name):
5455
)
5556

5657

58+
def _range(sh_i, i, nd, q, usm_t, dt):
59+
ind = dpt.arange(sh_i, dtype=dt, usm_type=usm_t, sycl_queue=q)
60+
ind.shape = tuple(sh_i if i == j else 1 for j in range(nd))
61+
return ind
62+
63+
5764
def extract(condition, arr):
5865
"""extract(condition, arr)
5966
@@ -343,6 +350,86 @@ def put_vec_duplicates(vec, ind, vals):
343350
_manager.add_event_pair(hev, put_ev)
344351

345352

353+
def put_along_axis(x, indices, vals, /, *, axis=-1, mode="wrap"):
354+
"""
355+
Puts elements into an array at the one-dimensional indices specified by
356+
``indices`` along a provided ``axis``.
357+
358+
Args:
359+
x (usm_ndarray):
360+
input array. Must be compatible with ``indices``, except for the
361+
axis (dimension) specified by ``axis``.
362+
indices (usm_ndarray):
363+
array indices. Must have the same rank (i.e., number of dimensions)
364+
as ``x``.
365+
vals (usm_ndarray):
366+
Array of values to be put into ``x``.
367+
Must be broadcastable to the shape of ``indices``.
368+
axis: int
369+
axis along which to select values. If ``axis`` is negative, the
370+
function determines the axis along which to select values by
371+
counting from the last dimension. Default: ``-1``.
372+
mode (str, optional):
373+
How out-of-bounds indices will be handled. Possible values
374+
are:
375+
376+
- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
377+
negative indices.
378+
- ``"clip"``: clips indices to (``0 <= i < n``).
379+
380+
Default: ``"wrap"``.
381+
382+
.. note::
383+
384+
If input array ``indices`` contains duplicates, a race condition
385+
occurs, and the value written into corresponding positions in ``x``
386+
may vary from run to run. Preserving sequential semantics in handing
387+
the duplicates to achieve deterministic behavior requires additional
388+
work.
389+
"""
390+
if not isinstance(x, dpt.usm_ndarray):
391+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
392+
if not isinstance(indices, dpt.usm_ndarray):
393+
raise TypeError(
394+
f"Expected dpctl.tensor.usm_ndarray, got {type(indices)}"
395+
)
396+
x_nd = x.ndim
397+
if x_nd != indices.ndim:
398+
raise ValueError(
399+
"Number of dimensions in the first and the second "
400+
"argument arrays must be equal"
401+
)
402+
pp = normalize_axis_index(operator.index(axis), x_nd)
403+
if isinstance(vals, dpt.usm_ndarray):
404+
queues_ = [x.sycl_queue, indices.sycl_queue, vals.sycl_queue]
405+
usm_types_ = [x.usm_type, indices.usm_type, vals.usm_type]
406+
else:
407+
queues_ = [x.sycl_queue, indices.sycl_queue]
408+
usm_types_ = [x.usm_type, indices.usm_type]
409+
exec_q = dpctl.utils.get_execution_queue(queues_)
410+
if exec_q is None:
411+
raise dpctl.utils.ExecutionPlacementError(
412+
"Execution placement can not be unambiguously inferred "
413+
"from input arguments. "
414+
)
415+
out_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
416+
mode_i = _get_indexing_mode(mode)
417+
indexes_dt = (
418+
dpt.uint64
419+
if indices.dtype == dpt.uint64
420+
else ti.default_device_index_type(exec_q.sycl_device)
421+
)
422+
_ind = tuple(
423+
(
424+
indices
425+
if i == pp
426+
else _range(x.shape[i], i, x_nd, exec_q, out_usm_type, indexes_dt)
427+
)
428+
for i in range(x_nd)
429+
)
430+
return _put_multi_index(x, _ind, 0, vals, mode=mode_i)
431+
432+
346433
def take(x, indices, /, *, axis=None, out=None, mode="wrap"):
347434
"""take(x, indices, axis=None, out=None, mode="wrap")
348435

dpnp/dpnp_iface_indexing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1807,7 +1807,7 @@ def put_along_axis(a, ind, values, axis, mode="wrap"):
18071807
values, usm_type=a.usm_type, sycl_queue=a.sycl_queue
18081808
)
18091809

1810-
dpt.put_along_axis(usm_a, usm_ind, usm_vals, axis=axis, mode=mode)
1810+
dpt_ext.put_along_axis(usm_a, usm_ind, usm_vals, axis=axis, mode=mode)
18111811

18121812

18131813
def putmask(x1, mask, values):

0 commit comments

Comments
 (0)