3636from array_api_extra ._lib ._utils ._compat import (
3737 device as get_device ,
3838)
39- from array_api_extra ._lib ._utils ._compat import (
40- is_pydata_sparse_namespace ,
41- )
4239from array_api_extra ._lib ._utils ._helpers import eager_shape , ndindex
4340from array_api_extra ._lib ._utils ._typing import Array , Device
4441from array_api_extra .testing import lazy_xp_function
@@ -1344,7 +1341,7 @@ def _assert_valid_partition(
13441341 def _partition (cls , x : np .ndarray , k : int , xp : ModuleType , axis : int | None = - 1 ):
13451342 return partition (xp .asarray (x ), k , axis = axis )
13461343
1347- def test_1d (self , xp : ModuleType ):
1344+ def _test_1d (self , xp : ModuleType ):
13481345 rng = np .random .default_rng ()
13491346 for n in [2 , 3 , 4 , 5 , 7 , 10 , 20 , 50 , 100 , 1_000 ]:
13501347 k = int (rng .integers (n ))
@@ -1355,8 +1352,7 @@ def test_1d(self, xp: ModuleType):
13551352 y = self ._partition (x2 , k , xp )
13561353 self ._assert_valid_partition (x2 , k , y , xp )
13571354
1358- @pytest .mark .parametrize ("ndim" , [2 , 3 , 4 ])
1359- def test_nd (self , xp : ModuleType , ndim : int ):
1355+ def _test_nd (self , xp : ModuleType , ndim : int ):
13601356 rng = np .random .default_rng ()
13611357
13621358 for n in [2 , 3 , 5 , 10 , 20 , 100 ]:
@@ -1375,20 +1371,28 @@ def test_nd(self, xp: ModuleType, ndim: int):
13751371 y = self ._partition (z , k , xp , axis = None )
13761372 self ._assert_valid_partition (z , k , y , xp , axis = None )
13771373
1378- def test_input_validation (self , xp : ModuleType ):
1374+ def _test_input_validation (self , xp : ModuleType ):
13791375 with pytest .raises (TypeError ):
13801376 _ = self ._partition (np .asarray (1 ), 1 , xp )
13811377 with pytest .raises (ValueError , match = "out of bounds" ):
13821378 _ = self ._partition (np .asarray ([1 , 2 ]), 3 , xp )
13831379
1380+ def test_1d (self , xp : ModuleType ):
1381+ self ._test_1d (xp )
1382+
1383+ @pytest .mark .parametrize ("ndim" , [2 , 3 , 4 ])
1384+ def test_nd (self , xp : ModuleType , ndim : int ):
1385+ self ._test_nd (xp , ndim )
1386+
1387+ def test_input_validation (self , xp : ModuleType ):
1388+ self ._test_input_validation (xp )
1389+
13841390
13851391@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no argsort" )
13861392class TestArgpartition (TestPartition ):
13871393 @classmethod
13881394 @override
13891395 def _partition (cls , x : np .ndarray , k : int , xp : ModuleType , axis : int | None = - 1 ):
1390- if is_pydata_sparse_namespace (xp ):
1391- pytest .xfail (reason = "Sparse backend has no argsort" )
13921396 arr = xp .asarray (x )
13931397 indices = argpartition (arr , k , axis = axis )
13941398 if axis is None :
@@ -1398,3 +1402,16 @@ def _partition(cls, x: np.ndarray, k: int, xp: ModuleType, axis: int | None = -1
13981402 if not hasattr (xp , "take_along_axis" ):
13991403 pytest .skip ("TODO: find an alternative to take_along_axis" )
14001404 return xp .take_along_axis (arr , indices , axis = axis )
1405+
1406+ @override
1407+ def test_1d (self , xp : ModuleType ):
1408+ self ._test_1d (xp )
1409+
1410+ @pytest .mark .parametrize ("ndim" , [2 , 3 , 4 ])
1411+ @override
1412+ def test_nd (self , xp : ModuleType , ndim : int ):
1413+ self ._test_nd (xp , ndim )
1414+
1415+ @override
1416+ def test_input_validation (self , xp : ModuleType ):
1417+ self ._test_input_validation (xp )
0 commit comments