|
1 | 1 | lq_rank(L; kwargs...) = qr_rank(L; kwargs...) |
2 | 2 |
|
3 | | -function check_lq_cotangents( |
| 3 | +function check_and_prepare_lq_cotangents( |
4 | 4 | L, Q, ΔL, ΔQ, p::Int; |
5 | 5 | gauge_atol::Real = default_pullback_gauge_atol(ΔQ) |
6 | 6 | ) |
7 | | - minmn = min(size(L, 1), size(Q, 2)) |
| 7 | + m, n = size(L, 1), size(Q, 2) |
| 8 | + minmn = min(m, n) |
8 | 9 | Δgauge = abs(zero(eltype(Q))) |
| 10 | + Q₁ = view(Q, 1:p, :) |
| 11 | + ΔQ₁ = zero!(similar(Q₁)) |
9 | 12 | if !iszerotangent(ΔQ) |
10 | | - ΔQ₂ = view(ΔQ, (p + 1):minmn, :) |
11 | | - ΔQ₃ = ΔQ[(minmn + 1):size(Q, 1), :] |
12 | | - Δgauge_Q = norm(ΔQ₂, Inf) |
13 | | - Q₁ = view(Q, 1:p, :) |
14 | | - ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁' |
15 | | - mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁, -1, 1) |
16 | | - Δgauge_Q = max(Δgauge_Q, norm(ΔQ₃, Inf)) |
| 13 | + size(ΔQ) == size(Q) || throw(DimensionMismatch("ΔQ must have the same size as Q")) |
| 14 | + ΔQ₁ .= view(ΔQ, 1:p, 1:n) |
| 15 | + if p == minmn # full rank case, ΔQ₃ contains gauge-invariant information along Q₁ |
| 16 | + Q₃ = view(Q, (minmn + 1):size(Q, 1), :) |
| 17 | + ΔQ₃ = view(ΔQ, (minmn + 1):size(Q, 1), :) |
| 18 | + ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁' |
| 19 | + mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁, -1, 1) |
| 20 | + Δgauge_Q = norm(ΔQ₃, Inf) |
| 21 | + mul!(ΔQ₁, ΔQ₃Q₁ᴴ', Q₃, -1, 1) |
| 22 | + else |
| 23 | + ΔQ₂ = view(ΔQ, (p + 1):size(ΔQ, 1), :) |
| 24 | + Δgauge_Q = norm(ΔQ₂, Inf) |
| 25 | + end |
17 | 26 | Δgauge = max(Δgauge, Δgauge_Q) |
18 | 27 | end |
19 | 28 | if !iszerotangent(ΔL) |
20 | | - ΔL22 = view(ΔL, (p + 1):size(ΔL, 1), (p + 1):minmn) |
21 | | - Δgauge_L = norm(view(ΔL22, lowertriangularind(ΔL22)), Inf) |
22 | | - Δgauge_L = max(Δgauge_L, norm(view(ΔL22, diagind(ΔL22)), Inf)) |
| 29 | + size(ΔL) == size(L) || throw(DimensionMismatch("ΔL must have the same size as L")) |
| 30 | + ΔL₁₁ = LowerTriangular(view(ΔL, 1:p, 1:p)) |
| 31 | + ΔL₂₁ = view(ΔL, (p + 1):size(ΔL, 1), 1:p) |
| 32 | + ΔL₂₂ = view(ΔL, (p + 1):size(ΔL, 1), (p + 1):minmn) |
| 33 | + Δgauge_L = norm(view(ΔL₂₂, lowertriangularind(ΔL₂₂)), Inf) |
| 34 | + Δgauge_L = max(Δgauge_L, norm(view(ΔL₂₂, diagind(ΔL₂₂)), Inf)) |
23 | 35 | Δgauge = max(Δgauge, Δgauge_L) |
| 36 | + else |
| 37 | + ΔL₁₁ = nothing |
| 38 | + ΔL₂₁ = nothing |
24 | 39 | end |
25 | 40 | Δgauge ≤ gauge_atol || |
26 | 41 | @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" |
27 | | - return nothing |
| 42 | + return ΔL₁₁, ΔL₂₁, ΔQ₁ |
28 | 43 | end |
29 | 44 |
|
30 | 45 | """ |
@@ -53,53 +68,37 @@ function lq_pullback!( |
53 | 68 | L, Q = LQ |
54 | 69 | m = size(L, 1) |
55 | 70 | n = size(Q, 2) |
56 | | - minmn = min(m, n) |
57 | 71 | p = lq_rank(L; rank_atol) |
| 72 | + (m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of L*Q ($m, $n)")) |
58 | 73 |
|
59 | | - ΔL, ΔQ = ΔLQ |
60 | | - |
61 | | - Q₁ = view(Q, 1:p, :) |
62 | 74 | L₁₁ = LowerTriangular(view(L, 1:p, 1:p)) |
| 75 | + L₂₁ = view(L, (p + 1):m, 1:p) |
| 76 | + Q₁ = view(Q, 1:p, :) |
| 77 | + |
63 | 78 | ΔA₁ = view(ΔA, 1:p, :) |
64 | 79 | ΔA₂ = view(ΔA, (p + 1):m, :) |
65 | 80 |
|
66 | | - check_lq_cotangents(L, Q, ΔL, ΔQ, p; gauge_atol) |
| 81 | + ΔL, ΔQ = ΔLQ |
| 82 | + ΔL₁₁, ΔL₂₁, ΔQ₁ = check_and_prepare_lq_cotangents(L, Q, ΔL, ΔQ, p; gauge_atol) |
67 | 83 |
|
68 | | - ΔQ̃ = zero!(similar(Q, (p, n))) |
69 | | - if !iszerotangent(ΔQ) |
70 | | - ΔQ₁ = view(ΔQ, 1:p, :) |
71 | | - copy!(ΔQ̃, ΔQ₁) |
72 | | - if minmn < size(Q, 1) |
73 | | - ΔQ₃ = view(ΔQ, (minmn + 1):size(ΔQ, 1), :) |
74 | | - Q₃ = view(Q, (minmn + 1):size(Q, 1), :) |
75 | | - ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁' |
76 | | - ΔQ̃ = mul!(ΔQ̃, ΔQ₃Q₁ᴴ', Q₃, -1, 1) |
77 | | - end |
78 | | - end |
79 | 84 | if !iszerotangent(ΔL) && m > p |
80 | | - L₂₁ = view(L, (p + 1):m, 1:p) |
81 | | - ΔL₂₁ = view(ΔL, (p + 1):m, 1:p) |
82 | | - ΔQ̃ = mul!(ΔQ̃, L₂₁' * ΔL₂₁, Q₁, -1, 1) |
| 85 | + ΔQ₁ = mul!(ΔQ₁, L₂₁' * ΔL₂₁, Q₁, -1, 1) |
83 | 86 | # Adding ΔA₂ contribution |
84 | 87 | ΔA₂ = mul!(ΔA₂, ΔL₂₁, Q₁, 1, 1) |
85 | 88 | end |
86 | 89 |
|
87 | 90 | # construct M |
88 | 91 | M = zero!(similar(L, (p, p))) |
89 | 92 | if !iszerotangent(ΔL) |
90 | | - ΔL₁₁ = LowerTriangular(view(ΔL, 1:p, 1:p)) |
91 | 93 | M = mul!(M, L₁₁', ΔL₁₁, 1, 1) |
92 | 94 | end |
93 | | - M = mul!(M, ΔQ̃, Q₁', -1, 1) |
| 95 | + M = mul!(M, ΔQ₁, Q₁', -1, 1) |
94 | 96 | view(M, uppertriangularind(M)) .= conj.(view(M, lowertriangularind(M))) |
95 | 97 | if eltype(M) <: Complex |
96 | 98 | Md = diagview(M) |
97 | 99 | Md .= real.(Md) |
98 | 100 | end |
99 | | - ldiv!(L₁₁', M) |
100 | | - ldiv!(L₁₁', ΔQ̃) |
101 | | - ΔA₁ = mul!(ΔA₁, M, Q₁, +1, 1) |
102 | | - ΔA₁ .+= ΔQ̃ |
| 101 | + ΔA₁ .+= ldiv!(L₁₁', mul!(ΔQ₁, M, Q₁, +1, 1)) |
103 | 102 | return ΔA |
104 | 103 | end |
105 | 104 |
|
@@ -134,3 +133,51 @@ function lq_null_pullback!( |
134 | 133 | end |
135 | 134 | return ΔA |
136 | 135 | end |
| 136 | + |
| 137 | + |
| 138 | +""" |
| 139 | + remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = ...) |
| 140 | +
|
| 141 | +Remove the gauge-dependent part from the cotangents `ΔL` and `ΔQ` of the LQ factors `L` and |
| 142 | +`Q`. For the full LQ decomposition, the extra rows of `Q` beyond the rank `r` are not uniquely |
| 143 | +determined by `A`, so the corresponding part of `ΔQ` is projected to remove this ambiguity. |
| 144 | +Additionally, columns of `ΔL` beyond the rank are zeroed out. |
| 145 | +""" |
| 146 | +function remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(L)) |
| 147 | + r = MatrixAlgebraKit.lq_rank(L; rank_atol) |
| 148 | + minmn = min(size(A)...) |
| 149 | + Q₁ = view(Q, 1:r, :) |
| 150 | + ΔQ₂ = view(ΔQ, (r + 1):minmn, :) |
| 151 | + zero!(ΔQ₂) |
| 152 | + ΔQ₃ = view(ΔQ, (minmn + 1):size(ΔQ, 1), :) # extra rows in the case of lq_full |
| 153 | + if r == minmn |
| 154 | + ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁' |
| 155 | + mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁) |
| 156 | + else # rank-deficient case, no gauge-invariant information |
| 157 | + zero!(ΔQ₃) |
| 158 | + end |
| 159 | + ΔL₂₂ = view(ΔL, (r + 1):size(ΔL, 1), (r + 1):minmn) |
| 160 | + zero!(diagview(ΔL₂₂)) |
| 161 | + zero!(view(ΔL₂₂, lowertriangularind(ΔL₂₂))) |
| 162 | + return ΔL, ΔQ |
| 163 | +end |
| 164 | + |
| 165 | +""" |
| 166 | + remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) |
| 167 | +
|
| 168 | +Remove the gauge-dependent part from the cotangent `ΔNᴴ` of the LQ null space `Nᴴ`. The null |
| 169 | +space is only determined up to a unitary rotation, so `ΔNᴴ` is projected onto the row span of |
| 170 | +the compact LQ factor `Q₁`. |
| 171 | +""" |
| 172 | +function remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) |
| 173 | + return mul!(ΔNᴴ, ΔNᴴ * Nᴴ', Nᴴ, -1, 1) |
| 174 | +end |
| 175 | + |
| 176 | +""" |
| 177 | + remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) |
| 178 | +
|
| 179 | +Remove the gauge-dependent part from the cotangent `ΔNᴴ` of the right null space `Nᴴ`. The |
| 180 | +null space basis is only determined up to a unitary rotation, so `ΔNᴴ` is projected onto the |
| 181 | +row span of the compact LQ factor `Q₁` of `A`. |
| 182 | +""" |
| 183 | +remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) = remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) |
0 commit comments