Skip to content

Commit 5c5e25b

Browse files
authored
Add some basic support for svd_trunc_no_error (#390)
* Add some basic support for svd_trunc_no_error * Import svd_trunc_no_error into test
1 parent d14759b commit 5c5e25b

2 files changed

Lines changed: 177 additions & 32 deletions

File tree

src/utility/svd.jl

Lines changed: 149 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,24 @@ function MatrixAlgebraKit.svd_trunc!(t::AdjointTensorMap, alg::SVDAdjoint)
133133
return adjoint(vt), adjoint(s), adjoint(u), ϵ
134134
end
135135

136+
"""
137+
svd_trunc_no_error(t, alg::SVDAdjoint)
138+
svd_trunc_no_error!(t, alg::SVDAdjoint)
139+
140+
Wrapper around `svd_trunc_no_error(!)` which dispatches on the `SVDAdjoint` algorithm.
141+
This is needed since a custom adjoint may be defined, depending on the `alg`.
142+
E.g., for `IterSVD` the adjoint for a truncated SVD from `KrylovKit.svdsolve` is used.
143+
The `_no_error(!)` versions of `svd_trunc(!)` do not compute the truncation error.
144+
"""
145+
MatrixAlgebraKit.svd_trunc_no_error(t, alg::SVDAdjoint) = svd_trunc_no_error!(copy(t), alg)
146+
function MatrixAlgebraKit.svd_trunc_no_error!(t, alg::SVDAdjoint)
147+
return svd_trunc_no_error!(t, alg.fwd_alg)
148+
end
149+
function MatrixAlgebraKit.svd_trunc_no_error!(t::AdjointTensorMap, alg::SVDAdjoint)
150+
u, s, vt = svd_trunc_no_error!(adjoint(t), alg)
151+
return adjoint(vt), adjoint(s), adjoint(u)
152+
end
153+
136154
#
137155
## Forward algorithms
138156
#
@@ -172,22 +190,24 @@ deterministic_start_vector(t::AbstractMatrix) = ones(scalartype(t), size(t, 1))
172190

173191
# Compute SVD data block-wise using KrylovKit algorithm
174192
# TODO: redefine _empty_svdtensors, _create_svdtensors
175-
function MatrixAlgebraKit.svd_trunc!(f, alg::TruncatedAlgorithm{<:IterSVD})
193+
function MatrixAlgebraKit.svd_trunc_no_error!(f, alg::TruncatedAlgorithm{<:IterSVD})
176194
fwd_alg = alg.alg
177195
trunc = alg.trunc
178196
U, S, V = if isempty(blocksectors(f))
179197
# early return
180-
truncation_error = zero(real(scalartype(f)))
181198
MatrixAlgebraKit.initialize_output(svd_compact!, f, DefaultAlgorithm()) # specified algorithm doesn't matter here
182199
else
183200
SVDdata, dims = _compute_svddata!(f, fwd_alg, trunc)
184201
_create_svdtensors(f, SVDdata, dims)
185202
end
203+
return U, S, V
204+
end
186205

206+
function MatrixAlgebraKit.svd_trunc!(f, alg::TruncatedAlgorithm{<:IterSVD})
207+
U, S, Vᴴ = svd_trunc_no_error!(f, alg)
187208
truncation_error =
188-
trunc isa NoTruncation ? abs(zero(scalartype(f))) : norm(U * S * V - f)
189-
190-
return U, S, V, truncation_error
209+
(trunc isa NoTruncation || isempty(blocksectors(f))) ? abs(zero(scalartype(f))) : norm(U * S * Vᴴ - f)
210+
return U, S, Vᴴ, truncation_error
191211
end
192212

193213
# Copy from TensorKit v0.14 internal functions
@@ -295,6 +315,36 @@ function ChainRulesCore.rrule(
295315
return (Ũ, S̃, Ṽ⁺, truncerror), svd_trunc!_full_pullback
296316
end
297317

318+
# svd_trunc_no_error! rrule wrapping MatrixAlgebraKit's svd_pullback!
319+
# https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/blob/b76c7bb60014ecfead6925d0df6cb4b8d7c2668a/src/pullbacks/svd.jl#L33
320+
function ChainRulesCore.rrule(
321+
::typeof(svd_trunc_no_error!),
322+
t::AbstractTensorMap,
323+
alg::SVDAdjoint{F, R}
324+
) where {F <: TruncatedAlgorithm{<:MatrixAlgebraKit.Algorithm}, R <: FullPullback}
325+
# TODO: filter out any decomposition algorithm that doesn't give access to the full spectrum
326+
327+
# requires access to the full decomposition
328+
U, S, V⁺ = svd_compact!(t, alg.fwd_alg.alg)
329+
(Ũ, S̃, Ṽ⁺), inds = truncate(svd_trunc!, (U, S, V⁺), alg.fwd_alg.trunc)
330+
331+
gtol = _get_pullback_gauge_tol(alg.rrule_alg.verbosity)
332+
333+
function svd_trunc!_full_pullback(ΔUSV′)
334+
ΔUSV = unthunk.(ΔUSV′)
335+
Δt = svd_pullback!(
336+
zeros(scalartype(t), space(t)), t, (U, S, V⁺), ΔUSV, inds;
337+
gauge_atol = gtol(ΔUSV), degeneracy_atol = alg.rrule_alg.degeneracy_atol,
338+
)
339+
return NoTangent(), Δt, NoTangent()
340+
end
341+
function svd_trunc!_full_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent})
342+
return NoTangent(), ZeroTangent(), NoTangent()
343+
end
344+
345+
return (Ũ, S̃, Ṽ⁺), svd_trunc!_full_pullback
346+
end
347+
298348
# svd_trunc! rrule wrapping MatrixAlgebraKit's svd_trunc_pullback! (also works for IterSVD)
299349
# https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/blob/b76c7bb60014ecfead6925d0df6cb4b8d7c2668a/src/pullbacks/svd.jl#L143
300350
function ChainRulesCore.rrule(
@@ -320,6 +370,31 @@ function ChainRulesCore.rrule(
320370
return (U, S, V⁺, ϵ), svd_trunc!_trunc_pullback
321371
end
322372

373+
# svd_trunc_no_error! rrule wrapping MatrixAlgebraKit's svd_trunc_pullback! (also works for IterSVD)
374+
# https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/blob/b76c7bb60014ecfead6925d0df6cb4b8d7c2668a/src/pullbacks/svd.jl#L143
375+
function ChainRulesCore.rrule(
376+
::typeof(svd_trunc_no_error!),
377+
t,
378+
alg::SVDAdjoint{F, R},
379+
) where {F, R <: TruncPullback}
380+
U, S, V⁺ = svd_trunc_no_error(t, alg)
381+
gtol = _get_pullback_gauge_tol(alg.rrule_alg.verbosity)
382+
383+
function svd_trunc!_trunc_pullback(ΔUSV′)
384+
ΔUSV = unthunk.(ΔUSV′)
385+
Δf = svd_trunc_pullback!(
386+
zeros(scalartype(t), space(t)), t, (U, S, V⁺), ΔUSV;
387+
gauge_atol = gtol(ΔUSV), degeneracy_atol = alg.rrule_alg.degeneracy_atol,
388+
)
389+
return NoTangent(), Δf, NoTangent()
390+
end
391+
function svd_trunc!_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent})
392+
return NoTangent(), ZeroTangent(), NoTangent()
393+
end
394+
395+
return (U, S, V⁺), svd_trunc!_trunc_pullback
396+
end
397+
323398
# KrylovKit rrule compatible with TensorMaps & function handles
324399
function ChainRulesCore.rrule(
325400
::typeof(svd_trunc!),
@@ -388,3 +463,72 @@ function ChainRulesCore.rrule(
388463

389464
return (U, S, V, ϵ), svd_trunc!_itersvd_pullback
390465
end
466+
467+
# KrylovKit rrule compatible with TensorMaps & function handles
468+
function ChainRulesCore.rrule(
469+
::typeof(svd_trunc_no_error!),
470+
f,
471+
alg::SVDAdjoint{F, R}
472+
) where {F, R <: Union{GMRES, BiCGStab, Arnoldi}}
473+
U, S, V = svd_trunc_no_error(f, alg)
474+
475+
# update rrule_alg tolerance to be compatible with smallest singular value
476+
rrule_alg = alg.rrule_alg
477+
smallest_sval = minimum(((_, b),) -> minimum(diag(b)), blocks(S))
478+
proper_tol = clamp(rrule_alg.tol, eps(scalartype(S))^(3 / 4), 1.0e-2 * smallest_sval)
479+
rrule_alg = @set rrule_alg.tol = proper_tol
480+
481+
function svd_trunc!_itersvd_pullback(ΔUSVi)
482+
Δf = similar(f)
483+
ΔU, ΔS, ΔV, = unthunk.(ΔUSVi)
484+
485+
for (c, b) in blocks(Δf)
486+
Uc, Sc, Vc = block(U, c), block(S, c), block(V, c)
487+
ΔUc, ΔSc, ΔVc = block(ΔU, c), block(ΔS, c), block(ΔV, c)
488+
Sdc = view(Sc, diagind(Sc))
489+
ΔSdc = ΔSc isa AbstractZero ? ΔSc : view(ΔSc, diagind(ΔSc))
490+
491+
n_vals = length(Sdc)
492+
lvecs = Vector{Vector{scalartype(f)}}(eachcol(Uc))
493+
rvecs = Vector{Vector{scalartype(f)}}(eachcol(Vc'))
494+
495+
# Dummy objects only used for warnings
496+
minimal_info = KrylovKit.ConvergenceInfo(n_vals, nothing, nothing, -1, -1) # Only num. converged is used
497+
minimal_alg = GKL(; tol = rrule_alg.tol, verbosity = 1) # Tolerance is used for gauge sensitivity, verbosity is used for warnings
498+
499+
if ΔUc isa AbstractZero && ΔVc isa AbstractZero # Handle ZeroTangent singular vectors
500+
Δlvecs = fill(ZeroTangent(), n_vals)
501+
Δrvecs = fill(ZeroTangent(), n_vals)
502+
else
503+
Δlvecs = Vector{Vector{scalartype(f)}}(eachcol(ΔUc))
504+
Δrvecs = Vector{Vector{scalartype(f)}}(eachcol(ΔVc'))
505+
end
506+
507+
xs, ys = KrylovKitCRCExt.compute_svdsolve_pullback_data(
508+
ΔSc isa AbstractZero ? fill(zero(Sc[1]), n_vals) : ΔSdc,
509+
Δlvecs,
510+
Δrvecs,
511+
Sdc,
512+
lvecs,
513+
rvecs,
514+
minimal_info,
515+
block(f, c),
516+
:LR,
517+
minimal_alg,
518+
rrule_alg,
519+
)
520+
copyto!(
521+
b,
522+
KrylovKitCRCExt.construct∂f_svd(
523+
HasReverseMode(), block(f, c), lvecs, rvecs, xs, ys
524+
),
525+
)
526+
end
527+
return NoTangent(), Δf, NoTangent()
528+
end
529+
function svd_trunc!_itersvd_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent})
530+
return NoTangent(), ZeroTangent(), NoTangent()
531+
end
532+
533+
return (U, S, V), svd_trunc!_itersvd_pullback
534+
end

test/utility/svd_wrapper.jl

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@ using ChainRulesCore, Zygote
66
using Accessors
77
using PEPSKit
88

9-
using MatrixAlgebraKit: TruncatedAlgorithm, diagview
9+
using MatrixAlgebraKit: TruncatedAlgorithm, diagview, svd_trunc_no_error
1010

1111
# Gauge-invariant loss function
12-
function lossfun(A, alg, R = randn(space(A)), trunc = notrunc())
12+
function lossfun(svd_trunc_f, A, alg, R = randn(space(A)), trunc = notrunc())
1313
alg = @set alg.fwd_alg = TruncatedAlgorithm(alg.fwd_alg, trunc)
14-
U, S, V, = svd_trunc(A, alg)
14+
USV = svd_trunc_f(A, alg)
15+
U, S, V = USV[1:3] # avoid looking at ϵ if present
1516
return real(dot(R, U * V)) + dot(S, S) # Overlap with random tensor R is gauge-invariant and differentiable, also for m≠n
1617
end
1718

@@ -28,29 +29,29 @@ full_alg = SVDAdjoint(; rrule_alg = (; alg = :FullPullback, degeneracy_atol = 1.
2829
trunc_alg = SVDAdjoint(; rrule_alg = (; alg = :TruncPullback, degeneracy_atol = 1.0e-13))
2930
iter_alg = SVDAdjoint(; fwd_alg = (; alg = :GKL))
3031

31-
@testset "Non-truncated SVD" begin
32-
l_full, g_full = withgradient(A -> lossfun(A, full_alg, R), r)
33-
l_trunc, g_trunc = withgradient(A -> lossfun(A, trunc_alg, R), r)
34-
l_iter, g_iter = withgradient(A -> lossfun(A, iter_alg, R), r)
32+
@testset "Non-truncated SVD $f" for f in (svd_trunc, svd_trunc_no_error)
33+
l_full, g_full = withgradient(A -> lossfun(f, A, full_alg, R), r)
34+
l_trunc, g_trunc = withgradient(A -> lossfun(f, A, trunc_alg, R), r)
35+
l_iter, g_iter = withgradient(A -> lossfun(f, A, iter_alg, R), r)
3536

3637
@test l_full l_trunc l_iter
3738
@test g_full[1] g_trunc[1] rtol = rtol
3839
@test g_full[1] g_iter[1] rtol = rtol
3940
@test g_trunc[1] g_iter[1] rtol = rtol
4041
end
4142

42-
@testset "Truncated SVD with χ=" begin
43-
l_full, g_full = withgradient(A -> lossfun(A, full_alg, R, trunc), r)
44-
l_trunc, g_trunc = withgradient(A -> lossfun(A, trunc_alg, R, trunc), r)
45-
l_iter, g_iter = withgradient(A -> lossfun(A, iter_alg, R, trunc), r)
43+
@testset "Truncated SVD $f with χ=" for f in (svd_trunc, svd_trunc_no_error)
44+
l_full, g_full = withgradient(A -> lossfun(f, A, full_alg, R, trunc), r)
45+
l_trunc, g_trunc = withgradient(A -> lossfun(f, A, trunc_alg, R, trunc), r)
46+
l_iter, g_iter = withgradient(A -> lossfun(f, A, iter_alg, R, trunc), r)
4647

4748
@test l_full l_trunc l_iter
4849
@test g_full[1] g_trunc[1] rtol = rtol
4950
@test g_full[1] g_iter[1] rtol = rtol
5051
@test g_trunc[1] g_iter[1] rtol = rtol
5152
end
5253

53-
@testset "Truncated SVD broadening for $(alg.rrule_alg)" for alg in [full_alg, trunc_alg]
54+
@testset "Truncated SVD broadening for $f, $(alg.rrule_alg)" for f in (svd_trunc, svd_trunc_no_error), alg in [full_alg, trunc_alg]
5455
u, s, v, = svd_compact(r)
5556
s.data[1:2:m] .= s.data[2:2:m] # make every singular value two-fold degenerate
5657
r_degen = u * s * v
@@ -59,13 +60,13 @@ end
5960
small_broadening_alg = @set full_alg.rrule_alg.degeneracy_atol = 1.0e-13
6061

6162
l_only_cutoff, g_only_cutoff = withgradient(
62-
A -> lossfun(A, full_alg, R, trunc), r_degen
63+
A -> lossfun(f, A, full_alg, R, trunc), r_degen
6364
) # cutoff sets degenerate difference to zero
6465
l_no_broadening_no_cutoff, g_no_broadening_no_cutoff = withgradient( # degenerate singular value differences lead to divergent contributions
65-
A -> lossfun(A, no_broadening_no_cutoff_alg, R, trunc), r_degen,
66+
A -> lossfun(f, A, no_broadening_no_cutoff_alg, R, trunc), r_degen,
6667
)
6768
l_small_broadening, g_small_broadening = withgradient( # broadening smoothens divergent contributions
68-
A -> lossfun(A, small_broadening_alg, R, trunc), r_degen,
69+
A -> lossfun(f, A, small_broadening_alg, R, trunc), r_degen,
6970
)
7071

7172
@test l_only_cutoff l_no_broadening_no_cutoff l_small_broadening
@@ -79,23 +80,23 @@ symm_trspace = truncspace(Z2Space(0 => symm_m ÷ 2, 1 => symm_n ÷ 3))
7980
symm_r = randn(dtype, symm_space, symm_space)
8081
symm_R = randn(dtype, space(symm_r))
8182

82-
@testset "IterSVD of symmetric tensors" begin
83-
l_full, g_full = withgradient(A -> lossfun(A, full_alg, symm_R), symm_r)
84-
l_trunc, g_trunc = withgradient(A -> lossfun(A, trunc_alg, symm_R), symm_r)
85-
l_iter, g_iter = withgradient(A -> lossfun(A, iter_alg, symm_R), symm_r)
83+
@testset "IterSVD of symmetric tensors $f" for f in (svd_trunc, svd_trunc_no_error)
84+
l_full, g_full = withgradient(A -> lossfun(f, A, full_alg, symm_R), symm_r)
85+
l_trunc, g_trunc = withgradient(A -> lossfun(f, A, trunc_alg, symm_R), symm_r)
86+
l_iter, g_iter = withgradient(A -> lossfun(f, A, iter_alg, symm_R), symm_r)
8687
@test l_full l_trunc l_iter
8788
@test g_full[1] g_trunc[1] rtol = rtol
8889
@test g_full[1] g_iter[1] rtol = rtol
8990
@test g_trunc[1] g_iter[1] rtol = rtol
9091

9192
l_full_tr, g_full_tr = withgradient(
92-
A -> lossfun(A, full_alg, symm_R, symm_trspace), symm_r
93+
A -> lossfun(f, A, full_alg, symm_R, symm_trspace), symm_r
9394
)
9495
l_trunc_tr, g_trunc_tr = withgradient(
95-
A -> lossfun(A, trunc_alg, symm_R, symm_trspace), symm_r
96+
A -> lossfun(f, A, trunc_alg, symm_R, symm_trspace), symm_r
9697
)
9798
l_iter_tr, g_iter_tr = withgradient(
98-
A -> lossfun(A, iter_alg, symm_R, symm_trspace), symm_r
99+
A -> lossfun(f, A, iter_alg, symm_R, symm_trspace), symm_r
99100
)
100101
@test l_full_tr l_trunc_tr l_iter_tr
101102
@test g_full_tr[1] g_trunc_tr[1] rtol = rtol
@@ -104,14 +105,14 @@ symm_R = randn(dtype, space(symm_r))
104105

105106
iter_alg_fallback = @set iter_alg.fwd_alg.fallback_threshold = 0.4 # Do dense decomposition in one block, sparse one in the other
106107
l_iter_fb, g_iter_fb = withgradient(
107-
A -> lossfun(A, iter_alg_fallback, symm_R, symm_trspace), symm_r
108+
A -> lossfun(f, A, iter_alg_fallback, symm_R, symm_trspace), symm_r
108109
)
109110
@test l_iter_fb l_trunc_tr l_full_tr
110111
@test g_full_tr[1] g_iter_fb[1] rtol = rtol
111112
@test g_trunc_tr[1] g_iter_fb[1] rtol = rtol
112113
end
113114

114-
@testset "Truncated symmetric SVD broadening for $(alg.rrule_alg)" for alg in [full_alg, trunc_alg]
115+
@testset "Truncated symmetric SVD broadening for $f, $(alg.rrule_alg)" for f in (svd_trunc, svd_trunc_no_error), alg in [full_alg, trunc_alg]
115116
u, s, v, = svd_compact(symm_r)
116117
# make every singular value in the 0-sector three-fold degenerate
117118
b0 = diagview(block(s, Z2Irrep(0)))
@@ -126,14 +127,14 @@ end
126127
small_broadening_alg = @set alg.rrule_alg.degeneracy_atol = 1.0e-13
127128

128129
l_only_cutoff, g_only_cutoff = withgradient(
129-
A -> lossfun(A, alg, symm_R, symm_trspace), symm_r_degen
130+
A -> lossfun(f, A, alg, symm_R, symm_trspace), symm_r_degen
130131
) # cutoff sets degenerate difference to zero
131132
l_no_broadening_no_cutoff, g_no_broadening_no_cutoff = withgradient( # degenerate singular value differences lead to divergent contributions
132-
A -> lossfun(A, no_broadening_no_cutoff_alg, symm_R, symm_trspace),
133+
A -> lossfun(f, A, no_broadening_no_cutoff_alg, symm_R, symm_trspace),
133134
symm_r_degen,
134135
)
135136
l_small_broadening, g_small_broadening = withgradient( # broadening smoothens divergent contributions
136-
A -> lossfun(A, small_broadening_alg, symm_R, symm_trspace),
137+
A -> lossfun(f, A, small_broadening_alg, symm_R, symm_trspace),
137138
symm_r_degen,
138139
)
139140

0 commit comments

Comments
 (0)