Skip to content

Commit e663374

Browse files
committed
more careful about gaugefixing svd_trunc
1 parent e0956f1 commit e663374

2 files changed

Lines changed: 25 additions & 35 deletions

File tree

src/common/gauge.jl

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,53 +12,42 @@ is real and positive.
1212
function gaugefix!(::Union{typeof(eig_full!), typeof(eigh_full!), typeof(gen_eig_full!)}, V::AbstractMatrix)
1313
for j in axes(V, 2)
1414
v = view(V, :, j)
15-
s = conj(sign(_argmaxabs(v)))
16-
@inbounds v .*= s
15+
s = sign(_argmaxabs(v))
16+
@inbounds v .*= conj(s)
1717
end
1818
return V
1919
end
2020

2121
function gaugefix!(::typeof(svd_full!), U, Vᴴ)
22-
m, n = size(U, 1), size(Vᴴ, 2)
22+
m, n = size(U, 2), size(Vᴴ, 1)
2323
for j in 1:max(m, n)
2424
if j <= min(m, n)
2525
u = view(U, :, j)
2626
v = view(Vᴴ, j, :)
27-
s = conj(sign(_argmaxabs(u)))
28-
u .*= s
29-
v .*= conj(s)
27+
s = sign(_argmaxabs(u))
28+
u .*= conj(s)
29+
v .*= s
3030
elseif j <= m
3131
u = view(U, :, j)
32-
s = conj(sign(_argmaxabs(u)))
33-
u .*= s
32+
s = sign(_argmaxabs(u))
33+
u .*= conj(s)
3434
else
3535
v = view(Vᴴ, j, :)
36-
s = conj(sign(_argmaxabs(v)))
37-
v .*= s
36+
s = sign(_argmaxabs(v))
37+
v .*= conj(s)
3838
end
3939
end
4040
return (U, Vᴴ)
4141
end
4242

43-
function gaugefix!(::typeof(svd_compact!), U, Vᴴ)
44-
for j in 1:size(U, 2)
45-
u = view(U, :, j)
46-
v = view(Vᴴ, j, :)
47-
s = conj(sign(_argmaxabs(u)))
48-
u .*= s
49-
v .*= conj(s)
50-
end
51-
return (U, Vᴴ)
52-
end
53-
54-
function gaugefix!(::typeof(svd_trunc!), U, Vᴴ)
55-
m, n = size(U, 1), size(Vᴴ, 2)
56-
for j in 1:min(m, n)
43+
function gaugefix!(::Union{typeof(svd_compact!), typeof(svd_trunc!)}, U, Vᴴ)
44+
@assert axes(U, 2) == axes(Vᴴ, 1)
45+
for j in axes(U, 2)
5746
u = view(U, :, j)
5847
v = view(Vᴴ, j, :)
59-
s = conj(sign(_argmaxabs(u)))
60-
u .*= s
61-
v .*= conj(s)
48+
s = sign(_argmaxabs(u))
49+
u .*= conj(s)
50+
v .*= s
6251
end
6352
return (U, Vᴴ)
6453
end

src/implementations/svd.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Inputs
1+
# Input
22
# ------
33
copy_input(::typeof(svd_full), A::AbstractMatrix) = copy!(similar(A, float(eltype(A))), A)
44
copy_input(::typeof(svd_compact), A) = copy_input(svd_full, A)
@@ -362,15 +362,16 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran
362362
U, S, Vᴴ = USVᴴ
363363
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...)
364364

365-
do_gauge_fix = get(alg.alg.kwargs, :gaugefix, default_gaugefix())::Bool
366-
do_gauge_fix && gaugefix!(svd_trunc!, U, Vᴴ)
367-
368365
# TODO: make sure that truncation is based on maxrank, otherwise this might be wrong
369-
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
370-
Strunc = diagview(USVᴴtrunc[2])
366+
(Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
367+
371368
# normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
372-
ϵ = sqrt(norm(A)^2 - norm(Strunc)^2) # is there a more accurate way to do this?
373-
return USVᴴtrunc..., ϵ
369+
ϵ = sqrt(norm(A)^2 - norm(diagview(Str))^2) # is there a more accurate way to do this?
370+
371+
do_gauge_fix = get(alg.alg.kwargs, :gaugefix, default_gaugefix())::Bool
372+
do_gauge_fix && gaugefix!(svd_trunc!, Utr, Vᴴtr)
373+
374+
return Utr, Str, Vᴴtr, ϵ
374375
end
375376

376377
function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)

0 commit comments

Comments
 (0)