Skip to content

Commit a1eea4e

Browse files
Move eye() to dpctl_ext/tensor and reuse it in dpnp
1 parent b60d095 commit a1eea4e

File tree

3 files changed

+131
-1
lines changed

3 files changed

+131
-1
lines changed

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
to_numpy,
3636
)
3737
from dpctl_ext.tensor._ctors import (
38+
eye,
3839
full,
3940
tril,
4041
triu,
@@ -58,6 +59,7 @@
5859
"astype",
5960
"copy",
6061
"extract",
62+
"eye",
6163
"from_numpy",
6264
"full",
6365
"nonzero",

dpctl_ext/tensor/_ctors.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,38 @@ def _cast_fill_val(fill_val, dt):
5858
return fill_val
5959

6060

61+
def _ensure_native_dtype_device_support(dtype, dev) -> None:
62+
"""Check that dtype is natively supported by device.
63+
64+
Arg:
65+
dtype:
66+
Elemental data-type
67+
dev (:class:`dpctl.SyclDevice`):
68+
The device about which the query is being made.
69+
Returns:
70+
None
71+
Raise:
72+
ValueError:
73+
if device does not natively support this `dtype`.
74+
"""
75+
if dtype in [dpt.float64, dpt.complex128] and not dev.has_aspect_fp64:
76+
raise ValueError(
77+
f"Device {dev.name} does not provide native support "
78+
"for double-precision floating point type."
79+
)
80+
if (
81+
dtype
82+
in [
83+
dpt.float16,
84+
]
85+
and not dev.has_aspect_fp16
86+
):
87+
raise ValueError(
88+
f"Device {dev.name} does not provide native support "
89+
"for half-precision floating point type."
90+
)
91+
92+
6193
def _to_scalar(obj, sc_ty):
6294
"""A way to convert object to NumPy scalar type.
6395
Raises OverflowError if obj can not be represented
@@ -67,6 +99,102 @@ def _to_scalar(obj, sc_ty):
6799
return zd_arr[()]
68100

69101

102+
def eye(
103+
n_rows,
104+
n_cols=None,
105+
/,
106+
*,
107+
k=0,
108+
dtype=None,
109+
order="C",
110+
device=None,
111+
usm_type="device",
112+
sycl_queue=None,
113+
):
114+
"""
115+
eye(n_rows, n_cols=None, /, *, k=0, dtype=None, \
116+
device=None, usm_type="device", sycl_queue=None)
117+
118+
Creates :class:`dpctl.tensor.usm_ndarray` with ones on the `k`-th
119+
diagonal.
120+
121+
Args:
122+
n_rows (int):
123+
number of rows in the output array.
124+
n_cols (int, optional):
125+
number of columns in the output array. If ``None``,
126+
``n_cols = n_rows``. Default: ``None``
127+
k (int):
128+
index of the diagonal, with ``0`` as the main diagonal.
129+
A positive value of ``k`` is a superdiagonal, a negative value
130+
is a subdiagonal.
131+
Raises :exc:`TypeError` if ``k`` is not an integer.
132+
Default: ``0``
133+
dtype (optional):
134+
data type of the array. Can be typestring,
135+
a :class:`numpy.dtype` object, :mod:`numpy` char string, or
136+
a NumPy scalar type. Default: ``None``
137+
order ("C" or "F"):
138+
memory layout for the array. Default: ``"C"``
139+
device (optional):
140+
array API concept of device where the output array
141+
is created. ``device`` can be ``None``, a oneAPI filter selector
142+
string, an instance of :class:`dpctl.SyclDevice` corresponding to
143+
a non-partitioned SYCL device, an instance of
144+
:class:`dpctl.SyclQueue`, or a :class:`dpctl.tensor.Device` object
145+
returned by :attr:`dpctl.tensor.usm_ndarray.device`.
146+
Default: ``None``
147+
usm_type (``"device"``, ``"shared"``, ``"host"``, optional):
148+
The type of SYCL USM allocation for the output array.
149+
Default: ``"device"``
150+
sycl_queue (:class:`dpctl.SyclQueue`, optional):
151+
The SYCL queue to use
152+
for output array allocation and copying. ``sycl_queue`` and
153+
``device`` are complementary arguments, i.e. use one or another.
154+
If both are specified, a :exc:`TypeError` is raised unless both
155+
imply the same underlying SYCL queue to be used. If both are
156+
``None``, a cached queue targeting default-selected device is
157+
used for allocation and population. Default: ``None``
158+
159+
Returns:
160+
usm_ndarray:
161+
A diagonal matrix.
162+
"""
163+
if not isinstance(order, str) or len(order) == 0 or order[0] not in "CcFf":
164+
raise ValueError(
165+
"Unrecognized order keyword value, expecting 'F' or 'C'."
166+
)
167+
order = order[0].upper()
168+
n_rows = operator.index(n_rows)
169+
n_cols = n_rows if n_cols is None else operator.index(n_cols)
170+
k = operator.index(k)
171+
if k >= n_cols or -k >= n_rows:
172+
return dpt.zeros(
173+
(n_rows, n_cols),
174+
dtype=dtype,
175+
order=order,
176+
device=device,
177+
usm_type=usm_type,
178+
sycl_queue=sycl_queue,
179+
)
180+
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
181+
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
182+
dtype = _get_dtype(dtype, sycl_queue)
183+
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
184+
res = dpt.usm_ndarray(
185+
(n_rows, n_cols),
186+
dtype=dtype,
187+
buffer=usm_type,
188+
order=order,
189+
buffer_ctor_kwargs={"queue": sycl_queue},
190+
)
191+
if n_rows != 0 and n_cols != 0:
192+
_manager = dpctl.utils.SequentialOrderManager[sycl_queue]
193+
hev, eye_ev = ti._eye(k, dst=res, sycl_queue=sycl_queue)
194+
_manager.add_event_pair(hev, eye_ev)
195+
return res
196+
197+
70198
def _validate_fill_value(fill_val):
71199
"""Validates that `fill_val` is a numeric or boolean scalar."""
72200
# TODO: verify if `np.True_` and `np.False_` should be instances of

dpnp/dpnp_container.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def eye(
196196
order = "C"
197197

198198
"""Creates `dpnp_array` with ones on the `k`th diagonal."""
199-
array_obj = dpt.eye(
199+
array_obj = dpt_ext.eye(
200200
N,
201201
M,
202202
k=k,

0 commit comments

Comments
 (0)