Skip to content

Commit 590ef91

Browse files
committed
Semi-working LQ
1 parent 81b12e7 commit 590ef91

4 files changed

Lines changed: 44 additions & 110 deletions

File tree

src/pushforwards/eigh.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@ function eigh_pushforward!(dA, A, DV, dDV; kwargs...)
1313
copyto!(dV, ∂V)
1414
return (dD, dV)
1515
end
16+
17+
function eigh_trunc_pushforward!(dA, A, DV, dDV; kwargs...) end

src/pushforwards/lq.jl

Lines changed: 4 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,7 @@
11
function lq_pushforward!(dA, A, LQ, dLQ; tol::Real=default_pullback_gauge_atol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol)
2-
3-
L, Q = LQ
4-
dL, dQ = dLQ
5-
m = size(L, 1)
6-
n = size(Q, 2)
7-
minmn = min(m, n)
8-
Ld = diagview(L)
9-
p = findlast(>=(rank_atol) abs, Ld)
10-
11-
if p == minmn && size(L,1) == size(L,2) # full-rank
12-
invL = inv(L)
13-
dQ .= invL * (dA - dL * Q)
14-
dL = invL * dA * Q'
15-
return (dL, dQ)
16-
end
17-
18-
n1 = p
19-
n2 = minmn - p
20-
n3 = n - minmn
21-
m1 = p
22-
m2 = m - p
23-
24-
#####
25-
Q1 = view(Q, 1:m1, 1:n) # full rank portion
26-
Q2 = view(Q, n1+1:n1+n2, 1:n)
27-
L11 = view(L, 1:m1, 1:n1)
28-
L21 = view(L, (m1+1):m, 1:n1)
29-
30-
dA1 = view(dA, 1:m1, 1:n)
31-
dA2 = view(dA, (m1+1):m, 1:n)
32-
33-
dQ1 = view(dQ, 1:n1, 1:n)
34-
dQ2 = view(dQ, n1+1:n1+n2, 1:n)
35-
dL11 = view(dL, 1:m1, 1:n1)
36-
dL21 = view(dL, (m1+1):m, 1:n1)
37-
dL22 = view(dL, (m1+1):m, n1+1:(n1+n2) )
38-
39-
# fwd rule for Q1 and R11 -- for a non-rank redeficient QR, this is all we need
40-
invL11 = inv(L11)
41-
tmp = invL11 * dA1 * Q1'
42-
Ltmp = tmp + tmp'
43-
diagview(Ltmp) ./= 2
44-
utLtmp = view(Ltmp, MatrixAlgebraKit.uppertriangularind(Ltmp))
45-
dL11 .= L11 * Ltmp
46-
dQ1 .= invL11 * dA1 - invL11 * dL11 * Q1
47-
48-
dL21 .= (dA2 - L21 * dQ1) * adjoint(Q1)
49-
dQ2 .= -(dQ2 * Q1') * Q1
50-
if size(Q2, 1) > 0
51-
dQ2 .+= Q2 * (Q2' * dQ2)
52-
end
53-
if n3 > 0 && size(dQ2, 1) > 0
54-
# only present for qr_full or rank-deficient qr_compact
55-
Q3 = view(Q, (n1+n2+1):n, 1:n)
56-
dQ2 .+= Q3 * (Q3' * dQ2)
57-
end
58-
if !isempty(dL22)
59-
_, l22 = qr_full(dA2 - L21 * dQ1 - dL12 * Q1, MatrixAlgebraKit.LAPACK_HouseholderQR(; positive=true))
60-
dL22 .= view(l22, 1:size(dL22, 1), 1:size(dL22, 2))
61-
end
62-
return (dL, dQ)
2+
qr_pushforward!(adjoint(dA), adjoint(A), adjoint.(reverse(LQ)), adjoint.(reverse(dLQ)); tol, rank_atol, gauge_atol)
633
end
644

65-
#=function lq_pushforward!(dA, A, LQ, dLQ; kwargs...)
66-
qr_pushforward!(dA, A, (adjoint(LQ[2]), adjoint(LQ[1])), (adjoint(dLQ[2]), adjoint(dLQ[1])); kwargs...)
67-
end=#
68-
69-
function lq_null_pushforward!(dA, A, LQ, dLQ; tol::Real=default_pullback_gauge_atol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol) end
5+
function lq_null_pushforward!(dA, A, Nᴴ, dNᴴ; tol::Real=default_pullback_gauge_atol(Nᴴ), rank_atol::Real=tol, gauge_atol::Real=tol)
6+
iszero(min(size(Nᴴ)...)) && return # nothing to do
7+
end

src/pushforwards/qr.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,6 @@ function qr_pushforward!(dA, A, QR, dQR; tol::Real=default_pullback_gauge_atol(Q
5656
return (dQ, dR)
5757
end
5858

59-
function qr_null_pushforward!(dA, A, N, dN; tol::Real=default_pullback_gauge_atol(N), rank_atol::Real=tol, gauge_atol::Real=tol) end
59+
function qr_null_pushforward!(dA, A, N, dN; tol::Real=default_pullback_gauge_atol(N), rank_atol::Real=tol, gauge_atol::Real=tol)
60+
iszero(min(size(N)...)) && return # nothing to do
61+
end

test/mooncake.jl

Lines changed: 35 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,6 @@ end
179179
end
180180
end
181181

182-
#=
183182
@timedtestset "LQ AD Rules with eltype $T" for T in ETs
184183
rng = StableRNG(12345)
185184
m = 19
@@ -193,57 +192,50 @@ end
193192
)
194193
@testset "lq_compact" begin
195194
L, Q = lq_compact(A, alg)
196-
Mooncake.TestUtils.test_rule(rng, lq_compact, A, alg; is_primitive = false, atol = atol, rtol = rtol)
195+
Mooncake.TestUtils.test_rule(rng, lq_compact, A, alg; atol = atol, rtol = rtol)
197196
test_pullbacks_match(rng, lq_compact!, lq_compact, A, (L, Q), (randn(rng, T, m, minmn), randn(rng, T, minmn, n)), alg)
198-
ΔL = randn(rng, T, m, minmn)
199-
ΔQ = randn(rng, T, minmn, n)
200-
dL = make_mooncake_tangent(ΔL)
201-
dQ = make_mooncake_tangent(ΔQ)
202-
dLQ = Mooncake.build_tangent(typeof((ΔL,ΔQ)), dL, dQ)
203-
Mooncake.TestUtils.test_rule(rng, lq_compact, A, alg; is_primitive=false, atol=atol, rtol=rtol, output_tangent = dLQ)
204197
end
205198
@testset "lq_null" begin
206199
L, Q = lq_compact(A, alg)
207-
ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q
208-
Nᴴ = randn(rng, T, max(0, n - minmn), n)
209-
dNᴴ = make_mooncake_tangent(ΔNᴴ)
210-
Mooncake.TestUtils.test_rule(rng, lq_null, A, alg; output_tangent = dNᴴ, is_primitive = false, atol = atol, rtol = rtol)
200+
ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q
201+
Nᴴ = randn(rng, T, max(0, n - minmn), n)
202+
dNᴴ = make_mooncake_tangent(ΔNᴴ)
203+
Mooncake.TestUtils.test_rule(rng, lq_null, A, alg; output_tangent = dNᴴ, atol = atol, rtol = rtol)
211204
test_pullbacks_match(rng, lq_null!, lq_null, A, Nᴴ, ΔNᴴ, alg)
212205
end
213206
@testset "lq_full" begin
214207
L, Q = lq_full(A, alg)
215-
Q1 = view(Q, 1:minmn, 1:n)
216-
ΔQ = randn(rng, T, n, n)
217-
ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n)
208+
Q1 = view(Q, 1:minmn, 1:n)
209+
ΔQ = randn(rng, T, n, n)
210+
ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n)
218211
mul!(ΔQ2, ΔQ2 * Q1', Q1)
219-
ΔL = randn(rng, T, m, n)
220-
dL = make_mooncake_tangent(ΔL)
221-
dQ = make_mooncake_tangent(ΔQ)
222-
dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ)
223-
Mooncake.TestUtils.test_rule(rng, lq_full, A, alg; output_tangent = dLQ, is_primitive = false, atol = atol, rtol = rtol)
212+
ΔL = randn(rng, T, m, n)
213+
dL = make_mooncake_tangent(ΔL)
214+
dQ = make_mooncake_tangent(ΔQ)
215+
dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ)
216+
Mooncake.TestUtils.test_rule(rng, lq_full, A, alg; output_tangent = dLQ, atol = atol, rtol = rtol)
224217
test_pullbacks_match(rng, lq_full!, lq_full, A, (L, Q), (ΔL, ΔQ), alg)
225218
end
226219
@testset "lq_compact - rank-deficient A" begin
227-
r = minmn - 5
228-
Ard = randn(rng, T, m, r) * randn(rng, T, r, n)
220+
r = minmn - 5
221+
Ard = randn(rng, T, m, r) * randn(rng, T, r, n)
229222
L, Q = lq_compact(Ard, alg)
230-
ΔL = randn(rng, T, m, minmn)
231-
ΔQ = randn(rng, T, minmn, n)
232-
Q1 = view(Q, 1:r, 1:n)
233-
Q2 = view(Q, (r + 1):minmn, 1:n)
234-
ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n)
223+
ΔL = randn(rng, T, m, minmn)
224+
ΔQ = randn(rng, T, minmn, n)
225+
Q1 = view(Q, 1:r, 1:n)
226+
Q2 = view(Q, (r + 1):minmn, 1:n)
227+
ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n)
235228
ΔQ2 .= 0
236229
view(ΔL, :, (r + 1):minmn) .= 0
237-
dL = make_mooncake_tangent(ΔL)
238-
dQ = make_mooncake_tangent(ΔQ)
239-
dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ)
240-
Mooncake.TestUtils.test_rule(rng, lq_compact, Ard, alg; output_tangent = dLQ, is_primitive = false, atol = atol, rtol = rtol)
230+
dL = make_mooncake_tangent(ΔL)
231+
dQ = make_mooncake_tangent(ΔQ)
232+
dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ)
233+
Mooncake.TestUtils.test_rule(rng, lq_compact, Ard, alg; output_tangent = dLQ, atol = atol, rtol = rtol)
241234
test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, (L, Q), (ΔL, ΔQ), alg)
242235
end
243236
end
244237
end
245238
end
246-
=#
247239

248240
@timedtestset "EIG AD Rules with eltype $T" for T in ETs
249241
rng = StableRNG(12345)
@@ -283,7 +275,7 @@ end
283275
dDtrunc = make_mooncake_tangent(ΔDtrunc)
284276
dVtrunc = make_mooncake_tangent(ΔVtrunc)
285277
dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T)))
286-
Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false)
278+
Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; output_tangent = dDVtrunc, atol = atol, rtol = rtol)
287279
test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
288280
end
289281
truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real))
@@ -295,7 +287,7 @@ end
295287
dDtrunc = make_mooncake_tangent(ΔDtrunc)
296288
dVtrunc = make_mooncake_tangent(ΔVtrunc)
297289
dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T)))
298-
Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false)
290+
Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; output_tangent = dDVtrunc, atol = atol, rtol = rtol)
299291
test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
300292
end
301293
end
@@ -357,11 +349,11 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.cop
357349
LAPACK_MultipleRelativelyRobustRepresentations(),
358350
)
359351
@testset "eigh_full" begin
360-
Mooncake.TestUtils.test_rule(rng, copy_eigh_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol)
352+
Mooncake.TestUtils.test_rule(rng, copy_eigh_full, A, alg; output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol)
361353
test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, (D, V), (ΔD2, ΔV), alg)
362354
end
363355
@testset "eigh_vals" begin
364-
Mooncake.TestUtils.test_rule(rng, copy_eigh_vals, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol)
356+
Mooncake.TestUtils.test_rule(rng, copy_eigh_vals, A, alg; is_primitive = false, atol = atol, rtol = rtol)
365357
test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D.diag, ΔD2.diag, alg)
366358
end
367359
@testset "eigh_trunc" begin
@@ -375,7 +367,7 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.cop
375367
dDtrunc = make_mooncake_tangent(ΔDtrunc)
376368
dVtrunc = make_mooncake_tangent(ΔVtrunc)
377369
dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T)))
378-
Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false)
370+
Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false)
379371
test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
380372
end
381373
truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2))
@@ -387,7 +379,7 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.cop
387379
dDtrunc = make_mooncake_tangent(ΔDtrunc)
388380
dVtrunc = make_mooncake_tangent(ΔVtrunc)
389381
dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T)))
390-
Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false)
382+
Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false)
391383
test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
392384
end
393385
end
@@ -523,12 +515,12 @@ right_orth_lq(X) = right_orth(X; alg = :lq)
523515
right_orth_polar(X) = right_orth(X; alg = :polar)
524516
right_null_lq(X) = right_null(X; alg = :lq)
525517

526-
MatrixAlgebraKit.copy_input(::typeof(left_orth_qr), A) = MatrixAlgebraKit.copy_input(left_orth, A)
527-
MatrixAlgebraKit.copy_input(::typeof(left_orth_polar), A) = MatrixAlgebraKit.copy_input(left_orth, A)
528-
MatrixAlgebraKit.copy_input(::typeof(left_null_qr), A) = MatrixAlgebraKit.copy_input(left_null, A)
529-
MatrixAlgebraKit.copy_input(::typeof(right_orth_lq), A) = MatrixAlgebraKit.copy_input(right_orth, A)
518+
MatrixAlgebraKit.copy_input(::typeof(left_orth_qr), A) = MatrixAlgebraKit.copy_input(left_orth, A)
519+
MatrixAlgebraKit.copy_input(::typeof(left_orth_polar), A) = MatrixAlgebraKit.copy_input(left_orth, A)
520+
MatrixAlgebraKit.copy_input(::typeof(left_null_qr), A) = MatrixAlgebraKit.copy_input(left_null, A)
521+
MatrixAlgebraKit.copy_input(::typeof(right_orth_lq), A) = MatrixAlgebraKit.copy_input(right_orth, A)
530522
MatrixAlgebraKit.copy_input(::typeof(right_orth_polar), A) = MatrixAlgebraKit.copy_input(right_orth, A)
531-
MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_input(right_null, A)
523+
MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_input(right_null, A)
532524

533525
@timedtestset "Orth and null with eltype $T" for T in ETs
534526
rng = StableRNG(12345)

0 commit comments

Comments
 (0)