Skip to content

Commit 7b267ff

Browse files
committed
add remove_f_gauge_dependence! implementations
slight formatting
1 parent 518af61 commit 7b267ff

4 files changed

Lines changed: 68 additions & 91 deletions

File tree

src/factorizations/pullbacks.jl

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ for pullback! in (:qr_null_pullback!, :lq_null_pullback!)
2424
return Δt
2525
end
2626
end
27-
2827
_notrunc_ind(t) = SectorDict(c => Colon() for c in blocksectors(t))
2928

3029
for pullback! in (:svd_pullback!, :eig_pullback!, :eigh_pullback!)
@@ -51,8 +50,57 @@ for pullback_trunc! in (:svd_trunc_pullback!, :eig_trunc_pullback!, :eigh_trunc_
5150
foreachblock(Δt, t) do c, (Δb, b)
5251
Fc = block.(F, Ref(c))
5352
ΔFc = block.(ΔF, Ref(c))
54-
return MAK.$pullback_trunc!(Δb, b, Fc, ΔFc; kwargs...)
53+
MAK.$pullback_trunc!(Δb, b, Fc, ΔFc; kwargs...)
54+
return nothing
5555
end
5656
return Δt
5757
end
5858
end
59+
60+
for f in (:qr, :lq)
61+
remove_f_gauge_dependence! = Symbol(:remove_, f, :_gauge_dependence!)
62+
remove_f_null_gauge_dependence! = Symbol(:remove_, f, :_null_gauge_dependence!)
63+
@eval function MAK.$remove_f_gauge_dependence!(
64+
ΔF₁::AbstractTensorMap, ΔF₂::AbstractTensorMap, A, F₁, F₂;
65+
kwargs...
66+
)
67+
foreachblock(ΔF₁, ΔF₂, A, F₁, F₂) do _, (Δf₁, Δf₂, a, f₁, f₂)
68+
MAK.$remove_f_gauge_dependence!(Δf₁, Δf₂, a, f₁, f₂)
69+
return nothing
70+
end
71+
return ΔF₁, ΔF₂
72+
end
73+
# Already captured by MAK implementation
74+
# @eval function MAK.$remove_f_null_gauge_dependence!(ΔN::AbstractTensorMap, A, N; kwargs...)
75+
# foreachblock(ΔN, A, N) do _, (Δn, a, n)
76+
# $remove_f_gauge_dependence!(Δn, a, n)
77+
# end
78+
# return ΔN
79+
# end
80+
end
81+
82+
for f in (:eig, :eigh)
83+
remove_f_gauge_dependence! = Symbol(:remove_, f, :_gauge_dependence!)
84+
@eval function MAK.$remove_f_gauge_dependence!(
85+
ΔV::AbstractTensorMap, D, V, inds = _notrunc_ind(ΔV);
86+
kwargs...
87+
)
88+
foreachblock(ΔV, D, V) do c, (Δv, d, v)
89+
haskey(inds, c) || return nothing
90+
ind = inds[c]
91+
MAK.$remove_f_gauge_dependence!(Δv, d, v, ind; kwargs...)
92+
return nothing
93+
end
94+
return ΔV
95+
end
96+
end
97+
function MAK.remove_svd_gauge_dependence!(
98+
ΔU::AbstractTensorMap, ΔVᴴ::AbstractTensorMap, U, S, Vᴴ;
99+
kwargs...
100+
)
101+
foreachblock(ΔU, ΔVᴴ, U, S, Vᴴ) do c, (Δu, Δvᴴ, u, s, vʰ)
102+
MAK.remove_svd_gauge_dependence!(Δu, Δvᴴ, u, s, vᴴ)
103+
return nothing
104+
end
105+
return ΔU, ΔVᴴ
106+
end

test/chainrules/factorizations.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ using LinearAlgebra
99
using Zygote
1010
using MatrixAlgebraKit
1111
using MatrixAlgebraKit: diagview
12-
12+
using MatrixAlgebraKit: remove_qr_gauge_dependence!, remove_lq_gauge_dependence!,
13+
remove_eigh_gauge_dependence!, remove_eig_gauge_dependence!, remove_svd_gauge_dependence!
1314

1415
# Tests
1516
# -----
@@ -52,7 +53,7 @@ for V in spacelist
5253
@test_logs (:warn, r"^`qr") match_mode = :any full_pb((ΔQ, ΔR))
5354
end
5455

55-
remove_qrgauge_dependence!(ΔQ, t, Q)
56+
remove_qr_gauge_dependence!(ΔQ, ΔR, t, Q, R)
5657

5758
test_ad_rrule(qr_full, t; fkwargs, atol, rtol, output_tangent = (ΔQ, ΔR))
5859
test_ad_rrule(
@@ -90,7 +91,7 @@ for V in spacelist
9091
# @test_logs (:warn, r"^`lq") match_mode = :any full_pb((ΔL, ΔQ))
9192
end
9293

93-
remove_lqgauge_dependence!(ΔQ, t, Q)
94+
remove_lq_gauge_dependence!(ΔL, ΔQ, t, L, Q)
9495

9596
test_ad_rrule(lq_full, t; fkwargs, atol, rtol, output_tangent = (ΔL, ΔQ))
9697
test_ad_rrule(
@@ -114,7 +115,7 @@ for V in spacelist
114115
Δv = rand_tangent(v)
115116
Δd = rand_tangent(d)
116117
Δd2 = randn!(similar(d, space(d)))
117-
remove_eiggauge_dependence!(Δv, d, v)
118+
remove_eig_gauge_dependence!(Δv, d, v)
118119

119120
test_ad_rrule(eig_full, t; output_tangent = (Δd, Δv), atol, rtol)
120121
test_ad_rrule(first eig_full, t; output_tangent = Δd, atol, rtol)
@@ -126,7 +127,7 @@ for V in spacelist
126127
Δv = rand_tangent(v)
127128
Δd = rand_tangent(d)
128129
Δd2 = randn!(similar(d, space(d)))
129-
remove_eighgauge_dependence!(Δv, d, v)
130+
remove_eigh_gauge_dependence!(Δv, d, v)
130131

131132
# necessary for FiniteDifferences to not complain
132133
eigh_full′ = eigh_full project_hermitian
@@ -155,7 +156,7 @@ for V in spacelist
155156
USVᴴ = svd_compact(t)
156157
ΔU, ΔS, ΔVᴴ = rand_tangent.(USVᴴ)
157158
ΔS2 = randn!(similar(ΔS, space(ΔS)))
158-
ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, USVᴴ...; degeneracy_atol)
159+
ΔU, ΔVᴴ = remove_svd_gauge_dependence!(ΔU, ΔVᴴ, USVᴴ...; degeneracy_atol)
159160

160161
# test_ad_rrule(svd_full, t; output_tangent = (ΔU, ΔS, ΔVᴴ), atol, rtol)
161162
# test_ad_rrule(svd_full, t; output_tangent = (ΔU, ΔS2, ΔVᴴ), atol, rtol)
@@ -170,7 +171,7 @@ for V in spacelist
170171
trunc = truncspace(V_trunc)
171172
USVᴴ_trunc = svd_trunc(t; trunc)
172173
ΔUSVᴴ_trunc = (rand_tangent.(Base.front(USVᴴ_trunc))..., zero(last(USVᴴ_trunc)))
173-
remove_svdgauge_dependence!(
174+
remove_svd_gauge_dependence!(
174175
ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], Base.front(USVᴴ_trunc)...; degeneracy_atol
175176
)
176177
test_ad_rrule(

test/mooncake/factorizations.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ using TensorKit
33
using TensorOperations
44
using VectorInterface: Zero, One
55
using MatrixAlgebraKit
6+
using MatrixAlgebraKit: remove_qr_gauge_dependence!, remove_lq_gauge_dependence!,
7+
remove_eigh_gauge_dependence!, remove_eig_gauge_dependence!, remove_svd_gauge_dependence!
68
using Mooncake
79
using Random
810

@@ -25,7 +27,7 @@ eltypes = (Float64, ComplexF64)
2527
# qr_full/qr_null requires being careful with gauges
2628
QR = qr_full(A)
2729
ΔQR = Mooncake.randn_tangent(rng, QR)
28-
remove_qrgauge_dependence!(ΔQR[1], A, QR[1])
30+
remove_qr_gauge_dependence!(ΔQR..., A, QR...)
2931
Mooncake.TestUtils.test_rule(rng, qr_full, A; output_tangent = ΔQR, atol, rtol, mode, is_primitive = false)
3032
# TODO:
3133
# Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false)
@@ -37,7 +39,7 @@ eltypes = (Float64, ComplexF64)
3739
# qr_full/qr_null requires being careful with gauges
3840
QR = qr_full(A)
3941
ΔQR = Mooncake.randn_tangent(rng, QR)
40-
remove_qrgauge_dependence!(ΔQR[1], A, QR[1])
42+
remove_qr_gauge_dependence!(ΔQR..., A, QR...)
4143
Mooncake.TestUtils.test_rule(rng, qr_full, A; output_tangent = ΔQR, atol, rtol, mode, is_primitive = false)
4244
# TODO:
4345
# Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false)
@@ -51,7 +53,7 @@ eltypes = (Float64, ComplexF64)
5153
# qr_full/qr_null requires being careful with gauges
5254
LQ = lq_full(A)
5355
ΔLQ = Mooncake.randn_tangent(rng, LQ)
54-
remove_lqgauge_dependence!(ΔLQ[2], A, LQ[2])
56+
remove_lq_gauge_dependence!(ΔLQ..., A, LQ...)
5557
Mooncake.TestUtils.test_rule(rng, lq_full, A; output_tangent = ΔLQ, atol, rtol, mode, is_primitive = false)
5658
# TODO:
5759
# Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false)
@@ -63,7 +65,7 @@ eltypes = (Float64, ComplexF64)
6365
# qr_full/qr_null requires being careful with gauges
6466
LQ = lq_full(A)
6567
ΔLQ = Mooncake.randn_tangent(rng, LQ)
66-
remove_lqgauge_dependence!(ΔLQ[2], A, LQ[2])
68+
remove_lq_gauge_dependence!(ΔLQ..., A, LQ...)
6769
Mooncake.TestUtils.test_rule(rng, lq_full, A; output_tangent = ΔLQ, atol, rtol, mode, is_primitive = false)
6870
# TODO:
6971
# Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false)
@@ -73,13 +75,13 @@ eltypes = (Float64, ComplexF64)
7375
for t in (randn(T, V[1] V[1]), rand(T, V[1] V[2] V[1] V[2]))
7476
DV = eig_full(t)
7577
ΔDV = Mooncake.randn_tangent(rng, DV)
76-
remove_eiggauge_dependence!(ΔDV[2], DV...)
78+
remove_eig_gauge_dependence!(ΔDV[2], DV...)
7779
Mooncake.TestUtils.test_rule(rng, eig_full, t; output_tangent = ΔDV, atol, rtol, mode, is_primitive = false)
7880

7981
th = project_hermitian(t)
8082
DV = eigh_full(th)
8183
ΔDV = Mooncake.randn_tangent(rng, DV)
82-
remove_eighgauge_dependence!(ΔDV[2], DV...)
84+
remove_eigh_gauge_dependence!(ΔDV[2], DV...)
8385
Mooncake.TestUtils.test_rule(rng, eigh_full project_hermitian, th; output_tangent = ΔDV, atol, rtol, mode, is_primitive = false)
8486
end
8587
end
@@ -88,7 +90,7 @@ eltypes = (Float64, ComplexF64)
8890
for t in (randn(T, V[1] V[1]), randn(T, V[1] V[2] (V[3] V[4] V[5])'))
8991
USVᴴ = svd_compact(t)
9092
ΔUSVᴴ = Mooncake.randn_tangent(rng, USVᴴ)
91-
remove_svdgauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...)
93+
remove_svd_gauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...)
9294
Mooncake.TestUtils.test_rule(rng, svd_compact, t; output_tangent = ΔUSVᴴ, atol, rtol, mode, is_primitive = false)
9395

9496
# USVᴴ = svd_full(t)
@@ -101,7 +103,7 @@ eltypes = (Float64, ComplexF64)
101103
alg = MatrixAlgebraKit.select_algorithm(svd_trunc, t, nothing; trunc)
102104
USVᴴtrunc = svd_trunc(t, alg)
103105
ΔUSVᴴtrunc = (Mooncake.randn_tangent(rng, Base.front(USVᴴtrunc))..., zero(last(USVᴴtrunc)))
104-
remove_svdgauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], Base.front(USVᴴtrunc)...)
106+
remove_svd_gauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], Base.front(USVᴴtrunc)...)
105107
Mooncake.TestUtils.test_rule(rng, svd_trunc, t, alg; output_tangent = ΔUSVᴴtrunc, atol, rtol, mode)
106108
end
107109
end

test/setup.jl

Lines changed: 0 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ export random_fusion
77
export sectorlist, fast_sectorlist
88
# export dim_isapprox
99
export default_spacelist, factorization_spacelist, ad_spacelist
10-
export remove_qrgauge_dependence!, remove_lqgauge_dependence!
11-
export remove_eiggauge_dependence!, remove_eighgauge_dependence!, remove_svdgauge_dependence!
1210
export test_ad_rrule
1311
export _isunitary, _isone
1412

@@ -398,78 +396,6 @@ function ad_spacelist(fast_tests::Bool)
398396
return fast_tests ? (Vtr, VRepU₁, VfHubbard, VRepA4Twistedℤ₄) : (Vtr, VRepℤ₂, VRepCU₁, VfHubbard, VRepA4Twistedℤ₄, VIBMRepA4)
399397
end
400398

401-
# Gauge-fixing tangents for AD factorization tests
402-
# -------------------------------------------------
403-
function remove_qrgauge_dependence!(ΔQ, t, Q)
404-
for (c, b) in blocks(ΔQ)
405-
m, n = size(block(t, c))
406-
minmn = min(m, n)
407-
Qc = block(Q, c)
408-
Q1 = view(Qc, 1:m, 1:minmn)
409-
ΔQ2 = view(b, :, (minmn + 1):m)
410-
mul!(ΔQ2, Q1, Q1' * ΔQ2)
411-
end
412-
return ΔQ
413-
end
414-
function remove_lqgauge_dependence!(ΔQ, t, Q)
415-
for (c, b) in blocks(ΔQ)
416-
m, n = size(block(t, c))
417-
minmn = min(m, n)
418-
Qc = block(Q, c)
419-
Q1 = view(Qc, 1:minmn, 1:n)
420-
ΔQ2 = view(b, (minmn + 1):n, :)
421-
mul!(ΔQ2, ΔQ2 * Q1', Q1)
422-
end
423-
return ΔQ
424-
end
425-
function remove_eiggauge_dependence!(
426-
ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D)
427-
)
428-
gaugepart = V' * ΔV
429-
for (c, b) in blocks(gaugepart)
430-
Dc = diagview(block(D, c))
431-
# for some reason this fails only on tests, and I cannot reproduce it in an
432-
# interactive session.
433-
# b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0
434-
for j in axes(b, 2), i in axes(b, 1)
435-
abs(Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0)
436-
end
437-
end
438-
mul!(ΔV, V / (V' * V), gaugepart, -1, 1)
439-
return ΔV
440-
end
441-
function remove_eighgauge_dependence!(
442-
ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D)
443-
)
444-
gaugepart = project_antihermitian!(V' * ΔV)
445-
for (c, b) in blocks(gaugepart)
446-
Dc = diagview(block(D, c))
447-
# for some reason this fails only on tests, and I cannot reproduce it in an
448-
# interactive session.
449-
# b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0
450-
for j in axes(b, 2), i in axes(b, 1)
451-
abs(Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0)
452-
end
453-
end
454-
mul!(ΔV, V, gaugepart, -1, 1)
455-
return ΔV
456-
end
457-
function remove_svdgauge_dependence!(
458-
ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(S)
459-
)
460-
gaugepart = project_antihermitian!(U' * ΔU + Vᴴ * ΔVᴴ')
461-
for (c, b) in blocks(gaugepart)
462-
Sd = diagview(block(S, c))
463-
# for some reason this fails only on tests, and I cannot reproduce it in an
464-
# interactive session.
465-
# b[abs.(transpose(diagview(Sc)) .- diagview(Sc)) .>= degeneracy_atol] .= 0
466-
for j in axes(b, 2), i in axes(b, 1)
467-
abs(Sd[i] - Sd[j]) >= degeneracy_atol && (b[i, j] = 0)
468-
end
469-
end
470-
mul!(ΔU, U, gaugepart, -1, 1)
471-
return ΔU, ΔVᴴ
472-
end
473399

474400
# ChainRules test utilities
475401
# -------------------------

0 commit comments

Comments
 (0)