Skip to content

Commit bd67b44

Browse files
committed
Update backend implememntation of dpnp.partition to base on std::nth_element from OneDPL
1 parent 2da3758 commit bd67b44

1 file changed

Lines changed: 11 additions & 76 deletions

File tree

dpnp/backend/kernels/dpnp_krnl_sorting.cpp

Lines changed: 11 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -70,90 +70,25 @@ DPCTLSyclEventRef dpnp_partition_c(DPCTLSyclQueueRef q_ref,
7070

7171
sycl::queue q = *(reinterpret_cast<sycl::queue *>(q_ref));
7272

73-
if (ndim == 1) // 1d array with C-contiguous data
74-
{
75-
_DataType *arr = static_cast<_DataType *>(array1_in);
76-
_DataType *result = static_cast<_DataType *>(result1);
73+
_DataType *arr = static_cast<_DataType *>(array1_in);
74+
_DataType *result = static_cast<_DataType *>(result1);
7775

78-
auto policy = oneapi::dpl::execution::make_device_policy<
79-
dpnp_partition_c_kernel<_DataType>>(q);
76+
auto policy = oneapi::dpl::execution::make_device_policy<
77+
dpnp_partition_c_kernel<_DataType>>(q);
8078

81-
// fill the result array with data from input one
82-
q.memcpy(result, arr, size * sizeof(_DataType)).wait();
79+
// fill the result array with data from input one
80+
q.memcpy(result, arr, size * sizeof(_DataType)).wait();
8381

84-
// make a partial sorting such that:
82+
for (size_t i = 0; i < size_; i++) {
83+
_DataType *bufptr = result + i * shape_[0];
84+
85+
// for every slice it makes a partial sorting such that:
8586
// 1. result[0 <= i < kth] <= result[kth]
8687
// 2. result[kth <= i < size] >= result[kth]
8788
// event-blocking call, no need for wait()
88-
std::nth_element(policy, result, result + kth, result + size,
89+
std::nth_element(policy, bufptr, bufptr + kth, bufptr + size,
8990
dpnp_less_comp());
90-
return event_ref;
91-
}
92-
93-
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, array1_in, size, true);
94-
DPNPC_ptr_adapter<_DataType> input2_ptr(q_ref, array2_in, size, true);
95-
DPNPC_ptr_adapter<_DataType> result1_ptr(q_ref, result1, size, true, true);
96-
_DataType *arr = input1_ptr.get_ptr();
97-
_DataType *arr2 = input2_ptr.get_ptr();
98-
_DataType *result = result1_ptr.get_ptr();
99-
100-
auto arr_to_result_event = q.memcpy(result, arr, size * sizeof(_DataType));
101-
arr_to_result_event.wait();
102-
103-
_DataType *matrix = new _DataType[shape_[ndim - 1]];
104-
105-
for (size_t i = 0; i < size_; ++i) {
106-
size_t ind_begin = i * shape_[ndim - 1];
107-
size_t ind_end = (i + 1) * shape_[ndim - 1] - 1;
108-
109-
for (size_t j = ind_begin; j < ind_end + 1; ++j) {
110-
size_t ind = j - ind_begin;
111-
matrix[ind] = arr2[j];
112-
}
113-
std::partial_sort(matrix, matrix + shape_[ndim - 1],
114-
matrix + shape_[ndim - 1], dpnp_less_comp());
115-
for (size_t j = ind_begin; j < ind_end + 1; ++j) {
116-
size_t ind = j - ind_begin;
117-
arr2[j] = matrix[ind];
118-
}
11991
}
120-
121-
shape_elem_type *shape = reinterpret_cast<shape_elem_type *>(
122-
sycl::malloc_shared(ndim * sizeof(shape_elem_type), q));
123-
auto memcpy_event = q.memcpy(shape, shape_, ndim * sizeof(shape_elem_type));
124-
125-
memcpy_event.wait();
126-
127-
sycl::range<2> gws(size_, kth + 1);
128-
auto kernel_parallel_for_func = [=](sycl::id<2> global_id) {
129-
size_t j = global_id[0];
130-
size_t k = global_id[1];
131-
132-
_DataType val = arr2[j * shape[ndim - 1] + k];
133-
134-
for (size_t i = 0; i < static_cast<size_t>(shape[ndim - 1]); ++i) {
135-
if (result[j * shape[ndim - 1] + i] == val) {
136-
_DataType change_val1 = result[j * shape[ndim - 1] + i];
137-
_DataType change_val2 = result[j * shape[ndim - 1] + k];
138-
result[j * shape[ndim - 1] + k] = change_val1;
139-
result[j * shape[ndim - 1] + i] = change_val2;
140-
}
141-
}
142-
};
143-
144-
auto kernel_func = [&](sycl::handler &cgh) {
145-
cgh.depends_on({memcpy_event});
146-
cgh.parallel_for<class dpnp_partition_c_kernel<_DataType>>(
147-
gws, kernel_parallel_for_func);
148-
};
149-
150-
auto event = q.submit(kernel_func);
151-
152-
event.wait();
153-
154-
delete[] matrix;
155-
sycl::free(shape, q);
156-
15792
return event_ref;
15893
}
15994

0 commit comments

Comments
 (0)