1+ from __future__ import annotations
2+
13import unittest
24
35import numpy
46import pytest
57
68import dpnp as cupy
9+
10+ # from cupyx import cusolver
11+ # from cupy.cuda import driver
12+ # from cupy.cuda import runtime
13+ # from cupy.linalg import _util
714from dpnp .tests .helper import (
15+ LTS_VERSION ,
816 has_support_aspect64 ,
9- is_cpu_device ,
17+ is_lts_driver ,
18+ is_win_platform ,
1019)
1120from dpnp .tests .third_party .cupy import testing
1221from dpnp .tests .third_party .cupy .testing import _condition
1322
23+ # import cupyx
24+
1425
1526def random_matrix (shape , dtype , scale , sym = False ):
1627 m , n = shape [- 2 :]
@@ -95,6 +106,8 @@ def test_decomposition(self, dtype):
95106 ]
96107 )
97108 def test_batched_decomposition (self , dtype ):
109+ # if not cusolver.check_availability("potrfBatched"):
110+ # pytest.skip("potrfBatched is not available")
98111 Ab1 = random_matrix ((3 , 5 , 5 ), dtype , scale = (10 , 10000 ), sym = True )
99112 self .check_L (Ab1 )
100113 Ab2 = random_matrix ((2 , 2 , 5 , 5 ), dtype , scale = (10 , 10000 ), sym = True )
@@ -134,9 +147,6 @@ def check_L(self, array):
134147 with pytest .raises (xp .linalg .LinAlgError ):
135148 xp .linalg .cholesky (a )
136149
137- # TODO: remove skipif when MKLD-17318 is resolved
138- # _potrf does not raise an error with singular matrices on CPU.
139- @pytest .mark .skipif (is_cpu_device (), reason = "MKLD-17318" )
140150 @testing .for_dtypes (
141151 [
142152 numpy .int32 ,
@@ -163,6 +173,10 @@ class TestQRDecomposition(unittest.TestCase):
163173
164174 @testing .for_dtypes ("fdFD" )
165175 def check_mode (self , array , mode , dtype ):
176+ # if runtime.is_hip and driver.get_build_version() < 307:
177+ # if dtype in (numpy.complex64, numpy.complex128):
178+ # pytest.skip("ungqr unsupported")
179+
166180 a_cpu = numpy .asarray (array , dtype = dtype )
167181 a_gpu = cupy .asarray (array , dtype = dtype )
168182 result_gpu = cupy .linalg .qr (a_gpu , mode = mode )
@@ -189,13 +203,21 @@ def test_mode(self):
189203 self .check_mode (numpy .random .randn (3 , 3 ), mode = self .mode )
190204 self .check_mode (numpy .random .randn (5 , 4 ), mode = self .mode )
191205
206+ @pytest .mark .skipif (
207+ not is_win_platform () and is_lts_driver (version = LTS_VERSION .V1_6 ),
208+ reason = "SAT-8375" ,
209+ )
192210 @testing .with_requires ("numpy>=1.22" )
193211 @testing .fix_random ()
194212 def test_mode_rank3 (self ):
195213 self .check_mode (numpy .random .randn (3 , 2 , 4 ), mode = self .mode )
196214 self .check_mode (numpy .random .randn (4 , 3 , 3 ), mode = self .mode )
197215 self .check_mode (numpy .random .randn (2 , 5 , 4 ), mode = self .mode )
198216
217+ @pytest .mark .skipif (
218+ not is_win_platform () and is_lts_driver (version = LTS_VERSION .V1_6 ),
219+ reason = "SAT-8375" ,
220+ )
199221 @testing .with_requires ("numpy>=1.22" )
200222 @testing .fix_random ()
201223 def test_mode_rank4 (self ):
0 commit comments