Skip to content

Commit a1d6fa3

Browse files
Move tril()/triu() to dpctl_ext/tensor
1 parent a030579 commit a1d6fa3

7 files changed

Lines changed: 638 additions & 24 deletions

File tree

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ set(_tensor_impl_sources
5454
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/eye_ctor.cpp
5555
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
5656
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/zeros_ctor.cpp
57-
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp
57+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp
5858
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp
5959
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
6060
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp

dpctl_ext/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
from dpctl_ext.tensor._ctors import (
3131
full,
32+
tril,
33+
triu,
3234
)
3335
from dpctl_ext.tensor._indexing_functions import (
3436
put,
@@ -39,4 +41,6 @@
3941
"full",
4042
"put",
4143
"take",
44+
"tril",
45+
"triu",
4246
]

dpctl_ext/tensor/_ctors.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
# THE POSSIBILITY OF SUCH DAMAGE.
2727
# *****************************************************************************
2828

29+
import operator
2930
from numbers import Number
3031

3132
import dpctl
@@ -167,3 +168,159 @@ def full(
167168
hev, full_ev = ti._full_usm_ndarray(fill_value, res, sycl_queue)
168169
_manager.add_event_pair(hev, full_ev)
169170
return res
171+
172+
173+
def tril(x, /, *, k=0):
174+
"""
175+
Returns the lower triangular part of a matrix (or a stack of matrices)
176+
``x``.
177+
178+
The lower triangular part of the matrix is defined as the elements on and
179+
below the specified diagonal ``k``.
180+
181+
Args:
182+
x (usm_ndarray):
183+
Input array
184+
k (int, optional):
185+
Specifies the diagonal above which to set
186+
elements to zero. If ``k = 0``, the diagonal is the main diagonal.
187+
If ``k < 0``, the diagonal is below the main diagonal.
188+
If ``k > 0``, the diagonal is above the main diagonal.
189+
Default: ``0``
190+
191+
Returns:
192+
usm_ndarray:
193+
A lower-triangular array or a stack of lower-triangular arrays.
194+
"""
195+
if not isinstance(x, dpt.usm_ndarray):
196+
raise TypeError(
197+
"Expected argument of type dpctl.tensor.usm_ndarray, "
198+
f"got {type(x)}."
199+
)
200+
201+
k = operator.index(k)
202+
203+
order = "F" if (x.flags.f_contiguous) else "C"
204+
205+
shape = x.shape
206+
nd = x.ndim
207+
if nd < 2:
208+
raise ValueError("Array dimensions less than 2.")
209+
210+
q = x.sycl_queue
211+
if k >= shape[nd - 1] - 1:
212+
res = dpt.empty(
213+
x.shape,
214+
dtype=x.dtype,
215+
order=order,
216+
usm_type=x.usm_type,
217+
sycl_queue=q,
218+
)
219+
_manager = dpctl.utils.SequentialOrderManager[q]
220+
dep_evs = _manager.submitted_events
221+
hev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
222+
src=x, dst=res, sycl_queue=q, depends=dep_evs
223+
)
224+
_manager.add_event_pair(hev, cpy_ev)
225+
elif k < -shape[nd - 2]:
226+
res = dpt.zeros(
227+
x.shape,
228+
dtype=x.dtype,
229+
order=order,
230+
usm_type=x.usm_type,
231+
sycl_queue=q,
232+
)
233+
else:
234+
res = dpt.empty(
235+
x.shape,
236+
dtype=x.dtype,
237+
order=order,
238+
usm_type=x.usm_type,
239+
sycl_queue=q,
240+
)
241+
_manager = dpctl.utils.SequentialOrderManager[q]
242+
dep_evs = _manager.submitted_events
243+
hev, tril_ev = ti._tril(
244+
src=x, dst=res, k=k, sycl_queue=q, depends=dep_evs
245+
)
246+
_manager.add_event_pair(hev, tril_ev)
247+
248+
return res
249+
250+
251+
def triu(x, /, *, k=0):
252+
"""
253+
Returns the upper triangular part of a matrix (or a stack of matrices)
254+
``x``.
255+
256+
The upper triangular part of the matrix is defined as the elements on and
257+
above the specified diagonal ``k``.
258+
259+
Args:
260+
x (usm_ndarray):
261+
Input array
262+
k (int, optional):
263+
Specifies the diagonal below which to set
264+
elements to zero. If ``k = 0``, the diagonal is the main diagonal.
265+
If ``k < 0``, the diagonal is below the main diagonal.
266+
If ``k > 0``, the diagonal is above the main diagonal.
267+
Default: ``0``
268+
269+
Returns:
270+
usm_ndarray:
271+
An upper-triangular array or a stack of upper-triangular arrays.
272+
"""
273+
if not isinstance(x, dpt.usm_ndarray):
274+
raise TypeError(
275+
"Expected argument of type dpctl.tensor.usm_ndarray, "
276+
f"got {type(x)}."
277+
)
278+
279+
k = operator.index(k)
280+
281+
order = "F" if (x.flags.f_contiguous) else "C"
282+
283+
shape = x.shape
284+
nd = x.ndim
285+
if nd < 2:
286+
raise ValueError("Array dimensions less than 2.")
287+
288+
q = x.sycl_queue
289+
if k > shape[nd - 1]:
290+
res = dpt.zeros(
291+
x.shape,
292+
dtype=x.dtype,
293+
order=order,
294+
usm_type=x.usm_type,
295+
sycl_queue=q,
296+
)
297+
elif k <= -shape[nd - 2] + 1:
298+
res = dpt.empty(
299+
x.shape,
300+
dtype=x.dtype,
301+
order=order,
302+
usm_type=x.usm_type,
303+
sycl_queue=q,
304+
)
305+
_manager = dpctl.utils.SequentialOrderManager[q]
306+
dep_evs = _manager.submitted_events
307+
hev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
308+
src=x, dst=res, sycl_queue=q, depends=dep_evs
309+
)
310+
_manager.add_event_pair(hev, cpy_ev)
311+
else:
312+
res = dpt.empty(
313+
x.shape,
314+
dtype=x.dtype,
315+
order=order,
316+
usm_type=x.usm_type,
317+
sycl_queue=q,
318+
)
319+
_manager = dpctl.utils.SequentialOrderManager[q]
320+
dep_evs = _manager.submitted_events
321+
hev, triu_ev = ti._triu(
322+
src=x, dst=res, k=k, sycl_queue=q, depends=dep_evs
323+
)
324+
_manager.add_event_pair(hev, triu_ev)
325+
326+
return res

dpctl_ext/tensor/libtensor/include/kernels/constructors.hpp

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,144 @@ sycl::event full_strided_impl(sycl::queue &q,
343343
return fill_ev;
344344
}
345345

346+
/* =========================== Tril and triu ============================== */
347+
348+
// define function type
349+
typedef sycl::event (*tri_fn_ptr_t)(sycl::queue &,
350+
ssize_t, // inner_range //ssize_t
351+
ssize_t, // outer_range
352+
char *, // src_data_ptr
353+
char *, // dst_data_ptr
354+
ssize_t, // nd
355+
ssize_t *, // shape_and_strides
356+
ssize_t, // k
357+
const std::vector<sycl::event> &,
358+
const std::vector<sycl::event> &);
359+
360+
/*!
361+
* @brief Function to copy triangular matrices from source stack to destination
362+
* stack.
363+
*
364+
* @param exec_q Sycl queue to which kernel is submitted for execution.
365+
* @param inner_range Number of elements in each matrix.
366+
* @param outer_range Number of matrices to copy.
367+
* @param src_p Kernel accessible USM pointer for the source array.
368+
* @param dst_p Kernel accessible USM pointer for the destination array.
369+
* @param nd The array dimensionality of source and destination arrays.
370+
* @param shape_and_strides Kernel accessible USM pointer to packed shape and
371+
* strides of arrays.
372+
* @param k Position of the diagonal above/below which to copy filling the rest
373+
* with zero elements.
374+
* @param depends List of events to wait for before starting computations, if
375+
* any.
376+
* @param additional_depends List of additional events to wait for before
377+
* starting computations, if any.
378+
*
379+
* @return Event to wait on to ensure that computation completes.
380+
* @defgroup CtorKernels
381+
*/
382+
template <typename Ty, bool>
383+
class tri_kernel;
384+
template <typename Ty, bool upper>
385+
sycl::event tri_impl(sycl::queue &exec_q,
386+
ssize_t inner_range,
387+
ssize_t outer_range,
388+
char *src_p,
389+
char *dst_p,
390+
ssize_t nd,
391+
ssize_t *shape_and_strides,
392+
ssize_t k,
393+
const std::vector<sycl::event> &depends,
394+
const std::vector<sycl::event> &additional_depends)
395+
{
396+
static constexpr int d2 = 2;
397+
ssize_t src_s = nd;
398+
ssize_t dst_s = 2 * nd;
399+
ssize_t nd_1 = nd - 1;
400+
ssize_t nd_2 = nd - 2;
401+
Ty *src = reinterpret_cast<Ty *>(src_p);
402+
Ty *dst = reinterpret_cast<Ty *>(dst_p);
403+
404+
dpctl::tensor::type_utils::validate_type_for_device<Ty>(exec_q);
405+
406+
sycl::event tri_ev = exec_q.submit([&](sycl::handler &cgh) {
407+
cgh.depends_on(depends);
408+
cgh.depends_on(additional_depends);
409+
410+
cgh.parallel_for<tri_kernel<Ty, upper>>(
411+
sycl::range<1>(inner_range * outer_range), [=](sycl::id<1> idx) {
412+
ssize_t outer_gid = idx[0] / inner_range;
413+
ssize_t inner_gid = idx[0] - inner_range * outer_gid;
414+
415+
ssize_t src_inner_offset = 0, dst_inner_offset = 0;
416+
bool to_copy{false};
417+
418+
{
419+
using dpctl::tensor::strides::CIndexer_array;
420+
CIndexer_array<d2, ssize_t> indexer_i(
421+
{shape_and_strides[nd_2], shape_and_strides[nd_1]});
422+
indexer_i.set(inner_gid);
423+
const std::array<ssize_t, d2> &inner = indexer_i.get();
424+
src_inner_offset =
425+
inner[0] * shape_and_strides[src_s + nd_2] +
426+
inner[1] * shape_and_strides[src_s + nd_1];
427+
dst_inner_offset =
428+
inner[0] * shape_and_strides[dst_s + nd_2] +
429+
inner[1] * shape_and_strides[dst_s + nd_1];
430+
431+
if constexpr (upper)
432+
to_copy = (inner[0] + k >= inner[1]);
433+
else
434+
to_copy = (inner[0] + k <= inner[1]);
435+
}
436+
437+
ssize_t src_offset = 0;
438+
ssize_t dst_offset = 0;
439+
{
440+
using dpctl::tensor::strides::CIndexer_vector;
441+
CIndexer_vector<ssize_t> outer(nd - d2);
442+
outer.get_displacement(
443+
outer_gid, shape_and_strides, shape_and_strides + src_s,
444+
shape_and_strides + dst_s, src_offset, dst_offset);
445+
}
446+
447+
src_offset += src_inner_offset;
448+
dst_offset += dst_inner_offset;
449+
450+
dst[dst_offset] = (to_copy) ? src[src_offset] : Ty(0);
451+
});
452+
});
453+
return tri_ev;
454+
}
455+
456+
/*!
457+
* @brief Factory to get function pointer of type `fnT` for data type `Ty`.
458+
* @ingroup CtorKernels
459+
*/
460+
template <typename fnT, typename Ty>
461+
struct TrilGenericFactory
462+
{
463+
fnT get()
464+
{
465+
fnT f = tri_impl<Ty, /*tril*/ true>;
466+
return f;
467+
}
468+
};
469+
470+
/*!
471+
* @brief Factory to get function pointer of type `fnT` for data type `Ty`.
472+
* @ingroup CtorKernels
473+
*/
474+
template <typename fnT, typename Ty>
475+
struct TriuGenericFactory
476+
{
477+
fnT get()
478+
{
479+
fnT f = tri_impl<Ty, /*triu*/ false>;
480+
return f;
481+
}
482+
};
483+
346484
} // namespace constructors
347485
} // namespace kernels
348486
} // namespace tensor

0 commit comments

Comments
 (0)