@@ -58,6 +58,38 @@ def _cast_fill_val(fill_val, dt):
5858 return fill_val
5959
6060
61+ def _ensure_native_dtype_device_support (dtype , dev ) -> None :
62+ """Check that dtype is natively supported by device.
63+
64+ Arg:
65+ dtype:
66+ Elemental data-type
67+ dev (:class:`dpctl.SyclDevice`):
68+ The device about which the query is being made.
69+ Returns:
70+ None
71+ Raise:
72+ ValueError:
73+ if device does not natively support this `dtype`.
74+ """
75+ if dtype in [dpt .float64 , dpt .complex128 ] and not dev .has_aspect_fp64 :
76+ raise ValueError (
77+ f"Device { dev .name } does not provide native support "
78+ "for double-precision floating point type."
79+ )
80+ if (
81+ dtype
82+ in [
83+ dpt .float16 ,
84+ ]
85+ and not dev .has_aspect_fp16
86+ ):
87+ raise ValueError (
88+ f"Device { dev .name } does not provide native support "
89+ "for half-precision floating point type."
90+ )
91+
92+
6193def _to_scalar (obj , sc_ty ):
6294 """A way to convert object to NumPy scalar type.
6395 Raises OverflowError if obj can not be represented
@@ -67,6 +99,102 @@ def _to_scalar(obj, sc_ty):
6799 return zd_arr [()]
68100
69101
102+ def eye (
103+ n_rows ,
104+ n_cols = None ,
105+ / ,
106+ * ,
107+ k = 0 ,
108+ dtype = None ,
109+ order = "C" ,
110+ device = None ,
111+ usm_type = "device" ,
112+ sycl_queue = None ,
113+ ):
114+ """
115+ eye(n_rows, n_cols=None, /, *, k=0, dtype=None, \
116+ device=None, usm_type="device", sycl_queue=None)
117+
118+ Creates :class:`dpctl.tensor.usm_ndarray` with ones on the `k`-th
119+ diagonal.
120+
121+ Args:
122+ n_rows (int):
123+ number of rows in the output array.
124+ n_cols (int, optional):
125+ number of columns in the output array. If ``None``,
126+ ``n_cols = n_rows``. Default: ``None``
127+ k (int):
128+ index of the diagonal, with ``0`` as the main diagonal.
129+ A positive value of ``k`` is a superdiagonal, a negative value
130+ is a subdiagonal.
131+ Raises :exc:`TypeError` if ``k`` is not an integer.
132+ Default: ``0``
133+ dtype (optional):
134+ data type of the array. Can be typestring,
135+ a :class:`numpy.dtype` object, :mod:`numpy` char string, or
136+ a NumPy scalar type. Default: ``None``
137+ order ("C" or "F"):
138+ memory layout for the array. Default: ``"C"``
139+ device (optional):
140+ array API concept of device where the output array
141+ is created. ``device`` can be ``None``, a oneAPI filter selector
142+ string, an instance of :class:`dpctl.SyclDevice` corresponding to
143+ a non-partitioned SYCL device, an instance of
144+ :class:`dpctl.SyclQueue`, or a :class:`dpctl.tensor.Device` object
145+ returned by :attr:`dpctl.tensor.usm_ndarray.device`.
146+ Default: ``None``
147+ usm_type (``"device"``, ``"shared"``, ``"host"``, optional):
148+ The type of SYCL USM allocation for the output array.
149+ Default: ``"device"``
150+ sycl_queue (:class:`dpctl.SyclQueue`, optional):
151+ The SYCL queue to use
152+ for output array allocation and copying. ``sycl_queue`` and
153+ ``device`` are complementary arguments, i.e. use one or another.
154+ If both are specified, a :exc:`TypeError` is raised unless both
155+ imply the same underlying SYCL queue to be used. If both are
156+ ``None``, a cached queue targeting default-selected device is
157+ used for allocation and population. Default: ``None``
158+
159+ Returns:
160+ usm_ndarray:
161+ A diagonal matrix.
162+ """
163+ if not isinstance (order , str ) or len (order ) == 0 or order [0 ] not in "CcFf" :
164+ raise ValueError (
165+ "Unrecognized order keyword value, expecting 'F' or 'C'."
166+ )
167+ order = order [0 ].upper ()
168+ n_rows = operator .index (n_rows )
169+ n_cols = n_rows if n_cols is None else operator .index (n_cols )
170+ k = operator .index (k )
171+ if k >= n_cols or - k >= n_rows :
172+ return dpt .zeros (
173+ (n_rows , n_cols ),
174+ dtype = dtype ,
175+ order = order ,
176+ device = device ,
177+ usm_type = usm_type ,
178+ sycl_queue = sycl_queue ,
179+ )
180+ dpctl .utils .validate_usm_type (usm_type , allow_none = False )
181+ sycl_queue = normalize_queue_device (sycl_queue = sycl_queue , device = device )
182+ dtype = _get_dtype (dtype , sycl_queue )
183+ _ensure_native_dtype_device_support (dtype , sycl_queue .sycl_device )
184+ res = dpt .usm_ndarray (
185+ (n_rows , n_cols ),
186+ dtype = dtype ,
187+ buffer = usm_type ,
188+ order = order ,
189+ buffer_ctor_kwargs = {"queue" : sycl_queue },
190+ )
191+ if n_rows != 0 and n_cols != 0 :
192+ _manager = dpctl .utils .SequentialOrderManager [sycl_queue ]
193+ hev , eye_ev = ti ._eye (k , dst = res , sycl_queue = sycl_queue )
194+ _manager .add_event_pair (hev , eye_ev )
195+ return res
196+
197+
70198def _validate_fill_value (fill_val ):
71199 """Validates that `fill_val` is a numeric or boolean scalar."""
72200 # TODO: verify if `np.True_` and `np.False_` should be instances of
0 commit comments