Skip to content

Commit 2da3758

Browse files
committed
Use dpnp.sort() for nd > 1 which speeds up dpnp.partition and resolves the computation issue
1 parent af6205f commit 2da3758

1 file changed

Lines changed: 106 additions & 29 deletions

File tree

dpnp/dpnp_iface_sorting.py

Lines changed: 106 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@
3939
4040
"""
4141

42+
from collections.abc import Sequence
43+
4244
import dpctl.tensor as dpt
43-
import numpy
4445
from dpctl.tensor._numpy_helper import normalize_axis_index
4546

4647
import dpnp
@@ -51,7 +52,6 @@
5152
)
5253
from .dpnp_array import dpnp_array
5354
from .dpnp_utils import (
54-
call_origin,
5555
map_dtype_to_device,
5656
)
5757

@@ -147,7 +147,7 @@ def argsort(
147147
148148
Limitations
149149
-----------
150-
Parameters `order` is only supported with its default value.
150+
Parameter `order` is only supported with its default value.
151151
Otherwise ``NotImplementedError`` exception will be raised.
152152
Sorting algorithms ``"quicksort"`` and ``"heapsort"`` are not supported.
153153
@@ -201,44 +201,121 @@ def argsort(
201201
)
202202

203203

204-
def partition(x1, kth, axis=-1, kind="introselect", order=None):
204+
def partition(a, kth, axis=-1, kind="introselect", order=None):
205205
"""
206206
Return a partitioned copy of an array.
207207
208208
For full documentation refer to :obj:`numpy.partition`.
209209
210+
Parameters
211+
----------
212+
a : {dpnp.ndarray, usm_ndarray}
213+
Array to be sorted.
214+
kth : {int, sequence of ints}
215+
Element index to partition by. The k-th value of the element will be in
216+
its final sorted position and all smaller elements will be moved before
217+
it and all equal or greater elements behind it. The order of all
218+
elements in the partitions is undefined. If provided with a sequence of
219+
k-th it will partition all elements indexed by k-th of them into their
220+
sorted position at once.
221+
axis : {None, int}, optional
222+
Axis along which to sort. If ``None``, the array is flattened before
223+
sorting. The default is ``-1``, which sorts along the last axis.
224+
225+
Default: ``-1``.
226+
227+
Returns
228+
-------
229+
out : dpnp.ndarray
230+
Array of the same type and shape as `a`.
231+
210232
Limitations
211233
-----------
212-
Input array is supported as :obj:`dpnp.ndarray`.
213-
Input `kth` is supported as :obj:`int`.
214-
Parameters `axis`, `kind` and `order` are supported only with default
215-
values.
234+
Parameters `kind` and `order` are only supported with its default value.
235+
Otherwise ``NotImplementedError`` exception will be raised.
236+
237+
See Also
238+
--------
239+
:obj:`dpnp.ndarray.partition` : Equivalent method.
240+
:obj:`dpnp.argpartition` : Indirect partition.
241+
:obj:`dpnp.sort` : Full sorting.
242+
243+
Examples
244+
--------
245+
>>> import dpnp as np
246+
>>> a = np.array([7, 1, 7, 7, 1, 5, 7, 2, 3, 2, 6, 2, 3, 0])
247+
>>> p = np.partition(a, 4)
248+
>>> p
249+
array([0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 7, 7, 6]) # may vary
250+
251+
``p[4]`` is 2; all elements in ``p[:4]`` are less than or equal to
252+
``p[4]``, and all elements in ``p[5:]`` are greater than or equal to
253+
``p[4]``. The partition is::
254+
255+
[0, 1, 1, 2], [2], [2, 3, 3, 5, 7, 7, 7, 7, 6]
256+
257+
The next example shows the use of multiple values passed to `kth`.
258+
259+
>>> p2 = np.partition(a, (4, 8))
260+
>>> p2
261+
array([0, 1, 1, 2, 2, 2, 3, 3, 5, 6, 7, 7, 7, 7])
262+
263+
``p2[4]`` is 2 and ``p2[8]`` is 5. All elements in ``p2[:4]`` are less
264+
than or equal to ``p2[4]``, all elements in ``p2[5:8]`` are greater than or
265+
equal to ``p2[4]`` and less than or equal to ``p2[8]``, and all elements in
266+
``p2[9:]`` are greater than or equal to ``p2[8]``. The partition is::
267+
268+
[0, 1, 1, 2], [2], [2, 3, 3], [5], [6, 7, 7, 7, 7]
216269
217270
"""
218271

219-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
220-
if x1_desc:
221-
if dpnp.is_cuda_backend(x1_desc.get_array()): # pragma: no cover
222-
raise NotImplementedError(
223-
"Running on CUDA is currently not supported"
224-
)
272+
dpnp.check_supported_arrays_type(a)
225273

226-
if not isinstance(kth, int):
227-
pass
228-
elif x1_desc.ndim == 0:
229-
pass
230-
elif kth >= x1_desc.shape[x1_desc.ndim - 1] or x1_desc.ndim + kth < 0:
231-
pass
232-
elif axis != -1:
233-
pass
234-
elif kind != "introselect":
235-
pass
236-
elif order is not None:
237-
pass
238-
else:
239-
return dpnp_partition(x1_desc, kth, axis, kind, order).get_pyobj()
274+
if kind != "introselect":
275+
raise NotImplementedError(
276+
"`kind` keyword argument is only supported with its default value."
277+
)
278+
if order is not None:
279+
raise NotImplementedError(
280+
"`order` keyword argument is only supported with its default value."
281+
)
282+
283+
if axis is None:
284+
a = dpnp.ravel(a)
285+
axis = -1
286+
287+
nd = a.ndim
288+
axis = normalize_axis_index(axis, nd)
289+
length = a.shape[axis]
290+
291+
if isinstance(kth, int):
292+
kth = (kth,)
293+
elif not isinstance(kth, Sequence):
294+
raise TypeError(
295+
f"kth must be int or sequence of ints, but got {type(kth)}"
296+
)
297+
elif not all(isinstance(k, int) for k in kth):
298+
raise TypeError("kth is a sequence, but not all elements are integers")
299+
300+
nkth = len(kth)
301+
if nkth == 0 or a.size == 0:
302+
return dpnp.copy(a)
303+
304+
# validate kth
305+
kth = list(kth)
306+
for i in range(nkth):
307+
if kth[i] < 0:
308+
kth[i] += length
309+
310+
if not 0 <= kth[i] < length:
311+
raise ValueError(f"kth(={kth[i]}) out of bounds {length}")
312+
313+
if a.ndim > 1 or nkth > 1 or dpnp.is_cuda_backend(a.get_array()):
314+
# sort is a faster path in case of ndim > 1
315+
return dpnp.sort(a, axis=axis)
240316

241-
return call_origin(numpy.partition, x1, kth, axis, kind, order)
317+
desc = dpnp.get_dpnp_descriptor(a, copy_when_nondefault_queue=False)
318+
return dpnp_partition(desc, kth[0], axis, kind, order).get_pyobj()
242319

243320

244321
def sort(a, axis=-1, kind=None, order=None, *, descending=False, stable=None):

0 commit comments

Comments
 (0)