Skip to content

Commit 81b12e7

Browse files
committed
Mooncake forward rules
1 parent eceef30 commit 81b12e7

13 files changed

Lines changed: 592 additions & 155 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 232 additions & 76 deletions
Large diffs are not rendered by default.

src/MatrixAlgebraKit.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,4 +113,11 @@ include("pullbacks/eigh.jl")
113113
include("pullbacks/svd.jl")
114114
include("pullbacks/polar.jl")
115115

116+
include("pushforwards/qr.jl")
117+
include("pushforwards/lq.jl")
118+
include("pushforwards/eig.jl")
119+
include("pushforwards/eigh.jl")
120+
include("pushforwards/polar.jl")
121+
include("pushforwards/svd.jl")
122+
116123
end

src/common/view.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# diagind: provided by LinearAlgebra.jl
2-
diagview(D::Diagonal) = D.diag
2+
diagview(D::Diagonal) = D.diag
33
diagview(D::AbstractMatrix) = view(D, diagind(D))
44

55
# triangularind

src/implementations/eigh.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function check_hermitian(A; atol::Real = default_hermitian_tol(A), rtol::Real =
1919
end
2020

2121
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::AbstractAlgorithm)
22-
check_hermitian(A, alg)
22+
#check_hermitian(A, alg)
2323
D, V = DV
2424
m = size(A, 1)
2525
@assert D isa Diagonal && V isa AbstractMatrix

src/pullbacks/eig.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ function eig_pullback!(
4646
Δgauge gauge_atol ||
4747
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
4848

49-
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
49+
VᴴΔV ./= conj.(transpose(D) .- D)
50+
diagview(VᴴΔV) .= zero(eltype(VᴴΔV))
5051

5152
if !iszerotangent(ΔDmat)
5253
ΔDvec = diagview(ΔDmat)

src/pushforwards/eig.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
function eig_pushforward!(ΔA, A, DV, ΔDV; kwargs...)
2+
D, V = DV
3+
ΔD, ΔV = ΔDV
4+
iVΔAV = inv(V) * ΔA * V
5+
diagview(ΔD) .= diagview(iVΔAV)
6+
if !iszerotangent(ΔV)
7+
F = 1 ./ (transpose(diagview(D)) .- diagview(D))
8+
fill!(diagview(F), zero(eltype(F)))
9+
= F .* iVΔAV
10+
mul!(ΔV, V, K̇, 1, 0)
11+
end
12+
return ΔDV
13+
end
14+
15+
function eig_trunc_pushforward!(ΔA, A, DV, ΔDV; kwargs...) end

src/pushforwards/eigh.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
function eigh_pushforward!(dA, A, DV, dDV; kwargs...)
2+
D, V = DV
3+
dD, dV = dDV
4+
tmpV = V \ dA
5+
∂K = tmpV * V
6+
∂Kdiag = diag(∂K)
7+
dD.diag .= real.(∂Kdiag)
8+
dDD = transpose(diagview(D)) .- diagview(D)
9+
F = one(eltype(dDD)) ./ dDD
10+
diagview(F) .= zero(eltype(F))
11+
∂K .*= F
12+
∂V = mul!(tmpV, V, ∂K)
13+
copyto!(dV, ∂V)
14+
return (dD, dV)
15+
end

src/pushforwards/lq.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
function lq_pushforward!(dA, A, LQ, dLQ; tol::Real=default_pullback_gauge_atol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol)
2+
3+
L, Q = LQ
4+
dL, dQ = dLQ
5+
m = size(L, 1)
6+
n = size(Q, 2)
7+
minmn = min(m, n)
8+
Ld = diagview(L)
9+
p = findlast(>=(rank_atol) abs, Ld)
10+
11+
if p == minmn && size(L,1) == size(L,2) # full-rank
12+
invL = inv(L)
13+
dQ .= invL * (dA - dL * Q)
14+
dL = invL * dA * Q'
15+
return (dL, dQ)
16+
end
17+
18+
n1 = p
19+
n2 = minmn - p
20+
n3 = n - minmn
21+
m1 = p
22+
m2 = m - p
23+
24+
#####
25+
Q1 = view(Q, 1:m1, 1:n) # full rank portion
26+
Q2 = view(Q, n1+1:n1+n2, 1:n)
27+
L11 = view(L, 1:m1, 1:n1)
28+
L21 = view(L, (m1+1):m, 1:n1)
29+
30+
dA1 = view(dA, 1:m1, 1:n)
31+
dA2 = view(dA, (m1+1):m, 1:n)
32+
33+
dQ1 = view(dQ, 1:n1, 1:n)
34+
dQ2 = view(dQ, n1+1:n1+n2, 1:n)
35+
dL11 = view(dL, 1:m1, 1:n1)
36+
dL21 = view(dL, (m1+1):m, 1:n1)
37+
dL22 = view(dL, (m1+1):m, n1+1:(n1+n2) )
38+
39+
# fwd rule for Q1 and R11 -- for a non-rank redeficient QR, this is all we need
40+
invL11 = inv(L11)
41+
tmp = invL11 * dA1 * Q1'
42+
Ltmp = tmp + tmp'
43+
diagview(Ltmp) ./= 2
44+
utLtmp = view(Ltmp, MatrixAlgebraKit.uppertriangularind(Ltmp))
45+
dL11 .= L11 * Ltmp
46+
dQ1 .= invL11 * dA1 - invL11 * dL11 * Q1
47+
48+
dL21 .= (dA2 - L21 * dQ1) * adjoint(Q1)
49+
dQ2 .= -(dQ2 * Q1') * Q1
50+
if size(Q2, 1) > 0
51+
dQ2 .+= Q2 * (Q2' * dQ2)
52+
end
53+
if n3 > 0 && size(dQ2, 1) > 0
54+
# only present for qr_full or rank-deficient qr_compact
55+
Q3 = view(Q, (n1+n2+1):n, 1:n)
56+
dQ2 .+= Q3 * (Q3' * dQ2)
57+
end
58+
if !isempty(dL22)
59+
_, l22 = qr_full(dA2 - L21 * dQ1 - dL12 * Q1, MatrixAlgebraKit.LAPACK_HouseholderQR(; positive=true))
60+
dL22 .= view(l22, 1:size(dL22, 1), 1:size(dL22, 2))
61+
end
62+
return (dL, dQ)
63+
end
64+
65+
#=function lq_pushforward!(dA, A, LQ, dLQ; kwargs...)
66+
qr_pushforward!(dA, A, (adjoint(LQ[2]), adjoint(LQ[1])), (adjoint(dLQ[2]), adjoint(dLQ[1])); kwargs...)
67+
end=#
68+
69+
function lq_null_pushforward!(dA, A, LQ, dLQ; tol::Real=default_pullback_gauge_atol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol) end

src/pushforwards/polar.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
function left_polar_pushforward!(ΔA, A, WP, ΔWP; kwargs...)
2+
W, P = WP
3+
ΔW, ΔP = ΔWP
4+
aWdA = adjoint(W) * ΔA
5+
= sylvester(P, P, -(aWdA - adjoint(aWdA)))
6+
= (Diagonal(ones(eltype(W), size(W, 1))) - W*adjoint(W))*ΔA*inv(P)
7+
ΔW .= W *+
8+
ΔP .= aWdA -*P
9+
return (ΔW, ΔP)
10+
end
11+
12+
function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...)
13+
P, Wᴴ = PWᴴ
14+
ΔP, ΔWᴴ = ΔPWᴴ
15+
dAW = ΔA * adjoint(Wᴴ)
16+
= sylvester(P, P, -(dAW - adjoint(dAW)))
17+
= inv(P)*ΔA*(Diagonal(ones(eltype(Wᴴ), size(Wᴴ, 2))) - adjoint(Wᴴ) * Wᴴ)
18+
ΔWᴴ .=* Wᴴ +
19+
ΔP .= dAW - P *
20+
return (ΔWᴴ, ΔP)
21+
end

src/pushforwards/qr.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
function qr_pushforward!(dA, A, QR, dQR; tol::Real=default_pullback_gauge_atol(QR[2]), rank_atol::Real=tol, gauge_atol::Real=tol)
2+
Q, R = QR
3+
m = size(A, 1)
4+
n = size(A, 2)
5+
minmn = min(m, n)
6+
Rd = diagview(R)
7+
p = findlast(>=(rank_atol) abs, Rd)
8+
9+
m1 = p
10+
m2 = minmn - p
11+
m3 = m - minmn
12+
n1 = p
13+
n2 = n - p
14+
15+
Q1 = view(Q, 1:m, 1:m1) # full rank portion
16+
Q2 = view(Q, 1:m, m1+1:m2+m1)
17+
R11 = view(R, 1:m1, 1:n1)
18+
R12 = view(R, 1:m1, n1+1:n)
19+
20+
dA1 = view(dA, 1:m, 1:n1)
21+
dA2 = view(dA, 1:m, (n1 + 1):n)
22+
23+
dQ, dR = dQR
24+
dQ1 = view(dQ, 1:m, 1:m1)
25+
dQ2 = view(dQ, 1:m, m1+1:m2+m1)
26+
dQ3 = minmn+1 < size(dQ, 2) ? view(dQ, :, minmn+1:size(dQ,2)) : similar(dQ, eltype(dQ), (0, 0))
27+
dR11 = view(dR, 1:m1, 1:n1)
28+
dR12 = view(dR, 1:m1, n1+1:n)
29+
dR22 = view(dR, m1+1:m1+m2, n1+1:n)
30+
31+
# fwd rule for Q1 and R11 -- for a non-rank redeficient QR, this is all we need
32+
invR11 = inv(R11)
33+
tmp = Q1' * dA1 * invR11
34+
Rtmp = tmp + tmp'
35+
diagview(Rtmp) ./= 2
36+
ltRtmp = view(Rtmp, lowertriangularind(Rtmp))
37+
ltRtmp .= zero(eltype(Rtmp))
38+
dR11 .= Rtmp * R11
39+
dQ1 .= dA1 * invR11 - Q1 * dR11 * invR11
40+
dR12 .= adjoint(Q1) * (dA2 - dQ1 * R12)
41+
if size(Q2, 2) > 0
42+
dQ2 .= -Q1 * (Q1' * Q2)
43+
dQ2 .+= Q2 * (Q2' * dQ2)
44+
end
45+
if m3 > 0 && size(Q, 2) > minmn
46+
# only present for qr_full or rank-deficient qr_compact
47+
Q′ = view(Q, :, 1:minmn)
48+
Q3 = view(Q, :, minmn+1:m)
49+
#dQ3 .= Q′ * (Q′' * Q3)
50+
dQ3 .= Q3
51+
end
52+
if !isempty(dR22)
53+
_, r22 = qr_full(dA2 - dQ1*R12 - Q1*dR12, LAPACK_HouseholderQR(; positive=true))
54+
dR22 .= view(r22, 1:size(dR22, 1), 1:size(dR22, 2))
55+
end
56+
return (dQ, dR)
57+
end
58+
59+
function qr_null_pushforward!(dA, A, N, dN; tol::Real=default_pullback_gauge_atol(N), rank_atol::Real=tol, gauge_atol::Real=tol) end

0 commit comments

Comments
 (0)