Skip to content

Commit 54f6021

Browse files
Jutholkdvos
andauthored
pullback reorganization (#208)
* pullback reorganization * add quadratic svd_trunc_pullback * Apply suggestions from code review Co-authored-by: Lukas Devos <ldevos98@gmail.com> Co-authored-by: Jutho <Jutho@users.noreply.github.com> * Apply more suggestions from code review Co-authored-by: Lukas Devos <ldevos98@gmail.com> * some changes from review * fixes and improved numerical stability * one more code suggestion Co-authored-by: Lukas Devos <ldevos98@gmail.com> * mark gauge dependence removal as public * Update src/pullbacks/svd.jl Co-authored-by: Jutho <Jutho@users.noreply.github.com> * improve error messages * remove unused function * formatting * more unicode --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent 03d21cf commit 54f6021

10 files changed

Lines changed: 478 additions & 398 deletions

File tree

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -240,19 +240,7 @@ for f in (:svd_compact!, :svd_full!)
240240
USVᴴval = something(cache_USVᴴ, USVᴴ.val)
241241
if !isa(A, Const)
242242
minmn = min(size(A.val)...)
243-
if $(f == svd_compact!) # compact
244-
svd_pullback!(A.dval, Aval, USVᴴval, dUSVᴴ)
245-
else # full
246-
# TODO: revisit this once `svd_pullback` supports `svd_full` output and adjoints
247-
U, S, Vᴴ = USVᴴval
248-
vU = view(U, :, 1:minmn)
249-
vS = Diagonal(view(diagview(S), 1:minmn))
250-
vVᴴ = view(Vᴴ, 1:minmn, :)
251-
vdU = view(dUSVᴴ[1], :, 1:minmn)
252-
vdS = Diagonal(view(diagview(dUSVᴴ[2]), 1:minmn))
253-
vdVᴴ = view(dUSVᴴ[3], 1:minmn, :)
254-
svd_pullback!(A.dval, Aval, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
255-
end
243+
svd_pullback!(A.dval, Aval, USVᴴval, dUSVᴴ)
256244
end
257245
!isa(USVᴴ, Const) && make_zero!(USVᴴ.dval)
258246
return (nothing, nothing, nothing)

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -418,28 +418,17 @@ for (f!, f) in (
418418
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
419419
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
420420
A, dA = arrayify(A_dA)
421-
Ac = copy(A)
422421
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
423422
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
424423
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
425424
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
426425
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
426+
Ac = copy(A)
427427
USVᴴc = copy.(USVᴴ)
428428
output = $f!(A, USVᴴ, Mooncake.primal(alg_dalg))
429429
function svd_adjoint(::NoRData)
430430
copy!(A, Ac)
431-
if $(f! == svd_compact!)
432-
svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
433-
else # full
434-
minmn = min(size(A)...)
435-
vU = view(U, :, 1:minmn)
436-
vS = Diagonal(diagview(S)[1:minmn])
437-
vVᴴ = view(Vᴴ, 1:minmn, :)
438-
vdU = view(dU, :, 1:minmn)
439-
vdS = Diagonal(diagview(dS)[1:minmn])
440-
vdVᴴ = view(dVᴴ, 1:minmn, :)
441-
svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
442-
end
431+
svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
443432
copy!(U, USVᴴc[1])
444433
copy!(S, USVᴴc[2])
445434
copy!(Vᴴ, USVᴴc[3])
@@ -448,7 +437,7 @@ for (f!, f) in (
448437
zero!(dVᴴ)
449438
return NoRData(), NoRData(), NoRData(), NoRData()
450439
end
451-
return CoDual(output, dUSVᴴ), svd_adjoint
440+
return USVᴴ_dUSVᴴ, svd_adjoint
452441
end
453442
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
454443
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual)
@@ -465,18 +454,7 @@ for (f!, f) in (
465454
U, dU = arrayify(U, dU_)
466455
S, dS = arrayify(S, dS_)
467456
Vᴴ, dVᴴ = arrayify(Vᴴ, dVᴴ_)
468-
if $(f == svd_compact)
469-
svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
470-
else # full
471-
minmn = min(size(A)...)
472-
vU = view(U, :, 1:minmn)
473-
vS = Diagonal(view(diagview(S), 1:minmn))
474-
vVᴴ = view(Vᴴ, 1:minmn, :)
475-
vdU = view(dU, :, 1:minmn)
476-
vdS = Diagonal(view(diagview(dS), 1:minmn))
477-
vdVᴴ = view(dVᴴ, 1:minmn, :)
478-
svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
479-
end
457+
svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
480458
zero!(dU)
481459
zero!(dS)
482460
zero!(dVᴴ)

src/MatrixAlgebraKit.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,15 @@ export notrunc, truncrank, trunctol, truncerror, truncfilter
7272
:svd_pullback!, :svd_trunc_pullback!, :svd_vals_pullback!
7373
)
7474
)
75+
eval(
76+
Expr(
77+
:public, :remove_svd_gauge_dependence!,
78+
:remove_eig_gauge_dependence!, :remove_eigh_gauge_dependence!,
79+
:remove_qr_gauge_dependence!, :remove_qr_null_gauge_dependence!,
80+
:remove_lq_gauge_dependence!, :remove_lq_null_gauge_dependence!,
81+
:remove_left_null_gauge_dependence!, :remove_right_null_gauge_dependence!,
82+
)
83+
)
7584
eval(Expr(:public, :is_left_isometric, :is_right_isometric))
7685
end
7786

src/pullbacks/eig.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,25 @@ function eig_vals_pullback!(
201201
ΔDV = (diagonal(ΔD), nothing)
202202
return eig_pullback!(ΔA, A, DV, ΔDV, ind; degeneracy_atol)
203203
end
204+
205+
"""
206+
remove_eig_gauge_dependence!(ΔV, D, V; degeneracy_atol = ...)
207+
208+
Remove the gauge-dependent part from the cotangent `ΔV` of the eigenvector matrix `V`. The
209+
eigenvectors are only determined up to a scalar factor (or an abitrary linear transformation
210+
across eigenvectors associated with degenerate eigenvalues), so the corresponding components of
211+
`ΔV` are projected out.
212+
"""
213+
function remove_eig_gauge_dependence!(
214+
ΔV, D, V, ind = axes(ΔV, 2);
215+
degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D)
216+
)
217+
length(ind) == size(ΔV, 2) || throw(DimensionMismatch())
218+
indV = axes(V, 2)[ind]
219+
Vp = view(V, :, indV)
220+
Ddiag = view(diagview(D), indV)
221+
gaugepart = Vp' * ΔV
222+
gaugepart[abs.(transpose(Ddiag) .- Ddiag) .>= degeneracy_atol] .= 0
223+
mul!(ΔV, Vp / (Vp' * Vp), gaugepart, -1, 1)
224+
return ΔV
225+
end

src/pullbacks/eigh.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,3 +191,25 @@ function eigh_vals_pullback!(
191191
ΔDV = (diagonal(ΔD), nothing)
192192
return eigh_pullback!(ΔA, A, DV, ΔDV, ind; degeneracy_atol)
193193
end
194+
195+
"""
196+
remove_eigh_gauge_dependence!(ΔV, D, V; degeneracy_atol = ...)
197+
198+
Remove the gauge-dependent part from the cotangent `ΔV` of the Hermitian eigenvector matrix
199+
`V`. The eigenvectors are only determined up to a complex phase (or a unitary transformation
200+
across eigenvectors associated with degenerate eigenvalues), so the corresponding anti-Hermitian
201+
components of `V' * ΔV` are projected out.
202+
"""
203+
function remove_eigh_gauge_dependence!(
204+
ΔV, D, V, ind = axes(ΔV, 2);
205+
degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D)
206+
)
207+
length(ind) == size(ΔV, 2) || throw(DimensionMismatch())
208+
indV = axes(V, 2)[ind]
209+
Vp = view(V, :, indV)
210+
Ddiag = view(diagview(D), indV)
211+
gaugepart = project_antihermitian!(Vp' * ΔV)
212+
gaugepart[abs.(transpose(Ddiag) .- Ddiag) .>= degeneracy_atol] .= 0
213+
mul!(ΔV, Vp, gaugepart, -1, 1)
214+
return ΔV
215+
end

src/pullbacks/lq.jl

Lines changed: 85 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,45 @@
11
lq_rank(L; kwargs...) = qr_rank(L; kwargs...)
22

3-
function check_lq_cotangents(
3+
function check_and_prepare_lq_cotangents(
44
L, Q, ΔL, ΔQ, p::Int;
55
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
66
)
7-
minmn = min(size(L, 1), size(Q, 2))
7+
m, n = size(L, 1), size(Q, 2)
8+
minmn = min(m, n)
89
Δgauge = abs(zero(eltype(Q)))
10+
Q₁ = view(Q, 1:p, :)
11+
ΔQ₁ = zero!(similar(Q₁))
912
if !iszerotangent(ΔQ)
10-
ΔQ₂ = view(ΔQ, (p + 1):minmn, :)
11-
ΔQ₃ = ΔQ[(minmn + 1):size(Q, 1), :]
12-
Δgauge_Q = norm(ΔQ₂, Inf)
13-
Q₁ = view(Q, 1:p, :)
14-
ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁'
15-
mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁, -1, 1)
16-
Δgauge_Q = max(Δgauge_Q, norm(ΔQ₃, Inf))
13+
size(ΔQ) == size(Q) || throw(DimensionMismatch("ΔQ must have the same size as Q"))
14+
ΔQ₁ .= view(ΔQ, 1:p, 1:n)
15+
if p == minmn # full rank case, ΔQ₃ contains gauge-invariant information along Q₁
16+
Q₃ = view(Q, (minmn + 1):size(Q, 1), :)
17+
ΔQ₃ = view(ΔQ, (minmn + 1):size(Q, 1), :)
18+
ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁'
19+
mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁, -1, 1)
20+
Δgauge_Q = norm(ΔQ₃, Inf)
21+
mul!(ΔQ₁, ΔQ₃Q₁ᴴ', Q₃, -1, 1)
22+
else
23+
ΔQ₂ = view(ΔQ, (p + 1):size(ΔQ, 1), :)
24+
Δgauge_Q = norm(ΔQ₂, Inf)
25+
end
1726
Δgauge = max(Δgauge, Δgauge_Q)
1827
end
1928
if !iszerotangent(ΔL)
20-
ΔL22 = view(ΔL, (p + 1):size(ΔL, 1), (p + 1):minmn)
21-
Δgauge_L = norm(view(ΔL22, lowertriangularind(ΔL22)), Inf)
22-
Δgauge_L = max(Δgauge_L, norm(view(ΔL22, diagind(ΔL22)), Inf))
29+
size(ΔL) == size(L) || throw(DimensionMismatch("ΔL must have the same size as L"))
30+
ΔL₁₁ = LowerTriangular(view(ΔL, 1:p, 1:p))
31+
ΔL₂₁ = view(ΔL, (p + 1):size(ΔL, 1), 1:p)
32+
ΔL₂₂ = view(ΔL, (p + 1):size(ΔL, 1), (p + 1):minmn)
33+
Δgauge_L = norm(view(ΔL₂₂, lowertriangularind(ΔL₂₂)), Inf)
34+
Δgauge_L = max(Δgauge_L, norm(view(ΔL₂₂, diagind(ΔL₂₂)), Inf))
2335
Δgauge = max(Δgauge, Δgauge_L)
36+
else
37+
ΔL₁₁ = nothing
38+
ΔL₂₁ = nothing
2439
end
2540
Δgauge gauge_atol ||
2641
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
27-
return nothing
42+
return ΔL₁₁, ΔL₂₁, ΔQ₁
2843
end
2944

3045
"""
@@ -53,53 +68,37 @@ function lq_pullback!(
5368
L, Q = LQ
5469
m = size(L, 1)
5570
n = size(Q, 2)
56-
minmn = min(m, n)
5771
p = lq_rank(L; rank_atol)
72+
(m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of L*Q ($m, $n)"))
5873

59-
ΔL, ΔQ = ΔLQ
60-
61-
Q₁ = view(Q, 1:p, :)
6274
L₁₁ = LowerTriangular(view(L, 1:p, 1:p))
75+
L₂₁ = view(L, (p + 1):m, 1:p)
76+
Q₁ = view(Q, 1:p, :)
77+
6378
ΔA₁ = view(ΔA, 1:p, :)
6479
ΔA₂ = view(ΔA, (p + 1):m, :)
6580

66-
check_lq_cotangents(L, Q, ΔL, ΔQ, p; gauge_atol)
81+
ΔL, ΔQ = ΔLQ
82+
ΔL₁₁, ΔL₂₁, ΔQ₁ = check_and_prepare_lq_cotangents(L, Q, ΔL, ΔQ, p; gauge_atol)
6783

68-
ΔQ̃ = zero!(similar(Q, (p, n)))
69-
if !iszerotangent(ΔQ)
70-
ΔQ₁ = view(ΔQ, 1:p, :)
71-
copy!(ΔQ̃, ΔQ₁)
72-
if minmn < size(Q, 1)
73-
ΔQ₃ = view(ΔQ, (minmn + 1):size(ΔQ, 1), :)
74-
Q₃ = view(Q, (minmn + 1):size(Q, 1), :)
75-
ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁'
76-
ΔQ̃ = mul!(ΔQ̃, ΔQ₃Q₁ᴴ', Q₃, -1, 1)
77-
end
78-
end
7984
if !iszerotangent(ΔL) && m > p
80-
L₂₁ = view(L, (p + 1):m, 1:p)
81-
ΔL₂₁ = view(ΔL, (p + 1):m, 1:p)
82-
ΔQ̃ = mul!(ΔQ̃, L₂₁' * ΔL₂₁, Q₁, -1, 1)
85+
ΔQ₁ = mul!(ΔQ₁, L₂₁' * ΔL₂₁, Q₁, -1, 1)
8386
# Adding ΔA₂ contribution
8487
ΔA₂ = mul!(ΔA₂, ΔL₂₁, Q₁, 1, 1)
8588
end
8689

8790
# construct M
8891
M = zero!(similar(L, (p, p)))
8992
if !iszerotangent(ΔL)
90-
ΔL₁₁ = LowerTriangular(view(ΔL, 1:p, 1:p))
9193
M = mul!(M, L₁₁', ΔL₁₁, 1, 1)
9294
end
93-
M = mul!(M, ΔQ̃, Q₁', -1, 1)
95+
M = mul!(M, ΔQ₁, Q₁', -1, 1)
9496
view(M, uppertriangularind(M)) .= conj.(view(M, lowertriangularind(M)))
9597
if eltype(M) <: Complex
9698
Md = diagview(M)
9799
Md .= real.(Md)
98100
end
99-
ldiv!(L₁₁', M)
100-
ldiv!(L₁₁', ΔQ̃)
101-
ΔA₁ = mul!(ΔA₁, M, Q₁, +1, 1)
102-
ΔA₁ .+= ΔQ̃
101+
ΔA₁ .+= ldiv!(L₁₁', mul!(ΔQ₁, M, Q₁, +1, 1))
103102
return ΔA
104103
end
105104

@@ -134,3 +133,51 @@ function lq_null_pullback!(
134133
end
135134
return ΔA
136135
end
136+
137+
138+
"""
139+
remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = ...)
140+
141+
Remove the gauge-dependent part from the cotangents `ΔL` and `ΔQ` of the LQ factors `L` and
142+
`Q`. For the full LQ decomposition, the extra rows of `Q` beyond the rank `r` are not uniquely
143+
determined by `A`, so the corresponding part of `ΔQ` is projected to remove this ambiguity.
144+
Additionally, columns of `ΔL` beyond the rank are zeroed out.
145+
"""
146+
function remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(L))
147+
r = MatrixAlgebraKit.lq_rank(L; rank_atol)
148+
minmn = min(size(A)...)
149+
Q₁ = view(Q, 1:r, :)
150+
ΔQ₂ = view(ΔQ, (r + 1):minmn, :)
151+
zero!(ΔQ₂)
152+
ΔQ₃ = view(ΔQ, (minmn + 1):size(ΔQ, 1), :) # extra rows in the case of lq_full
153+
if r == minmn
154+
ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁'
155+
mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁)
156+
else # rank-deficient case, no gauge-invariant information
157+
zero!(ΔQ₃)
158+
end
159+
ΔL₂₂ = view(ΔL, (r + 1):size(ΔL, 1), (r + 1):minmn)
160+
zero!(diagview(ΔL₂₂))
161+
zero!(view(ΔL₂₂, lowertriangularind(ΔL₂₂)))
162+
return ΔL, ΔQ
163+
end
164+
165+
"""
166+
remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ)
167+
168+
Remove the gauge-dependent part from the cotangent `ΔNᴴ` of the LQ null space `Nᴴ`. The null
169+
space is only determined up to a unitary rotation, so `ΔNᴴ` is projected onto the row span of
170+
the compact LQ factor `Q₁`.
171+
"""
172+
function remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ)
173+
return mul!(ΔNᴴ, ΔNᴴ * Nᴴ', Nᴴ, -1, 1)
174+
end
175+
176+
"""
177+
remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ)
178+
179+
Remove the gauge-dependent part from the cotangent `ΔNᴴ` of the right null space `Nᴴ`. The
180+
null space basis is only determined up to a unitary rotation, so `ΔNᴴ` is projected onto the
181+
row span of the compact LQ factor `Q₁` of `A`.
182+
"""
183+
remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) = remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ)

0 commit comments

Comments
 (0)