Skip to content

Commit 564a788

Browse files
author
Katharine Hyatt
committed
Move randomized SVD to svd_trunc
1 parent 562afe2 commit 564a788

1 file changed

Lines changed: 11 additions & 4 deletions

File tree

src/implementations/svd.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ const GPU_SVDPolar = Union{CUSOLVER_SVDPolar}
175175
const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi}
176176
const GPU_Randomized = Union{CUSOLVER_Randomized}
177177

178-
function check_input(::typeof(svd_compact!), A::AbstractMatrix, USVᴴ, ::CUSOLVER_Randomized)
178+
function check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{CUSOLVER_Randomized})
179179
m, n = size(A)
180180
minmn = min(m, n)
181181
U, S, Vᴴ = USVᴴ
@@ -189,7 +189,7 @@ function check_input(::typeof(svd_compact!), A::AbstractMatrix, USVᴴ, ::CUSOLV
189189
return nothing
190190
end
191191

192-
function initialize_output(::typeof(svd_compact!), A::AbstractMatrix, ::CUSOLVER_Randomized)
192+
function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::TruncatedAlgorithm{CUSOLVER_Randomized})
193193
m, n = size(A)
194194
minmn = min(m, n)
195195
U = similar(A, (m, m))
@@ -231,6 +231,15 @@ function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgor
231231
return USVᴴ
232232
end
233233

234+
function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized})
235+
check_input(svd_trunc!, A, USVᴴ, alg)
236+
U, S, Vᴴ = USVᴴ
237+
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.kwargs...)
238+
# TODO: make this controllable using a `gaugefix` keyword argument
239+
gaugefix!(Val(:compact), U, S, Vᴴ, m, n)
240+
return truncate!(svd_trunc!, USVᴴ′, alg.trunc)
241+
end
242+
234243
function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
235244
check_input(svd_compact!, A, USVᴴ, alg)
236245
U, S, Vᴴ = USVᴴ
@@ -242,8 +251,6 @@ function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAl
242251
_gpu_Xgesvdp!(A, S.diag, U, Vᴴ; alg.kwargs...)
243252
elseif alg isa GPU_Jacobi
244253
_gpu_gesvdj!(A, S.diag, U, Vᴴ; alg.kwargs...)
245-
elseif alg isa GPU_Randomized
246-
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.kwargs...)
247254
else
248255
throw(ArgumentError("Unsupported SVD algorithm"))
249256
end

0 commit comments

Comments
 (0)