Skip to content

Commit 24eee77

Browse files
committed
Update dpctl sycl kernel example
1 parent ad83e95 commit 24eee77

File tree

3 files changed

+43
-36
lines changed

3 files changed

+43
-36
lines changed

examples/pybind11/use_dpctl_sycl_kernel/example.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616

1717
# coding: utf-8
1818

19+
import numpy as np
1920
import use_kernel as eg
2021

2122
import dpctl
23+
import dpctl.memory as dpmem
2224
import dpctl.program as dppr
23-
import dpctl.tensor as dpt
2425

2526
# create execution queue, targeting default selected device
2627
q = dpctl.SyclQueue()
@@ -38,10 +39,18 @@
3839
assert krn.num_args == 2
3940

4041
# Construct the argument, and allocate memory for the result
41-
x = dpt.arange(0, stop=13, step=1, dtype="i4", sycl_queue=q)
42-
y = dpt.empty_like(x)
42+
x = np.arange(0, stop=13, step=1, dtype="i4")
43+
y = np.empty_like(x)
44+
x_dev = dpmem.MemoryUSMDevice(x.nbytes, queue=q)
45+
y_dev = dpmem.MemoryUSMDevice(y.nbytes, queue=q)
4346

44-
eg.submit_custom_kernel(q, krn, src=x, dst=y)
47+
# Copy input data to the device
48+
q.memcpy(dest=x_dev, src=x, count=x.nbytes)
49+
50+
eg.submit_custom_kernel(q, krn, src=x_dev, dst=y_dev)
51+
52+
# Copy result data back to host
53+
q.memcpy(dest=y, src=y_dev, count=y.nbytes)
4554

4655
# output the result
47-
print(dpt.asnumpy(y))
56+
print(y)

examples/pybind11/use_dpctl_sycl_kernel/tests/test_user_kernel.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
import use_kernel as uk
2424

2525
import dpctl
26-
import dpctl.program as dpm
27-
import dpctl.tensor as dpt
26+
import dpctl.memory as dpmem
27+
import dpctl.program as dppr
2828

2929

3030
def _get_spv_path():
@@ -45,7 +45,7 @@ def test_kernel_can_be_found():
4545
q = dpctl.SyclQueue()
4646
except dpctl.SyclQueueCreationError:
4747
pytest.skip("Could not create default queue")
48-
pr = dpm.create_program_from_spirv(q, il, "")
48+
pr = dppr.create_program_from_spirv(q, il, "")
4949
assert pr.has_sycl_kernel("double_it")
5050

5151

@@ -57,14 +57,20 @@ def test_kernel_submit_through_extension():
5757
q = dpctl.SyclQueue()
5858
except dpctl.SyclQueueCreationError:
5959
pytest.skip("Could not create default queue")
60-
pr = dpm.create_program_from_spirv(q, il, "")
60+
pr = dppr.create_program_from_spirv(q, il, "")
6161
krn = pr.get_sycl_kernel("double_it")
6262
assert krn.num_args == 2
6363

64-
x = dpt.arange(0, stop=13, step=1, dtype="i4", sycl_queue=q)
65-
y = dpt.zeros_like(x)
64+
x = np.arange(0, stop=13, step=1, dtype="i4")
65+
y = np.empty_like(x)
6666

67-
q.wait()
68-
uk.submit_custom_kernel(q, krn, x, y, [])
67+
x_usm = dpmem.MemoryUSMDevice(x.nbytes, queue=q)
68+
y_usm = dpmem.MemoryUSMDevice(y.nbytes, queue=q)
6969

70-
assert np.array_equal(dpt.asnumpy(y), np.arange(0, 26, step=2, dtype="i4"))
70+
ev = q.memcpy_async(dest=x_usm, src=x, count=x_usm.nbytes)
71+
72+
uk.submit_custom_kernel(q, krn, x_usm, y_usm, [ev])
73+
74+
q.memcpy(dest=y, src=y_usm, count=y.nbytes)
75+
76+
assert np.array_equal(y, np.arange(0, 26, step=2, dtype="i4"))

examples/pybind11/use_dpctl_sycl_kernel/use_kernel/_example.cpp

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -37,45 +37,37 @@ namespace py = pybind11;
3737

3838
void submit_custom_kernel(sycl::queue &q,
3939
sycl::kernel &krn,
40-
dpctl::tensor::usm_ndarray x,
41-
dpctl::tensor::usm_ndarray y,
40+
dpctl::memory::usm_memory x,
41+
dpctl::memory::usm_memory y,
4242
const std::vector<sycl::event> &depends = {})
4343
{
44-
if (x.get_ndim() != 1 || !x.is_c_contiguous() || y.get_ndim() != 1 ||
45-
!y.is_c_contiguous())
46-
{
47-
throw py::value_error(
48-
"src and dst arguments must be 1D and contiguous.");
49-
}
44+
const std::size_t nbytes_x = x.get_nbytes();
45+
const std::size_t nbytes_y = y.get_nbytes();
5046

51-
auto const &api = dpctl::detail::dpctl_capi::get();
52-
if (x.get_typenum() != api.UAR_INT32_ || y.get_typenum() != api.UAR_INT32_)
53-
{
54-
throw py::value_error(
55-
"src and dst arguments must have int32 element data types.");
47+
if (nbytes_x != nbytes_y) {
48+
throw py::value_error("src and dst arguments must have equal nbytes.");
49+
}
50+
if (nbytes_x % sizeof(std::int32_t) != 0) {
51+
throw py::value_error("src and dst must be interpretable as int32 "
52+
"(nbytes must be a multiple of 4).");
5653
}
5754

58-
size_t n_x = x.get_size();
59-
size_t n_y = y.get_size();
55+
auto *x_data = reinterpret_cast<std::int32_t *>(x.get_pointer());
56+
auto *y_data = reinterpret_cast<std::int32_t *>(y.get_pointer());
6057

61-
if (n_x != n_y) {
62-
throw py::value_error("src and dst arguments must have equal size.");
63-
}
58+
const std::size_t n_elems = nbytes_x / sizeof(std::int32_t);
6459

6560
if (!dpctl::utils::queues_are_compatible(q, {x.get_queue(), y.get_queue()}))
6661
{
6762
throw std::runtime_error(
6863
"Execution queue is not compatible with allocation queues");
6964
}
7065

71-
void *x_data = x.get_data<void>();
72-
void *y_data = y.get_data<void>();
73-
7466
sycl::event e = q.submit([&](sycl::handler &cgh) {
7567
cgh.depends_on(depends);
7668
cgh.set_arg(0, x_data);
7769
cgh.set_arg(1, y_data);
78-
cgh.parallel_for(sycl::range<1>(n_x), krn);
70+
cgh.parallel_for(sycl::range<1>(n_elems), krn);
7971
});
8072

8173
e.wait();

0 commit comments

Comments
 (0)