Skip to content

Commit 4a0516d

Browse files
committed
task: syclqueue.copy() method
1 parent 442d61f commit 4a0516d

File tree

7 files changed

+366
-0
lines changed

7 files changed

+366
-0
lines changed

dpctl/_backend.pxd

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,18 @@ cdef extern from "syclinterface/dpctl_sycl_queue_interface.h":
511511
size_t Count,
512512
const DPCTLSyclEventRef *depEvents,
513513
size_t depEventsCount)
514+
cdef DPCTLSyclEventRef DPCTLQueue_CopyData(
515+
const DPCTLSyclQueueRef Q,
516+
void *Dest,
517+
const void *Src,
518+
size_t Count)
519+
cdef DPCTLSyclEventRef DPCTLQueue_CopyDataWithEvents(
520+
const DPCTLSyclQueueRef Q,
521+
void *Dest,
522+
const void *Src,
523+
size_t Count,
524+
const DPCTLSyclEventRef *depEvents,
525+
size_t depEventsCount)
514526
cdef DPCTLSyclEventRef DPCTLQueue_Memset(
515527
const DPCTLSyclQueueRef Q,
516528
void *Dest,

dpctl/_sycl_queue.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ cdef public api class SyclQueue (_SyclQueue) [
103103
cdef DPCTLSyclQueueRef get_queue_ref(self)
104104
cpdef memcpy(self, dest, src, size_t count)
105105
cpdef SyclEvent memcpy_async(self, dest, src, size_t count, list dEvents=*)
106+
cpdef copy(self, dest, src, size_t count)
107+
cpdef SyclEvent copy_async(self, dest, src, size_t count, list dEvents=*)
106108
cpdef prefetch(self, ptr, size_t count=*)
107109
cpdef mem_advise(self, ptr, size_t count, int mem)
108110
cpdef SyclEvent submit_barrier(self, dependent_events=*)

dpctl/_sycl_queue.pyx

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ from ._backend cimport ( # noqa: E211
3535
DPCTLFilterSelector_Create,
3636
DPCTLQueue_AreEq,
3737
DPCTLQueue_Copy,
38+
DPCTLQueue_CopyData,
39+
DPCTLQueue_CopyDataWithEvents,
3840
DPCTLQueue_Create,
3941
DPCTLQueue_Delete,
4042
DPCTLQueue_GetBackend,
@@ -535,6 +537,82 @@ cdef DPCTLSyclEventRef _memcpy_impl(
535537
return ERef
536538

537539

540+
cdef DPCTLSyclEventRef _copy_impl(
541+
SyclQueue q,
542+
object dst,
543+
object src,
544+
size_t byte_count,
545+
DPCTLSyclEventRef *dep_events,
546+
size_t dep_events_count
547+
) except *:
548+
cdef void *c_dst_ptr = NULL
549+
cdef void *c_src_ptr = NULL
550+
cdef DPCTLSyclEventRef ERef = NULL
551+
cdef Py_buffer src_buf_view
552+
cdef Py_buffer dst_buf_view
553+
cdef bint src_is_buf = False
554+
cdef bint dst_is_buf = False
555+
cdef int ret_code = 0
556+
557+
if isinstance(src, _Memory):
558+
c_src_ptr = <void*>(<_Memory>src).get_data_ptr()
559+
elif _is_buffer(src):
560+
ret_code = PyObject_GetBuffer(
561+
src, &src_buf_view, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS
562+
)
563+
if ret_code != 0: # pragma: no cover
564+
raise RuntimeError("Could not access buffer")
565+
c_src_ptr = src_buf_view.buf
566+
src_is_buf = True
567+
else:
568+
raise TypeError(
569+
"Parameter `src` should have either type "
570+
"`dpctl.memory._Memory` or a type that "
571+
"supports Python buffer protocol"
572+
)
573+
574+
if isinstance(dst, _Memory):
575+
c_dst_ptr = <void*>(<_Memory>dst).get_data_ptr()
576+
elif _is_buffer(dst):
577+
ret_code = PyObject_GetBuffer(
578+
dst, &dst_buf_view,
579+
PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS | PyBUF_WRITABLE
580+
)
581+
if ret_code != 0: # pragma: no cover
582+
if src_is_buf:
583+
PyBuffer_Release(&src_buf_view)
584+
raise RuntimeError("Could not access buffer")
585+
c_dst_ptr = dst_buf_view.buf
586+
dst_is_buf = True
587+
else:
588+
raise TypeError(
589+
"Parameter `dst` should have either type "
590+
"`dpctl.memory._Memory` or a type that "
591+
"supports Python buffer protocol"
592+
)
593+
594+
if dep_events_count == 0 or dep_events is NULL:
595+
ERef = DPCTLQueue_CopyData(
596+
q._queue_ref, c_dst_ptr, c_src_ptr, byte_count
597+
)
598+
else:
599+
ERef = DPCTLQueue_CopyDataWithEvents(
600+
q._queue_ref,
601+
c_dst_ptr,
602+
c_src_ptr,
603+
byte_count,
604+
dep_events,
605+
dep_events_count
606+
)
607+
608+
if src_is_buf:
609+
PyBuffer_Release(&src_buf_view)
610+
if dst_is_buf:
611+
PyBuffer_Release(&dst_buf_view)
612+
613+
return ERef
614+
615+
538616
cdef class _SyclQueue:
539617
""" Barebone data owner class used by SyclQueue.
540618
"""
@@ -1426,6 +1504,82 @@ cdef class SyclQueue(_SyclQueue):
14261504

14271505
return SyclEvent._create(ERef)
14281506

1507+
cpdef copy(self, dest, src, size_t count):
1508+
"""Copy ``count`` bytes from ``src`` to ``dest`` and wait.
1509+
1510+
Internally, this dispatches ``sycl::queue::copy`` instantiated for
1511+
byte-sized elements.
1512+
1513+
This is a synchronizing variant corresponding to
1514+
:meth:`dpctl.SyclQueue.copy_async`.
1515+
"""
1516+
cdef DPCTLSyclEventRef ERef = NULL
1517+
1518+
ERef = _copy_impl(<SyclQueue>self, dest, src, count, NULL, 0)
1519+
if (ERef is NULL):
1520+
raise RuntimeError(
1521+
"SyclQueue.copy operation encountered an error"
1522+
)
1523+
with nogil:
1524+
DPCTLEvent_Wait(ERef)
1525+
DPCTLEvent_Delete(ERef)
1526+
1527+
cpdef SyclEvent copy_async(
1528+
self, dest, src, size_t count, list dEvents=None
1529+
):
1530+
"""Copy ``count`` bytes from ``src`` to ``dest`` asynchronously.
1531+
1532+
Internally, this dispatches ``sycl::queue::copy`` instantiated for
1533+
byte-sized elements.
1534+
1535+
Args:
1536+
dest:
1537+
Destination USM object or Python object supporting
1538+
writable buffer protocol.
1539+
src:
1540+
Source USM object or Python object supporting buffer
1541+
protocol.
1542+
count (int):
1543+
Number of bytes to copy.
1544+
dEvents (List[dpctl.SyclEvent], optional):
1545+
Events that this copy depends on.
1546+
1547+
Returns:
1548+
dpctl.SyclEvent:
1549+
Event associated with the copy operation.
1550+
"""
1551+
cdef DPCTLSyclEventRef ERef = NULL
1552+
cdef DPCTLSyclEventRef *depEvents = NULL
1553+
cdef size_t nDE = 0
1554+
1555+
if dEvents is None:
1556+
ERef = _copy_impl(<SyclQueue>self, dest, src, count, NULL, 0)
1557+
else:
1558+
nDE = len(dEvents)
1559+
depEvents = (
1560+
<DPCTLSyclEventRef*>malloc(nDE*sizeof(DPCTLSyclEventRef))
1561+
)
1562+
if depEvents is NULL:
1563+
raise MemoryError()
1564+
else:
1565+
for idx, de in enumerate(dEvents):
1566+
if isinstance(de, SyclEvent):
1567+
depEvents[idx] = (<SyclEvent>de).get_event_ref()
1568+
else:
1569+
free(depEvents)
1570+
raise TypeError(
1571+
"A sequence of dpctl.SyclEvent is expected"
1572+
)
1573+
ERef = _copy_impl(self, dest, src, count, depEvents, nDE)
1574+
free(depEvents)
1575+
1576+
if (ERef is NULL):
1577+
raise RuntimeError(
1578+
"SyclQueue.copy operation encountered an error"
1579+
)
1580+
1581+
return SyclEvent._create(ERef)
1582+
14291583
cpdef prefetch(self, mem, size_t count=0):
14301584
cdef void *ptr
14311585
cdef DPCTLSyclEventRef ERef = NULL
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Defines unit test cases for the SyclQueue.copy."""
18+
19+
import pytest
20+
21+
import dpctl
22+
import dpctl.memory
23+
24+
25+
def _create_memory(q):
26+
nbytes = 1024
27+
mobj = dpctl.memory.MemoryUSMShared(nbytes, queue=q)
28+
return mobj
29+
30+
31+
def test_copy_copy_host_to_host():
32+
try:
33+
q = dpctl.SyclQueue()
34+
except dpctl.SyclQueueCreationError:
35+
pytest.skip("Default constructor for SyclQueue failed")
36+
37+
src_buf = b"abcdefghijklmnopqrstuvwxyz"
38+
dst_buf = bytearray(len(src_buf))
39+
40+
q.copy(dst_buf, src_buf, len(src_buf))
41+
42+
assert dst_buf == src_buf
43+
44+
45+
def test_copy_async():
46+
try:
47+
q = dpctl.SyclQueue()
48+
except dpctl.SyclQueueCreationError:
49+
pytest.skip("Default constructor for SyclQueue failed")
50+
51+
src_buf = b"abcdefghijklmnopqrstuvwxyz"
52+
n = len(src_buf)
53+
dst_buf = bytearray(n)
54+
dst_buf2 = bytearray(n)
55+
56+
e = q.copy_async(dst_buf, src_buf, n)
57+
e2 = q.copy_async(dst_buf2, src_buf, n, [e])
58+
59+
e.wait()
60+
e2.wait()
61+
assert dst_buf == src_buf
62+
assert dst_buf2 == src_buf
63+
64+
65+
def test_copy_type_error():
66+
try:
67+
q = dpctl.SyclQueue()
68+
except dpctl.SyclQueueCreationError:
69+
pytest.skip("Default constructor for SyclQueue failed")
70+
mobj = _create_memory(q)
71+
72+
with pytest.raises(TypeError) as cm:
73+
q.copy(None, mobj, 3)
74+
assert "_Memory" in str(cm.value)
75+
76+
with pytest.raises(TypeError) as cm:
77+
q.copy(mobj, None, 3)
78+
assert "_Memory" in str(cm.value)

libsyclinterface/include/syclinterface/dpctl_sycl_queue_interface.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,47 @@ DPCTLQueue_MemcpyWithEvents(__dpctl_keep const DPCTLSyclQueueRef QRef,
327327
__dpctl_keep const DPCTLSyclEventRef *DepEvents,
328328
size_t DepEventsCount);
329329

330+
/*!
331+
* @brief C-API wrapper for ``sycl::queue::copy``.
332+
*
333+
* @param QRef An opaque pointer to the ``sycl::queue``.
334+
* @param Dest A destination pointer.
335+
* @param Src A source pointer.
336+
* @param Count A number of bytes to copy.
337+
* @return An opaque pointer to the ``sycl::event`` returned by the
338+
* ``sycl::queue::copy`` function.
339+
* @ingroup QueueInterface
340+
*/
341+
DPCTL_API
342+
__dpctl_give DPCTLSyclEventRef
343+
DPCTLQueue_CopyData(__dpctl_keep const DPCTLSyclQueueRef QRef,
344+
void *Dest,
345+
const void *Src,
346+
size_t Count);
347+
348+
/*!
349+
* @brief C-API wrapper for ``sycl::queue::copy``.
350+
*
351+
* @param QRef An opaque pointer to the ``sycl::queue``.
352+
* @param Dest A destination pointer.
353+
* @param Src A source pointer.
354+
* @param Count A number of bytes to copy.
355+
* @param DepEvents A pointer to array of DPCTLSyclEventRef opaque
356+
* pointers to dependent events.
357+
* @param DepEventsCount A number of dependent events.
358+
* @return An opaque pointer to the ``sycl::event`` returned by the
359+
* ``sycl::queue::copy`` function.
360+
* @ingroup QueueInterface
361+
*/
362+
DPCTL_API
363+
__dpctl_give DPCTLSyclEventRef
364+
DPCTLQueue_CopyDataWithEvents(__dpctl_keep const DPCTLSyclQueueRef QRef,
365+
void *Dest,
366+
const void *Src,
367+
size_t Count,
368+
__dpctl_keep const DPCTLSyclEventRef *DepEvents,
369+
size_t DepEventsCount);
370+
330371
/*!
331372
* @brief C-API wrapper for ``sycl::queue::prefetch``.
332373
*

0 commit comments

Comments
 (0)