Skip to content

Commit 5a9c14c

Browse files
Add copy_usm_ndarray_into_usm_ndarray implementation
1 parent dcc421b commit 5a9c14c

2 files changed

Lines changed: 370 additions & 0 deletions

File tree

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===----------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
33+
//===----------------------------------------------------------------------===//
34+
35+
#include <algorithm>
36+
#include <complex>
37+
#include <cstddef>
38+
#include <cstdint>
39+
#include <stdexcept>
40+
#include <sycl/sycl.hpp>
41+
#include <thread>
42+
#include <type_traits>
43+
#include <utility>
44+
45+
#include "dpnp4pybind11.hpp"
46+
#include <pybind11/complex.h>
47+
#include <pybind11/numpy.h>
48+
#include <pybind11/pybind11.h>
49+
#include <pybind11/stl.h>
50+
51+
#include "kernels/copy_and_cast.hpp"
52+
#include "utils/memory_overlap.hpp"
53+
#include "utils/offset_utils.hpp"
54+
#include "utils/output_validation.hpp"
55+
#include "utils/sycl_alloc_utils.hpp"
56+
#include "utils/type_dispatch.hpp"
57+
#include "utils/type_utils.hpp"
58+
59+
#include "copy_as_contig.hpp"
60+
#include "simplify_iteration_space.hpp"
61+
62+
namespace dpctl
63+
{
64+
namespace tensor
65+
{
66+
namespace py_internal
67+
{
68+
69+
namespace td_ns = dpctl::tensor::type_dispatch;
70+
71+
using dpctl::tensor::kernels::copy_and_cast::copy_and_cast_1d_fn_ptr_t;
72+
using dpctl::tensor::kernels::copy_and_cast::copy_and_cast_contig_fn_ptr_t;
73+
using dpctl::tensor::kernels::copy_and_cast::copy_and_cast_generic_fn_ptr_t;
74+
75+
static copy_and_cast_generic_fn_ptr_t
76+
copy_and_cast_generic_dispatch_table[td_ns::num_types][td_ns::num_types];
77+
static copy_and_cast_1d_fn_ptr_t
78+
copy_and_cast_1d_dispatch_table[td_ns::num_types][td_ns::num_types];
79+
static copy_and_cast_contig_fn_ptr_t
80+
copy_and_cast_contig_dispatch_table[td_ns::num_types][td_ns::num_types];
81+
82+
namespace py = pybind11;
83+
84+
using dpctl::utils::keep_args_alive;
85+
86+
std::pair<sycl::event, sycl::event> copy_usm_ndarray_into_usm_ndarray(
87+
const dpctl::tensor::usm_ndarray &src,
88+
const dpctl::tensor::usm_ndarray &dst,
89+
sycl::queue &exec_q,
90+
const std::vector<sycl::event> &depends = {})
91+
{
92+
// array dimensions must be the same
93+
int src_nd = src.get_ndim();
94+
int dst_nd = dst.get_ndim();
95+
96+
if (src_nd != dst_nd) {
97+
throw py::value_error("Array dimensions are not the same.");
98+
}
99+
100+
// shapes must be the same
101+
const py::ssize_t *src_shape = src.get_shape_raw();
102+
const py::ssize_t *dst_shape = dst.get_shape_raw();
103+
104+
bool shapes_equal(true);
105+
std::size_t src_nelems(1);
106+
107+
for (int i = 0; shapes_equal && (i < src_nd); ++i) {
108+
src_nelems *= static_cast<std::size_t>(src_shape[i]);
109+
shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
110+
}
111+
if (!shapes_equal) {
112+
throw py::value_error("Array shapes are not the same.");
113+
}
114+
115+
if (src_nelems == 0) {
116+
// nothing to do
117+
return std::make_pair(sycl::event(), sycl::event());
118+
}
119+
120+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
121+
122+
// check compatibility of execution queue and allocation queue
123+
if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
124+
throw py::value_error(
125+
"Execution queue is not compatible with allocation queues");
126+
}
127+
128+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
129+
130+
int src_typenum = src.get_typenum();
131+
int dst_typenum = dst.get_typenum();
132+
133+
auto array_types = td_ns::usm_ndarray_types();
134+
int src_type_id = array_types.typenum_to_lookup_id(src_typenum);
135+
int dst_type_id = array_types.typenum_to_lookup_id(dst_typenum);
136+
137+
char *src_data = src.get_data();
138+
char *dst_data = dst.get_data();
139+
140+
// check that arrays do not overlap, and concurrent copying is safe.
141+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
142+
if (overlap(src, dst)) {
143+
// TODO: could use a temporary, but this is done by the caller
144+
throw py::value_error("Arrays index overlapping segments of memory");
145+
}
146+
147+
bool is_src_c_contig = src.is_c_contiguous();
148+
bool is_src_f_contig = src.is_f_contiguous();
149+
150+
bool is_dst_c_contig = dst.is_c_contiguous();
151+
bool is_dst_f_contig = dst.is_f_contiguous();
152+
153+
// check for applicability of special cases:
154+
// (both C-contiguous || both F-contiguous)
155+
bool both_c_contig = (is_src_c_contig && is_dst_c_contig);
156+
bool both_f_contig = (is_src_f_contig && is_dst_f_contig);
157+
if (both_c_contig || both_f_contig) {
158+
159+
sycl::event copy_ev;
160+
if (src_type_id == dst_type_id) {
161+
162+
int src_elem_size = src.get_elemsize();
163+
164+
copy_ev = exec_q.memcpy(static_cast<void *>(dst_data),
165+
static_cast<const void *>(src_data),
166+
src_nelems * src_elem_size, depends);
167+
}
168+
else {
169+
auto contig_fn =
170+
copy_and_cast_contig_dispatch_table[dst_type_id][src_type_id];
171+
copy_ev =
172+
contig_fn(exec_q, src_nelems, src_data, dst_data, depends);
173+
}
174+
// make sure src and dst are not GC-ed before copy_ev is complete
175+
return std::make_pair(keep_args_alive(exec_q, {src, dst}, {copy_ev}),
176+
copy_ev);
177+
}
178+
179+
if ((src_type_id == dst_type_id) && (src_nd > 1)) {
180+
if (is_dst_c_contig) {
181+
return py_as_c_contig(src, dst, exec_q, depends);
182+
}
183+
else if (is_dst_f_contig) {
184+
return py_as_f_contig(src, dst, exec_q, depends);
185+
}
186+
}
187+
188+
auto const &src_strides = src.get_strides_vector();
189+
auto const &dst_strides = dst.get_strides_vector();
190+
191+
using shT = std::vector<py::ssize_t>;
192+
shT simplified_shape;
193+
shT simplified_src_strides;
194+
shT simplified_dst_strides;
195+
py::ssize_t src_offset(0);
196+
py::ssize_t dst_offset(0);
197+
198+
int nd = src_nd;
199+
const py::ssize_t *shape = src_shape;
200+
201+
// nd, simplified_* and *_offset are modified by reference
202+
dpctl::tensor::py_internal::simplify_iteration_space(
203+
nd, shape, src_strides, dst_strides,
204+
// output
205+
simplified_shape, simplified_src_strides, simplified_dst_strides,
206+
src_offset, dst_offset);
207+
208+
if (nd < 2) {
209+
if (nd == 1) {
210+
std::array<py::ssize_t, 1> shape_arr = {simplified_shape[0]};
211+
std::array<py::ssize_t, 1> src_strides_arr = {
212+
simplified_src_strides[0]};
213+
std::array<py::ssize_t, 1> dst_strides_arr = {
214+
simplified_dst_strides[0]};
215+
216+
sycl::event copy_and_cast_1d_event;
217+
if ((src_strides_arr[0] == 1) && (dst_strides_arr[0] == 1) &&
218+
(src_offset == 0) && (dst_offset == 0))
219+
{
220+
auto contig_fn =
221+
copy_and_cast_contig_dispatch_table[dst_type_id]
222+
[src_type_id];
223+
copy_and_cast_1d_event =
224+
contig_fn(exec_q, src_nelems, src_data, dst_data, depends);
225+
}
226+
else {
227+
auto fn =
228+
copy_and_cast_1d_dispatch_table[dst_type_id][src_type_id];
229+
copy_and_cast_1d_event =
230+
fn(exec_q, src_nelems, shape_arr, src_strides_arr,
231+
dst_strides_arr, src_data, src_offset, dst_data,
232+
dst_offset, depends);
233+
}
234+
return std::make_pair(
235+
keep_args_alive(exec_q, {src, dst}, {copy_and_cast_1d_event}),
236+
copy_and_cast_1d_event);
237+
}
238+
else if (nd == 0) { // case of a scalar
239+
assert(src_nelems == 1);
240+
std::array<py::ssize_t, 1> shape_arr = {1};
241+
std::array<py::ssize_t, 1> src_strides_arr = {1};
242+
std::array<py::ssize_t, 1> dst_strides_arr = {1};
243+
244+
auto fn = copy_and_cast_1d_dispatch_table[dst_type_id][src_type_id];
245+
246+
sycl::event copy_and_cast_0d_event = fn(
247+
exec_q, src_nelems, shape_arr, src_strides_arr, dst_strides_arr,
248+
src_data, src_offset, dst_data, dst_offset, depends);
249+
250+
return std::make_pair(
251+
keep_args_alive(exec_q, {src, dst}, {copy_and_cast_0d_event}),
252+
copy_and_cast_0d_event);
253+
}
254+
}
255+
256+
// Generic implementation
257+
auto copy_and_cast_fn =
258+
copy_and_cast_generic_dispatch_table[dst_type_id][src_type_id];
259+
260+
std::vector<sycl::event> host_task_events;
261+
host_task_events.reserve(2);
262+
263+
using dpctl::tensor::offset_utils::device_allocate_and_pack;
264+
auto ptr_size_event_tuple = device_allocate_and_pack<py::ssize_t>(
265+
exec_q, host_task_events, simplified_shape, simplified_src_strides,
266+
simplified_dst_strides);
267+
auto shape_strides_owner = std::move(std::get<0>(ptr_size_event_tuple));
268+
const sycl::event &copy_shape_ev = std::get<2>(ptr_size_event_tuple);
269+
const py::ssize_t *shape_strides = shape_strides_owner.get();
270+
271+
const sycl::event &copy_and_cast_generic_ev = copy_and_cast_fn(
272+
exec_q, src_nelems, nd, shape_strides, src_data, src_offset, dst_data,
273+
dst_offset, depends, {copy_shape_ev});
274+
275+
// async free of shape_strides temporary
276+
const auto &temporaries_cleanup_ev =
277+
dpctl::tensor::alloc_utils::async_smart_free(
278+
exec_q, {copy_and_cast_generic_ev}, shape_strides_owner);
279+
host_task_events.push_back(temporaries_cleanup_ev);
280+
281+
return std::make_pair(keep_args_alive(exec_q, {src, dst}, host_task_events),
282+
copy_and_cast_generic_ev);
283+
}
284+
285+
void init_copy_and_cast_usm_to_usm_dispatch_tables(void)
286+
{
287+
using namespace td_ns;
288+
289+
using dpctl::tensor::kernels::copy_and_cast::CopyAndCastContigFactory;
290+
DispatchTableBuilder<copy_and_cast_contig_fn_ptr_t,
291+
CopyAndCastContigFactory, num_types>
292+
dtb_contig;
293+
dtb_contig.populate_dispatch_table(copy_and_cast_contig_dispatch_table);
294+
295+
using dpctl::tensor::kernels::copy_and_cast::CopyAndCastGenericFactory;
296+
DispatchTableBuilder<copy_and_cast_generic_fn_ptr_t,
297+
CopyAndCastGenericFactory, num_types>
298+
dtb_generic;
299+
dtb_generic.populate_dispatch_table(copy_and_cast_generic_dispatch_table);
300+
301+
using dpctl::tensor::kernels::copy_and_cast::CopyAndCast1DFactory;
302+
DispatchTableBuilder<copy_and_cast_1d_fn_ptr_t, CopyAndCast1DFactory,
303+
num_types>
304+
dtb_1d;
305+
dtb_1d.populate_dispatch_table(copy_and_cast_1d_dispatch_table);
306+
}
307+
308+
} // namespace py_internal
309+
} // namespace tensor
310+
} // namespace dpctl
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===----------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
33+
//===----------------------------------------------------------------------===//
34+
35+
#pragma once
36+
#include <sycl/sycl.hpp>
37+
#include <utility>
38+
#include <vector>
39+
40+
#include "dpnp4pybind11.hpp"
41+
#include <pybind11/pybind11.h>
42+
43+
namespace dpctl
44+
{
45+
namespace tensor
46+
{
47+
namespace py_internal
48+
{
49+
50+
extern std::pair<sycl::event, sycl::event> copy_usm_ndarray_into_usm_ndarray(
51+
const dpctl::tensor::usm_ndarray &src,
52+
const dpctl::tensor::usm_ndarray &dst,
53+
sycl::queue &exec_q,
54+
const std::vector<sycl::event> &depends = {});
55+
56+
extern void init_copy_and_cast_usm_to_usm_dispatch_tables();
57+
58+
} // namespace py_internal
59+
} // namespace tensor
60+
} // namespace dpctl

0 commit comments

Comments
 (0)