|
| 1 | +function lq_pushforward!(dA, A, LQ, dLQ; tol::Real=default_pullback_gauge_atol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol) |
| 2 | + |
| 3 | + L, Q = LQ |
| 4 | + dL, dQ = dLQ |
| 5 | + m = size(L, 1) |
| 6 | + n = size(Q, 2) |
| 7 | + minmn = min(m, n) |
| 8 | + Ld = diagview(L) |
| 9 | + p = findlast(>=(rank_atol) ∘ abs, Ld) |
| 10 | + |
| 11 | + if p == minmn && size(L,1) == size(L,2) # full-rank |
| 12 | + invL = inv(L) |
| 13 | + dQ .= invL * (dA - dL * Q) |
| 14 | + dL = invL * dA * Q' |
| 15 | + return (dL, dQ) |
| 16 | + end |
| 17 | + |
| 18 | + n1 = p |
| 19 | + n2 = minmn - p |
| 20 | + n3 = n - minmn |
| 21 | + m1 = p |
| 22 | + m2 = m - p |
| 23 | + |
| 24 | + ##### |
| 25 | + Q1 = view(Q, 1:m1, 1:n) # full rank portion |
| 26 | + Q2 = view(Q, n1+1:n1+n2, 1:n) |
| 27 | + L11 = view(L, 1:m1, 1:n1) |
| 28 | + L21 = view(L, (m1+1):m, 1:n1) |
| 29 | + |
| 30 | + dA1 = view(dA, 1:m1, 1:n) |
| 31 | + dA2 = view(dA, (m1+1):m, 1:n) |
| 32 | + |
| 33 | + dQ1 = view(dQ, 1:n1, 1:n) |
| 34 | + dQ2 = view(dQ, n1+1:n1+n2, 1:n) |
| 35 | + dL11 = view(dL, 1:m1, 1:n1) |
| 36 | + dL21 = view(dL, (m1+1):m, 1:n1) |
| 37 | + dL22 = view(dL, (m1+1):m, n1+1:(n1+n2) ) |
| 38 | + |
| 39 | + # fwd rule for Q1 and R11 -- for a non-rank redeficient QR, this is all we need |
| 40 | + invL11 = inv(L11) |
| 41 | + tmp = invL11 * dA1 * Q1' |
| 42 | + Ltmp = tmp + tmp' |
| 43 | + diagview(Ltmp) ./= 2 |
| 44 | + utLtmp = view(Ltmp, MatrixAlgebraKit.uppertriangularind(Ltmp)) |
| 45 | + dL11 .= L11 * Ltmp |
| 46 | + dQ1 .= invL11 * dA1 - invL11 * dL11 * Q1 |
| 47 | + |
| 48 | + dL21 .= (dA2 - L21 * dQ1) * adjoint(Q1) |
| 49 | + dQ2 .= -(dQ2 * Q1') * Q1 |
| 50 | + if size(Q2, 1) > 0 |
| 51 | + dQ2 .+= Q2 * (Q2' * dQ2) |
| 52 | + end |
| 53 | + if n3 > 0 && size(dQ2, 1) > 0 |
| 54 | + # only present for qr_full or rank-deficient qr_compact |
| 55 | + Q3 = view(Q, (n1+n2+1):n, 1:n) |
| 56 | + dQ2 .+= Q3 * (Q3' * dQ2) |
| 57 | + end |
| 58 | + if !isempty(dL22) |
| 59 | + _, l22 = qr_full(dA2 - L21 * dQ1 - dL12 * Q1, MatrixAlgebraKit.LAPACK_HouseholderQR(; positive=true)) |
| 60 | + dL22 .= view(l22, 1:size(dL22, 1), 1:size(dL22, 2)) |
| 61 | + end |
| 62 | + return (dL, dQ) |
| 63 | +end |
| 64 | + |
| 65 | +#=function lq_pushforward!(dA, A, LQ, dLQ; kwargs...) |
| 66 | + qr_pushforward!(dA, A, (adjoint(LQ[2]), adjoint(LQ[1])), (adjoint(dLQ[2]), adjoint(dLQ[1])); kwargs...) |
| 67 | +end=# |
| 68 | + |
| 69 | +function lq_null_pushforward!(dA, A, LQ, dLQ; tol::Real=default_pullback_gauge_atol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol) end |
0 commit comments