@@ -4,6 +4,7 @@ copy_input(::typeof(svd_full), A::AbstractMatrix) = copy!(similar(A, float(eltyp
44copy_input (:: typeof (svd_compact), A) = copy_input (svd_full, A)
55copy_input (:: typeof (svd_vals), A) = copy_input (svd_full, A)
66copy_input (:: typeof (svd_trunc), A) = copy_input (svd_compact, A)
7+ copy_input (:: typeof (svd_trunc_with_err), A) = copy_input (svd_compact, A)
78
89copy_input (:: typeof (svd_full), A:: Diagonal ) = copy (A)
910
9293function initialize_output (:: typeof (svd_trunc!), A, alg:: TruncatedAlgorithm )
9394 return initialize_output (svd_compact!, A, alg. alg)
9495end
96+ function initialize_output (:: typeof (svd_trunc_with_err!), A, alg:: TruncatedAlgorithm )
97+ return initialize_output (svd_compact!, A, alg. alg)
98+ end
9599
96100function initialize_output (:: typeof (svd_full!), A:: Diagonal , :: DiagonalAlgorithm )
97101 TA = eltype (A)
@@ -206,19 +210,16 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
206210 return S
207211end
208212
209- function svd_trunc! (A, USVᴴ:: Tuple{TU, TS, TVᴴ} , alg:: TruncatedAlgorithm ; compute_error :: Bool = true ) where {TU, TS, TVᴴ}
210- ϵ = similar (A, real ( eltype (A)), compute_error )
211- (U, S, Vᴴ, ϵ) = svd_trunc! (A , (USVᴴ ... , ϵ ), alg)
212- return compute_error ? (U, S, Vᴴ, norm (ϵ)) : (U, S, Vᴴ, - one ( eltype (ϵ)))
213+ function svd_trunc! (A, USVᴴ, alg:: TruncatedAlgorithm )
214+ U, S, Vᴴ = svd_compact! (A, USVᴴ, alg . alg )
215+ USVᴴtrunc, ind = truncate ( svd_trunc!, (U, S, Vᴴ ), alg. trunc )
216+ return USVᴴtrunc
213217end
214218
215- function svd_trunc! (A, USVᴴϵ:: Tuple{TU, TS, TVᴴ, Tϵ} , alg:: TruncatedAlgorithm ) where {TU, TS, TVᴴ, Tϵ}
216- U, S, Vᴴ, ϵ = USVᴴϵ
217- U, S, Vᴴ = svd_compact! (A, (U, S, Vᴴ), alg. alg)
219+ function svd_trunc_with_err! (A, USVᴴ, alg:: TruncatedAlgorithm )
220+ U, S, Vᴴ = svd_compact! (A, USVᴴ, alg. alg)
218221 USVᴴtrunc, ind = truncate (svd_trunc!, (U, S, Vᴴ), alg. trunc)
219- if ! isempty (ϵ)
220- ϵ .= truncation_error! (diagview (S), ind)
221- end
222+ ϵ = truncation_error! (diagview (S), ind)
222223 return USVᴴtrunc... , ϵ
223224end
224225
@@ -287,6 +288,22 @@ function check_input(
287288 return nothing
288289end
289290
291+ function check_input (
292+ :: typeof (svd_trunc_with_err!), A:: AbstractMatrix , USVᴴ, alg:: CUSOLVER_Randomized
293+ )
294+ m, n = size (A)
295+ minmn = min (m, n)
296+ U, S, Vᴴ = USVᴴ
297+ @assert U isa AbstractMatrix && S isa Diagonal && Vᴴ isa AbstractMatrix
298+ @check_size (U, (m, m))
299+ @check_scalar (U, A)
300+ @check_size (S, (minmn, minmn))
301+ @check_scalar (S, A, real)
302+ @check_size (Vᴴ, (n, n))
303+ @check_scalar (Vᴴ, A)
304+ return nothing
305+ end
306+
290307function initialize_output (
291308 :: typeof (svd_trunc!), A:: AbstractMatrix , alg:: TruncatedAlgorithm{<:CUSOLVER_Randomized}
292309 )
@@ -298,6 +315,17 @@ function initialize_output(
298315 return (U, S, Vᴴ)
299316end
300317
318+ function initialize_output (
319+ :: typeof (svd_trunc_with_err!), A:: AbstractMatrix , alg:: TruncatedAlgorithm{<:CUSOLVER_Randomized}
320+ )
321+ m, n = size (A)
322+ minmn = min (m, n)
323+ U = similar (A, (m, m))
324+ S = Diagonal (similar (A, real (eltype (A)), (minmn,)))
325+ Vᴴ = similar (A, (n, n))
326+ return (U, S, Vᴴ)
327+ end
328+
301329function _gpu_gesvd! (
302330 A:: AbstractMatrix , S:: AbstractVector , U:: AbstractMatrix , Vᴴ:: AbstractMatrix
303331 )
@@ -372,22 +400,34 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
372400 return USVᴴ
373401end
374402
375- function svd_trunc! (A:: AbstractMatrix , USVᴴϵ :: Tuple{TU, TS, TVᴴ, Tϵ} , alg:: TruncatedAlgorithm{<:GPU_Randomized} ) where {TU, TS, TVᴴ, Tϵ}
376- U, S, Vᴴ, ϵ = USVᴴϵ
403+ function svd_trunc! (A:: AbstractMatrix , USVᴴ, alg:: TruncatedAlgorithm{<:GPU_Randomized} )
404+ U, S, Vᴴ = USVᴴ
377405 check_input (svd_trunc!, A, (U, S, Vᴴ), alg. alg)
378406 _gpu_Xgesvdr! (A, S. diag, U, Vᴴ; alg. alg. kwargs... )
379407
380408 # TODO : make sure that truncation is based on maxrank, otherwise this might be wrong
381409 (Utr, Str, Vᴴtr), _ = truncate (svd_trunc!, (U, S, Vᴴ), alg. trunc)
382410
383- if ! isempty (ϵ)
384- # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
385- normS = norm (diagview (Str))
386- normA = norm (A)
387- # equivalent to sqrt(normA^2 - normS^2)
388- # but may be more accurate
389- ϵ = sqrt ((normA + normS) * (normA - normS))
390- end
411+ do_gauge_fix = get (alg. alg. kwargs, :fixgauge , default_fixgauge ()):: Bool
412+ do_gauge_fix && gaugefix! (svd_trunc!, Utr, Vᴴtr)
413+
414+ return Utr, Str, Vᴴtr, ϵ
415+ end
416+
417+ function svd_trunc_with_err! (A:: AbstractMatrix , USVᴴ, alg:: TruncatedAlgorithm{<:GPU_Randomized} )
418+ U, S, Vᴴ = USVᴴ
419+ check_input (svd_trunc!, A, (U, S, Vᴴ), alg. alg)
420+ _gpu_Xgesvdr! (A, S. diag, U, Vᴴ; alg. alg. kwargs... )
421+
422+ # TODO : make sure that truncation is based on maxrank, otherwise this might be wrong
423+ (Utr, Str, Vᴴtr), _ = truncate (svd_trunc!, (U, S, Vᴴ), alg. trunc)
424+
425+ # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
426+ normS = norm (diagview (Str))
427+ normA = norm (A)
428+ # equivalent to sqrt(normA^2 - normS^2)
429+ # but may be more accurate
430+ ϵ = sqrt ((normA + normS) * (normA - normS))
391431
392432 do_gauge_fix = get (alg. alg. kwargs, :fixgauge , default_fixgauge ()):: Bool
393433 do_gauge_fix && gaugefix! (svd_trunc!, Utr, Vᴴtr)
0 commit comments