@@ -70,90 +70,25 @@ DPCTLSyclEventRef dpnp_partition_c(DPCTLSyclQueueRef q_ref,
7070
7171 sycl::queue q = *(reinterpret_cast <sycl::queue *>(q_ref));
7272
73- if (ndim == 1 ) // 1d array with C-contiguous data
74- {
75- _DataType *arr = static_cast <_DataType *>(array1_in);
76- _DataType *result = static_cast <_DataType *>(result1);
73+ _DataType *arr = static_cast <_DataType *>(array1_in);
74+ _DataType *result = static_cast <_DataType *>(result1);
7775
78- auto policy = oneapi::dpl::execution::make_device_policy<
79- dpnp_partition_c_kernel<_DataType>>(q);
76+ auto policy = oneapi::dpl::execution::make_device_policy<
77+ dpnp_partition_c_kernel<_DataType>>(q);
8078
81- // fill the result array with data from input one
82- q.memcpy (result, arr, size * sizeof (_DataType)).wait ();
79+ // fill the result array with data from input one
80+ q.memcpy (result, arr, size * sizeof (_DataType)).wait ();
8381
84- // make a partial sorting such that:
82+ for (size_t i = 0 ; i < size_; i++) {
83+ _DataType *bufptr = result + i * shape_[0 ];
84+
85+ // for every slice it makes a partial sorting such that:
8586 // 1. result[0 <= i < kth] <= result[kth]
8687 // 2. result[kth <= i < size] >= result[kth]
8788 // event-blocking call, no need for wait()
88- std::nth_element (policy, result, result + kth, result + size,
89+ std::nth_element (policy, bufptr, bufptr + kth, bufptr + size,
8990 dpnp_less_comp ());
90- return event_ref;
91- }
92-
93- DPNPC_ptr_adapter<_DataType> input1_ptr (q_ref, array1_in, size, true );
94- DPNPC_ptr_adapter<_DataType> input2_ptr (q_ref, array2_in, size, true );
95- DPNPC_ptr_adapter<_DataType> result1_ptr (q_ref, result1, size, true , true );
96- _DataType *arr = input1_ptr.get_ptr ();
97- _DataType *arr2 = input2_ptr.get_ptr ();
98- _DataType *result = result1_ptr.get_ptr ();
99-
100- auto arr_to_result_event = q.memcpy (result, arr, size * sizeof (_DataType));
101- arr_to_result_event.wait ();
102-
103- _DataType *matrix = new _DataType[shape_[ndim - 1 ]];
104-
105- for (size_t i = 0 ; i < size_; ++i) {
106- size_t ind_begin = i * shape_[ndim - 1 ];
107- size_t ind_end = (i + 1 ) * shape_[ndim - 1 ] - 1 ;
108-
109- for (size_t j = ind_begin; j < ind_end + 1 ; ++j) {
110- size_t ind = j - ind_begin;
111- matrix[ind] = arr2[j];
112- }
113- std::partial_sort (matrix, matrix + shape_[ndim - 1 ],
114- matrix + shape_[ndim - 1 ], dpnp_less_comp ());
115- for (size_t j = ind_begin; j < ind_end + 1 ; ++j) {
116- size_t ind = j - ind_begin;
117- arr2[j] = matrix[ind];
118- }
11991 }
120-
121- shape_elem_type *shape = reinterpret_cast <shape_elem_type *>(
122- sycl::malloc_shared (ndim * sizeof (shape_elem_type), q));
123- auto memcpy_event = q.memcpy (shape, shape_, ndim * sizeof (shape_elem_type));
124-
125- memcpy_event.wait ();
126-
127- sycl::range<2 > gws (size_, kth + 1 );
128- auto kernel_parallel_for_func = [=](sycl::id<2 > global_id) {
129- size_t j = global_id[0 ];
130- size_t k = global_id[1 ];
131-
132- _DataType val = arr2[j * shape[ndim - 1 ] + k];
133-
134- for (size_t i = 0 ; i < static_cast <size_t >(shape[ndim - 1 ]); ++i) {
135- if (result[j * shape[ndim - 1 ] + i] == val) {
136- _DataType change_val1 = result[j * shape[ndim - 1 ] + i];
137- _DataType change_val2 = result[j * shape[ndim - 1 ] + k];
138- result[j * shape[ndim - 1 ] + k] = change_val1;
139- result[j * shape[ndim - 1 ] + i] = change_val2;
140- }
141- }
142- };
143-
144- auto kernel_func = [&](sycl::handler &cgh) {
145- cgh.depends_on ({memcpy_event});
146- cgh.parallel_for <class dpnp_partition_c_kernel <_DataType>>(
147- gws, kernel_parallel_for_func);
148- };
149-
150- auto event = q.submit (kernel_func);
151-
152- event.wait ();
153-
154- delete[] matrix;
155- sycl::free (shape, q);
156-
15792 return event_ref;
15893}
15994
0 commit comments