@@ -175,7 +175,7 @@ const GPU_SVDPolar = Union{CUSOLVER_SVDPolar}
175175const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi}
176176const GPU_Randomized = Union{CUSOLVER_Randomized}
177177
178- function check_input (:: typeof (svd_compact !), A:: AbstractMatrix , USVᴴ, :: CUSOLVER_Randomized )
178+ function check_input (:: typeof (svd_trunc !), A:: AbstractMatrix , USVᴴ, alg :: TruncatedAlgorithm{ CUSOLVER_Randomized} )
179179 m, n = size (A)
180180 minmn = min (m, n)
181181 U, S, Vᴴ = USVᴴ
@@ -189,7 +189,7 @@ function check_input(::typeof(svd_compact!), A::AbstractMatrix, USVᴴ, ::CUSOLV
189189 return nothing
190190end
191191
192- function initialize_output (:: typeof (svd_compact !), A:: AbstractMatrix , :: CUSOLVER_Randomized )
192+ function initialize_output (:: typeof (svd_trunc !), A:: AbstractMatrix , alg :: TruncatedAlgorithm{ CUSOLVER_Randomized} )
193193 m, n = size (A)
194194 minmn = min (m, n)
195195 U = similar (A, (m, m))
@@ -231,6 +231,15 @@ function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgor
231231 return USVᴴ
232232end
233233
234+ function svd_trunc! (A:: AbstractMatrix , USVᴴ, alg:: TruncatedAlgorithm{<:GPU_Randomized} )
235+ check_input (svd_trunc!, A, USVᴴ, alg)
236+ U, S, Vᴴ = USVᴴ
237+ _gpu_Xgesvdr! (A, S. diag, U, Vᴴ; alg. kwargs... )
238+ # TODO : make this controllable using a `gaugefix` keyword argument
239+ gaugefix! (Val (:compact ), U, S, Vᴴ, m, n)
240+ return truncate! (svd_trunc!, USVᴴ′, alg. trunc)
241+ end
242+
234243function MatrixAlgebraKit. svd_compact! (A:: AbstractMatrix , USVᴴ, alg:: GPU_SVDAlgorithm )
235244 check_input (svd_compact!, A, USVᴴ, alg)
236245 U, S, Vᴴ = USVᴴ
@@ -242,8 +251,6 @@ function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAl
242251 _gpu_Xgesvdp! (A, S. diag, U, Vᴴ; alg. kwargs... )
243252 elseif alg isa GPU_Jacobi
244253 _gpu_gesvdj! (A, S. diag, U, Vᴴ; alg. kwargs... )
245- elseif alg isa GPU_Randomized
246- _gpu_Xgesvdr! (A, S. diag, U, Vᴴ; alg. kwargs... )
247254 else
248255 throw (ArgumentError (" Unsupported SVD algorithm" ))
249256 end
0 commit comments