1+ from __future__ import annotations
2+
13import unittest
24
35import numpy
46import pytest
57
68import dpnp as cupy
9+
10+ # from cupyx import cusolver
11+ # from cupy.cuda import driver
12+ # from cupy.cuda import runtime
13+ # from cupy.linalg import _util
714from dpnp .tests .helper import (
15+ LTS_VERSION ,
816 has_support_aspect64 ,
9- is_cpu_device ,
17+ is_lts_driver ,
1018)
1119from dpnp .tests .third_party .cupy import testing
1220from dpnp .tests .third_party .cupy .testing import _condition
1321
22+ # import cupyx
23+
1424
1525def random_matrix (shape , dtype , scale , sym = False ):
1626 m , n = shape [- 2 :]
@@ -95,6 +105,8 @@ def test_decomposition(self, dtype):
95105 ]
96106 )
97107 def test_batched_decomposition (self , dtype ):
108+ # if not cusolver.check_availability("potrfBatched"):
109+ # pytest.skip("potrfBatched is not available")
98110 Ab1 = random_matrix ((3 , 5 , 5 ), dtype , scale = (10 , 10000 ), sym = True )
99111 self .check_L (Ab1 )
100112 Ab2 = random_matrix ((2 , 2 , 5 , 5 ), dtype , scale = (10 , 10000 ), sym = True )
@@ -134,9 +146,6 @@ def check_L(self, array):
134146 with pytest .raises (xp .linalg .LinAlgError ):
135147 xp .linalg .cholesky (a )
136148
137- # TODO: remove skipif when MKLD-17318 is resolved
138- # _potrf does not raise an error with singular matrices on CPU.
139- @pytest .mark .skipif (is_cpu_device (), reason = "MKLD-17318" )
140149 @testing .for_dtypes (
141150 [
142151 numpy .int32 ,
@@ -163,6 +172,10 @@ class TestQRDecomposition(unittest.TestCase):
163172
164173 @testing .for_dtypes ("fdFD" )
165174 def check_mode (self , array , mode , dtype ):
175+ # if runtime.is_hip and driver.get_build_version() < 307:
176+ # if dtype in (numpy.complex64, numpy.complex128):
177+ # pytest.skip("ungqr unsupported")
178+
166179 a_cpu = numpy .asarray (array , dtype = dtype )
167180 a_gpu = cupy .asarray (array , dtype = dtype )
168181 result_gpu = cupy .linalg .qr (a_gpu , mode = mode )
@@ -189,13 +202,19 @@ def test_mode(self):
189202 self .check_mode (numpy .random .randn (3 , 3 ), mode = self .mode )
190203 self .check_mode (numpy .random .randn (5 , 4 ), mode = self .mode )
191204
205+ @pytest .mark .skipif (
206+ is_lts_driver (version = LTS_VERSION .V1_6 ), reason = "SAT-8375"
207+ )
192208 @testing .with_requires ("numpy>=1.22" )
193209 @testing .fix_random ()
194210 def test_mode_rank3 (self ):
195211 self .check_mode (numpy .random .randn (3 , 2 , 4 ), mode = self .mode )
196212 self .check_mode (numpy .random .randn (4 , 3 , 3 ), mode = self .mode )
197213 self .check_mode (numpy .random .randn (2 , 5 , 4 ), mode = self .mode )
198214
215+ @pytest .mark .skipif (
216+ is_lts_driver (version = LTS_VERSION .V1_6 ), reason = "SAT-8375"
217+ )
199218 @testing .with_requires ("numpy>=1.22" )
200219 @testing .fix_random ()
201220 def test_mode_rank4 (self ):
0 commit comments