@@ -331,6 +331,8 @@ def pad(
331331def partition (
332332 a : Array ,
333333 kth : int ,
334+ / ,
335+ axis : int | None = - 1 ,
334336 * ,
335337 xp : ModuleType | None = None ,
336338) -> Array :
@@ -343,6 +345,9 @@ def partition(
343345 Input array.
344346 kth : int
345347 Element index to partition by.
348+ axis : int, optional
349+ Axis along which to partition. The default is -1 (the last axis).
350+ If None, the flattened array is used.
346351 xp : array_namespace, optional
347352 The standard-compatible namespace for `x`. Default: infer.
348353
@@ -354,36 +359,61 @@ def partition(
354359 # Validate inputs.
355360 if xp is None :
356361 xp = array_namespace (a )
357- if a .ndim != 1 :
358- msg = "only 1-dimensional arrays are currently supported"
359- raise NotImplementedError (msg )
362+ if a .ndim < 1 :
363+ msg = "`a` must be at least 1-dimensional"
364+ raise TypeError (msg )
365+ if axis is None :
366+ return partition (xp .reshape (a , - 1 ), kth , axis = 0 , xp = xp )
367+ size = a .shape [axis ]
368+ if size is None :
369+ msg = "Array dimensions must be known"
370+ raise ValueError (msg )
371+ if not (0 <= kth < size ):
372+ msg = f"kth(={ kth } ) out of bounds [0 { size } )"
373+ raise ValueError (msg )
360374
361375 # Delegate where possible.
362- if is_numpy_namespace (xp ) or is_cupy_namespace (xp ):
363- return xp .partition (a , kth )
364- if is_jax_namespace (xp ):
365- from jax import numpy
366-
367- return numpy .partition (a , kth )
376+ if is_numpy_namespace (xp ) or is_cupy_namespace (xp ) or is_jax_namespace (xp ):
377+ return xp .partition (a , kth , axis = axis )
368378
369379 # Use top-k when possible:
370380 if is_torch_namespace (xp ):
371- from torch import topk
381+ if not (axis == - 1 or axis == a .ndim - 1 ):
382+ a = xp .transpose (a , axis , - 1 )
372383
373- a_left , indices_left = topk (a , kth , largest = False , sorted = False )
384+ # Get smallest `kth` elements along axis
385+ kth += 1 # HACK: we use a non-specified behavior of torch.topk:
386+ # in `a_left`, the element in the last position is the max
387+ a_left , indices = xp .topk (a , kth , dim = - 1 , largest = False , sorted = False )
388+
389+ # Build a mask to remove the selected elements
374390 mask_right = xp .ones (a .shape , dtype = bool )
375- mask_right [indices_left ] = False
376- return xp .concat ((a_left , a [mask_right ]))
391+ mask_right .scatter_ (dim = - 1 , index = indices , value = False )
392+
393+ # Remaining elements along axis
394+ a_right = a [mask_right ] # 1-d array
395+
396+ # Reshape. This is valid only because we work on the last axis
397+ a_right = xp .reshape (a_right , shape = (* a .shape [:- 1 ], - 1 ))
398+
399+ # Concatenate the two parts along axis
400+ partitioned_array = xp .cat ((a_left , a_right ), dim = - 1 )
401+ if not (axis == - 1 or axis == a .ndim - 1 ):
402+ partitioned_array = xp .transpose (partitioned_array , axis , - 1 )
403+ return partitioned_array
404+
377405 # Note: dask topk/argtopk sort the return values, so it's
378406 # not much more efficient than sorting everything when
379407 # kth is not small compared to x.size
380408
381- return _funcs .partition (a , kth , xp = xp )
409+ return _funcs .partition (a , kth , axis = axis , xp = xp )
382410
383411
384412def argpartition (
385413 a : Array ,
386414 kth : int ,
415+ / ,
416+ axis : int | None = - 1 ,
387417 * ,
388418 xp : ModuleType | None = None ,
389419) -> Array :
@@ -392,10 +422,13 @@ def argpartition(
392422
393423 Parameters
394424 ----------
395- a : 1-dimensional array
425+ a : Array
396426 Input array.
397427 kth : int
398428 Element index to partition by.
429+ axis : int, optional
430+ Axis along which to partition. The default is -1 (the last axis).
431+ If None, the flattened array is used.
399432 xp : array_namespace, optional
400433 The standard-compatible namespace for `x`. Default: infer.
401434
@@ -407,29 +440,46 @@ def argpartition(
407440 # Validate inputs.
408441 if xp is None :
409442 xp = array_namespace (a )
410- if a .ndim != 1 :
411- msg = "only 1-dimensional arrays are currently supported"
412- raise NotImplementedError (msg )
443+ if a .ndim < 1 :
444+ msg = "`a` must be at least 1-dimensional"
445+ raise TypeError (msg )
446+ if axis is None :
447+ return partition (xp .reshape (a , - 1 ), kth , axis = 0 , xp = xp )
448+ size = a .shape [axis ]
449+ if size is None :
450+ msg = "Array dimensions must be known"
451+ raise ValueError (msg )
452+ if not (0 <= kth < size ):
453+ msg = f"kth(={ kth } ) out of bounds [0 { size } )"
454+ raise ValueError (msg )
413455
414456 # Delegate where possible.
415- if is_numpy_namespace (xp ) or is_cupy_namespace (xp ):
416- return xp .argpartition (a , kth )
417- if is_jax_namespace (xp ):
418- from jax import numpy
419-
420- return numpy .argpartition (a , kth )
457+ if is_numpy_namespace (xp ) or is_cupy_namespace (xp ) or is_jax_namespace (xp ):
458+ return xp .argpartition (a , kth , axis = axis )
421459
422460 # Use top-k when possible:
423461 if is_torch_namespace (xp ):
424- from torch import topk
462+ # see `partition` above for commented details of those steps:
463+ if not (axis == - 1 or axis == a .ndim - 1 ):
464+ a = xp .transpose (a , axis , - 1 )
465+
466+ kth += 1 # HACK
467+ _ , indices_left = xp .topk (a , kth , dim = - 1 , largest = False , sorted = False )
468+
469+ mask_right = xp .ones (a .shape , dtype = bool )
470+ mask_right .scatter_ (dim = - 1 , index = indices_left , value = False )
471+
472+ indices_right = xp .nonzero (mask_right )[- 1 ]
473+ indices_right = xp .reshape (indices_right , shape = (* a .shape [:- 1 ], - 1 ))
474+
475+ # Concatenate the two parts along axis
476+ index_array = xp .cat ((indices_left , indices_right ), dim = - 1 )
477+ if not (axis == - 1 or axis == a .ndim - 1 ):
478+ index_array = xp .transpose (index_array , axis , - 1 )
479+ return index_array
425480
426- _ , indices = topk (a , kth , largest = False , sorted = False )
427- mask = xp .ones (a .shape , dtype = bool )
428- mask [indices ] = False
429- indices_above = xp .arange (a .shape [0 ])[mask ]
430- return xp .concat ((indices , indices_above ))
431481 # Note: dask topk/argtopk sort the return values, so it's
432482 # not much more efficient than sorting everything when
433483 # kth is not small compared to x.size
434484
435- return _funcs .argpartition (a , kth , xp = xp )
485+ return _funcs .argpartition (a , kth , axis = axis , xp = xp )
0 commit comments