Skip to content

Commit 09d7f69

Browse files
kshyattKatharine Hyatt
authored andcommitted
Use Testsuite for AD tests
1 parent 0d78d08 commit 09d7f69

16 files changed

Lines changed: 1686 additions & 1302 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ GenericLinearAlgebra = "0.3.19"
3232
GenericSchur = "0.5.6"
3333
JET = "0.9, 0.10"
3434
LinearAlgebra = "1"
35-
Mooncake = "0.4.183"
35+
Mooncake = "0.4.195"
3636
ParallelTestRunner = "2"
3737
Random = "1"
3838
SafeTestsets = "0.1"

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module MatrixAlgebraKitCUDAExt
33
using MatrixAlgebraKit
44
using MatrixAlgebraKit: @algdef, Algorithm, check_input
55
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
6-
using MatrixAlgebraKit: diagview, sign_safe
6+
using MatrixAlgebraKit: diagview, sign_safe, default_pullback_gauge_atol, default_pullback_rank_atol
77
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
@@ -195,4 +195,23 @@ end
195195
MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) =
196196
MatrixAlgebraKit._ind_intersect(collect(A), collect(B))
197197

198+
MatrixAlgebraKit.default_pullback_rank_atol(A::AnyCuArray) = eps(norm(CuArray(A), Inf))^(3 / 4)
199+
MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray) = MatrixAlgebraKit.iszerotangent(A) ? 0 : eps(norm(CuArray(A), Inf))^(3 / 4)
200+
function MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray, As...)
201+
As′ = filter(!MatrixAlgebraKit.iszerotangent, (A, As...))
202+
return isempty(As′) ? 0 : eps(norm(CuArray.(As′), Inf))^(3 / 4)
203+
end
204+
205+
function LinearAlgebra.sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix)
206+
#=m = size(A, 1)
207+
n = size(B, 2)
208+
I_n = fill!(similar(A, n), one(eltype(A)))
209+
I_m = fill!(similar(B, m), one(eltype(B)))
210+
L = kron(diagm(I_n), A) + kron(adjoint(B), diagm(I_m))
211+
x_vec = L \ -vec(C)
212+
X = CuMatrix(reshape(x_vec, m, n))=#
213+
hX = sylvester(collect(A), collect(B), collect(C))
214+
return CuArray(hX)
215+
end
216+
198217
end

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ for eig in (:eig, :eigh)
9595
eig_t! = Symbol(eig, "_trunc!")
9696
eig_t_pb = Symbol(eig, "_trunc_pullback")
9797
_make_eig_t_pb = Symbol("_make_", eig_t_pb)
98+
eig_t_ne! = Symbol(eig, "_trunc_no_error!")
99+
eig_t_ne_pb = Symbol(eig, "_trunc_no_error_pullback")
100+
_make_eig_t_ne_pb = Symbol("_make_", eig_t_ne_pb)
98101
eig_v = Symbol(eig, "_vals")
99102
eig_v! = Symbol(eig_v, "!")
100103
eig_v_pb = Symbol(eig_v, "_pullback")
@@ -136,6 +139,24 @@ for eig in (:eig, :eigh)
136139
end
137140
return $eig_t_pb
138141
end
142+
function ChainRulesCore.rrule(::typeof($eig_t_ne!), A, DV, alg::TruncatedAlgorithm)
143+
Ac = copy_input($eig_f, A)
144+
DV = $(eig_f!)(Ac, DV, alg.alg)
145+
DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc)
146+
return DV′, $(_make_eig_t_ne_pb)(A, DV, ind)
147+
end
148+
function $(_make_eig_t_ne_pb)(A, DV, ind)
149+
function $eig_t_ne_pb(ΔDV)
150+
ΔA = zero(A)
151+
ΔD, ΔV = ΔDV
152+
MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.((ΔD, ΔV)), ind)
153+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
154+
end
155+
function $eig_t_ne_pb(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful?
156+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
157+
end
158+
return $eig_t_ne_pb
159+
end
139160
function ChainRulesCore.rrule(::typeof($eig_v!), A, D, alg)
140161
DV = $eig_f(A, alg)
141162
function $eig_v_pb(ΔD)

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module MatrixAlgebraKitMooncakeExt
33
using Mooncake
44
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
55
using MatrixAlgebraKit
6-
using MatrixAlgebraKit: inv_safe, diagview, copy_input
6+
using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output
77
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
88
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
99
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
@@ -18,14 +18,24 @@ Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.N
1818
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any}
1919
function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual)
2020
Ac = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA))
21-
dAc = Mooncake.zero_tangent(Ac)
21+
Ac_dAc = Mooncake.zero_fcodual(Ac)
22+
dAc = Mooncake.tangent(Ac_dAc)
2223
function copy_input_pb(::NoRData)
2324
Mooncake.increment!!(Mooncake.tangent(A_dA), dAc)
2425
return NoRData(), NoRData(), NoRData()
2526
end
26-
return CoDual(Ac, dAc), copy_input_pb
27+
return Ac_dAc, copy_input_pb
2728
end
2829

30+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(initialize_output), Any, Any, Any}
31+
function Mooncake.rrule!!(::CoDual{typeof(initialize_output)}, f_df::CoDual, A_dA::CoDual, alg_dalg::CoDual)
32+
output = initialize_output(Mooncake.primal(f_df), Mooncake.primal(A_dA), Mooncake.primal(alg_dalg))
33+
output_doutput = Mooncake.zero_fcodual(output)
34+
initialize_output_pb(::NoRData) = (NoRData(), NoRData(), NoRData(), NoRData())
35+
return output_doutput, initialize_output_pb
36+
end
37+
38+
2939
# two-argument in-place factorizations like LQ, QR, EIG
3040
for (f!, f, pb, adj) in (
3141
(:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint),

src/pullbacks/eig.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,16 @@ function eig_pullback!(
3939
VᴴΔV = fill!(similar(V), 0)
4040
indV = axes(V, 2)[ind]
4141
length(indV) == pV || throw(DimensionMismatch())
42-
mul!(view(VᴴΔV, :, indV), V', ΔV)
42+
VᴴΔV[:, indV] .= V' * ΔV
43+
#mul!(view(VᴴΔV, :, indV), V', ΔV)
4344

4445
mask = abs.(transpose(D) .- D) .< degeneracy_atol
45-
Δgauge = norm(view(VᴴΔV, mask), Inf)
46-
Δgauge gauge_atol ||
47-
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
46+
if isa(ΔA, Array)
47+
# not GPU friendly...
48+
Δgauge = norm(view(VᴴΔV, mask), Inf)
49+
Δgauge gauge_atol ||
50+
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
51+
end
4852

4953
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
5054

src/pullbacks/lq.jl

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,28 +36,30 @@ function lq_pullback!(
3636
ΔA1 = view(ΔA, 1:p, :)
3737
ΔA2 = view(ΔA, (p + 1):m, :)
3838

39-
if minmn > p # case where A is rank-deficient
40-
Δgauge = abs(zero(eltype(Q)))
41-
if !iszerotangent(ΔQ)
42-
# in this case the number Householder reflections will
43-
# change upon small variations, and all of the remaining
44-
# columns of ΔQ should be zero for a gauge-invariant
45-
# cost function
46-
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
47-
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
48-
end
49-
if !iszerotangent(ΔL)
50-
ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn)
51-
Δgauge = max(Δgauge, norm(ΔL22, Inf))
39+
if isa(ΔA, Array) # not GPU friendly
40+
if minmn > p # case where A is rank-deficient
41+
Δgauge = abs(zero(eltype(Q)))
42+
if !iszerotangent(ΔQ)
43+
# in this case the number Householder reflections will
44+
# change upon small variations, and all of the remaining
45+
# columns of ΔQ should be zero for a gauge-invariant
46+
# cost function
47+
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
48+
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
49+
end
50+
if !iszerotangent(ΔL)
51+
ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn)
52+
Δgauge = max(Δgauge, norm(ΔL22, Inf))
53+
end
54+
Δgauge gauge_atol ||
55+
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
5256
end
53-
Δgauge gauge_atol ||
54-
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
5557
end
5658

5759
ΔQ̃ = zero!(similar(Q, (p, n)))
5860
if !iszerotangent(ΔQ)
5961
ΔQ1 = view(ΔQ, 1:p, :)
60-
copy!(ΔQ̃, ΔQ1)
62+
ΔQ̃ .= ΔQ1
6163
if p < size(Q, 1)
6264
Q2 = view(Q, (p + 1):size(Q, 1), :)
6365
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
@@ -69,9 +71,11 @@ function lq_pullback!(
6971
# how the full Q2 will change, but this we omit for now, and we consider
7072
# Q2' * ΔQ2 as a gauge dependent quantity.
7173
ΔQ2Q1ᴴ = ΔQ2 * Q1'
72-
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
73-
Δgauge gauge_atol ||
74-
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
74+
if isa(ΔA, Array) # not GPU friendly
75+
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
76+
Δgauge gauge_atol ||
77+
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
78+
end
7579
ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1)
7680
end
7781
end
@@ -95,8 +99,10 @@ function lq_pullback!(
9599
Md = diagview(M)
96100
Md .= real.(Md)
97101
end
98-
ldiv!(LowerTriangular(L11)', M)
99-
ldiv!(LowerTriangular(L11)', ΔQ̃)
102+
# not GPU friendly...
103+
L11arr = typeof(L)(L11)
104+
ldiv!(LowerTriangular(L11arr)', M)
105+
ldiv!(LowerTriangular(L11arr)', ΔQ̃)
100106
ΔA1 = mul!(ΔA1, M, Q1, +1, 1)
101107
ΔA1 .+= ΔQ̃
102108
return ΔA

src/pullbacks/polar.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...)
2222
if !iszerotangent(ΔW)
2323
ΔWP = ΔW / P
2424
WdΔWP = W' * ΔWP
25-
ΔWP = mul!(ΔWP, W, WdΔWP, -1, 1)
25+
ΔWP .-= W * WdΔWP
2626
ΔA .+= ΔWP
2727
end
2828
return ΔA
@@ -48,11 +48,11 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...
4848
!iszerotangent(ΔP) && mul!(M, P, ΔP, -1, 1)
4949
C = sylvester(P, P, M' - M)
5050
C .+= ΔP
51-
ΔA = mul!(ΔA, C, Wᴴ, 1, 1)
51+
ΔA .+= C * Wᴴ
5252
if !iszerotangent(ΔWᴴ)
5353
PΔWᴴ = P \ ΔWᴴ
5454
PΔWᴴW = PΔWᴴ * Wᴴ'
55-
PΔWᴴ = mul!(PΔWᴴ, PΔWᴴW, Wᴴ, -1, 1)
55+
PΔWᴴ .-= PΔWᴴW * Wᴴ
5656
ΔA .+= PΔWᴴ
5757
end
5858
return ΔA

src/pullbacks/qr.jl

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,27 +37,29 @@ function qr_pullback!(
3737
ΔA1 = view(ΔA, :, 1:p)
3838
ΔA2 = view(ΔA, :, (p + 1):n)
3939

40-
if minmn > p # case where A is rank-deficient
41-
Δgauge = abs(zero(eltype(Q)))
42-
if !iszerotangent(ΔQ)
43-
# in this case the number Householder reflections will
44-
# change upon small variations, and all of the remaining
45-
# columns of ΔQ should be zero for a gauge-invariant
46-
# cost function
47-
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
48-
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
49-
end
50-
if !iszerotangent(ΔR)
51-
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):n)
52-
Δgauge = max(Δgauge, norm(ΔR22, Inf))
40+
if isa(ΔA, Array) # not GPU friendly
41+
if minmn > p # case where A is rank-deficient
42+
Δgauge = abs(zero(eltype(Q)))
43+
if !iszerotangent(ΔQ)
44+
# in this case the number Householder reflections will
45+
# change upon small variations, and all of the remaining
46+
# columns of ΔQ should be zero for a gauge-invariant
47+
# cost function
48+
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
49+
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
50+
end
51+
if !iszerotangent(ΔR)
52+
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):n)
53+
Δgauge = max(Δgauge, norm(ΔR22, Inf))
54+
end
55+
Δgauge gauge_atol ||
56+
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
5357
end
54-
Δgauge gauge_atol ||
55-
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
5658
end
5759

5860
ΔQ̃ = zero!(similar(Q, (m, p)))
5961
if !iszerotangent(ΔQ)
60-
copy!(ΔQ̃, view(ΔQ, :, 1:p))
62+
ΔQ̃ .= view(ΔQ, :, 1:p)
6163
if p < size(Q, 2)
6264
Q2 = view(Q, :, (p + 1):size(Q, 2))
6365
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
@@ -69,9 +71,11 @@ function qr_pullback!(
6971
# how the full Q2 will change, but this we omit for now, and we consider
7072
# Q2' * ΔQ2 as a gauge dependent quantity.
7173
Q1dΔQ2 = Q1' * ΔQ2
72-
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
73-
Δgauge gauge_atol ||
74-
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
74+
if isa(ΔA, Array) # not GPU friendly
75+
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
76+
Δgauge gauge_atol ||
77+
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
78+
end
7579
ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1)
7680
end
7781
end
@@ -87,16 +91,18 @@ function qr_pullback!(
8791
M = zero!(similar(R, (p, p)))
8892
if !iszerotangent(ΔR)
8993
ΔR11 = view(ΔR, 1:p, 1:p)
90-
M = mul!(M, ΔR11, R11', 1, 1)
94+
M += ΔR11 * R11'
9195
end
92-
M = mul!(M, Q1', ΔQ̃, -1, 1)
96+
M -= Q1' * ΔQ̃
9397
view(M, lowertriangularind(M)) .= conj.(view(M, uppertriangularind(M)))
9498
if eltype(M) <: Complex
9599
Md = diagview(M)
96100
Md .= real.(Md)
97101
end
98-
rdiv!(M, UpperTriangular(R11)')
99-
rdiv!(ΔQ̃, UpperTriangular(R11)')
102+
# not GPU-friendly...
103+
R11arr = typeof(R)(R11)
104+
rdiv!(M, UpperTriangular(R11arr)')
105+
rdiv!(ΔQ̃, UpperTriangular(R11arr)')
100106
ΔA1 = mul!(ΔA1, Q1, M, +1, 1)
101107
ΔA1 .+= ΔQ̃
102108
return ΔA

src/pullbacks/svd.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ which `abs(S[i] - S[j]) < degeneracy_atol`, is not small compared to `gauge_atol
2222
"""
2323
function svd_pullback!(
2424
ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ, ind = Colon();
25-
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
26-
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
25+
rank_atol::Real = default_pullback_rank_atol(diagview(USVᴴ[2])),
26+
degeneracy_atol::Real = default_pullback_rank_atol(diagview(USVᴴ[2])),
2727
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
2828
)
2929
# Extract the SVD components
@@ -33,7 +33,7 @@ function svd_pullback!(
3333
minmn = min(m, n)
3434
S = diagview(Smat)
3535
length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)"))
36-
r = searchsortedlast(S, rank_atol; rev = true) # rank
36+
r = findlast(s -> s rank_atol, S) # rank
3737
Ur = view(U, :, 1:r)
3838
Vᴴr = view(Vᴴ, 1:r, :)
3939
Sr = view(S, 1:r)
@@ -71,9 +71,11 @@ function svd_pullback!(
7171

7272
# check whether cotangents arise from gauge-invariance objective function
7373
mask = abs.(Sr' .- Sr) .< degeneracy_atol
74-
Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf)
75-
Δgauge gauge_atol ||
76-
@warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
74+
if isa(ΔA, Array) # norm check not GPU friendly
75+
Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf)
76+
Δgauge gauge_atol ||
77+
@warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
78+
end
7779

7880
UdΔAV = (aUΔU .+ aVΔV) .* inv_safe.(Sr' .- Sr, degeneracy_atol) .+
7981
(aUΔU .- aVΔV) .* inv_safe.(Sr' .+ Sr, degeneracy_atol)
@@ -84,18 +86,18 @@ function svd_pullback!(
8486
length(indS) == pS || throw(DimensionMismatch("length of selected S diagonals ($(length(indS))) does not match length of ΔS diagonal ($(length(ΔS)))"))
8587
view(diagview(UdΔAV), indS) .+= real.(ΔS)
8688
end
87-
ΔA = mul!(ΔA, Ur, UdΔAV * Vᴴr, 1, 1) # add the contribution to ΔA
89+
ΔA .+= Ur * UdΔAV * Vᴴr # add the contribution to ΔA
8890

8991
# Add the remaining contributions
9092
if m > r && !iszerotangent(ΔU) # remaining ΔU is already orthogonal to Ur
9193
Sp = view(S, indU)
9294
Vᴴp = view(Vᴴ, indU, :)
93-
ΔA = mul!(ΔA, ΔU ./ Sp', Vᴴp, 1, 1)
95+
ΔA .+= (ΔU ./ Sp') * Vᴴp
9496
end
9597
if n > r && !iszerotangent(ΔVᴴ) # remaining ΔV is already orthogonal to Vᴴr
9698
Sp = view(S, indV)
9799
Up = view(U, :, indV)
98-
ΔA = mul!(ΔA, Up, Sp .\ ΔVᴴ, 1, 1)
100+
ΔA .+= Up * (Sp .\ ΔVᴴ)
99101
end
100102
return ΔA
101103
end

0 commit comments

Comments
 (0)