Skip to content

Commit b00d064

Browse files
Extend ._tensor_impl with remaining functions used by dpnp (#2758)
This PR extends `_tensor_impl` in `dpctl_ext.tensor` with the remaining functions that are explicitly used in `dpnp` implementations (`_take`, `_full_usm_ndarray`, `_zeros_usm_ndarray`, `_triu`) enabling a complete switch to `dpctl_ext.tensor._tensor_impl` instead of `dpctl.tensor._tensor_impl` It also adds `take()`, `put()`, `full()`,`tril()` and `triu()` to `dpctl_ext.tensor` and updates the corresponding dpnp functions to use these implementations internally
1 parent 1a1a099 commit b00d064

23 files changed

+3325
-90
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,12 @@ set(_tensor_impl_sources
5252
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_reshape.cpp
5353
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_roll.cpp
5454
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
55-
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/integer_advanced_indexing.cpp
55+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/integer_advanced_indexing.cpp
5656
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_advanced_indexing.cpp
5757
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/eye_ctor.cpp
58-
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
59-
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/zeros_ctor.cpp
60-
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp
58+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
59+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/zeros_ctor.cpp
60+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp
6161
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp
6262
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
6363
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp

dpctl_ext/tensor/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,22 @@
2525
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
2626
# THE POSSIBILITY OF SUCH DAMAGE.
2727
# *****************************************************************************
28+
29+
30+
from dpctl_ext.tensor._ctors import (
31+
full,
32+
tril,
33+
triu,
34+
)
35+
from dpctl_ext.tensor._indexing_functions import (
36+
put,
37+
take,
38+
)
39+
40+
__all__ = [
41+
"full",
42+
"put",
43+
"take",
44+
"tril",
45+
"triu",
46+
]

dpctl_ext/tensor/_ctors.py

Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2026, Intel Corporation
3+
# All rights reserved.
4+
#
5+
# Redistribution and use in source and binary forms, with or without
6+
# modification, are permitted provided that the following conditions are met:
7+
# - Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# - Redistributions in binary form must reproduce the above copyright notice,
10+
# this list of conditions and the following disclaimer in the documentation
11+
# and/or other materials provided with the distribution.
12+
# - Neither the name of the copyright holder nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
# THE POSSIBILITY OF SUCH DAMAGE.
27+
# *****************************************************************************
28+
29+
import operator
30+
from numbers import Number
31+
32+
import dpctl
33+
import dpctl.tensor as dpt
34+
import dpctl.utils
35+
import numpy as np
36+
from dpctl.tensor._data_types import _get_dtype
37+
from dpctl.tensor._device import normalize_queue_device
38+
39+
import dpctl_ext.tensor._tensor_impl as ti
40+
41+
42+
def _cast_fill_val(fill_val, dt):
43+
"""
44+
Casts the Python scalar `fill_val` to another Python type coercible to the
45+
requested data type `dt`, if necessary.
46+
"""
47+
val_type = type(fill_val)
48+
if val_type in [float, complex] and np.issubdtype(dt, np.integer):
49+
return int(fill_val.real)
50+
elif val_type is complex and np.issubdtype(dt, np.floating):
51+
return fill_val.real
52+
elif val_type is int and np.issubdtype(dt, np.integer):
53+
return _to_scalar(fill_val, dt)
54+
else:
55+
return fill_val
56+
57+
58+
def _to_scalar(obj, sc_ty):
59+
"""A way to convert object to NumPy scalar type.
60+
Raises OverflowError if obj can not be represented
61+
using the requested scalar type.
62+
"""
63+
zd_arr = np.asarray(obj, dtype=sc_ty)
64+
return zd_arr[()]
65+
66+
67+
def _validate_fill_value(fill_val):
68+
"""Validates that `fill_val` is a numeric or boolean scalar."""
69+
# TODO: verify if `np.True_` and `np.False_` should be instances of
70+
# Number in NumPy, like other NumPy scalars and like Python bools
71+
# check for `np.bool_` separately as NumPy<2 has no `np.bool`
72+
if not isinstance(fill_val, Number) and not isinstance(fill_val, np.bool_):
73+
raise TypeError(
74+
f"array cannot be filled with scalar of type {type(fill_val)}"
75+
)
76+
77+
78+
def full(
79+
shape,
80+
fill_value,
81+
*,
82+
dtype=None,
83+
order="C",
84+
device=None,
85+
usm_type=None,
86+
sycl_queue=None,
87+
):
88+
"""
89+
Returns a new :class:`dpctl.tensor.usm_ndarray` having a specified
90+
shape and filled with `fill_value`.
91+
92+
Args:
93+
shape (tuple):
94+
Dimensions of the array to be created.
95+
fill_value (int,float,complex,usm_ndarray):
96+
fill value
97+
dtype (optional): data type of the array. Can be typestring,
98+
a :class:`numpy.dtype` object, :mod:`numpy` char string,
99+
or a NumPy scalar type. Default: ``None``
100+
order ("C", or "F"):
101+
memory layout for the array. Default: ``"C"``
102+
device (optional): array API concept of device where the output array
103+
is created. ``device`` can be ``None``, a oneAPI filter selector
104+
string, an instance of :class:`dpctl.SyclDevice` corresponding to
105+
a non-partitioned SYCL device, an instance of
106+
:class:`dpctl.SyclQueue`, or a :class:`dpctl.tensor.Device` object
107+
returned by :attr:`dpctl.tensor.usm_ndarray.device`.
108+
Default: ``None``
109+
usm_type (``"device"``, ``"shared"``, ``"host"``, optional):
110+
The type of SYCL USM allocation for the output array.
111+
Default: ``"device"``
112+
sycl_queue (:class:`dpctl.SyclQueue`, optional):
113+
The SYCL queue to use
114+
for output array allocation and copying. ``sycl_queue`` and
115+
``device`` are complementary arguments, i.e. use one or another.
116+
If both are specified, a :exc:`TypeError` is raised unless both
117+
imply the same underlying SYCL queue to be used. If both are
118+
``None``, a cached queue targeting default-selected device is
119+
used for allocation and population. Default: ``None``
120+
121+
Returns:
122+
usm_ndarray:
123+
New array initialized with given value.
124+
"""
125+
if not isinstance(order, str) or len(order) == 0 or order[0] not in "CcFf":
126+
raise ValueError(
127+
"Unrecognized order keyword value, expecting 'F' or 'C'."
128+
)
129+
order = order[0].upper()
130+
dpctl.utils.validate_usm_type(usm_type, allow_none=True)
131+
132+
if isinstance(fill_value, (dpt.usm_ndarray, np.ndarray, tuple, list)):
133+
if (
134+
isinstance(fill_value, dpt.usm_ndarray)
135+
and sycl_queue is None
136+
and device is None
137+
):
138+
sycl_queue = fill_value.sycl_queue
139+
else:
140+
sycl_queue = normalize_queue_device(
141+
sycl_queue=sycl_queue, device=device
142+
)
143+
X = dpt.asarray(
144+
fill_value,
145+
dtype=dtype,
146+
order=order,
147+
usm_type=usm_type,
148+
sycl_queue=sycl_queue,
149+
)
150+
return dpt.copy(dpt.broadcast_to(X, shape), order=order)
151+
else:
152+
_validate_fill_value(fill_value)
153+
154+
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
155+
usm_type = usm_type if usm_type is not None else "device"
156+
dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value))
157+
res = dpt.usm_ndarray(
158+
shape,
159+
dtype=dtype,
160+
buffer=usm_type,
161+
order=order,
162+
buffer_ctor_kwargs={"queue": sycl_queue},
163+
)
164+
fill_value = _cast_fill_val(fill_value, dtype)
165+
166+
_manager = dpctl.utils.SequentialOrderManager[sycl_queue]
167+
# populating new allocation, no dependent events
168+
hev, full_ev = ti._full_usm_ndarray(fill_value, res, sycl_queue)
169+
_manager.add_event_pair(hev, full_ev)
170+
return res
171+
172+
173+
def tril(x, /, *, k=0):
174+
"""
175+
Returns the lower triangular part of a matrix (or a stack of matrices)
176+
``x``.
177+
178+
The lower triangular part of the matrix is defined as the elements on and
179+
below the specified diagonal ``k``.
180+
181+
Args:
182+
x (usm_ndarray):
183+
Input array
184+
k (int, optional):
185+
Specifies the diagonal above which to set
186+
elements to zero. If ``k = 0``, the diagonal is the main diagonal.
187+
If ``k < 0``, the diagonal is below the main diagonal.
188+
If ``k > 0``, the diagonal is above the main diagonal.
189+
Default: ``0``
190+
191+
Returns:
192+
usm_ndarray:
193+
A lower-triangular array or a stack of lower-triangular arrays.
194+
"""
195+
if not isinstance(x, dpt.usm_ndarray):
196+
raise TypeError(
197+
"Expected argument of type dpctl.tensor.usm_ndarray, "
198+
f"got {type(x)}."
199+
)
200+
201+
k = operator.index(k)
202+
203+
order = "F" if (x.flags.f_contiguous) else "C"
204+
205+
shape = x.shape
206+
nd = x.ndim
207+
if nd < 2:
208+
raise ValueError("Array dimensions less than 2.")
209+
210+
q = x.sycl_queue
211+
if k >= shape[nd - 1] - 1:
212+
res = dpt.empty(
213+
x.shape,
214+
dtype=x.dtype,
215+
order=order,
216+
usm_type=x.usm_type,
217+
sycl_queue=q,
218+
)
219+
_manager = dpctl.utils.SequentialOrderManager[q]
220+
dep_evs = _manager.submitted_events
221+
hev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
222+
src=x, dst=res, sycl_queue=q, depends=dep_evs
223+
)
224+
_manager.add_event_pair(hev, cpy_ev)
225+
elif k < -shape[nd - 2]:
226+
res = dpt.zeros(
227+
x.shape,
228+
dtype=x.dtype,
229+
order=order,
230+
usm_type=x.usm_type,
231+
sycl_queue=q,
232+
)
233+
else:
234+
res = dpt.empty(
235+
x.shape,
236+
dtype=x.dtype,
237+
order=order,
238+
usm_type=x.usm_type,
239+
sycl_queue=q,
240+
)
241+
_manager = dpctl.utils.SequentialOrderManager[q]
242+
dep_evs = _manager.submitted_events
243+
hev, tril_ev = ti._tril(
244+
src=x, dst=res, k=k, sycl_queue=q, depends=dep_evs
245+
)
246+
_manager.add_event_pair(hev, tril_ev)
247+
248+
return res
249+
250+
251+
def triu(x, /, *, k=0):
252+
"""
253+
Returns the upper triangular part of a matrix (or a stack of matrices)
254+
``x``.
255+
256+
The upper triangular part of the matrix is defined as the elements on and
257+
above the specified diagonal ``k``.
258+
259+
Args:
260+
x (usm_ndarray):
261+
Input array
262+
k (int, optional):
263+
Specifies the diagonal below which to set
264+
elements to zero. If ``k = 0``, the diagonal is the main diagonal.
265+
If ``k < 0``, the diagonal is below the main diagonal.
266+
If ``k > 0``, the diagonal is above the main diagonal.
267+
Default: ``0``
268+
269+
Returns:
270+
usm_ndarray:
271+
An upper-triangular array or a stack of upper-triangular arrays.
272+
"""
273+
if not isinstance(x, dpt.usm_ndarray):
274+
raise TypeError(
275+
"Expected argument of type dpctl.tensor.usm_ndarray, "
276+
f"got {type(x)}."
277+
)
278+
279+
k = operator.index(k)
280+
281+
order = "F" if (x.flags.f_contiguous) else "C"
282+
283+
shape = x.shape
284+
nd = x.ndim
285+
if nd < 2:
286+
raise ValueError("Array dimensions less than 2.")
287+
288+
q = x.sycl_queue
289+
if k > shape[nd - 1]:
290+
res = dpt.zeros(
291+
x.shape,
292+
dtype=x.dtype,
293+
order=order,
294+
usm_type=x.usm_type,
295+
sycl_queue=q,
296+
)
297+
elif k <= -shape[nd - 2] + 1:
298+
res = dpt.empty(
299+
x.shape,
300+
dtype=x.dtype,
301+
order=order,
302+
usm_type=x.usm_type,
303+
sycl_queue=q,
304+
)
305+
_manager = dpctl.utils.SequentialOrderManager[q]
306+
dep_evs = _manager.submitted_events
307+
hev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
308+
src=x, dst=res, sycl_queue=q, depends=dep_evs
309+
)
310+
_manager.add_event_pair(hev, cpy_ev)
311+
else:
312+
res = dpt.empty(
313+
x.shape,
314+
dtype=x.dtype,
315+
order=order,
316+
usm_type=x.usm_type,
317+
sycl_queue=q,
318+
)
319+
_manager = dpctl.utils.SequentialOrderManager[q]
320+
dep_evs = _manager.submitted_events
321+
hev, triu_ev = ti._triu(
322+
src=x, dst=res, k=k, sycl_queue=q, depends=dep_evs
323+
)
324+
_manager.add_event_pair(hev, triu_ev)
325+
326+
return res

0 commit comments

Comments
 (0)