Skip to content

Commit 6452b9e

Browse files
committed
Updates for eig and lq/qr diag pullbacks
1 parent 7098f45 commit 6452b9e

7 files changed

Lines changed: 138 additions & 103 deletions

File tree

src/pullbacks/eig.jl

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,3 @@
1-
function check_eig_cotangents(D, VᴴΔV; degeneracy_atol::Real = default_pullback_rank_atol(D), gauge_atol::Real = default_pullback_gauge_atol(VᴴΔV))
2-
mask = abs.(transpose(D) .- D) .< degeneracy_atol
3-
# not GPU friendly...
4-
Δgauge = norm(view(VᴴΔV, mask))
5-
Δgauge gauge_atol ||
6-
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
7-
return
8-
end
9-
101
"""
112
eig_pullback!(
123
ΔA::AbstractMatrix, A, DV, ΔDV, [ind];
@@ -49,7 +40,11 @@ function eig_pullback!(
4940
indV = axes(V, 2)[ind]
5041
length(indV) == pV || throw(DimensionMismatch())
5142
mul!(view(VᴴΔV, :, indV), V', ΔV)
52-
check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol)
43+
44+
mask = abs.(transpose(D) .- D) .< degeneracy_atol
45+
Δgauge = norm(view(VᴴΔV, mask), Inf)
46+
Δgauge gauge_atol ||
47+
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
5348

5449
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
5550

@@ -83,16 +78,6 @@ function eig_pullback!(
8378
end
8479
return ΔA
8580
end
86-
function eig_pullback!(
87-
ΔA::Diagonal, A, DV, ΔDV, ind = Colon();
88-
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
89-
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
90-
)
91-
ΔA_full = zero!(similar(ΔA, size(ΔA)))
92-
ΔA_full = eig_pullback!(ΔA_full, A, DV, ΔDV, ind; degeneracy_atol, gauge_atol)
93-
diagview(ΔA) .+= diagview(ΔA_full)
94-
return ΔA
95-
end
9681

9782
"""
9883
eig_trunc_pullback!(
@@ -134,7 +119,10 @@ function eig_trunc_pullback!(
134119
if !iszerotangent(ΔV)
135120
(n, p) == size(ΔV) || throw(DimensionMismatch())
136121
VᴴΔV = V' * ΔV
137-
check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol)
122+
mask = abs.(transpose(D) .- D) .< degeneracy_atol
123+
Δgauge = norm(view(VᴴΔV, mask), Inf)
124+
Δgauge gauge_atol ||
125+
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
138126

139127
ΔVperp = ΔV - V * inv(G) * VᴴΔV
140128
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
@@ -163,16 +151,6 @@ function eig_trunc_pullback!(
163151
end
164152
return ΔA
165153
end
166-
function eig_trunc_pullback!(
167-
ΔA::Diagonal, A, DV, ΔDV;
168-
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
169-
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
170-
)
171-
ΔA_full = zero!(similar(ΔA, size(ΔA)))
172-
ΔA_full = eig_trunc_pullback!(ΔA_full, A, DV, ΔDV; degeneracy_atol, gauge_atol)
173-
diagview(ΔA) .+= diagview(ΔA_full)
174-
return ΔA
175-
end
176154

177155
"""
178156
eig_vals_pullback!(
@@ -193,6 +171,29 @@ function eig_vals_pullback!(
193171
ΔA, A, DV, ΔD, ind = Colon();
194172
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
195173
)
174+
196175
ΔDV = (diagonal(ΔD), nothing)
197176
return eig_pullback!(ΔA, A, DV, ΔDV, ind; degeneracy_atol)
198177
end
178+
179+
function eig_pullback!(
180+
ΔA::Diagonal, A, DV, ΔDV, ind = Colon();
181+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
182+
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
183+
)
184+
ΔA_full = zero!(similar(ΔA, size(ΔA)))
185+
eig_pullback!(ΔA_full, A, DV, ΔDV, ind; degeneracy_atol, gauge_atol)
186+
diagview(ΔA) .+= diagview(ΔA_full)
187+
return ΔA
188+
end
189+
190+
function eig_trunc_pullback!(
191+
ΔA::Diagonal, A, DV, ΔDV;
192+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
193+
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
194+
)
195+
ΔA_full = zero!(similar(ΔA, size(ΔA)))
196+
eig_trunc_pullback!(ΔA_full, A, DV, ΔDV; degeneracy_atol, gauge_atol)
197+
diagview(ΔA) .+= diagview(ΔA_full)
198+
return ΔA
199+
end

src/pullbacks/lq.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,16 @@ function lq_pullback!(
124124
ΔA1 .+= ΔQ̃
125125
return ΔA
126126
end
127+
function lq_pullback!(
128+
ΔA::Diagonal, A, LQ, ΔLQ;
129+
rank_atol::Real = default_pullback_rank_atol(LQ[1]),
130+
gauge_atol::Real = default_pullback_gauge_atol(ΔLQ[2])
131+
)
132+
ΔA_full = zero!(similar(ΔA, size(ΔA)))
133+
ΔA_full = lq_pullback!(ΔA_full, A, LQ, ΔLQ; rank_atol, gauge_atol)
134+
diagview(ΔA) .+= diagview(ΔA_full)
135+
return ΔA
136+
end
127137

128138
function check_lq_null_cotangents(Nᴴ, ΔNᴴ; gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ))
129139
aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ')

src/pullbacks/polar.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...)
2727
end
2828
return ΔA
2929
end
30+
function left_polar_pullback!(ΔA::Diagonal, A, WP, ΔWP; kwargs...)
31+
ΔA_full = zero!(similar(ΔA, size(ΔA)))
32+
ΔA_full = left_polar_pullback!(ΔA_full, A, WP, ΔWP; kwargs...)
33+
diagview(ΔA) .+= diagview(ΔA_full)
34+
return ΔA
35+
end
3036

3137
"""
3238
right_polar_pullback!(ΔA, A, PWᴴ, ΔPWᴴ)
@@ -57,3 +63,9 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...
5763
end
5864
return ΔA
5965
end
66+
function right_polar_pullback!(ΔA::Diagonal, A, PWᴴ, ΔPWᴴ; kwargs...)
67+
ΔA_full = zero!(similar(ΔA, size(ΔA)))
68+
ΔA_full = right_polar_pullback!(ΔA_full, A, PWᴴ, ΔPWᴴ; kwargs...)
69+
diagview(ΔA) .+= diagview(ΔA_full)
70+
return ΔA
71+
end

src/pullbacks/qr.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ function check_qr_cotangents(Q, R, ΔQ, ΔR, minmn::Int, p::Int; gauge_atol::Rea
1010
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
1111
end
1212
if !iszerotangent(ΔR)
13-
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):n)
13+
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2))
1414
Δgauge = max(Δgauge, norm(ΔR22, Inf))
1515
end
1616
Δgauge gauge_atol ||
@@ -19,7 +19,7 @@ function check_qr_cotangents(Q, R, ΔQ, ΔR, minmn::Int, p::Int; gauge_atol::Rea
1919
return
2020
end
2121

22-
function check_qr_full_cotangents(Q1, ΔQ2, ΔR, Q1dΔQ2, ; gauge_atol::Real = default_pullback_gauge_atol(ΔQ2))
22+
function check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol::Real = default_pullback_gauge_atol(ΔQ2))
2323
# in the case where A is full rank, but there are more columns in Q than in A
2424
# (the case of `qr_full`), there is gauge-invariant information in the
2525
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
@@ -113,6 +113,17 @@ function qr_pullback!(
113113
ΔA1 .+= ΔQ̃
114114
return ΔA
115115
end
116+
function qr_pullback!(
117+
ΔA::Diagonal, A, QR, ΔQR;
118+
rank_atol::Real = default_pullback_rank_atol(QR[2]),
119+
gauge_atol::Real = default_pullback_gauge_atol(ΔQR[1])
120+
)
121+
ΔA_full = zero!(similar(ΔA, size(ΔA)))
122+
ΔA_full = qr_pullback!(ΔA_full, A, QR, ΔQR; rank_atol, gauge_atol)
123+
@assert isdiag(ΔA_full)
124+
diagview(ΔA) .+= diagview(ΔA_full)
125+
return ΔA
126+
end
116127

117128
function check_qr_null_cotangents(N, ΔN; gauge_atol::Real = default_pullback_gauge_atol(ΔN))
118129
aNᴴΔN = project_antihermitian!(N' * ΔN)

test/mooncake.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
1616
TestSuite.seed_rng!(123)
1717
if CUDA.functional()
1818
TestSuite.test_mooncake(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
19-
#n == m && TestSuite.test_mooncake(Diagonal{T, CuVector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T))
19+
n == m && TestSuite.test_mooncake(Diagonal{T, CuVector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T))
2020
end
2121
#=if AMDGPU.functional()
2222
TestSuite.test_mooncake(ROCMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))

test/testsuite/ad_utils.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,13 @@ function make_eig_matrix(T, sz)
5050
D, V = eig_full(A)
5151
stabilize_eigvals!(diagview(D))
5252
Ac = V * D * inv(V)
53-
return (eltype(T) <: Real) ? real(Ac) : Ac
53+
Af = (eltype(T) <: Real) ? real(Ac) : Ac
54+
if T <: Diagonal
55+
copyto!(diagview(A), diagview(Af))
56+
else
57+
copyto!(A, Af)
58+
end
59+
return A
5460
end
5561
function make_eigh_matrix(T, sz)
5662
A = project_hermitian!(instantiate_matrix(T, sz))
@@ -116,7 +122,7 @@ function ad_qr_rank_deficient_compact_setup(A)
116122
Q1 = view(Q, 1:m, 1:r)
117123
Q2 = view(Q, 1:m, (r + 1):minmn)
118124
ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn)
119-
zero!(ΔQ2)
125+
MatrixAlgebraKit.zero!(ΔQ2)
120126
ΔR = randn!(similar(A, T, minmn, n))
121127
view(ΔR, (r + 1):minmn, :) .= 0
122128
return (Q, R), (ΔQ, ΔR)
@@ -128,13 +134,13 @@ function ad_qr_rank_deficient_compact_setup(A::Diagonal)
128134
T = eltype(A)
129135
r = minmn - 5
130136
Ard_ = randn!(similar(A, T, m))
131-
zero!(view(Ard_, (r + 1):m))
137+
MatrixAlgebraKit.zero!(view(Ard_, (r + 1):m))
132138
Ard = Diagonal(Ard_)
133139
Q, R = qr_compact(Ard)
134140
ΔQ = Diagonal(randn!(similar(diagview(A), T, m)))
135141
ΔR = Diagonal(randn!(similar(diagview(A), T, m)))
136-
zero!(view(diagview(ΔQ), (r + 1):m))
137-
zero!(view(diagview(ΔR), (r + 1):m))
142+
MatrixAlgebraKit.zero!(view(diagview(ΔQ), (r + 1):m))
143+
MatrixAlgebraKit.zero!(view(diagview(ΔR), (r + 1):m))
138144
return (Q, R), (ΔQ, ΔR)
139145
end
140146

@@ -206,10 +212,10 @@ end
206212

207213
function ad_eig_full_setup(A::Diagonal)
208214
m, n = size(A)
209-
T = eltype(A)
215+
T = complex(eltype(A))
210216
DV = eig_full(A)
211217
D, V = DV
212-
ΔV = Diagonal(randn!(similar(A.diag, T, m)))
218+
ΔV = randn!(similar(A.diag, T, m, m))
213219
ΔV = remove_eiggauge_dependence!(ΔV, D, V)
214220
ΔD = Diagonal(randn!(similar(A.diag, T, m)))
215221
ΔD2 = Diagonal(randn!(similar(A.diag, T, m)))
@@ -218,17 +224,17 @@ end
218224

219225
function ad_eig_vals_setup(A)
220226
m, n = size(A)
221-
T = eltype(A)
227+
T = complex(eltype(A))
222228
D = eig_vals(A)
223229
ΔD = randn!(similar(A, complex(T), m))
224230
return D, ΔD
225231
end
226232

227233
function ad_eig_vals_setup(A::Diagonal)
228234
m, n = size(A)
229-
T = eltype(A)
235+
T = complex(eltype(A))
230236
D = eig_vals(A)
231-
ΔD = randn!(similar(A.diag))
237+
ΔD = randn!(similar(A.diag, T, m))
232238
return D, ΔD
233239
end
234240

0 commit comments

Comments
 (0)