|
39 | 39 |
|
40 | 40 | """ |
41 | 41 |
|
| 42 | +from collections.abc import Sequence |
| 43 | + |
42 | 44 | import dpctl.tensor as dpt |
43 | | -import numpy |
44 | 45 | from dpctl.tensor._numpy_helper import normalize_axis_index |
45 | 46 |
|
46 | 47 | import dpnp |
|
51 | 52 | ) |
52 | 53 | from .dpnp_array import dpnp_array |
53 | 54 | from .dpnp_utils import ( |
54 | | - call_origin, |
55 | 55 | map_dtype_to_device, |
56 | 56 | ) |
57 | 57 |
|
@@ -147,7 +147,7 @@ def argsort( |
147 | 147 |
|
148 | 148 | Limitations |
149 | 149 | ----------- |
150 | | - Parameters `order` is only supported with its default value. |
| 150 | + Parameter `order` is only supported with its default value. |
151 | 151 | Otherwise ``NotImplementedError`` exception will be raised. |
152 | 152 | Sorting algorithms ``"quicksort"`` and ``"heapsort"`` are not supported. |
153 | 153 |
|
@@ -201,44 +201,121 @@ def argsort( |
201 | 201 | ) |
202 | 202 |
|
203 | 203 |
|
204 | | -def partition(x1, kth, axis=-1, kind="introselect", order=None): |
| 204 | +def partition(a, kth, axis=-1, kind="introselect", order=None): |
205 | 205 | """ |
206 | 206 | Return a partitioned copy of an array. |
207 | 207 |
|
208 | 208 | For full documentation refer to :obj:`numpy.partition`. |
209 | 209 |
|
| 210 | + Parameters |
| 211 | + ---------- |
| 212 | + a : {dpnp.ndarray, usm_ndarray} |
| 213 | + Array to be sorted. |
| 214 | + kth : {int, sequence of ints} |
| 215 | + Element index to partition by. The k-th value of the element will be in |
| 216 | + its final sorted position and all smaller elements will be moved before |
| 217 | + it and all equal or greater elements behind it. The order of all |
| 218 | + elements in the partitions is undefined. If provided with a sequence of |
| 219 | + k-th it will partition all elements indexed by k-th of them into their |
| 220 | + sorted position at once. |
| 221 | + axis : {None, int}, optional |
| 222 | + Axis along which to sort. If ``None``, the array is flattened before |
| 223 | + sorting. The default is ``-1``, which sorts along the last axis. |
| 224 | +
|
| 225 | + Default: ``-1``. |
| 226 | +
|
| 227 | + Returns |
| 228 | + ------- |
| 229 | + out : dpnp.ndarray |
| 230 | + Array of the same type and shape as `a`. |
| 231 | +
|
210 | 232 | Limitations |
211 | 233 | ----------- |
212 | | - Input array is supported as :obj:`dpnp.ndarray`. |
213 | | - Input `kth` is supported as :obj:`int`. |
214 | | - Parameters `axis`, `kind` and `order` are supported only with default |
215 | | - values. |
| 234 | + Parameters `kind` and `order` are only supported with its default value. |
| 235 | + Otherwise ``NotImplementedError`` exception will be raised. |
| 236 | +
|
| 237 | + See Also |
| 238 | + -------- |
| 239 | + :obj:`dpnp.ndarray.partition` : Equivalent method. |
| 240 | + :obj:`dpnp.argpartition` : Indirect partition. |
| 241 | + :obj:`dpnp.sort` : Full sorting. |
| 242 | +
|
| 243 | + Examples |
| 244 | + -------- |
| 245 | + >>> import dpnp as np |
| 246 | + >>> a = np.array([7, 1, 7, 7, 1, 5, 7, 2, 3, 2, 6, 2, 3, 0]) |
| 247 | + >>> p = np.partition(a, 4) |
| 248 | + >>> p |
| 249 | + array([0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 7, 7, 6]) # may vary |
| 250 | +
|
| 251 | + ``p[4]`` is 2; all elements in ``p[:4]`` are less than or equal to |
| 252 | + ``p[4]``, and all elements in ``p[5:]`` are greater than or equal to |
| 253 | + ``p[4]``. The partition is:: |
| 254 | +
|
| 255 | + [0, 1, 1, 2], [2], [2, 3, 3, 5, 7, 7, 7, 7, 6] |
| 256 | +
|
| 257 | + The next example shows the use of multiple values passed to `kth`. |
| 258 | +
|
| 259 | + >>> p2 = np.partition(a, (4, 8)) |
| 260 | + >>> p2 |
| 261 | + array([0, 1, 1, 2, 2, 2, 3, 3, 5, 6, 7, 7, 7, 7]) |
| 262 | +
|
| 263 | + ``p2[4]`` is 2 and ``p2[8]`` is 5. All elements in ``p2[:4]`` are less |
| 264 | + than or equal to ``p2[4]``, all elements in ``p2[5:8]`` are greater than or |
| 265 | + equal to ``p2[4]`` and less than or equal to ``p2[8]``, and all elements in |
| 266 | + ``p2[9:]`` are greater than or equal to ``p2[8]``. The partition is:: |
| 267 | +
|
| 268 | + [0, 1, 1, 2], [2], [2, 3, 3], [5], [6, 7, 7, 7, 7] |
216 | 269 |
|
217 | 270 | """ |
218 | 271 |
|
219 | | - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) |
220 | | - if x1_desc: |
221 | | - if dpnp.is_cuda_backend(x1_desc.get_array()): # pragma: no cover |
222 | | - raise NotImplementedError( |
223 | | - "Running on CUDA is currently not supported" |
224 | | - ) |
| 272 | + dpnp.check_supported_arrays_type(a) |
225 | 273 |
|
226 | | - if not isinstance(kth, int): |
227 | | - pass |
228 | | - elif x1_desc.ndim == 0: |
229 | | - pass |
230 | | - elif kth >= x1_desc.shape[x1_desc.ndim - 1] or x1_desc.ndim + kth < 0: |
231 | | - pass |
232 | | - elif axis != -1: |
233 | | - pass |
234 | | - elif kind != "introselect": |
235 | | - pass |
236 | | - elif order is not None: |
237 | | - pass |
238 | | - else: |
239 | | - return dpnp_partition(x1_desc, kth, axis, kind, order).get_pyobj() |
| 274 | + if kind != "introselect": |
| 275 | + raise NotImplementedError( |
| 276 | + "`kind` keyword argument is only supported with its default value." |
| 277 | + ) |
| 278 | + if order is not None: |
| 279 | + raise NotImplementedError( |
| 280 | + "`order` keyword argument is only supported with its default value." |
| 281 | + ) |
| 282 | + |
| 283 | + if axis is None: |
| 284 | + a = dpnp.ravel(a) |
| 285 | + axis = -1 |
| 286 | + |
| 287 | + nd = a.ndim |
| 288 | + axis = normalize_axis_index(axis, nd) |
| 289 | + length = a.shape[axis] |
| 290 | + |
| 291 | + if isinstance(kth, int): |
| 292 | + kth = (kth,) |
| 293 | + elif not isinstance(kth, Sequence): |
| 294 | + raise TypeError( |
| 295 | + f"kth must be int or sequence of ints, but got {type(kth)}" |
| 296 | + ) |
| 297 | + elif not all(isinstance(k, int) for k in kth): |
| 298 | + raise TypeError("kth is a sequence, but not all elements are integers") |
| 299 | + |
| 300 | + nkth = len(kth) |
| 301 | + if nkth == 0 or a.size == 0: |
| 302 | + return dpnp.copy(a) |
| 303 | + |
| 304 | + # validate kth |
| 305 | + kth = list(kth) |
| 306 | + for i in range(nkth): |
| 307 | + if kth[i] < 0: |
| 308 | + kth[i] += length |
| 309 | + |
| 310 | + if not 0 <= kth[i] < length: |
| 311 | + raise ValueError(f"kth(={kth[i]}) out of bounds {length}") |
| 312 | + |
| 313 | + if a.ndim > 1 or nkth > 1 or dpnp.is_cuda_backend(a.get_array()): |
| 314 | + # sort is a faster path in case of ndim > 1 |
| 315 | + return dpnp.sort(a, axis=axis) |
240 | 316 |
|
241 | | - return call_origin(numpy.partition, x1, kth, axis, kind, order) |
| 317 | + desc = dpnp.get_dpnp_descriptor(a, copy_when_nondefault_queue=False) |
| 318 | + return dpnp_partition(desc, kth[0], axis, kind, order).get_pyobj() |
242 | 319 |
|
243 | 320 |
|
244 | 321 | def sort(a, axis=-1, kind=None, order=None, *, descending=False, stable=None): |
|
0 commit comments