-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathqr.jl
More file actions
153 lines (139 loc) · 5.45 KB
/
qr.jl
File metadata and controls
153 lines (139 loc) · 5.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
qr_rank(R; rank_atol = default_pullback_rank_atol(R)) =
@something findlast(>=(rank_atol) ∘ abs, diagview(R)) 0
function check_qr_cotangents(
Q, R, ΔQ, ΔR, p::Int;
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
)
minmn = min(size(Q, 1), size(R, 2))
if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
# in this case the number Householder reflections will
# change upon small variations, and all of the remaining
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
Δgauge_Q = norm(ΔQ2, Inf)
Δgauge = max(Δgauge, Δgauge_Q)
end
if !iszerotangent(ΔR)
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2))
Δgauge_R = norm(ΔR22, Inf)
Δgauge = max(Δgauge, Δgauge_R)
end
Δgauge ≤ gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end
return
end
function check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol::Real = default_pullback_gauge_atol(ΔQ2))
# in the case where A is full rank, but there are more columns in Q than in A
# (the case of `qr_full`), there is gauge-invariant information in the
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
# matrix. As the number of Householder reflections is in fixed in the full rank
# case, Q is expected to rotate smoothly (we might even be able to predict) also
# how the full Q2 will change, but this we omit for now, and we consider
# Q2' * ΔQ2 as a gauge dependent quantity.
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`qr` full cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
end
"""
qr_pullback!(
ΔA, A, QR, ΔQR;
tol::Real = default_pullback_gaugetol(QR[2]),
rank_atol::Real = default_pullback_rank_atol(QR[2]),
gauge_atol::Real = default_pullback_gauge_atol(ΔQR[1])
)
Adds the pullback from the QR decomposition of `A` to `ΔA` given the output `QR` and
cotangent `ΔQR` of `qr_compact(A; positive = true)` or `qr_full(A; positive = true)`.
In the case where the rank `r` of the original matrix `A ≈ Q * R` (as determined by
`rank_atol`) is less then the minimum of the number of rows and columns, the cotangents `ΔQ`
and `ΔR`, only the first `r` columns of `Q` and the first `r` rows of `R` are well-defined,
and also the adjoint variables `ΔQ` and `ΔR` should have nonzero values only in the first
`r` columns and rows respectively. If nonzero values in the remaining columns or rows exceed
`gauge_atol`, a warning will be printed.
"""
function qr_pullback!(
ΔA::AbstractMatrix, A, QR, ΔQR;
rank_atol::Real = default_pullback_rank_atol(QR[2]),
gauge_atol::Real = default_pullback_gauge_atol(ΔQR[1])
)
# process
Q, R = QR
m = size(Q, 1)
n = size(R, 2)
Rd = diagview(R)
p = qr_rank(R)
ΔQ, ΔR = ΔQR
Q1 = view(Q, :, 1:p)
Q2 = view(Q, :, (p + 1):size(Q, 2))
R11 = view(R, 1:p, 1:p)
ΔA1 = view(ΔA, :, 1:p)
ΔA2 = view(ΔA, :, (p + 1):n)
check_qr_cotangents(Q, R, ΔQ, ΔR, p; gauge_atol)
ΔQ̃ = zero!(similar(Q, (m, p)))
if !iszerotangent(ΔQ)
copy!(ΔQ̃, view(ΔQ, :, 1:p))
if p < size(Q, 2)
Q2 = view(Q, :, (p + 1):size(Q, 2))
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
Q1dΔQ2 = Q1' * ΔQ2
check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol)
ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1)
end
end
if !iszerotangent(ΔR) && n > p
R12 = view(R, 1:p, (p + 1):n)
ΔR12 = view(ΔR, 1:p, (p + 1):n)
ΔQ̃ = mul!(ΔQ̃, Q1, ΔR12 * R12', -1, 1)
# Adding ΔA2 contribution
ΔA2 = mul!(ΔA2, Q1, ΔR12, 1, 1)
end
# construct M
M = zero!(similar(R, (p, p)))
if !iszerotangent(ΔR)
ΔR11 = view(ΔR, 1:p, 1:p)
M = mul!(M, ΔR11, R11', 1, 1)
end
M = mul!(M, Q1', ΔQ̃, -1, 1)
view(M, lowertriangularind(M)) .= conj.(view(M, uppertriangularind(M)))
if eltype(M) <: Complex
Md = diagview(M)
Md .= real.(Md)
end
rdiv!(M, UpperTriangular(R11)')
rdiv!(ΔQ̃, UpperTriangular(R11)')
ΔA1 = mul!(ΔA1, Q1, M, +1, 1)
ΔA1 .+= ΔQ̃
return ΔA
end
function check_qr_null_cotangents(N, ΔN; gauge_atol::Real = default_pullback_gauge_atol(ΔN))
aNᴴΔN = project_antihermitian!(N' * ΔN)
Δgauge = norm(aNᴴΔN)
Δgauge ≤ gauge_atol ||
@warn "`qr_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
end
"""
qr_null_pullback!(
ΔA::AbstractMatrix, A, N, ΔN;
gauge_atol::Real = default_pullback_gauge_atol(ΔN)
)
Adds the pullback from the right nullspace of `A` to `ΔA`, given the nullspace basis
`N` and its cotangent `ΔN` of `qr_null(A)`.
See also [`qr_pullback!`](@ref).
"""
function qr_null_pullback!(
ΔA::AbstractMatrix, A, N, ΔN;
gauge_atol::Real = default_pullback_gauge_atol(ΔN)
)
if !iszerotangent(ΔN) && size(N, 2) > 0
check_qr_null_cotangents(N, ΔN; gauge_atol)
Q, R = qr_compact(A; positive = true)
X = rdiv!(ΔN' * Q, UpperTriangular(R)')
ΔA = mul!(ΔA, N, X, -1, 1)
end
return ΔA
end