Skip to content

Commit 8a63647

Browse files
committed
Use Testsuite for AD tests
1 parent 579d792 commit 8a63647

22 files changed

Lines changed: 2231 additions & 1823 deletions

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Enzyme = "0.13.118"
3636
EnzymeTestUtils = "0.2.5"
3737
JET = "0.9, 0.10"
3838
LinearAlgebra = "1"
39-
Mooncake = "0.4.183"
39+
Mooncake = "0.4.195"
4040
ParallelTestRunner = "2"
4141
Random = "1"
4242
SafeTestsets = "0.1"

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using MatrixAlgebraKit: diagview, sign_safe
77
using MatrixAlgebraKit: LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
10-
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!
10+
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester, svd_rank
1111
using AMDGPU
1212
using LinearAlgebra
1313
using LinearAlgebra: BlasFloat
@@ -171,4 +171,11 @@ end
171171
MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::ROCVector{Int}) =
172172
MatrixAlgebraKit._ind_intersect(collect(A), collect(B))
173173

174+
function _sylvester(A::AnyROCMatrix, B::AnyROCMatrix, C::AnyROCMatrix)
175+
hX = sylvester(collect(A), collect(B), collect(C))
176+
return ROCArray(hX)
177+
end
178+
179+
svd_rank(S::AnyROCVector, rank_atol) = findlast(s -> s rank_atol, S)
180+
174181
end

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ module MatrixAlgebraKitCUDAExt
33
using MatrixAlgebraKit
44
using MatrixAlgebraKit: @algdef, Algorithm, check_input
55
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
6-
using MatrixAlgebraKit: diagview, sign_safe
6+
using MatrixAlgebraKit: diagview, sign_safe, default_pullback_gauge_atol, default_pullback_rank_atol
77
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
10-
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!
10+
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _sylvester, svd_rank
1111
using CUDA, CUDA.CUBLAS
1212
using CUDA: i32
1313
using LinearAlgebra
@@ -195,4 +195,20 @@ end
195195
MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) =
196196
MatrixAlgebraKit._ind_intersect(collect(A), collect(B))
197197

198+
MatrixAlgebraKit.default_pullback_rank_atol(A::AnyCuArray) = eps(norm(CuArray(A), Inf))^(3 / 4)
199+
MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray) = MatrixAlgebraKit.iszerotangent(A) ? 0 : eps(norm(CuArray(A), Inf))^(3 / 4)
200+
function MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray, As...)
201+
As′ = filter(!MatrixAlgebraKit.iszerotangent, (A, As...))
202+
return isempty(As′) ? 0 : eps(norm(CuArray.(As′), Inf))^(3 / 4)
203+
end
204+
205+
function _sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix)
206+
# https://github.com/JuliaGPU/CUDA.jl/issues/3021
207+
# to add native sylvester to CUDA
208+
hX = sylvester(collect(A), collect(B), collect(C))
209+
return CuArray(hX)
210+
end
211+
212+
svd_rank(S::AnyCuVector, rank_atol) = findlast(s -> s rank_atol, S)
213+
198214
end

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ for eig in (:eig, :eigh)
9595
eig_t! = Symbol(eig, "_trunc!")
9696
eig_t_pb = Symbol(eig, "_trunc_pullback")
9797
_make_eig_t_pb = Symbol("_make_", eig_t_pb)
98+
eig_t_ne! = Symbol(eig, "_trunc_no_error!")
99+
eig_t_ne_pb = Symbol(eig, "_trunc_no_error_pullback")
100+
_make_eig_t_ne_pb = Symbol("_make_", eig_t_ne_pb)
98101
eig_v = Symbol(eig, "_vals")
99102
eig_v! = Symbol(eig_v, "!")
100103
eig_v_pb = Symbol(eig_v, "_pullback")
@@ -136,6 +139,24 @@ for eig in (:eig, :eigh)
136139
end
137140
return $eig_t_pb
138141
end
142+
function ChainRulesCore.rrule(::typeof($eig_t_ne!), A, DV, alg::TruncatedAlgorithm)
143+
Ac = copy_input($eig_f, A)
144+
DV = $(eig_f!)(Ac, DV, alg.alg)
145+
DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc)
146+
return DV′, $(_make_eig_t_ne_pb)(A, DV, ind)
147+
end
148+
function $(_make_eig_t_ne_pb)(A, DV, ind)
149+
function $eig_t_ne_pb(ΔDV)
150+
ΔA = zero(A)
151+
ΔD, ΔV = ΔDV
152+
MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.((ΔD, ΔV)), ind)
153+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
154+
end
155+
function $eig_t_ne_pb(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful?
156+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
157+
end
158+
return $eig_t_ne_pb
159+
end
139160
function ChainRulesCore.rrule(::typeof($eig_v!), A, D, alg)
140161
DV = $eig_f(A, alg)
141162
function $eig_v_pb(ΔD)

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module MatrixAlgebraKitMooncakeExt
33
using Mooncake
44
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
55
using MatrixAlgebraKit
6-
using MatrixAlgebraKit: inv_safe, diagview, copy_input
6+
using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output
77
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
88
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
99
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
@@ -18,14 +18,16 @@ Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.N
1818
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any}
1919
function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual)
2020
Ac = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA))
21-
dAc = Mooncake.zero_tangent(Ac)
21+
Ac_dAc = Mooncake.zero_fcodual(Ac)
22+
dAc = Mooncake.tangent(Ac_dAc)
2223
function copy_input_pb(::NoRData)
2324
Mooncake.increment!!(Mooncake.tangent(A_dA), dAc)
2425
return NoRData(), NoRData(), NoRData()
2526
end
26-
return CoDual(Ac, dAc), copy_input_pb
27+
return Ac_dAc, copy_input_pb
2728
end
2829

30+
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(initialize_output), Any, Any, Any}
2931
# two-argument in-place factorizations like LQ, QR, EIG
3032
for (f!, f, pb, adj) in (
3133
(:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint),

src/common/defaults.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ default_pullback_degeneracy_atol(A) = eps(norm(A, Inf))^(3 / 4)
3434
Default tolerance for deciding what values should be considered equal to 0.
3535
"""
3636
default_pullback_rank_atol(A) = eps(norm(A, Inf))^(3 / 4)
37+
default_pullback_rank_atol(A::Diagonal) = default_pullback_rank_atol(diagview(A))
3738

3839
"""
3940
default_hermitian_tol(A)

src/common/pullbacks.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@ function iszerotangent end
1010

1111
iszerotangent(::Any) = false
1212
iszerotangent(::Nothing) = true
13+
14+
# fallback
15+
_sylvester(A, B, C) = LinearAlgebra.sylvester(A, B, C)

src/pullbacks/eig.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
function check_eig_cotangents(
2+
D, VᴴΔV;
3+
degeneracy_atol::Real = default_pullback_rank_atol(D),
4+
gauge_atol::Real = default_pullback_gauge_atol(VᴴΔV)
5+
)
6+
mask = abs.(transpose(D) .- D) .< degeneracy_atol
7+
Δgauge = norm(view(VᴴΔV, mask))
8+
Δgauge gauge_atol ||
9+
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
10+
return
11+
end
12+
113
"""
214
eig_pullback!(
315
ΔA::AbstractMatrix, A, DV, ΔDV, [ind];
@@ -41,10 +53,7 @@ function eig_pullback!(
4153
length(indV) == pV || throw(DimensionMismatch())
4254
mul!(view(VᴴΔV, :, indV), V', ΔV)
4355

44-
mask = abs.(transpose(D) .- D) .< degeneracy_atol
45-
Δgauge = norm(view(VᴴΔV, mask), Inf)
46-
Δgauge gauge_atol ||
47-
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
56+
check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol)
4857

4958
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
5059

@@ -129,10 +138,7 @@ function eig_trunc_pullback!(
129138
if !iszerotangent(ΔV)
130139
(n, p) == size(ΔV) || throw(DimensionMismatch())
131140
VᴴΔV = V' * ΔV
132-
mask = abs.(transpose(D) .- D) .< degeneracy_atol
133-
Δgauge = norm(view(VᴴΔV, mask), Inf)
134-
Δgauge gauge_atol ||
135-
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
141+
check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol)
136142

137143
ΔVperp = ΔV - V * inv(G) * VᴴΔV
138144
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
@@ -150,7 +156,7 @@ function eig_trunc_pullback!(
150156
# add contribution from orthogonal complement
151157
PA = A - (A * V) / V
152158
Y = mul!(ΔVperp, PA', Z, 1, 1)
153-
X = sylvester(PA', -Dmat', Y)
159+
X = _sylvester(PA', -Dmat', Y)
154160
Z .+= X
155161

156162
if eltype(ΔA) <: Real

src/pullbacks/eigh.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
function check_eigh_cotangents(
2+
D, aVᴴΔV;
3+
degeneracy_atol::Real = default_pullback_rank_atol(D),
4+
gauge_atol::Real = default_pullback_gauge_atol(aVᴴΔV)
5+
)
6+
mask = abs.(D' .- D) .< degeneracy_atol
7+
Δgauge = norm(view(aVᴴΔV, mask))
8+
Δgauge gauge_atol ||
9+
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
10+
return
11+
end
12+
113
"""
214
eigh_pullback!(
315
ΔA::AbstractMatrix, A, DV, ΔDV, [ind];
@@ -42,10 +54,7 @@ function eigh_pullback!(
4254
mul!(view(VᴴΔV, :, indV), V', ΔV)
4355
aVᴴΔV = project_antihermitian(VᴴΔV) # can't use in-place or recycling doesn't work
4456

45-
mask = abs.(D' .- D) .< degeneracy_atol
46-
Δgauge = norm(view(aVᴴΔV, mask))
47-
Δgauge gauge_atol ||
48-
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
57+
check_eigh_cotangents(D, aVᴴΔV; degeneracy_atol, gauge_atol)
4958

5059
aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol)
5160

@@ -120,10 +129,7 @@ function eigh_trunc_pullback!(
120129
VᴴΔV = V' * ΔV
121130
aVᴴΔV = project_antihermitian!(VᴴΔV)
122131

123-
mask = abs.(D' .- D) .< degeneracy_atol
124-
Δgauge = norm(view(aVᴴΔV, mask))
125-
Δgauge gauge_atol ||
126-
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
132+
check_eigh_cotangents(D, aVᴴΔV; degeneracy_atol, gauge_atol)
127133

128134
aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol)
129135

@@ -138,7 +144,7 @@ function eigh_trunc_pullback!(
138144
# add contribution from orthogonal complement
139145
W = qr_null(V)
140146
WᴴΔV = W' * ΔV
141-
X = sylvester(W' * A * W, -Dmat, WᴴΔV)
147+
X = _sylvester(W' * A * W, -Dmat, WᴴΔV)
142148
Z = mul!(Z, W, X, 1, 1)
143149

144150
# put everything together: symmetrize for hermitian case

src/pullbacks/lq.jl

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,41 @@
1+
function check_lq_cotangents(
2+
L, Q, ΔL, ΔQ, minmn::Int, p::Int;
3+
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
4+
)
5+
if minmn > p # case where A is rank-deficient
6+
Δgauge = abs(zero(eltype(Q)))
7+
if !iszerotangent(ΔQ)
8+
# in this case the number Householder reflections will
9+
# change upon small variations, and all of the remaining
10+
# columns of ΔQ should be zero for a gauge-invariant
11+
# cost function
12+
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
13+
Δgauge = max(Δgauge, norm(ΔQ2))
14+
end
15+
if !iszerotangent(ΔL)
16+
ΔL22 = view(ΔL, (p + 1):size(L, 1), (p + 1):minmn)
17+
Δgauge = max(Δgauge, norm(ΔL22))
18+
end
19+
Δgauge gauge_atol ||
20+
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
21+
end
22+
return
23+
end
24+
25+
function check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol::Real = default_pullback_gauge_atol(Q1))
26+
# in the case where A is full rank, but there are more columns in Q than in A
27+
# (the case of `lq_full`), there is gauge-invariant information in the
28+
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
29+
# matrix. As the number of Householder reflections is in fixed in the full rank
30+
# case, Q is expected to rotate smoothly (we might even be able to predict) also
31+
# how the full Q2 will change, but this we omit for now, and we consider
32+
# Q2' * ΔQ2 as a gauge dependent quantity.
33+
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
34+
Δgauge gauge_atol ||
35+
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
36+
return
37+
end
38+
139
"""
240
lq_pullback!(
341
ΔA, A, LQ, ΔLQ;
@@ -36,23 +74,7 @@ function lq_pullback!(
3674
ΔA1 = view(ΔA, 1:p, :)
3775
ΔA2 = view(ΔA, (p + 1):m, :)
3876

39-
if minmn > p # case where A is rank-deficient
40-
Δgauge = abs(zero(eltype(Q)))
41-
if !iszerotangent(ΔQ)
42-
# in this case the number Householder reflections will
43-
# change upon small variations, and all of the remaining
44-
# columns of ΔQ should be zero for a gauge-invariant
45-
# cost function
46-
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
47-
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
48-
end
49-
if !iszerotangent(ΔL)
50-
ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn)
51-
Δgauge = max(Δgauge, norm(ΔL22, Inf))
52-
end
53-
Δgauge gauge_atol ||
54-
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
55-
end
77+
check_lq_cotangents(L, Q, ΔL, ΔQ, minmn, p; gauge_atol)
5678

5779
ΔQ̃ = zero!(similar(Q, (p, n)))
5880
if !iszerotangent(ΔQ)
@@ -61,17 +83,8 @@ function lq_pullback!(
6183
if p < size(Q, 1)
6284
Q2 = view(Q, (p + 1):size(Q, 1), :)
6385
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
64-
# in the case where A is full rank, but there are more columns in Q than in A
65-
# (the case of `qr_full`), there is gauge-invariant information in the
66-
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
67-
# matrix. As the number of Householder reflections is in fixed in the full rank
68-
# case, Q is expected to rotate smoothly (we might even be able to predict) also
69-
# how the full Q2 will change, but this we omit for now, and we consider
70-
# Q2' * ΔQ2 as a gauge dependent quantity.
7186
ΔQ2Q1ᴴ = ΔQ2 * Q1'
72-
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
73-
Δgauge gauge_atol ||
74-
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
87+
check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol)
7588
ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1)
7689
end
7790
end
@@ -102,6 +115,14 @@ function lq_pullback!(
102115
return ΔA
103116
end
104117

118+
function check_lq_null_cotangents(Nᴴ, ΔNᴴ; gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ))
119+
aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ')
120+
Δgauge = norm(aNᴴΔN)
121+
Δgauge gauge_atol ||
122+
@warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
123+
return
124+
end
125+
105126
"""
106127
lq_null_pullback!(
107128
ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ;
@@ -118,10 +139,7 @@ function lq_null_pullback!(
118139
gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ)
119140
)
120141
if !iszerotangent(ΔNᴴ) && size(Nᴴ, 1) > 0
121-
aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ')
122-
Δgauge = norm(aNᴴΔN)
123-
Δgauge gauge_atol ||
124-
@warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
142+
check_lq_null_cotangents(Nᴴ, ΔNᴴ; gauge_atol)
125143
L, Q = lq_compact(A; positive = true) # should we be able to provide algorithm here?
126144
X = ldiv!(LowerTriangular(L)', Q * ΔNᴴ')
127145
ΔA = mul!(ΔA, X, Nᴴ, -1, 1)

0 commit comments

Comments
 (0)