Skip to content

Commit 9c4aed2

Browse files
authored
Improve dpnp.partition implementation (#2766)
The PR propose to improve implementation and to use `dpnp.sort` call when - input array has number of dimensions > 1 - input array has previously not supported integer dtype - `axis` keyword is passed (previously not supported) - sequence of `kth` is passed (previously not supported) In case of `ndim > 1` previously the implementation from legacy backend was used, which is significantly slow (see performance comparation below). It used a copy of input data into the shared USM memory and included computations on the host. This PR proposes to reuse `dpnp.sort` for all the above cases. While in case when the legacy implementation is stable and fast (for 1D input array), it will remain, because it relays on `std::nth_element` from OneDPL. The benchmark results were collected on PVC with help of the below code: ```python import dpnp, numpy as np from dpnp.tests.helper import generate_random_numpy_array a = generate_random_numpy_array(10**7, dtype=np.float64, seed_value=117) ia = dpnp.array(a) %timeit x = dpnp.partition(ia, 513); x.sycl_queue.wait() ``` Below tables contains data in case of 1D input array (shape=(10**7,)), where the implementation path was kept the same, plus adding support of missing integer dtypes using fallback on the sort function: | Implementation | int32 | uint32 | int64 | uint64 | float32 | float64 | complex64 | complex128 | |--------|--------|--------|--------|--------|--------|--------|--------|--------| | old (legacy backend) | 7.46 ms | not supported | 9.46 ms | not supported | 7.39 ms | 8.92 ms | 10.9 ms | 21.2 ms | | new (backend + sort) | 7.34 ms | 10.8 ms | 9.48 ms | 12.5 ms | 7.37 ms | 8.89 ms | 11 ms | 21.2 ms | The following code was used for 2D input array with shape=(10**4, 10**4): ```python import dpnp, numpy as np from dpnp.tests.helper import generate_random_numpy_array a = generate_random_numpy_array((10**4, 10**4), dtype=np.float64, seed_value=117) ia = dpnp.array(a) %timeit x = dpnp.partition(ia, 1513); x.sycl_queue.wait() ``` In that case the new implementation is fully based on the sort call: | Implementation | int32 | int64 | float32 | float64 | complex64 | complex128 | |--------|--------|--------|--------|--------|--------|--------| | old (legacy backend) | 6.4 s | 6.89 s | 7.36 s | 7.66 s | 8.61 s | 10 s | | new (sort) | 57.4 ms | 64.7 ms | 62.2 ms | 68 ms | 77 ms | 151 ms |
1 parent 9d6d5a5 commit 9c4aed2

File tree

7 files changed

+393
-300
lines changed

7 files changed

+393
-300
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum
4949
* Aligned `dpnp.trim_zeros` with NumPy 2.4 to support a tuple of integers passed with `axis` keyword [#2746](https://github.com/IntelPython/dpnp/pull/2746)
5050
* Aligned `strides` property of `dpnp.ndarray` with NumPy and CuPy implementations [#2747](https://github.com/IntelPython/dpnp/pull/2747)
5151
* Extended `dpnp.nan_to_num` to support broadcasting of `nan`, `posinf`, and `neginf` keywords [#2754](https://github.com/IntelPython/dpnp/pull/2754)
52+
* Changed `dpnp.partition` implementation to reuse `dpnp.sort` where it brings the performance benefit [#2766](https://github.com/IntelPython/dpnp/pull/2766)
5253

5354
### Deprecated
5455

dpnp/backend/kernels/dpnp_krnl_sorting.cpp

Lines changed: 13 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -70,90 +70,27 @@ DPCTLSyclEventRef dpnp_partition_c(DPCTLSyclQueueRef q_ref,
7070

7171
sycl::queue q = *(reinterpret_cast<sycl::queue *>(q_ref));
7272

73-
if (ndim == 1) // 1d array with C-contiguous data
74-
{
75-
_DataType *arr = static_cast<_DataType *>(array1_in);
76-
_DataType *result = static_cast<_DataType *>(result1);
73+
_DataType *arr = static_cast<_DataType *>(array1_in);
74+
_DataType *result = static_cast<_DataType *>(result1);
7775

78-
auto policy = oneapi::dpl::execution::make_device_policy<
79-
dpnp_partition_c_kernel<_DataType>>(q);
76+
auto policy = oneapi::dpl::execution::make_device_policy<
77+
dpnp_partition_c_kernel<_DataType>>(q);
8078

81-
// fill the result array with data from input one
82-
q.memcpy(result, arr, size * sizeof(_DataType)).wait();
79+
// fill the result array with data from input one
80+
q.memcpy(result, arr, size * sizeof(_DataType)).wait();
8381

84-
// make a partial sorting such that:
82+
// note, a loop for a multidemension input array (size_ > 1) is an
83+
// experimental and it isn't tested properly as for now
84+
for (size_t i = 0; i < size_; i++) {
85+
_DataType *bufptr = result + i * shape_[0];
86+
87+
// for every slice it makes a partial sorting such that:
8588
// 1. result[0 <= i < kth] <= result[kth]
8689
// 2. result[kth <= i < size] >= result[kth]
8790
// event-blocking call, no need for wait()
88-
std::nth_element(policy, result, result + kth, result + size,
91+
std::nth_element(policy, bufptr, bufptr + kth, bufptr + size,
8992
dpnp_less_comp());
90-
return event_ref;
91-
}
92-
93-
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, array1_in, size, true);
94-
DPNPC_ptr_adapter<_DataType> input2_ptr(q_ref, array2_in, size, true);
95-
DPNPC_ptr_adapter<_DataType> result1_ptr(q_ref, result1, size, true, true);
96-
_DataType *arr = input1_ptr.get_ptr();
97-
_DataType *arr2 = input2_ptr.get_ptr();
98-
_DataType *result = result1_ptr.get_ptr();
99-
100-
auto arr_to_result_event = q.memcpy(result, arr, size * sizeof(_DataType));
101-
arr_to_result_event.wait();
102-
103-
_DataType *matrix = new _DataType[shape_[ndim - 1]];
104-
105-
for (size_t i = 0; i < size_; ++i) {
106-
size_t ind_begin = i * shape_[ndim - 1];
107-
size_t ind_end = (i + 1) * shape_[ndim - 1] - 1;
108-
109-
for (size_t j = ind_begin; j < ind_end + 1; ++j) {
110-
size_t ind = j - ind_begin;
111-
matrix[ind] = arr2[j];
112-
}
113-
std::partial_sort(matrix, matrix + shape_[ndim - 1],
114-
matrix + shape_[ndim - 1], dpnp_less_comp());
115-
for (size_t j = ind_begin; j < ind_end + 1; ++j) {
116-
size_t ind = j - ind_begin;
117-
arr2[j] = matrix[ind];
118-
}
11993
}
120-
121-
shape_elem_type *shape = reinterpret_cast<shape_elem_type *>(
122-
sycl::malloc_shared(ndim * sizeof(shape_elem_type), q));
123-
auto memcpy_event = q.memcpy(shape, shape_, ndim * sizeof(shape_elem_type));
124-
125-
memcpy_event.wait();
126-
127-
sycl::range<2> gws(size_, kth + 1);
128-
auto kernel_parallel_for_func = [=](sycl::id<2> global_id) {
129-
size_t j = global_id[0];
130-
size_t k = global_id[1];
131-
132-
_DataType val = arr2[j * shape[ndim - 1] + k];
133-
134-
for (size_t i = 0; i < static_cast<size_t>(shape[ndim - 1]); ++i) {
135-
if (result[j * shape[ndim - 1] + i] == val) {
136-
_DataType change_val1 = result[j * shape[ndim - 1] + i];
137-
_DataType change_val2 = result[j * shape[ndim - 1] + k];
138-
result[j * shape[ndim - 1] + k] = change_val1;
139-
result[j * shape[ndim - 1] + i] = change_val2;
140-
}
141-
}
142-
};
143-
144-
auto kernel_func = [&](sycl::handler &cgh) {
145-
cgh.depends_on({memcpy_event});
146-
cgh.parallel_for<class dpnp_partition_c_kernel<_DataType>>(
147-
gws, kernel_parallel_for_func);
148-
};
149-
150-
auto event = q.submit(kernel_func);
151-
152-
event.wait();
153-
154-
delete[] matrix;
155-
sycl::free(shape, q);
156-
15794
return event_ref;
15895
}
15996

dpnp/dpnp_array.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,35 +1459,54 @@ def nonzero(self):
14591459

14601460
def partition(self, /, kth, axis=-1, kind="introselect", order=None):
14611461
"""
1462-
Return a partitioned copy of an array.
1462+
Partially sorts the elements in the array in such a way that the value
1463+
of the element in k-th position is in the position it would be in a
1464+
sorted array. In the output array, all elements smaller than the k-th
1465+
element are located to the left of this element and all equal or
1466+
greater are located to its right. The ordering of the elements in the
1467+
two partitions on the either side of the k-th element in the output
1468+
array is undefined.
14631469
1464-
Rearranges the elements in the array in such a way that the value of
1465-
the element in `kth` position is in the position it would be in
1466-
a sorted array.
1470+
Refer to `dpnp.partition` for full documentation.
14671471
1468-
All elements smaller than the `kth` element are moved before this
1469-
element and all equal or greater are moved behind it. The ordering
1470-
of the elements in the two partitions is undefined.
1472+
kth : {int, sequence of ints}
1473+
Element index to partition by. The kth element value will be in its
1474+
final sorted position and all smaller elements will be moved before
1475+
it and all equal or greater elements behind it.
1476+
The order of all elements in the partitions is undefined. If
1477+
provided with a sequence of kth it will partition all elements
1478+
indexed by kth of them into their sorted position at once.
1479+
axis : int, optional
1480+
Axis along which to sort. The default is ``-1``, which means sort
1481+
along the the last axis.
14711482
1472-
Refer to `dpnp.partition` for full documentation.
1483+
Default: ``-1``.
14731484
14741485
See Also
14751486
--------
14761487
:obj:`dpnp.partition` : Return a partitioned copy of an array.
1488+
:obj:`dpnp.argpartition` : Indirect partition.
1489+
:obj:`dpnp.sort` : Full sort.
14771490
14781491
Examples
14791492
--------
14801493
>>> import dpnp as np
14811494
>>> a = np.array([3, 4, 2, 1])
14821495
>>> a.partition(3)
14831496
>>> a
1497+
array([1, 2, 3, 4]) # may vary
1498+
1499+
>>> a.partition((1, 3))
1500+
>>> a
14841501
array([1, 2, 3, 4])
14851502
14861503
"""
14871504

1488-
self._array_obj = dpnp.partition(
1489-
self, kth, axis=axis, kind=kind, order=order
1490-
).get_array()
1505+
if axis is None:
1506+
raise TypeError(
1507+
"'NoneType' object cannot be interpreted as an integer"
1508+
)
1509+
self[...] = dpnp.partition(self, kth, axis=axis, kind=kind, order=order)
14911510

14921511
def prod(
14931512
self,

dpnp/dpnp_iface_sorting.py

Lines changed: 113 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@
3939
4040
"""
4141

42+
from collections.abc import Sequence
43+
4244
import dpctl.tensor as dpt
43-
import numpy
4445
from dpctl.tensor._numpy_helper import normalize_axis_index
4546

4647
import dpnp
@@ -51,7 +52,6 @@
5152
)
5253
from .dpnp_array import dpnp_array
5354
from .dpnp_utils import (
54-
call_origin,
5555
map_dtype_to_device,
5656
)
5757

@@ -147,7 +147,7 @@ def argsort(
147147
148148
Limitations
149149
-----------
150-
Parameters `order` is only supported with its default value.
150+
Parameter `order` is only supported with its default value.
151151
Otherwise ``NotImplementedError`` exception will be raised.
152152
Sorting algorithms ``"quicksort"`` and ``"heapsort"`` are not supported.
153153
@@ -201,44 +201,128 @@ def argsort(
201201
)
202202

203203

204-
def partition(x1, kth, axis=-1, kind="introselect", order=None):
204+
def partition(a, kth, axis=-1, kind="introselect", order=None):
205205
"""
206206
Return a partitioned copy of an array.
207207
208208
For full documentation refer to :obj:`numpy.partition`.
209209
210+
Parameters
211+
----------
212+
a : {dpnp.ndarray, usm_ndarray}
213+
Array to be sorted.
214+
kth : {int, sequence of ints}
215+
Element index to partition by. The k-th value of the element will be in
216+
its final sorted position and all smaller elements will be moved before
217+
it and all equal or greater elements behind it. The order of all
218+
elements in the partitions is undefined. If provided with a sequence of
219+
k-th it will partition all elements indexed by k-th of them into their
220+
sorted position at once.
221+
axis : {None, int}, optional
222+
Axis along which to sort. If ``None``, the array is flattened before
223+
sorting. The default is ``-1``, which sorts along the last axis.
224+
225+
Default: ``-1``.
226+
227+
Returns
228+
-------
229+
out : dpnp.ndarray
230+
Array of the same type and shape as `a`.
231+
210232
Limitations
211233
-----------
212-
Input array is supported as :obj:`dpnp.ndarray`.
213-
Input `kth` is supported as :obj:`int`.
214-
Parameters `axis`, `kind` and `order` are supported only with default
215-
values.
234+
Parameters `kind` and `order` are only supported with its default value.
235+
Otherwise ``NotImplementedError`` exception will be raised.
236+
237+
See Also
238+
--------
239+
:obj:`dpnp.ndarray.partition` : Equivalent method.
240+
:obj:`dpnp.argpartition` : Indirect partition.
241+
:obj:`dpnp.sort` : Full sorting.
242+
243+
Examples
244+
--------
245+
>>> import dpnp as np
246+
>>> a = np.array([7, 1, 7, 7, 1, 5, 7, 2, 3, 2, 6, 2, 3, 0])
247+
>>> p = np.partition(a, 4)
248+
>>> p
249+
array([0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 7, 7, 6]) # may vary
250+
251+
``p[4]`` is 2; all elements in ``p[:4]`` are less than or equal to
252+
``p[4]``, and all elements in ``p[5:]`` are greater than or equal to
253+
``p[4]``. The partition is::
254+
255+
[0, 1, 1, 2], [2], [2, 3, 3, 5, 7, 7, 7, 7, 6]
256+
257+
The next example shows the use of multiple values passed to `kth`.
258+
259+
>>> p2 = np.partition(a, (4, 8))
260+
>>> p2
261+
array([0, 1, 1, 2, 2, 2, 3, 3, 5, 6, 7, 7, 7, 7])
262+
263+
``p2[4]`` is 2 and ``p2[8]`` is 5. All elements in ``p2[:4]`` are less
264+
than or equal to ``p2[4]``, all elements in ``p2[5:8]`` are greater than or
265+
equal to ``p2[4]`` and less than or equal to ``p2[8]``, and all elements in
266+
``p2[9:]`` are greater than or equal to ``p2[8]``. The partition is::
267+
268+
[0, 1, 1, 2], [2], [2, 3, 3], [5], [6, 7, 7, 7, 7]
216269
217270
"""
218271

219-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
220-
if x1_desc:
221-
if dpnp.is_cuda_backend(x1_desc.get_array()): # pragma: no cover
222-
raise NotImplementedError(
223-
"Running on CUDA is currently not supported"
224-
)
272+
dpnp.check_supported_arrays_type(a)
225273

226-
if not isinstance(kth, int):
227-
pass
228-
elif x1_desc.ndim == 0:
229-
pass
230-
elif kth >= x1_desc.shape[x1_desc.ndim - 1] or x1_desc.ndim + kth < 0:
231-
pass
232-
elif axis != -1:
233-
pass
234-
elif kind != "introselect":
235-
pass
236-
elif order is not None:
237-
pass
238-
else:
239-
return dpnp_partition(x1_desc, kth, axis, kind, order).get_pyobj()
274+
if kind != "introselect":
275+
raise NotImplementedError(
276+
"`kind` keyword argument is only supported with its default value."
277+
)
278+
if order is not None:
279+
raise NotImplementedError(
280+
"`order` keyword argument is only supported with its default value."
281+
)
240282

241-
return call_origin(numpy.partition, x1, kth, axis, kind, order)
283+
if axis is None:
284+
a = dpnp.ravel(a)
285+
axis = -1
286+
287+
nd = a.ndim
288+
axis = normalize_axis_index(axis, nd)
289+
length = a.shape[axis]
290+
291+
if isinstance(kth, int):
292+
kth = (kth,)
293+
elif not isinstance(kth, Sequence):
294+
raise TypeError(
295+
f"kth must be int or sequence of ints, but got {type(kth)}"
296+
)
297+
elif not all(isinstance(k, int) for k in kth):
298+
raise TypeError("kth is a sequence, but not all elements are integers")
299+
300+
nkth = len(kth)
301+
if nkth == 0 or a.size == 0:
302+
return dpnp.copy(a)
303+
304+
# validate kth
305+
kth = list(kth)
306+
for i in range(nkth):
307+
if kth[i] < 0:
308+
kth[i] += length
309+
310+
if not 0 <= kth[i] < length:
311+
raise ValueError(f"kth(={kth[i]}) out of bounds {length}")
312+
313+
dt = a.dtype
314+
if (
315+
nd > 1
316+
or nkth > 1
317+
or dpnp.issubdtype(dt, dpnp.unsignedinteger)
318+
or dt in (dpnp.int8, dpnp.int16)
319+
or dpnp.is_cuda_backend(a.get_array())
320+
):
321+
# sort is a faster path in case of ndim > 1
322+
return dpnp.sort(a, axis=axis)
323+
324+
desc = dpnp.get_dpnp_descriptor(a, copy_when_nondefault_queue=False)
325+
return dpnp_partition(desc, kth[0], axis, kind, order).get_pyobj()
242326

243327

244328
def sort(a, axis=-1, kind=None, order=None, *, descending=False, stable=None):

0 commit comments

Comments
 (0)