Skip to content

Commit d9c4a3f

Browse files
committed
some more changes / lq tests are failing for unknown reasons
1 parent 1abbd3b commit d9c4a3f

4 files changed

Lines changed: 86 additions & 111 deletions

File tree

src/common/view.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ diagonal(v::AbstractVector) = Diagonal(v)
2424
function lowertriangularind(A::AbstractMatrix)
2525
Base.require_one_based_indexing(A)
2626
m, n = size(A)
27-
I = Vector{Int}(undef, div(m * (m - 1), 2) + m * (n - m))
27+
minmn = min(m, n)
28+
I = Vector{Int}(undef, div(minmn * (minmn - 1), 2) + minmn * (m - minmn))
2829
offset = 0
2930
for j in 1:n
3031
r = (j + 1):m
@@ -37,7 +38,8 @@ end
3738
function uppertriangularind(A::AbstractMatrix)
3839
Base.require_one_based_indexing(A)
3940
m, n = size(A)
40-
I = Vector{Int}(undef, div(m * (m - 1), 2) + m * (n - m))
41+
minmn = min(m, n)
42+
I = Vector{Int}(undef, div(minmn * (minmn - 1), 2) + minmn * (n - minmn))
4143
offset = 0
4244
for i in 1:m
4345
r = (i + 1):n

src/pullbacks/lq.jl

Lines changed: 27 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,40 +5,25 @@ function check_lq_cotangents(
55
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
66
)
77
minmn = min(size(L, 1), size(Q, 2))
8-
if minmn > p # case where A is rank-deficient
9-
Δgauge = abs(zero(eltype(Q)))
10-
if !iszerotangent(ΔQ)
11-
# in this case the number Householder reflections will
12-
# change upon small variations, and all of the remaining
13-
# rows of ΔQ should be zero for a gauge-invariant
14-
# cost function
15-
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
16-
Δgauge_Q = norm(ΔQ2, Inf)
17-
Δgauge = max(Δgauge, Δgauge_Q)
18-
end
19-
if !iszerotangent(ΔL)
20-
ΔL22 = view(ΔL, (p + 1):size(L, 1), (p + 1):minmn)
21-
Δgauge_L = norm(ΔL22, Inf)
22-
Δgauge = max(Δgauge, Δgauge_L)
23-
end
24-
Δgauge gauge_atol ||
25-
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
8+
Δgauge = abs(zero(eltype(Q)))
9+
if !iszerotangent(ΔQ)
10+
ΔQ₂ = view(ΔQ, (p + 1):minmn, :)
11+
ΔQ₃ = ΔQ[(minmn + 1):size(Q, 1), :]
12+
Δgauge_Q = norm(ΔQ₂, Inf)
13+
Q₁ = view(Q, 1:p, :)
14+
ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁'
15+
mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁, -1, 1)
16+
Δgauge_Q = max(Δgauge_Q, norm(ΔQ₃, Inf))
17+
Δgauge = max(Δgauge, Δgauge_Q)
18+
end
19+
if !iszerotangent(ΔL)
20+
ΔL22 = view(ΔL, (p + 1):size(ΔL, 1), (p + 1):minmn)
21+
Δgauge_L = norm(view(ΔL22, lowertriangularind(ΔL22)), Inf)
22+
Δgauge = max(Δgauge, Δgauge_L)
2623
end
27-
return
28-
end
29-
30-
function check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol::Real = default_pullback_gauge_atol(ΔQ2))
31-
# in the case where A is full rank, but there are more columns in Q than in A
32-
# (the case of `lq_full`), there is gauge-invariant information in the
33-
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
34-
# matrix. As the number of Householder reflections is in fixed in the full rank
35-
# case, Q is expected to rotate smoothly (we might even be able to predict) also
36-
# how the full Q2 will change, but this we omit for now, and we consider
37-
# Q2' * ΔQ2 as a gauge dependent quantity.
38-
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
3924
Δgauge gauge_atol ||
40-
@warn "`lq_full` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
41-
return
25+
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
26+
return nothing
4227
end
4328

4429
"""
@@ -67,13 +52,13 @@ function lq_pullback!(
6752
L, Q = LQ
6853
m = size(L, 1)
6954
n = size(Q, 2)
55+
minmn = min(m, n)
7056
p = lq_rank(L; rank_atol)
7157

7258
ΔL, ΔQ = ΔLQ
7359

7460
Q1 = view(Q, 1:p, :)
75-
Q2 = view(Q, (p + 1):size(Q, 1), :)
76-
L11 = view(L, 1:p, 1:p)
61+
L11 = LowerTriangular(view(L, 1:p, 1:p))
7762
ΔA1 = view(ΔA, 1:p, :)
7863
ΔA2 = view(ΔA, (p + 1):m, :)
7964

@@ -83,12 +68,11 @@ function lq_pullback!(
8368
if !iszerotangent(ΔQ)
8469
ΔQ1 = view(ΔQ, 1:p, :)
8570
copy!(ΔQ̃, ΔQ1)
86-
if p < size(Q, 1)
87-
Q2 = view(Q, (p + 1):size(Q, 1), :)
88-
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
89-
ΔQ2Q1ᴴ = ΔQ2 * Q1'
90-
check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol)
91-
ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1)
71+
if minmn < size(Q, 1)
72+
ΔQ3 = view(ΔQ, (minmn + 1):size(ΔQ, 1), :)
73+
Q3 = view(Q, (minmn + 1):size(Q, 1), :)
74+
ΔQ3Q1ᴴ = ΔQ3 * Q1'
75+
ΔQ̃ = mul!(ΔQ̃, ΔQ3Q1ᴴ', Q3, -1, 1)
9276
end
9377
end
9478
if !iszerotangent(ΔL) && m > p
@@ -102,7 +86,7 @@ function lq_pullback!(
10286
# construct M
10387
M = zero!(similar(L, (p, p)))
10488
if !iszerotangent(ΔL)
105-
ΔL11 = view(ΔL, 1:p, 1:p)
89+
ΔL11 = LowerTriangular(view(ΔL, 1:p, 1:p))
10690
M = mul!(M, L11', ΔL11, 1, 1)
10791
end
10892
M = mul!(M, ΔQ̃, Q1', -1, 1)
@@ -111,8 +95,8 @@ function lq_pullback!(
11195
Md = diagview(M)
11296
Md .= real.(Md)
11397
end
114-
ldiv!(LowerTriangular(L11)', M)
115-
ldiv!(LowerTriangular(L11)', ΔQ̃)
98+
ldiv!(L11', M)
99+
ldiv!(L11', ΔQ̃)
116100
ΔA1 = mul!(ΔA1, M, Q1, +1, 1)
117101
ΔA1 .+= ΔQ̃
118102
return ΔA

src/pullbacks/qr.jl

Lines changed: 29 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,40 +6,25 @@ function check_qr_cotangents(
66
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
77
)
88
minmn = min(size(Q, 1), size(R, 2))
9-
if minmn > p # case where A is rank-deficient
10-
Δgauge = abs(zero(eltype(Q)))
11-
if !iszerotangent(ΔQ)
12-
# in this case the number Householder reflections will
13-
# change upon small variations, and all of the remaining
14-
# columns of ΔQ should be zero for a gauge-invariant
15-
# cost function
16-
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
17-
Δgauge_Q = norm(ΔQ2, Inf)
18-
Δgauge = max(Δgauge, Δgauge_Q)
19-
end
20-
if !iszerotangent(ΔR)
21-
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2))
22-
Δgauge_R = norm(ΔR22, Inf)
23-
Δgauge = max(Δgauge, Δgauge_R)
24-
end
25-
Δgauge gauge_atol ||
26-
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
9+
Δgauge = abs(zero(eltype(Q)))
10+
if !iszerotangent(ΔQ)
11+
ΔQ₂ = view(ΔQ, :, (p + 1):minmn)
12+
ΔQ₃ = ΔQ[:, (minmn + 1):size(Q, 2)] # extra columns in the case of qr_full
13+
Δgauge_Q = norm(ΔQ₂, Inf)
14+
Q₁ = view(Q, :, 1:p)
15+
Q₁ᴴΔQ₃ = Q₁' * ΔQ₃
16+
mul!(ΔQ₃, Q₁, Q₁ᴴΔQ₃, -1, 1)
17+
Δgauge_Q = max(Δgauge_Q, norm(ΔQ₃, Inf))
18+
Δgauge = max(Δgauge, Δgauge_Q)
19+
end
20+
if !iszerotangent(ΔR)
21+
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2))
22+
Δgauge_R = norm(view(ΔR22, uppertriangularind(ΔR22)), Inf)
23+
Δgauge = max(Δgauge, Δgauge_R)
2724
end
28-
return
29-
end
30-
31-
function check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol::Real = default_pullback_gauge_atol(ΔQ2))
32-
# in the case where A is full rank, but there are more columns in Q than in A
33-
# (the case of `qr_full`), there is gauge-invariant information in the
34-
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
35-
# matrix. As the number of Householder reflections is in fixed in the full rank
36-
# case, Q is expected to rotate smoothly (we might even be able to predict) also
37-
# how the full Q2 will change, but this we omit for now, and we consider
38-
# Q2' * ΔQ2 as a gauge dependent quantity.
39-
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
4025
Δgauge gauge_atol ||
41-
@warn "`qr_full` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
42-
return
26+
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
27+
return nothing
4328
end
4429

4530
"""
@@ -69,27 +54,28 @@ function qr_pullback!(
6954
Q, R = QR
7055
m = size(Q, 1)
7156
n = size(R, 2)
57+
minmn = min(m, n)
7258
Rd = diagview(R)
7359
p = qr_rank(R; rank_atol)
7460

7561
ΔQ, ΔR = ΔQR
7662

7763
Q1 = view(Q, :, 1:p)
78-
R11 = view(R, 1:p, 1:p)
64+
R11 = UpperTriangular(view(R, 1:p, 1:p))
7965
ΔA1 = view(ΔA, :, 1:p)
8066
ΔA2 = view(ΔA, :, (p + 1):n)
8167

8268
check_qr_cotangents(Q, R, ΔQ, ΔR, p; gauge_atol)
8369

8470
ΔQ̃ = zero!(similar(Q, (m, p)))
8571
if !iszerotangent(ΔQ)
86-
copy!(ΔQ̃, view(ΔQ, :, 1:p))
87-
if p < size(Q, 2)
88-
Q2 = view(Q, :, (p + 1):size(Q, 2))
89-
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
90-
Q1dΔQ2 = Q1' * ΔQ2
91-
check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol)
92-
ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1)
72+
ΔQ₁ = view(ΔQ, :, 1:p)
73+
copy!(ΔQ̃, ΔQ₁)
74+
if minmn < size(Q, 2)
75+
ΔQ3 = view(ΔQ, :, (minmn + 1):size(ΔQ, 2)) # extra columns in the case of qr_full
76+
Q3 = view(Q, :, (minmn + 1):size(Q, 2))
77+
Q1ᴴΔQ3 = Q1' * ΔQ3
78+
ΔQ̃ = mul!(ΔQ̃, Q3, Q1ᴴΔQ3', -1, 1)
9379
end
9480
end
9581
if !iszerotangent(ΔR) && n > p
@@ -103,7 +89,7 @@ function qr_pullback!(
10389
# construct M
10490
M = zero!(similar(R, (p, p)))
10591
if !iszerotangent(ΔR)
106-
ΔR11 = view(ΔR, 1:p, 1:p)
92+
ΔR11 = UpperTriangular(view(ΔR, 1:p, 1:p))
10793
M = mul!(M, ΔR11, R11', 1, 1)
10894
end
10995
M = mul!(M, Q1', ΔQ̃, -1, 1)
@@ -112,8 +98,8 @@ function qr_pullback!(
11298
Md = diagview(M)
11399
Md .= real.(Md)
114100
end
115-
rdiv!(M, UpperTriangular(R11)')
116-
rdiv!(ΔQ̃, UpperTriangular(R11)')
101+
rdiv!(M, R11') # R11 is upper triangular
102+
rdiv!(ΔQ̃, R11')
117103
ΔA1 = mul!(ΔA1, Q1, M, +1, 1)
118104
ΔA1 .+= ΔQ̃
119105
return ΔA

test/testsuite/ad_utils.jl

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,15 @@ ambiguity. Additionally, rows of `ΔR` beyond the rank are zeroed out.
7878
"""
7979
function remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(R))
8080
r = MatrixAlgebraKit.qr_rank(R; rank_atol)
81-
Q₁ = @view Q[:, 1:r]
82-
ΔQ₂ = @view ΔQ[:, (r + 1):end]
81+
minmn = min(size(A)...)
82+
Q₁ = view(Q, :, 1:r)
83+
ΔQ₂ = view(ΔQ, :, (r + 1):minmn)
8384
ΔQ₂ .= 0
84-
# TODO: refine this by differentiating between rank deficiency and qr_full cases
85-
# Q₁ᴴΔQ₂ = Q₁' * ΔQ₂
86-
# mul!(ΔQ₂, Q₁, Q₁ᴴΔQ₂)
87-
view(ΔR, (r + 1):size(ΔR, 1), :) .= 0
85+
ΔQ₃ = view(ΔQ, :, (minmn + 1):size(ΔQ, 2)) # extra columns in the case of qr_full
86+
Q₁ᴴΔQ₃ = Q₁' * ΔQ₃
87+
mul!(ΔQ₃, Q₁, Q₁ᴴΔQ₃)
88+
ΔR22 = view(ΔR, (r + 1):minmn, (r + 1):size(R, 2))
89+
view(ΔR22, MatrixAlgebraKit.uppertriangularind(ΔR22)) .= 0
8890
return ΔQ, ΔR
8991
end
9092

@@ -110,13 +112,15 @@ Additionally, columns of `ΔL` beyond the rank are zeroed out.
110112
"""
111113
function remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(L))
112114
r = MatrixAlgebraKit.lq_rank(L; rank_atol)
113-
Q₁ = @view Q[1:r, :]
114-
ΔQ₂ = @view ΔQ[(r + 1):end, :]
115+
minmn = min(size(A)...)
116+
Q₁ = view(Q, 1:r, :)
117+
ΔQ₂ = view(ΔQ, (r + 1):minmn, :)
115118
ΔQ₂ .= 0
116-
# TODO: refine this by differentiating between rank deficiency and lq_full cases
117-
# ΔQ₂Q₁ᴴ = ΔQ₂ * Q₁'
118-
# mul!(ΔQ₂, ΔQ₂Q₁ᴴ, Q₁)
119-
view(ΔL, :, (r + 1):size(ΔL, 2)) .= 0
119+
ΔQ₃ = view(ΔQ, (minmn + 1):size(ΔQ, 1), :) # extra rows in the case of lq_full
120+
ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁'
121+
mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁)
122+
ΔL22 = view(ΔL, (r + 1):size(ΔL, 1), (r + 1):minmn)
123+
view(ΔL22, MatrixAlgebraKit.lowertriangularind(ΔL22)) .= 0
120124
return ΔL, ΔQ
121125
end
122126

@@ -220,22 +224,22 @@ end
220224

221225
function ad_qr_compact_setup(A)
222226
QR = qr_compact(A)
223-
ΔQR = structured_randn!.(copy.(QR))
224-
A isa Diagonal || remove_qr_gauge_dependence!(ΔQR..., A, QR...)
227+
ΔQR = structured_randn!.(similar.(QR))
228+
remove_qr_gauge_dependence!(ΔQR..., A, QR...)
225229
return QR, ΔQR
226230
end
227231

228232
function ad_qr_null_setup(A)
229233
N = qr_null(A)
230-
ΔN = randn!(copy(N))
234+
ΔN = structured_randn!(similar(N))
231235
remove_qr_null_gauge_dependence!(ΔN, A, N)
232236
return N, ΔN
233237
end
234238

235239
function ad_qr_full_setup(A)
236240
QR = qr_full(A)
237-
ΔQR = structured_randn!.(copy.(QR))
238-
A isa Diagonal || remove_qr_gauge_dependence!(ΔQR..., A, QR...)
241+
ΔQR = structured_randn!.(similar.(QR))
242+
remove_qr_gauge_dependence!(ΔQR..., A, QR...)
239243
return QR, ΔQR
240244
end
241245

@@ -275,23 +279,22 @@ end
275279

276280
function ad_lq_compact_setup(A)
277281
LQ = lq_compact(A)
278-
ΔLQ = structured_randn!.(copy.(LQ))
279-
A isa Diagonal || remove_lq_gauge_dependence!(ΔLQ..., A, LQ...)
282+
ΔLQ = structured_randn!.(similar.(LQ))
283+
remove_lq_gauge_dependence!(ΔLQ..., A, LQ...)
280284
return LQ, ΔLQ
281285
end
282286

283287
function ad_lq_null_setup(A)
284-
T = eltype(A)
285288
Nᴴ = lq_null(A)
286-
ΔNᴴ = randn!(similar(A, T, size(Nᴴ)...))
289+
ΔNᴴ = structured_randn!(similar(Nᴴ))
287290
remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ)
288291
return Nᴴ, ΔNᴴ
289292
end
290293

291294
function ad_lq_full_setup(A)
292295
LQ = lq_full(A)
293-
ΔLQ = structured_randn!.(copy.(LQ))
294-
A isa Diagonal || remove_lq_gauge_dependence!(ΔLQ..., A, LQ...)
296+
ΔLQ = structured_randn!.(similar.(LQ))
297+
remove_lq_gauge_dependence!(ΔLQ..., A, LQ...)
295298
return LQ, ΔLQ
296299
end
297300

0 commit comments

Comments
 (0)