-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathqr.jl
More file actions
104 lines (97 loc) · 3.9 KB
/
qr.jl
File metadata and controls
104 lines (97 loc) · 3.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
qr_pullback!(
ΔA, A, QR, ΔQR;
tol::Real = default_pullback_gaugetol(QR[2]),
rank_atol::Real = tol,
gauge_atol::Real = tol
)
Adds the pullback from the QR decomposition of `A` to `ΔA` given the output `QR` and
cotangent `ΔQR` of `qr_compact(A; positive = true)` or `qr_full(A; positive = true)`.
In the case where the rank `r` of the original matrix `A ≈ Q * R` (as determined by
`rank_atol`) is less then the minimum of the number of rows and columns, the cotangents `ΔQ`
and `ΔR`, only the first `r` columns of `Q` and the first `r` rows of `R` are well-defined,
and also the adjoint variables `ΔQ` and `ΔR` should have nonzero values only in the first
`r` columns and rows respectively. If nonzero values in the remaining columns or rows exceed
`gauge_atol`, a warning will be printed.
"""
function qr_pullback!(
ΔA::AbstractMatrix, A, QR, ΔQR;
tol::Real = default_pullback_gaugetol(QR[2]),
rank_atol::Real = tol,
gauge_atol::Real = tol
)
# process
Q, R = QR
m = size(Q, 1)
n = size(R, 2)
minmn = min(m, n)
Rd = diagview(R)
p = findlast(>=(rank_atol) ∘ abs, Rd)
ΔQ, ΔR = ΔQR
Q1 = view(Q, :, 1:p)
Q2 = view(Q, :, (p + 1):size(Q, 2))
R11 = view(R, 1:p, 1:p)
ΔA1 = view(ΔA, :, 1:p)
ΔA2 = view(ΔA, :, (p + 1):n)
if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
# in this case the number Householder reflections will
# change upon small variations, and all of the remaining
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
end
if !iszerotangent(ΔR)
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):n)
Δgauge = max(Δgauge, norm(ΔR22, Inf))
end
Δgauge < gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end
ΔQ̃ = zero!(similar(Q, (m, p)))
if !iszerotangent(ΔQ)
copy!(ΔQ̃, view(ΔQ, :, 1:p))
if p < size(Q, 2)
Q2 = view(Q, :, (p + 1):size(Q, 2))
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
# in the case where A is full rank, but there are more columns in Q than in A
# (the case of `qr_full`), there is gauge-invariant information in the
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
# matrix. As the number of Householder reflections is in fixed in the full rank
# case, Q is expected to rotate smoothly (we might even be able to predict) also
# how the full Q2 will change, but this we omit for now, and we consider
# Q2' * ΔQ2 as a gauge dependent quantity.
Q1dΔQ2 = Q1' * ΔQ2
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
Δgauge < tol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1)
end
end
if !iszerotangent(ΔR) && n > p
R12 = view(R, 1:p, (p + 1):n)
ΔR12 = view(ΔR, 1:p, (p + 1):n)
ΔQ̃ = mul!(ΔQ̃, Q1, ΔR12 * R12', -1, 1)
# Adding ΔA2 contribution
ΔA2 = mul!(ΔA2, Q1, ΔR12, 1, 1)
end
# construct M
M = zero!(similar(R, (p, p)))
if !iszerotangent(ΔR)
ΔR11 = view(ΔR, 1:p, 1:p)
M = mul!(M, ΔR11, R11', 1, 1)
end
M = mul!(M, Q1', ΔQ̃, -1, 1)
view(M, lowertriangularind(M)) .= conj.(view(M, uppertriangularind(M)))
if eltype(M) <: Complex
Md = diagview(M)
Md .= real.(Md)
end
rdiv!(M, UpperTriangular(R11)')
rdiv!(ΔQ̃, UpperTriangular(R11)')
ΔA1 = mul!(ΔA1, Q1, M, +1, 1)
ΔA1 .+= ΔQ̃
return ΔA
end