Skip to content

Commit 2348b5e

Browse files
committed
Fold in enzyme tests
1 parent e407a2d commit 2348b5e

11 files changed

Lines changed: 519 additions & 566 deletions

File tree

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using MatrixAlgebraKit: diagview, sign_safe
77
using MatrixAlgebraKit: LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
10-
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester
10+
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester, svd_rank
1111
using AMDGPU
1212
using LinearAlgebra
1313
using LinearAlgebra: BlasFloat
@@ -176,4 +176,6 @@ function _sylvester(A::AnyROCMatrix, B::AnyROCMatrix, C::AnyROCMatrix)
176176
return ROCArray(hX)
177177
end
178178

179+
svd_rank(S::AnyROCVector, rank_atol) = findlast(s -> s rank_atol, S)
180+
179181
end

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using MatrixAlgebraKit: diagview, sign_safe, default_pullback_gauge_atol, defaul
77
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
10-
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _sylvester
10+
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _sylvester, svd_rank
1111
using CUDA, CUDA.CUBLAS
1212
using CUDA: i32
1313
using LinearAlgebra
@@ -209,4 +209,6 @@ function _sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix)
209209
return CuArray(hX)
210210
end
211211

212+
svd_rank(S::AnyCuVector, rank_atol) = findlast(s -> s rank_atol, S)
213+
212214
end

src/pullbacks/eig.jl

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,16 @@ function eig_pullback!(
7878
end
7979
return ΔA
8080
end
81+
function eig_pullback!(
82+
ΔA::Diagonal, A, DV, ΔDV, ind = Colon();
83+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
84+
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
85+
)
86+
ΔA_full = zero!(similar(ΔA, size(ΔA)))
87+
ΔA_full = eig_pullback!(ΔA_full, A, DV, ΔDV, ind; degeneracy_atol, gauge_atol)
88+
diagview(ΔA) .+= diagview(ΔA_full)
89+
return ΔA
90+
end
8191

8292
"""
8393
eig_trunc_pullback!(
@@ -151,6 +161,16 @@ function eig_trunc_pullback!(
151161
end
152162
return ΔA
153163
end
164+
function eig_trunc_pullback!(
165+
ΔA::Diagonal, A, DV, ΔDV;
166+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
167+
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
168+
)
169+
ΔA_full = zero!(similar(ΔA, size(ΔA)))
170+
ΔA_full = eig_trunc_pullback!(ΔA_full, A, DV, ΔDV; degeneracy_atol, gauge_atol)
171+
diagview(ΔA) .+= diagview(ΔA_full)
172+
return ΔA
173+
end
154174

155175
"""
156176
eig_vals_pullback!(
@@ -175,25 +195,3 @@ function eig_vals_pullback!(
175195
ΔDV = (diagonal(ΔD), nothing)
176196
return eig_pullback!(ΔA, A, DV, ΔDV, ind; degeneracy_atol)
177197
end
178-
179-
function eig_pullback!(
180-
ΔA::Diagonal, A, DV, ΔDV, ind = Colon();
181-
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
182-
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
183-
)
184-
ΔA_full = zero!(similar(ΔA, size(ΔA)))
185-
eig_pullback!(ΔA_full, A, DV, ΔDV, ind; degeneracy_atol, gauge_atol)
186-
diagview(ΔA) .+= diagview(ΔA_full)
187-
return ΔA
188-
end
189-
190-
function eig_trunc_pullback!(
191-
ΔA::Diagonal, A, DV, ΔDV;
192-
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
193-
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
194-
)
195-
ΔA_full = zero!(similar(ΔA, size(ΔA)))
196-
eig_trunc_pullback!(ΔA_full, A, DV, ΔDV; degeneracy_atol, gauge_atol)
197-
diagview(ΔA) .+= diagview(ΔA_full)
198-
return ΔA
199-
end

src/pullbacks/eigh.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ function eigh_pullback!(
5353
length(indV) == pV || throw(DimensionMismatch())
5454
mul!(view(VᴴΔV, :, indV), V', ΔV)
5555
aVᴴΔV = project_antihermitian(VᴴΔV) # can't use in-place or recycling doesn't work
56+
5657
check_eigh_cotangents(D, aVᴴΔV; degeneracy_atol, gauge_atol)
58+
5759
aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol)
5860

5961
if !iszerotangent(ΔDmat)

src/pullbacks/lq.jl

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ function check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol::Real = defaul
3636
return
3737
end
3838

39-
4039
"""
4140
lq_pullback!(
4241
ΔA, A, LQ, ΔLQ;
@@ -80,17 +79,10 @@ function lq_pullback!(
8079
ΔQ̃ = zero!(similar(Q, (p, n)))
8180
if !iszerotangent(ΔQ)
8281
ΔQ1 = view(ΔQ, 1:p, :)
83-
ΔQ̃ .= ΔQ1
82+
copy!(ΔQ̃, ΔQ1)
8483
if p < size(Q, 1)
8584
Q2 = view(Q, (p + 1):size(Q, 1), :)
8685
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
87-
# in the case where A is full rank, but there are more columns in Q than in A
88-
# (the case of `qr_full`), there is gauge-invariant information in the
89-
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
90-
# matrix. As the number of Householder reflections is in fixed in the full rank
91-
# case, Q is expected to rotate smoothly (we might even be able to predict) also
92-
# how the full Q2 will change, but this we omit for now, and we consider
93-
# Q2' * ΔQ2 as a gauge dependent quantity.
9486
ΔQ2Q1ᴴ = ΔQ2 * Q1'
9587
check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol)
9688
ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1)
@@ -116,24 +108,12 @@ function lq_pullback!(
116108
Md = diagview(M)
117109
Md .= real.(Md)
118110
end
119-
# not GPU friendly...
120-
L11arr = typeof(L)(L11)
121-
ldiv!(LowerTriangular(L11arr)', M)
122-
ldiv!(LowerTriangular(L11arr)', ΔQ̃)
111+
ldiv!(LowerTriangular(L11)', M)
112+
ldiv!(LowerTriangular(L11)', ΔQ̃)
123113
ΔA1 = mul!(ΔA1, M, Q1, +1, 1)
124114
ΔA1 .+= ΔQ̃
125115
return ΔA
126116
end
127-
function lq_pullback!(
128-
ΔA::Diagonal, A, LQ, ΔLQ;
129-
rank_atol::Real = default_pullback_rank_atol(LQ[1]),
130-
gauge_atol::Real = default_pullback_gauge_atol(ΔLQ[2])
131-
)
132-
ΔA_full = zero!(similar(ΔA, size(ΔA)))
133-
ΔA_full = lq_pullback!(ΔA_full, A, LQ, ΔLQ; rank_atol, gauge_atol)
134-
diagview(ΔA) .+= diagview(ΔA_full)
135-
return ΔA
136-
end
137117

138118
function check_lq_null_cotangents(Nᴴ, ΔNᴴ; gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ))
139119
aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ')

src/pullbacks/polar.jl

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,6 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...)
2727
end
2828
return ΔA
2929
end
30-
function left_polar_pullback!(ΔA::Diagonal, A, WP, ΔWP; kwargs...)
31-
ΔA_full = zero!(similar(ΔA, size(ΔA)))
32-
ΔA_full = left_polar_pullback!(ΔA_full, A, WP, ΔWP; kwargs...)
33-
diagview(ΔA) .+= diagview(ΔA_full)
34-
return ΔA
35-
end
3630

3731
"""
3832
right_polar_pullback!(ΔA, A, PWᴴ, ΔPWᴴ)
@@ -52,7 +46,7 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...
5246
M = zero(P)
5347
!iszerotangent(ΔWᴴ) && mul!(M, ΔWᴴ, Wᴴ', 1, 1)
5448
!iszerotangent(ΔP) && mul!(M, P, ΔP, -1, 1)
55-
C = _sylvester(P, P, M' - M)
49+
C = sylvester(P, P, M' - M)
5650
C .+= ΔP
5751
ΔA = mul!(ΔA, C, Wᴴ, 1, 1)
5852
if !iszerotangent(ΔWᴴ)
@@ -63,9 +57,3 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...
6357
end
6458
return ΔA
6559
end
66-
function right_polar_pullback!(ΔA::Diagonal, A, PWᴴ, ΔPWᴴ; kwargs...)
67-
ΔA_full = zero!(similar(ΔA, size(ΔA)))
68-
ΔA_full = right_polar_pullback!(ΔA_full, A, PWᴴ, ΔPWᴴ; kwargs...)
69-
diagview(ΔA) .+= diagview(ΔA_full)
70-
return ΔA
71-
end

src/pullbacks/qr.jl

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ function qr_pullback!(
7676

7777
ΔQ̃ = zero!(similar(Q, (m, p)))
7878
if !iszerotangent(ΔQ)
79-
ΔQ̃ .= view(ΔQ, :, 1:p)
79+
copy!(ΔQ̃, view(ΔQ, :, 1:p))
8080
if p < size(Q, 2)
8181
Q2 = view(Q, :, (p + 1):size(Q, 2))
8282
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
@@ -105,25 +105,12 @@ function qr_pullback!(
105105
Md = diagview(M)
106106
Md .= real.(Md)
107107
end
108-
# not GPU-friendly...
109-
R11arr = typeof(R)(R11)
110-
rdiv!(M, UpperTriangular(R11arr)')
111-
rdiv!(ΔQ̃, UpperTriangular(R11arr)')
108+
rdiv!(M, UpperTriangular(R11)')
109+
rdiv!(ΔQ̃, UpperTriangular(R11)')
112110
ΔA1 = mul!(ΔA1, Q1, M, +1, 1)
113111
ΔA1 .+= ΔQ̃
114112
return ΔA
115113
end
116-
function qr_pullback!(
117-
ΔA::Diagonal, A, QR, ΔQR;
118-
rank_atol::Real = default_pullback_rank_atol(QR[2]),
119-
gauge_atol::Real = default_pullback_gauge_atol(ΔQR[1])
120-
)
121-
ΔA_full = zero!(similar(ΔA, size(ΔA)))
122-
ΔA_full = qr_pullback!(ΔA_full, A, QR, ΔQR; rank_atol, gauge_atol)
123-
@assert isdiag(ΔA_full)
124-
diagview(ΔA) .+= diagview(ΔA_full)
125-
return ΔA
126-
end
127114

128115
function check_qr_null_cotangents(N, ΔN; gauge_atol::Real = default_pullback_gauge_atol(ΔN))
129116
aNᴴΔN = project_antihermitian!(N' * ΔN)

src/pullbacks/svd.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
svd_rank(S, rank_atol) = searchsortedlast(S, rank_atol; rev = true)
2+
13
function check_svd_cotangents(aUΔU, Sr, aVΔV; degeneracy_atol = default_pullback_rank_atol(Sr), gauge_atol = default_pullback_gauge_atol(aUΔU, aVΔV))
24
mask = abs.(Sr' .- Sr) .< degeneracy_atol
35
Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf)
@@ -41,7 +43,7 @@ function svd_pullback!(
4143
minmn = min(m, n)
4244
S = diagview(Smat)
4345
length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)"))
44-
r = findlast(s -> s rank_atol, S) # rank
46+
r = svd_rank(S, rank_atol)
4547
Ur = view(U, :, 1:r)
4648
Vᴴr = view(Vᴴ, 1:r, :)
4749
Sr = view(S, 1:r)

0 commit comments

Comments
 (0)