Skip to content

Commit 192bd93

Browse files
Extend ._tensor_impl with advanced indexing functions (#2777)
This PR extends `_tensor_impl` in `dpctl_ext.tensor` with the advanced indexing (`_extract, _place, _nonzero, mask_positions, `), repeat (`_cumsum_1d`) and `_eye` functions It also adds `eye(), extract(), nonzero(), place(), put_along_axis(), take_along_axis()` to `dpctl_ext.tensor` and updates the corresponding dpnp functions to use these implementations internally
1 parent 195b893 commit 192bd93

18 files changed

+4809
-52
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,16 @@ set(_static_lib_sources
4545
)
4646
set(_tensor_impl_sources
4747
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_ctors.cpp
48-
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators.cpp
48+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators.cpp
4949
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_and_cast_usm_to_usm.cpp
5050
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_as_contig.cpp
5151
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp
5252
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_reshape.cpp
5353
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_roll.cpp
5454
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
5555
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/integer_advanced_indexing.cpp
56-
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_advanced_indexing.cpp
57-
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/eye_ctor.cpp
56+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_advanced_indexing.cpp
57+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/eye_ctor.cpp
5858
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
5959
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/zeros_ctor.cpp
6060
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp

dpctl_ext/tensor/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,19 @@
3535
to_numpy,
3636
)
3737
from dpctl_ext.tensor._ctors import (
38+
eye,
3839
full,
3940
tril,
4041
triu,
4142
)
4243
from dpctl_ext.tensor._indexing_functions import (
44+
extract,
45+
nonzero,
46+
place,
4347
put,
48+
put_along_axis,
4449
take,
50+
take_along_axis,
4551
)
4652
from dpctl_ext.tensor._manipulation_functions import (
4753
roll,
@@ -52,12 +58,18 @@
5258
"asnumpy",
5359
"astype",
5460
"copy",
61+
"extract",
62+
"eye",
5563
"from_numpy",
5664
"full",
65+
"nonzero",
66+
"place",
5767
"put",
68+
"put_along_axis",
5869
"reshape",
5970
"roll",
6071
"take",
72+
"take_along_axis",
6173
"to_numpy",
6274
"tril",
6375
"triu",

dpctl_ext/tensor/_copy_utils.py

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
# *****************************************************************************
2828

2929
import builtins
30+
import operator
31+
from numbers import Integral
3032

3133
import dpctl
3234
import dpctl.memory as dpm
@@ -39,8 +41,11 @@
3941

4042
# TODO: revert to `import dpctl.tensor...`
4143
# when dpnp fully migrates dpctl/tensor
44+
import dpctl_ext.tensor as dpt_ext
4245
import dpctl_ext.tensor._tensor_impl as ti
4346

47+
from ._numpy_helper import normalize_axis_index
48+
4449
__doc__ = (
4550
"Implementation module for copy- and cast- operations on "
4651
":class:`dpctl.tensor.usm_ndarray`."
@@ -130,6 +135,307 @@ def _copy_from_numpy_into(dst, np_ary):
130135
)
131136

132137

138+
def _extract_impl(ary, ary_mask, axis=0):
139+
"""
140+
Extract elements of ary by applying mask starting from slot
141+
dimension axis
142+
"""
143+
if not isinstance(ary, dpt.usm_ndarray):
144+
raise TypeError(
145+
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
146+
)
147+
if isinstance(ary_mask, dpt.usm_ndarray):
148+
dst_usm_type = dpctl.utils.get_coerced_usm_type(
149+
(ary.usm_type, ary_mask.usm_type)
150+
)
151+
exec_q = dpctl.utils.get_execution_queue(
152+
(ary.sycl_queue, ary_mask.sycl_queue)
153+
)
154+
if exec_q is None:
155+
raise dpctl.utils.ExecutionPlacementError(
156+
"arrays have different associated queues. "
157+
"Use `y.to_device(x.device)` to migrate."
158+
)
159+
elif isinstance(ary_mask, np.ndarray):
160+
dst_usm_type = ary.usm_type
161+
exec_q = ary.sycl_queue
162+
ary_mask = dpt.asarray(
163+
ary_mask, usm_type=dst_usm_type, sycl_queue=exec_q
164+
)
165+
else:
166+
raise TypeError(
167+
"Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got "
168+
f"{type(ary_mask)}"
169+
)
170+
ary_nd = ary.ndim
171+
pp = normalize_axis_index(operator.index(axis), ary_nd)
172+
mask_nd = ary_mask.ndim
173+
if pp < 0 or pp + mask_nd > ary_nd:
174+
raise ValueError(
175+
"Parameter p is inconsistent with input array dimensions"
176+
)
177+
mask_nelems = ary_mask.size
178+
cumsum_dt = dpt.int32 if mask_nelems < int32_t_max else dpt.int64
179+
cumsum = dpt.empty(mask_nelems, dtype=cumsum_dt, device=ary_mask.device)
180+
exec_q = cumsum.sycl_queue
181+
_manager = dpctl.utils.SequentialOrderManager[exec_q]
182+
dep_evs = _manager.submitted_events
183+
mask_count = ti.mask_positions(
184+
ary_mask, cumsum, sycl_queue=exec_q, depends=dep_evs
185+
)
186+
dst_shape = ary.shape[:pp] + (mask_count,) + ary.shape[pp + mask_nd :]
187+
dst = dpt.empty(
188+
dst_shape, dtype=ary.dtype, usm_type=dst_usm_type, device=ary.device
189+
)
190+
if dst.size == 0:
191+
return dst
192+
hev, ev = ti._extract(
193+
src=ary,
194+
cumsum=cumsum,
195+
axis_start=pp,
196+
axis_end=pp + mask_nd,
197+
dst=dst,
198+
sycl_queue=exec_q,
199+
depends=dep_evs,
200+
)
201+
_manager.add_event_pair(hev, ev)
202+
return dst
203+
204+
205+
def _get_indices_queue_usm_type(inds, queue, usm_type):
206+
"""
207+
Utility for validating indices are NumPy ndarray or usm_ndarray of integral
208+
dtype or Python integers. At least one must be an array.
209+
210+
For each array, the queue and usm type are appended to `queue_list` and
211+
`usm_type_list`, respectively.
212+
"""
213+
queues = [queue]
214+
usm_types = [usm_type]
215+
any_array = False
216+
for ind in inds:
217+
if isinstance(ind, (np.ndarray, dpt.usm_ndarray)):
218+
any_array = True
219+
if ind.dtype.kind not in "ui":
220+
raise IndexError(
221+
"arrays used as indices must be of integer (or boolean) "
222+
"type"
223+
)
224+
if isinstance(ind, dpt.usm_ndarray):
225+
queues.append(ind.sycl_queue)
226+
usm_types.append(ind.usm_type)
227+
elif not isinstance(ind, Integral):
228+
raise TypeError(
229+
"all elements of `ind` expected to be usm_ndarrays, "
230+
f"NumPy arrays, or integers, found {type(ind)}"
231+
)
232+
if not any_array:
233+
raise TypeError(
234+
"at least one element of `inds` expected to be an array"
235+
)
236+
usm_type = dpctl.utils.get_coerced_usm_type(usm_types)
237+
q = dpctl.utils.get_execution_queue(queues)
238+
return q, usm_type
239+
240+
241+
def _nonzero_impl(ary):
242+
if not isinstance(ary, dpt.usm_ndarray):
243+
raise TypeError(
244+
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
245+
)
246+
exec_q = ary.sycl_queue
247+
usm_type = ary.usm_type
248+
mask_nelems = ary.size
249+
cumsum_dt = dpt.int32 if mask_nelems < int32_t_max else dpt.int64
250+
cumsum = dpt.empty(
251+
mask_nelems, dtype=cumsum_dt, sycl_queue=exec_q, order="C"
252+
)
253+
_manager = dpctl.utils.SequentialOrderManager[exec_q]
254+
dep_evs = _manager.submitted_events
255+
mask_count = ti.mask_positions(
256+
ary, cumsum, sycl_queue=exec_q, depends=dep_evs
257+
)
258+
indexes_dt = ti.default_device_index_type(exec_q.sycl_device)
259+
indexes = dpt.empty(
260+
(ary.ndim, mask_count),
261+
dtype=indexes_dt,
262+
usm_type=usm_type,
263+
sycl_queue=exec_q,
264+
order="C",
265+
)
266+
hev, nz_ev = ti._nonzero(cumsum, indexes, ary.shape, exec_q)
267+
res = tuple(indexes[i, :] for i in range(ary.ndim))
268+
_manager.add_event_pair(hev, nz_ev)
269+
return res
270+
271+
272+
def _prepare_indices_arrays(inds, q, usm_type):
273+
"""
274+
Utility taking a mix of usm_ndarray and possibly Python int scalar indices,
275+
a queue (assumed to be common to arrays in inds), and a usm type.
276+
277+
Python scalar integers are promoted to arrays on the provided queue and
278+
with the provided usm type. All arrays are then promoted to a common
279+
integral type (if possible) before being broadcast to a common shape.
280+
"""
281+
# scalar integers -> arrays
282+
inds = tuple(
283+
map(
284+
lambda ind: (
285+
ind
286+
if isinstance(ind, dpt.usm_ndarray)
287+
else dpt.asarray(ind, usm_type=usm_type, sycl_queue=q)
288+
),
289+
inds,
290+
)
291+
)
292+
293+
# promote to a common integral type if possible
294+
ind_dt = dpt.result_type(*inds)
295+
if ind_dt.kind not in "ui":
296+
raise ValueError(
297+
"cannot safely promote indices to an integer data type"
298+
)
299+
inds = tuple(
300+
map(
301+
lambda ind: (
302+
ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt)
303+
),
304+
inds,
305+
)
306+
)
307+
308+
# broadcast
309+
inds = dpt.broadcast_arrays(*inds)
310+
311+
return inds
312+
313+
314+
def _put_multi_index(ary, inds, p, vals, mode=0):
315+
if not isinstance(ary, dpt.usm_ndarray):
316+
raise TypeError(
317+
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
318+
)
319+
ary_nd = ary.ndim
320+
p = normalize_axis_index(operator.index(p), ary_nd)
321+
mode = operator.index(mode)
322+
if mode not in [0, 1]:
323+
raise ValueError(
324+
"Invalid value for mode keyword, only 0 or 1 is supported"
325+
)
326+
if not isinstance(inds, (list, tuple)):
327+
inds = (inds,)
328+
329+
exec_q, coerced_usm_type = _get_indices_queue_usm_type(
330+
inds, ary.sycl_queue, ary.usm_type
331+
)
332+
333+
if exec_q is not None:
334+
if not isinstance(vals, dpt.usm_ndarray):
335+
vals = dpt.asarray(
336+
vals,
337+
dtype=ary.dtype,
338+
usm_type=coerced_usm_type,
339+
sycl_queue=exec_q,
340+
)
341+
else:
342+
exec_q = dpctl.utils.get_execution_queue((exec_q, vals.sycl_queue))
343+
coerced_usm_type = dpctl.utils.get_coerced_usm_type(
344+
(
345+
coerced_usm_type,
346+
vals.usm_type,
347+
)
348+
)
349+
if exec_q is None:
350+
raise dpctl.utils.ExecutionPlacementError(
351+
"Can not automatically determine where to allocate the "
352+
"result or performance execution. "
353+
"Use `usm_ndarray.to_device` method to migrate data to "
354+
"be associated with the same queue."
355+
)
356+
357+
inds = _prepare_indices_arrays(inds, exec_q, coerced_usm_type)
358+
359+
ind0 = inds[0]
360+
ary_sh = ary.shape
361+
p_end = p + len(inds)
362+
if 0 in ary_sh[p:p_end] and ind0.size != 0:
363+
raise IndexError(
364+
"cannot put into non-empty indices along an empty axis"
365+
)
366+
expected_vals_shape = ary_sh[:p] + ind0.shape + ary_sh[p_end:]
367+
if vals.dtype == ary.dtype:
368+
rhs = vals
369+
else:
370+
rhs = dpt_ext.astype(vals, ary.dtype)
371+
rhs = dpt.broadcast_to(rhs, expected_vals_shape)
372+
_manager = dpctl.utils.SequentialOrderManager[exec_q]
373+
dep_ev = _manager.submitted_events
374+
hev, put_ev = ti._put(
375+
dst=ary,
376+
ind=inds,
377+
val=rhs,
378+
axis_start=p,
379+
mode=mode,
380+
sycl_queue=exec_q,
381+
depends=dep_ev,
382+
)
383+
_manager.add_event_pair(hev, put_ev)
384+
return
385+
386+
387+
def _take_multi_index(ary, inds, p, mode=0):
388+
if not isinstance(ary, dpt.usm_ndarray):
389+
raise TypeError(
390+
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
391+
)
392+
ary_nd = ary.ndim
393+
p = normalize_axis_index(operator.index(p), ary_nd)
394+
mode = operator.index(mode)
395+
if mode not in [0, 1]:
396+
raise ValueError(
397+
"Invalid value for mode keyword, only 0 or 1 is supported"
398+
)
399+
if not isinstance(inds, (list, tuple)):
400+
inds = (inds,)
401+
402+
exec_q, res_usm_type = _get_indices_queue_usm_type(
403+
inds, ary.sycl_queue, ary.usm_type
404+
)
405+
if exec_q is None:
406+
raise dpctl.utils.ExecutionPlacementError(
407+
"Can not automatically determine where to allocate the "
408+
"result or performance execution. "
409+
"Use `usm_ndarray.to_device` method to migrate data to "
410+
"be associated with the same queue."
411+
)
412+
413+
inds = _prepare_indices_arrays(inds, exec_q, res_usm_type)
414+
415+
ind0 = inds[0]
416+
ary_sh = ary.shape
417+
p_end = p + len(inds)
418+
if 0 in ary_sh[p:p_end] and ind0.size != 0:
419+
raise IndexError("cannot take non-empty indices from an empty axis")
420+
res_shape = ary_sh[:p] + ind0.shape + ary_sh[p_end:]
421+
res = dpt.empty(
422+
res_shape, dtype=ary.dtype, usm_type=res_usm_type, sycl_queue=exec_q
423+
)
424+
_manager = dpctl.utils.SequentialOrderManager[exec_q]
425+
dep_ev = _manager.submitted_events
426+
hev, take_ev = ti._take(
427+
src=ary,
428+
ind=inds,
429+
dst=res,
430+
axis_start=p,
431+
mode=mode,
432+
sycl_queue=exec_q,
433+
depends=dep_ev,
434+
)
435+
_manager.add_event_pair(hev, take_ev)
436+
return res
437+
438+
133439
def from_numpy(np_ary, /, *, device=None, usm_type="device", sycl_queue=None):
134440
"""
135441
from_numpy(arg, device=None, usm_type="device", sycl_queue=None)

0 commit comments

Comments
 (0)