@@ -326,3 +326,110 @@ def pad(
326326 return xp .nn .functional .pad (x , tuple (pad_width ), value = constant_values ) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
327327
328328 return _funcs .pad (x , pad_width , constant_values = constant_values , xp = xp )
329+
330+
331+ def partition (
332+ a : Array ,
333+ kth : int ,
334+ * ,
335+ xp : ModuleType | None = None ,
336+ ) -> Array :
337+ """
338+ Return a partitioned copy of an array.
339+
340+ Parameters
341+ ----------
342+ a : 1-dimensional array
343+ Input array.
344+ kth : int
345+ Element index to partition by.
346+ xp : array_namespace, optional
347+ The standard-compatible namespace for `x`. Default: infer.
348+
349+ Returns
350+ -------
351+ partitioned_array
352+ Array of the same type and shape as a.
353+ """
354+ # Validate inputs.
355+ if xp is None :
356+ xp = array_namespace (a )
357+ if a .ndim != 1 :
358+ msg = "only 1-dimensional arrays are currently supported"
359+ raise NotImplementedError (msg )
360+
361+ # Delegate where possible.
362+ if is_numpy_namespace (xp ) or is_cupy_namespace (xp ):
363+ return xp .partition (a , kth )
364+ if is_jax_namespace (xp ):
365+ from jax import numpy
366+
367+ return numpy .partition (a , kth )
368+
369+ # Use top-k when possible:
370+ if is_torch_namespace (xp ):
371+ from torch import topk
372+
373+ a_left , indices_left = topk (a , kth , largest = False , sorted = False )
374+ mask_right = xp .ones (a .shape , dtype = bool )
375+ mask_right [indices_left ] = False
376+ return xp .concat ((a_left , a [mask_right ]))
377+ # Note: dask topk/argtopk sort the return values, so it's
378+ # not much more efficient than sorting everything when
379+ # kth is not small compared to x.size
380+
381+ return _funcs .partition (a , kth , xp = xp )
382+
383+
384+ def argpartition (
385+ a : Array ,
386+ kth : int ,
387+ * ,
388+ xp : ModuleType | None = None ,
389+ ) -> Array :
390+ """
391+ Perform an indirect partition along the given axis.
392+
393+ Parameters
394+ ----------
395+ a : 1-dimensional array
396+ Input array.
397+ kth : int
398+ Element index to partition by.
399+ xp : array_namespace, optional
400+ The standard-compatible namespace for `x`. Default: infer.
401+
402+ Returns
403+ -------
404+ index_array
405+ Array of indices that partition `a` along the specified axis.
406+ """
407+ # Validate inputs.
408+ if xp is None :
409+ xp = array_namespace (a )
410+ if a .ndim != 1 :
411+ msg = "only 1-dimensional arrays are currently supported"
412+ raise NotImplementedError (msg )
413+
414+ # Delegate where possible.
415+ if is_numpy_namespace (xp ) or is_cupy_namespace (xp ):
416+ return xp .argpartition (a , kth )
417+ if is_jax_namespace (xp ):
418+ from jax import numpy
419+
420+ return numpy .argpartition (a , kth )
421+
422+ # Use top-k when possible:
423+ if is_torch_namespace (xp ):
424+ from torch import topk
425+
426+ _ , indices = topk (a , kth , largest = False , sorted = False )
427+ mask = xp .ones (a .shape , dtype = bool )
428+ mask [indices ] = False
429+ indices_above = xp .arange (a .shape [0 ])[mask ]
430+ return xp .concat ((indices , indices_above ))
431+ # Note: dask topk/argtopk sort the return values, so it's
432+ # not much more efficient than sorting everything when
433+ # kth is not small compared to x.size
434+
435+ return _funcs .argpartition (a , kth , xp = xp )
0 commit comments