-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathlq.jl
More file actions
151 lines (137 loc) · 5.45 KB
/
lq.jl
File metadata and controls
151 lines (137 loc) · 5.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
lq_rank(L; kwargs...) = qr_rank(L; kwargs...)
function check_lq_cotangents(
L, Q, ΔL, ΔQ, p::Int;
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
)
minmn = min(size(L, 1), size(Q, 2))
if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
# in this case the number Householder reflections will
# change upon small variations, and all of the remaining
# rows of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
Δgauge_Q = norm(ΔQ2, Inf)
Δgauge = max(Δgauge, Δgauge_Q)
end
if !iszerotangent(ΔL)
ΔL22 = view(ΔL, (p + 1):size(L, 1), (p + 1):minmn)
Δgauge_L = norm(ΔL22, Inf)
Δgauge = max(Δgauge, Δgauge_L)
end
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end
return
end
function check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol::Real = default_pullback_gauge_atol(ΔQ2))
# in the case where A is full rank, but there are more columns in Q than in A
# (the case of `lq_full`), there is gauge-invariant information in the
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
# matrix. As the number of Householder reflections is in fixed in the full rank
# case, Q is expected to rotate smoothly (we might even be able to predict) also
# how the full Q2 will change, but this we omit for now, and we consider
# Q2' * ΔQ2 as a gauge dependent quantity.
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`lq_full` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
end
"""
lq_pullback!(
ΔA, A, LQ, ΔLQ;
rank_atol::Real = default_pullback_rank_atol(LQ[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔLQ[2])
)
Adds the pullback from the LQ decomposition of `A` to `ΔA` given the output `LQ` and
cotangent `ΔLQ` of `lq_compact(A; positive = true)` or `lq_full(A; positive = true)`.
In the case where the rank `r` of the original matrix `A ≈ L * Q` (as determined by
`rank_atol`) is less then the minimum of the number of rows and columns of the cotangents
`ΔL` and `ΔQ`, only the first `r` columns of `L` and the first `r` rows of `Q` are
well-defined, and also the adjoint variables `ΔL` and `ΔQ` should have nonzero values only
in the first `r` columns and rows respectively. If nonzero values in the remaining columns
or rows exceed `gauge_atol`, a warning will be printed.
"""
function lq_pullback!(
ΔA::AbstractMatrix, A, LQ, ΔLQ;
rank_atol::Real = default_pullback_rank_atol(LQ[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔLQ[2])
)
# process
L, Q = LQ
m = size(L, 1)
n = size(Q, 2)
p = lq_rank(L; rank_atol)
ΔL, ΔQ = ΔLQ
Q1 = view(Q, 1:p, :)
Q2 = view(Q, (p + 1):size(Q, 1), :)
L11 = view(L, 1:p, 1:p)
ΔA1 = view(ΔA, 1:p, :)
ΔA2 = view(ΔA, (p + 1):m, :)
check_lq_cotangents(L, Q, ΔL, ΔQ, p; gauge_atol)
ΔQ̃ = zero!(similar(Q, (p, n)))
if !iszerotangent(ΔQ)
ΔQ1 = view(ΔQ, 1:p, :)
copy!(ΔQ̃, ΔQ1)
if p < size(Q, 1)
Q2 = view(Q, (p + 1):size(Q, 1), :)
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
ΔQ2Q1ᴴ = ΔQ2 * Q1'
check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol)
ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1)
end
end
if !iszerotangent(ΔL) && m > p
L21 = view(L, (p + 1):m, 1:p)
ΔL21 = view(ΔL, (p + 1):m, 1:p)
ΔQ̃ = mul!(ΔQ̃, L21' * ΔL21, Q1, -1, 1)
# Adding ΔA2 contribution
ΔA2 = mul!(ΔA2, ΔL21, Q1, 1, 1)
end
# construct M
M = zero!(similar(L, (p, p)))
if !iszerotangent(ΔL)
ΔL11 = view(ΔL, 1:p, 1:p)
M = mul!(M, L11', ΔL11, 1, 1)
end
M = mul!(M, ΔQ̃, Q1', -1, 1)
view(M, uppertriangularind(M)) .= conj.(view(M, lowertriangularind(M)))
if eltype(M) <: Complex
Md = diagview(M)
Md .= real.(Md)
end
ldiv!(LowerTriangular(L11)', M)
ldiv!(LowerTriangular(L11)', ΔQ̃)
ΔA1 = mul!(ΔA1, M, Q1, +1, 1)
ΔA1 .+= ΔQ̃
return ΔA
end
function check_lq_null_cotangents(Nᴴ, ΔNᴴ; gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ))
aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ')
Δgauge = norm(aNᴴΔN)
Δgauge ≤ gauge_atol ||
@warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
end
"""
lq_null_pullback!(
ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ;
gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ)
)
Adds the pullback from the left nullspace of `A` to `ΔA`, given the nullspace basis
`Nᴴ` and its cotangent `ΔNᴴ` of `lq_null(A)`.
See also [`lq_pullback!`](@ref).
"""
function lq_null_pullback!(
ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ;
gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ)
)
if !iszerotangent(ΔNᴴ) && size(Nᴴ, 1) > 0
check_lq_null_cotangents(Nᴴ, ΔNᴴ; gauge_atol)
L, Q = lq_compact(A; positive = true) # should we be able to provide algorithm here?
X = ldiv!(LowerTriangular(L)', Q * ΔNᴴ')
ΔA = mul!(ΔA, X, Nᴴ, -1, 1)
end
return ΔA
end