@@ -6,12 +6,13 @@ using ChainRulesCore, Zygote
66using Accessors
77using PEPSKit
88
9- using MatrixAlgebraKit: TruncatedAlgorithm, diagview
9+ using MatrixAlgebraKit: TruncatedAlgorithm, diagview, svd_trunc_no_error
1010
1111# Gauge-invariant loss function
12- function lossfun (A, alg, R = randn (space (A)), trunc = notrunc ())
12+ function lossfun (svd_trunc_f, A, alg, R = randn (space (A)), trunc = notrunc ())
1313 alg = @set alg. fwd_alg = TruncatedAlgorithm (alg. fwd_alg, trunc)
14- U, S, V, = svd_trunc (A, alg)
14+ USV = svd_trunc_f (A, alg)
15+ U, S, V = USV[1 : 3 ] # avoid looking at ϵ if present
1516 return real (dot (R, U * V)) + dot (S, S) # Overlap with random tensor R is gauge-invariant and differentiable, also for m≠n
1617end
1718
@@ -28,29 +29,29 @@ full_alg = SVDAdjoint(; rrule_alg = (; alg = :FullPullback, degeneracy_atol = 1.
2829trunc_alg = SVDAdjoint (; rrule_alg = (; alg = :TruncPullback , degeneracy_atol = 1.0e-13 ))
2930iter_alg = SVDAdjoint (; fwd_alg = (; alg = :GKL ))
3031
31- @testset " Non-truncated SVD" begin
32- l_full, g_full = withgradient (A -> lossfun (A, full_alg, R), r)
33- l_trunc, g_trunc = withgradient (A -> lossfun (A, trunc_alg, R), r)
34- l_iter, g_iter = withgradient (A -> lossfun (A, iter_alg, R), r)
32+ @testset " Non-truncated SVD $f " for f in (svd_trunc, svd_trunc_no_error)
33+ l_full, g_full = withgradient (A -> lossfun (f, A, full_alg, R), r)
34+ l_trunc, g_trunc = withgradient (A -> lossfun (f, A, trunc_alg, R), r)
35+ l_iter, g_iter = withgradient (A -> lossfun (f, A, iter_alg, R), r)
3536
3637 @test l_full ≈ l_trunc ≈ l_iter
3738 @test g_full[1 ] ≈ g_trunc[1 ] rtol = rtol
3839 @test g_full[1 ] ≈ g_iter[1 ] rtol = rtol
3940 @test g_trunc[1 ] ≈ g_iter[1 ] rtol = rtol
4041end
4142
42- @testset " Truncated SVD with χ=$χ " begin
43- l_full, g_full = withgradient (A -> lossfun (A, full_alg, R, trunc), r)
44- l_trunc, g_trunc = withgradient (A -> lossfun (A, trunc_alg, R, trunc), r)
45- l_iter, g_iter = withgradient (A -> lossfun (A, iter_alg, R, trunc), r)
43+ @testset " Truncated SVD $f with χ=$χ " for f in (svd_trunc, svd_trunc_no_error)
44+ l_full, g_full = withgradient (A -> lossfun (f, A, full_alg, R, trunc), r)
45+ l_trunc, g_trunc = withgradient (A -> lossfun (f, A, trunc_alg, R, trunc), r)
46+ l_iter, g_iter = withgradient (A -> lossfun (f, A, iter_alg, R, trunc), r)
4647
4748 @test l_full ≈ l_trunc ≈ l_iter
4849 @test g_full[1 ] ≈ g_trunc[1 ] rtol = rtol
4950 @test g_full[1 ] ≈ g_iter[1 ] rtol = rtol
5051 @test g_trunc[1 ] ≈ g_iter[1 ] rtol = rtol
5152end
5253
53- @testset " Truncated SVD broadening for $(alg. rrule_alg) " for alg in [full_alg, trunc_alg]
54+ @testset " Truncated SVD broadening for $f , $ (alg. rrule_alg)" for f in (svd_trunc, svd_trunc_no_error), alg in [full_alg, trunc_alg]
5455 u, s, v, = svd_compact (r)
5556 s. data[1 : 2 : m] .= s. data[2 : 2 : m] # make every singular value two-fold degenerate
5657 r_degen = u * s * v
5960 small_broadening_alg = @set full_alg. rrule_alg. degeneracy_atol = 1.0e-13
6061
6162 l_only_cutoff, g_only_cutoff = withgradient (
62- A -> lossfun (A, full_alg, R, trunc), r_degen
63+ A -> lossfun (f, A, full_alg, R, trunc), r_degen
6364 ) # cutoff sets degenerate difference to zero
6465 l_no_broadening_no_cutoff, g_no_broadening_no_cutoff = withgradient ( # degenerate singular value differences lead to divergent contributions
65- A -> lossfun (A, no_broadening_no_cutoff_alg, R, trunc), r_degen,
66+ A -> lossfun (f, A, no_broadening_no_cutoff_alg, R, trunc), r_degen,
6667 )
6768 l_small_broadening, g_small_broadening = withgradient ( # broadening smoothens divergent contributions
68- A -> lossfun (A, small_broadening_alg, R, trunc), r_degen,
69+ A -> lossfun (f, A, small_broadening_alg, R, trunc), r_degen,
6970 )
7071
7172 @test l_only_cutoff ≈ l_no_broadening_no_cutoff ≈ l_small_broadening
@@ -79,23 +80,23 @@ symm_trspace = truncspace(Z2Space(0 => symm_m ÷ 2, 1 => symm_n ÷ 3))
7980symm_r = randn (dtype, symm_space, symm_space)
8081symm_R = randn (dtype, space (symm_r))
8182
82- @testset " IterSVD of symmetric tensors" begin
83- l_full, g_full = withgradient (A -> lossfun (A, full_alg, symm_R), symm_r)
84- l_trunc, g_trunc = withgradient (A -> lossfun (A, trunc_alg, symm_R), symm_r)
85- l_iter, g_iter = withgradient (A -> lossfun (A, iter_alg, symm_R), symm_r)
83+ @testset " IterSVD of symmetric tensors $f " for f in (svd_trunc, svd_trunc_no_error)
84+ l_full, g_full = withgradient (A -> lossfun (f, A, full_alg, symm_R), symm_r)
85+ l_trunc, g_trunc = withgradient (A -> lossfun (f, A, trunc_alg, symm_R), symm_r)
86+ l_iter, g_iter = withgradient (A -> lossfun (f, A, iter_alg, symm_R), symm_r)
8687 @test l_full ≈ l_trunc ≈ l_iter
8788 @test g_full[1 ] ≈ g_trunc[1 ] rtol = rtol
8889 @test g_full[1 ] ≈ g_iter[1 ] rtol = rtol
8990 @test g_trunc[1 ] ≈ g_iter[1 ] rtol = rtol
9091
9192 l_full_tr, g_full_tr = withgradient (
92- A -> lossfun (A, full_alg, symm_R, symm_trspace), symm_r
93+ A -> lossfun (f, A, full_alg, symm_R, symm_trspace), symm_r
9394 )
9495 l_trunc_tr, g_trunc_tr = withgradient (
95- A -> lossfun (A, trunc_alg, symm_R, symm_trspace), symm_r
96+ A -> lossfun (f, A, trunc_alg, symm_R, symm_trspace), symm_r
9697 )
9798 l_iter_tr, g_iter_tr = withgradient (
98- A -> lossfun (A, iter_alg, symm_R, symm_trspace), symm_r
99+ A -> lossfun (f, A, iter_alg, symm_R, symm_trspace), symm_r
99100 )
100101 @test l_full_tr ≈ l_trunc_tr ≈ l_iter_tr
101102 @test g_full_tr[1 ] ≈ g_trunc_tr[1 ] rtol = rtol
@@ -104,14 +105,14 @@ symm_R = randn(dtype, space(symm_r))
104105
105106 iter_alg_fallback = @set iter_alg. fwd_alg. fallback_threshold = 0.4 # Do dense decomposition in one block, sparse one in the other
106107 l_iter_fb, g_iter_fb = withgradient (
107- A -> lossfun (A, iter_alg_fallback, symm_R, symm_trspace), symm_r
108+ A -> lossfun (f, A, iter_alg_fallback, symm_R, symm_trspace), symm_r
108109 )
109110 @test l_iter_fb ≈ l_trunc_tr ≈ l_full_tr
110111 @test g_full_tr[1 ] ≈ g_iter_fb[1 ] rtol = rtol
111112 @test g_trunc_tr[1 ] ≈ g_iter_fb[1 ] rtol = rtol
112113end
113114
114- @testset " Truncated symmetric SVD broadening for $(alg. rrule_alg) " for alg in [full_alg, trunc_alg]
115+ @testset " Truncated symmetric SVD broadening for $f , $ (alg. rrule_alg)" for f in (svd_trunc, svd_trunc_no_error), alg in [full_alg, trunc_alg]
115116 u, s, v, = svd_compact (symm_r)
116117 # make every singular value in the 0-sector three-fold degenerate
117118 b0 = diagview (block (s, Z2Irrep (0 )))
@@ -126,14 +127,14 @@ end
126127 small_broadening_alg = @set alg. rrule_alg. degeneracy_atol = 1.0e-13
127128
128129 l_only_cutoff, g_only_cutoff = withgradient (
129- A -> lossfun (A, alg, symm_R, symm_trspace), symm_r_degen
130+ A -> lossfun (f, A, alg, symm_R, symm_trspace), symm_r_degen
130131 ) # cutoff sets degenerate difference to zero
131132 l_no_broadening_no_cutoff, g_no_broadening_no_cutoff = withgradient ( # degenerate singular value differences lead to divergent contributions
132- A -> lossfun (A, no_broadening_no_cutoff_alg, symm_R, symm_trspace),
133+ A -> lossfun (f, A, no_broadening_no_cutoff_alg, symm_R, symm_trspace),
133134 symm_r_degen,
134135 )
135136 l_small_broadening, g_small_broadening = withgradient ( # broadening smoothens divergent contributions
136- A -> lossfun (A, small_broadening_alg, symm_R, symm_trspace),
137+ A -> lossfun (f, A, small_broadening_alg, symm_R, symm_trspace),
137138 symm_r_degen,
138139 )
139140
0 commit comments