@@ -3,7 +3,7 @@ module MatrixAlgebraKitCUDAExt
33using MatrixAlgebraKit
44using MatrixAlgebraKit: @algdef , Algorithm, check_input
55using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
6- using MatrixAlgebraKit: diagview, sign_safe, default_pullback_gauge_atol, default_pullback_rank_atol
6+ using MatrixAlgebraKit: diagview, sign_safe
77using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
88using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
99import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
195195MatrixAlgebraKit. _ind_intersect (A:: CuVector{Int} , B:: CuVector{Int} ) =
196196 MatrixAlgebraKit. _ind_intersect (collect (A), collect (B))
197197
198- MatrixAlgebraKit. default_pullback_rank_atol (A:: AnyCuArray ) = eps (norm (CuArray (A), Inf ))^ (3 / 4 )
199- MatrixAlgebraKit. default_pullback_gauge_atol (A:: AnyCuArray ) = MatrixAlgebraKit. iszerotangent (A) ? 0 : eps (norm (CuArray (A), Inf ))^ (3 / 4 )
200- function MatrixAlgebraKit. default_pullback_gauge_atol (A:: AnyCuArray , As... )
201- As′ = filter (! MatrixAlgebraKit. iszerotangent, (A, As... ))
202- return isempty (As′) ? 0 : eps (norm (CuArray .(As′), Inf ))^ (3 / 4 )
203- end
204-
205198function _sylvester (A:: AnyCuMatrix , B:: AnyCuMatrix , C:: AnyCuMatrix )
206199 # https://github.com/JuliaGPU/CUDA.jl/issues/3021
207200 # to add native sylvester to CUDA
0 commit comments