Skip to content

Commit b60d095

Browse files
Move ti._eye to() to dpctl_ext/tensor/libtensor
1 parent 6fecefe commit b60d095

File tree

5 files changed

+307
-14
lines changed

5 files changed

+307
-14
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ set(_tensor_impl_sources
5454
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
5555
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/integer_advanced_indexing.cpp
5656
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_advanced_indexing.cpp
57-
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/eye_ctor.cpp
57+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/eye_ctor.cpp
5858
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
5959
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/zeros_ctor.cpp
6060
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp

dpctl_ext/tensor/libtensor/include/kernels/constructors.hpp

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ using dpctl::tensor::ssize_t;
5656

5757
template <typename Ty>
5858
class full_strided_kernel;
59-
// template <typename Ty> class eye_kernel;
59+
template <typename Ty>
60+
class eye_kernel;
6061

6162
using namespace dpctl::tensor::offset_utils;
6263

@@ -162,6 +163,99 @@ sycl::event full_strided_impl(sycl::queue &q,
162163
return fill_ev;
163164
}
164165

166+
/* ================ Eye ================== */
167+
168+
typedef sycl::event (*eye_fn_ptr_t)(sycl::queue &,
169+
std::size_t nelems, // num_elements
170+
ssize_t start,
171+
ssize_t end,
172+
ssize_t step,
173+
char *, // dst_data_ptr
174+
const std::vector<sycl::event> &);
175+
176+
template <typename Ty>
177+
class EyeFunctor
178+
{
179+
private:
180+
Ty *p = nullptr;
181+
ssize_t start_v;
182+
ssize_t end_v;
183+
ssize_t step_v;
184+
185+
public:
186+
EyeFunctor(char *dst_p,
187+
const ssize_t v0,
188+
const ssize_t v1,
189+
const ssize_t dv)
190+
: p(reinterpret_cast<Ty *>(dst_p)), start_v(v0), end_v(v1), step_v(dv)
191+
{
192+
}
193+
194+
void operator()(sycl::id<1> wiid) const
195+
{
196+
Ty set_v = 0;
197+
ssize_t i = static_cast<ssize_t>(wiid.get(0));
198+
if (i >= start_v and i <= end_v) {
199+
if ((i - start_v) % step_v == 0) {
200+
set_v = 1;
201+
}
202+
}
203+
p[i] = set_v;
204+
}
205+
};
206+
207+
/*!
208+
* @brief Function to populate 2D array with eye matrix.
209+
*
210+
* @param exec_q Sycl queue to which kernel is submitted for execution.
211+
* @param nelems Number of elements to assign.
212+
* @param start Position of the first non-zero value.
213+
* @param end Position of the last non-zero value.
214+
* @param step Number of array elements between non-zeros.
215+
* @param array_data Kernel accessible USM pointer for the destination array.
216+
* @param depends List of events to wait for before starting computations, if
217+
* any.
218+
*
219+
* @return Event to wait on to ensure that computation completes.
220+
* @defgroup CtorKernels
221+
*/
222+
template <typename Ty>
223+
sycl::event eye_impl(sycl::queue &exec_q,
224+
std::size_t nelems,
225+
const ssize_t start,
226+
const ssize_t end,
227+
const ssize_t step,
228+
char *array_data,
229+
const std::vector<sycl::event> &depends)
230+
{
231+
dpctl::tensor::type_utils::validate_type_for_device<Ty>(exec_q);
232+
sycl::event eye_event = exec_q.submit([&](sycl::handler &cgh) {
233+
cgh.depends_on(depends);
234+
235+
using KernelName = eye_kernel<Ty>;
236+
using Impl = EyeFunctor<Ty>;
237+
238+
cgh.parallel_for<KernelName>(sycl::range<1>{nelems},
239+
Impl(array_data, start, end, step));
240+
});
241+
242+
return eye_event;
243+
}
244+
245+
/*!
246+
* @brief Factory to get function pointer of type `fnT` for data type `Ty`.
247+
* @ingroup CtorKernels
248+
*/
249+
template <typename fnT, typename Ty>
250+
struct EyeFactory
251+
{
252+
fnT get()
253+
{
254+
fnT f = eye_impl<Ty>;
255+
return f;
256+
}
257+
};
258+
165259
/* =========================== Tril and triu ============================== */
166260

167261
// define function type
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===--------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
33+
//===--------------------------------------------------------------------===//
34+
35+
#include <algorithm>
36+
#include <cstddef>
37+
#include <utility>
38+
#include <vector>
39+
40+
#include <sycl/sycl.hpp>
41+
42+
#include "dpnp4pybind11.hpp"
43+
#include <pybind11/pybind11.h>
44+
45+
#include "eye_ctor.hpp"
46+
#include "kernels/constructors.hpp"
47+
#include "utils/output_validation.hpp"
48+
#include "utils/type_dispatch.hpp"
49+
50+
namespace py = pybind11;
51+
namespace td_ns = dpctl::tensor::type_dispatch;
52+
53+
namespace dpctl::tensor::py_internal
54+
{
55+
56+
using dpctl::utils::keep_args_alive;
57+
58+
using dpctl::tensor::kernels::constructors::eye_fn_ptr_t;
59+
static eye_fn_ptr_t eye_dispatch_vector[td_ns::num_types];
60+
61+
std::pair<sycl::event, sycl::event>
62+
usm_ndarray_eye(py::ssize_t k,
63+
const dpctl::tensor::usm_ndarray &dst,
64+
sycl::queue &exec_q,
65+
const std::vector<sycl::event> &depends)
66+
{
67+
// dst must be 2D
68+
69+
if (dst.get_ndim() != 2) {
70+
throw py::value_error(
71+
"usm_ndarray_eye: Expecting 2D array to populate");
72+
}
73+
74+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst})) {
75+
throw py::value_error("Execution queue is not compatible with the "
76+
"allocation queue");
77+
}
78+
79+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
80+
81+
auto array_types = td_ns::usm_ndarray_types();
82+
int dst_typenum = dst.get_typenum();
83+
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
84+
85+
const py::ssize_t nelem = dst.get_size();
86+
const py::ssize_t rows = dst.get_shape(0);
87+
const py::ssize_t cols = dst.get_shape(1);
88+
if (rows == 0 || cols == 0) {
89+
// nothing to do
90+
return std::make_pair(sycl::event{}, sycl::event{});
91+
}
92+
93+
bool is_dst_c_contig = dst.is_c_contiguous();
94+
bool is_dst_f_contig = dst.is_f_contiguous();
95+
if (!is_dst_c_contig && !is_dst_f_contig) {
96+
throw py::value_error("USM array is not contiguous");
97+
}
98+
99+
py::ssize_t start;
100+
if (is_dst_c_contig) {
101+
start = (k < 0) ? -k * cols : k;
102+
}
103+
else {
104+
start = (k < 0) ? -k : k * rows;
105+
}
106+
107+
const py::ssize_t *strides = dst.get_strides_raw();
108+
py::ssize_t step;
109+
if (strides == nullptr) {
110+
step = (is_dst_c_contig) ? cols + 1 : rows + 1;
111+
}
112+
else {
113+
step = strides[0] + strides[1];
114+
}
115+
116+
const py::ssize_t length = std::min({rows, cols, rows + k, cols - k});
117+
const py::ssize_t end = start + step * (length - 1);
118+
119+
char *dst_data = dst.get_data();
120+
sycl::event eye_event;
121+
122+
auto fn = eye_dispatch_vector[dst_typeid];
123+
124+
eye_event = fn(exec_q, static_cast<std::size_t>(nelem), start, end, step,
125+
dst_data, depends);
126+
127+
return std::make_pair(keep_args_alive(exec_q, {dst}, {eye_event}),
128+
eye_event);
129+
}
130+
131+
void init_eye_ctor_dispatch_vectors(void)
132+
{
133+
using namespace td_ns;
134+
using dpctl::tensor::kernels::constructors::EyeFactory;
135+
136+
DispatchVectorBuilder<eye_fn_ptr_t, EyeFactory, num_types> dvb;
137+
dvb.populate_dispatch_vector(eye_dispatch_vector);
138+
139+
return;
140+
}
141+
142+
} // namespace dpctl::tensor::py_internal
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===--------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
33+
//===--------------------------------------------------------------------===//
34+
35+
#pragma once
36+
#include <utility>
37+
#include <vector>
38+
39+
#include <sycl/sycl.hpp>
40+
41+
#include "dpnp4pybind11.hpp"
42+
#include <pybind11/pybind11.h>
43+
44+
namespace py = pybind11;
45+
46+
namespace dpctl::tensor::py_internal
47+
{
48+
49+
extern std::pair<sycl::event, sycl::event>
50+
usm_ndarray_eye(py::ssize_t k,
51+
const dpctl::tensor::usm_ndarray &dst,
52+
sycl::queue &exec_q,
53+
const std::vector<sycl::event> &depends = {});
54+
55+
extern void init_eye_ctor_dispatch_vectors(void);
56+
57+
} // namespace dpctl::tensor::py_internal

dpctl_ext/tensor/libtensor/source/tensor_ctors.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
#include "copy_for_roll.hpp"
5353
#include "copy_numpy_ndarray_into_usm_ndarray.hpp"
5454
#include "device_support_queries.hpp"
55-
// #include "eye_ctor.hpp"
55+
#include "eye_ctor.hpp"
5656
#include "full_ctor.hpp"
5757
#include "integer_advanced_indexing.hpp"
5858
#include "kernels/dpctl_tensor_types.hpp"
@@ -124,7 +124,7 @@ using dpctl::tensor::py_internal::py_cumsum_1d;
124124

125125
/* ================ Eye ================== */
126126

127-
// using dpctl::tensor::py_internal::usm_ndarray_eye;
127+
using dpctl::tensor::py_internal::usm_ndarray_eye;
128128

129129
/* =========================== Tril and triu ============================== */
130130

@@ -160,7 +160,7 @@ void init_dispatch_vectors(void)
160160
// init_linear_sequences_dispatch_vectors();
161161
init_full_ctor_dispatch_vectors();
162162
init_zeros_ctor_dispatch_vectors();
163-
// init_eye_ctor_dispatch_vectors();
163+
init_eye_ctor_dispatch_vectors();
164164
init_triul_ctor_dispatch_vectors();
165165

166166
populate_masked_extract_dispatch_vectors();
@@ -348,15 +348,15 @@ PYBIND11_MODULE(_tensor_impl, m)
348348
py::arg("mode"), py::arg("sycl_queue"),
349349
py::arg("depends") = py::list());
350350

351-
// m.def("_eye", &usm_ndarray_eye,
352-
// "Fills input 2D contiguous usm_ndarray `dst` with "
353-
// "zeros outside of the diagonal "
354-
// "specified by "
355-
// "the diagonal index `k` "
356-
// "which is filled with ones."
357-
// "Returns a tuple of events: (ht_event, comp_event)",
358-
// py::arg("k"), py::arg("dst"), py::arg("sycl_queue"),
359-
// py::arg("depends") = py::list());
351+
m.def("_eye", &usm_ndarray_eye,
352+
"Fills input 2D contiguous usm_ndarray `dst` with "
353+
"zeros outside of the diagonal "
354+
"specified by "
355+
"the diagonal index `k` "
356+
"which is filled with ones."
357+
"Returns a tuple of events: (ht_event, comp_event)",
358+
py::arg("k"), py::arg("dst"), py::arg("sycl_queue"),
359+
py::arg("depends") = py::list());
360360

361361
m.def("default_device_fp_type",
362362
dpctl::tensor::py_internal::default_device_fp_type,

0 commit comments

Comments
 (0)