Skip to content

Commit c90f0b2

Browse files
committed
Update third party tests
1 parent 3967a79 commit c90f0b2

1 file changed

Lines changed: 9 additions & 6 deletions

File tree

dpnp/tests/third_party/cupy/sorting_tests/test_sort.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import unittest
24

35
import numpy
@@ -455,7 +457,6 @@ def test_sort_complex_nan(self, xp, dtype):
455457
}
456458
)
457459
)
458-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
459460
class TestPartition(unittest.TestCase):
460461

461462
def partition(self, a, kth, axis=-1):
@@ -478,17 +479,13 @@ def test_partition_zero_dim(self):
478479
@testing.for_all_dtypes()
479480
@testing.numpy_cupy_equal()
480481
def test_partition_one_dim(self, xp, dtype):
481-
flag = xp.issubdtype(dtype, xp.unsignedinteger)
482-
if flag or dtype in [xp.int8, xp.int16]:
483-
pytest.skip("dpnp.partition() does not support new integer dtypes.")
484482
a = testing.shaped_random((self.length,), xp, dtype)
485483
kth = 2
486484
x = self.partition(a, kth)
487485
assert xp.all(x[0:kth] <= x[kth : kth + 1])
488486
assert xp.all(x[kth : kth + 1] <= x[kth + 1 :])
489487
return x[kth]
490488

491-
@pytest.mark.skip("multidimensional case doesn't work properly")
492489
@testing.for_all_dtypes()
493490
@testing.numpy_cupy_array_equal()
494491
def test_partition_multi_dim(self, xp, dtype):
@@ -505,6 +502,12 @@ def test_partition_multi_dim(self, xp, dtype):
505502
def test_partition_non_contiguous(self, xp):
506503
a = testing.shaped_random((self.length,), xp)[::-1]
507504
kth = 2
505+
# if not self.external:
506+
# if xp is cupy:
507+
# with self.assertRaises(NotImplementedError):
508+
# return self.partition(a, kth)
509+
# return 0 # dummy
510+
# else:
508511
x = self.partition(a, kth)
509512
assert xp.all(x[0:kth] <= x[kth : kth + 1])
510513
assert xp.all(x[kth : kth + 1] <= x[kth + 1 :])
@@ -607,7 +610,7 @@ def test_partition_invalid_negative_axis2(self):
607610
}
608611
)
609612
)
610-
@pytest.mark.skip("not fully supported yet")
613+
@pytest.mark.skip("not supported yet")
611614
class TestArgpartition(unittest.TestCase):
612615

613616
def argpartition(self, a, kth, axis=-1):

0 commit comments

Comments
 (0)