@@ -6,40 +6,25 @@ function check_qr_cotangents(
66 gauge_atol:: Real = default_pullback_gauge_atol (ΔQ)
77 )
88 minmn = min (size (Q, 1 ), size (R, 2 ))
9- if minmn > p # case where A is rank-deficient
10- Δgauge = abs (zero (eltype (Q)))
11- if ! iszerotangent (ΔQ)
12- # in this case the number Householder reflections will
13- # change upon small variations, and all of the remaining
14- # columns of ΔQ should be zero for a gauge-invariant
15- # cost function
16- ΔQ2 = view (ΔQ, :, (p + 1 ): size (Q, 2 ))
17- Δgauge_Q = norm (ΔQ2, Inf )
18- Δgauge = max (Δgauge, Δgauge_Q)
19- end
20- if ! iszerotangent (ΔR)
21- ΔR22 = view (ΔR, (p + 1 ): minmn, (p + 1 ): size (R, 2 ))
22- Δgauge_R = norm (ΔR22, Inf )
23- Δgauge = max (Δgauge, Δgauge_R)
24- end
25- Δgauge ≤ gauge_atol ||
26- @warn " `qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge )"
9+ Δgauge = abs (zero (eltype (Q)))
10+ if ! iszerotangent (ΔQ)
11+ ΔQ₂ = view (ΔQ, :, (p + 1 ): minmn)
12+ ΔQ₃ = ΔQ[:, (minmn + 1 ): size (Q, 2 )] # extra columns in the case of qr_full
13+ Δgauge_Q = norm (ΔQ₂, Inf )
14+ Q₁ = view (Q, :, 1 : p)
15+ Q₁ᴴΔQ₃ = Q₁' * ΔQ₃
16+ mul! (ΔQ₃, Q₁, Q₁ᴴΔQ₃, - 1 , 1 )
17+ Δgauge_Q = max (Δgauge_Q, norm (ΔQ₃, Inf ))
18+ Δgauge = max (Δgauge, Δgauge_Q)
19+ end
20+ if ! iszerotangent (ΔR)
21+ ΔR22 = view (ΔR, (p + 1 ): minmn, (p + 1 ): size (R, 2 ))
22+ Δgauge_R = norm (view (ΔR22, uppertriangularind (ΔR22)), Inf )
23+ Δgauge = max (Δgauge, Δgauge_R)
2724 end
28- return
29- end
30-
31- function check_qr_full_cotangents (Q1, ΔQ2, Q1dΔQ2; gauge_atol:: Real = default_pullback_gauge_atol (ΔQ2))
32- # in the case where A is full rank, but there are more columns in Q than in A
33- # (the case of `qr_full`), there is gauge-invariant information in the
34- # projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
35- # matrix. As the number of Householder reflections is in fixed in the full rank
36- # case, Q is expected to rotate smoothly (we might even be able to predict) also
37- # how the full Q2 will change, but this we omit for now, and we consider
38- # Q2' * ΔQ2 as a gauge dependent quantity.
39- Δgauge = norm (mul! (copy (ΔQ2), Q1, Q1dΔQ2, - 1 , 1 ), Inf )
4025 Δgauge ≤ gauge_atol ||
41- @warn " `qr_full ` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge )"
42- return
26+ @warn " `qr ` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge )"
27+ return nothing
4328end
4429
4530"""
@@ -69,27 +54,28 @@ function qr_pullback!(
6954 Q, R = QR
7055 m = size (Q, 1 )
7156 n = size (R, 2 )
57+ minmn = min (m, n)
7258 Rd = diagview (R)
7359 p = qr_rank (R; rank_atol)
7460
7561 ΔQ, ΔR = ΔQR
7662
7763 Q1 = view (Q, :, 1 : p)
78- R11 = view (R, 1 : p, 1 : p)
64+ R11 = UpperTriangular ( view (R, 1 : p, 1 : p) )
7965 ΔA1 = view (ΔA, :, 1 : p)
8066 ΔA2 = view (ΔA, :, (p + 1 ): n)
8167
8268 check_qr_cotangents (Q, R, ΔQ, ΔR, p; gauge_atol)
8369
8470 ΔQ̃ = zero! (similar (Q, (m, p)))
8571 if ! iszerotangent (ΔQ)
86- copy! (ΔQ̃, view (ΔQ, :, 1 : p) )
87- if p < size (Q, 2 )
88- Q2 = view (Q, :, (p + 1 ) : size (Q, 2 ) )
89- ΔQ2 = view (ΔQ, :, (p + 1 ): size (Q , 2 ))
90- Q1dΔQ2 = Q1 ' * ΔQ2
91- check_qr_full_cotangents (Q1, ΔQ2, Q1dΔQ2; gauge_atol)
92- ΔQ̃ = mul! (ΔQ̃, Q2, Q1dΔQ2 ' , - 1 , 1 )
72+ ΔQ₁ = view (ΔQ, :, 1 : p)
73+ copy! (ΔQ̃, ΔQ₁ )
74+ if minmn < size (Q, 2 )
75+ ΔQ3 = view (ΔQ, :, (minmn + 1 ): size (ΔQ , 2 )) # extra columns in the case of qr_full
76+ Q3 = view (Q, :, (minmn + 1 ) : size (Q, 2 ))
77+ Q1ᴴΔQ3 = Q1 ' * ΔQ3
78+ ΔQ̃ = mul! (ΔQ̃, Q3, Q1ᴴΔQ3 ' , - 1 , 1 )
9379 end
9480 end
9581 if ! iszerotangent (ΔR) && n > p
@@ -103,7 +89,7 @@ function qr_pullback!(
10389 # construct M
10490 M = zero! (similar (R, (p, p)))
10591 if ! iszerotangent (ΔR)
106- ΔR11 = view (ΔR, 1 : p, 1 : p)
92+ ΔR11 = UpperTriangular ( view (ΔR, 1 : p, 1 : p) )
10793 M = mul! (M, ΔR11, R11' , 1 , 1 )
10894 end
10995 M = mul! (M, Q1' , ΔQ̃, - 1 , 1 )
@@ -112,8 +98,8 @@ function qr_pullback!(
11298 Md = diagview (M)
11399 Md .= real .(Md)
114100 end
115- rdiv! (M, UpperTriangular ( R11) ' )
116- rdiv! (ΔQ̃, UpperTriangular ( R11) ' )
101+ rdiv! (M, R11' ) # R11 is upper triangular
102+ rdiv! (ΔQ̃, R11' )
117103 ΔA1 = mul! (ΔA1, Q1, M, + 1 , 1 )
118104 ΔA1 .+ = ΔQ̃
119105 return ΔA
0 commit comments