Skip to content

Commit 958dd36

Browse files
committed
Use Testsuite for AD tests
1 parent 3848f2b commit 958dd36

15 files changed

Lines changed: 1557 additions & 1289 deletions

File tree

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5757

5858
[targets]
5959
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "Random", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake"]
60+
61+
[sources]
62+
CUDA = {url="https://github.com/JuliaGPU/CUDA.jl", rev="master"}

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module MatrixAlgebraKitCUDAExt
33
using MatrixAlgebraKit
44
using MatrixAlgebraKit: @algdef, Algorithm, check_input
55
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
6-
using MatrixAlgebraKit: diagview, sign_safe
6+
using MatrixAlgebraKit: diagview, sign_safe, default_pullback_gauge_atol, default_pullback_rank_atol
77
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
@@ -183,4 +183,23 @@ function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix)
183183
return A, B
184184
end
185185

186+
MatrixAlgebraKit.default_pullback_rank_atol(A::AnyCuArray) = eps(norm(CuArray(A), Inf))^(3 / 4)
187+
MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray) = MatrixAlgebraKit.iszerotangent(A) ? 0 : eps(norm(CuArray(A), Inf))^(3 / 4)
188+
function MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray, As...)
189+
As′ = filter(!MatrixAlgebraKit.iszerotangent, (A, As...))
190+
return isempty(As′) ? 0 : eps(norm(CuArray.(As′), Inf))^(3 / 4)
191+
end
192+
193+
function LinearAlgebra.sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix)
194+
#=m = size(A, 1)
195+
n = size(B, 2)
196+
I_n = fill!(similar(A, n), one(eltype(A)))
197+
I_m = fill!(similar(B, m), one(eltype(B)))
198+
L = kron(diagm(I_n), A) + kron(adjoint(B), diagm(I_m))
199+
x_vec = L \ -vec(C)
200+
X = CuMatrix(reshape(x_vec, m, n))=#
201+
hX = sylvester(collect(A), collect(B), collect(C))
202+
return CuArray(hX)
203+
end
204+
186205
end

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module MatrixAlgebraKitMooncakeExt
33
using Mooncake
44
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
55
using MatrixAlgebraKit
6-
using MatrixAlgebraKit: inv_safe, diagview, copy_input
6+
using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output
77
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
88
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
99
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
@@ -26,6 +26,17 @@ function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDu
2626
return CoDual(Ac, dAc), copy_input_pb
2727
end
2828

29+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(initialize_output), Any, Any, Any}
30+
function Mooncake.rrule!!(::CoDual{typeof(initialize_output)}, f_df::CoDual, A_dA::CoDual, alg_dalg::CoDual)
31+
output = initialize_output(Mooncake.primal(f_df), Mooncake.primal(A_dA), Mooncake.primal(alg_dalg))
32+
doutput = Mooncake.zero_tangent(output)
33+
function initialize_output_pb(::NoRData)
34+
return NoRData(), NoRData(), NoRData(), NoRData()
35+
end
36+
return CoDual(output, doutput), initialize_output_pb
37+
end
38+
39+
2940
# two-argument in-place factorizations like LQ, QR, EIG
3041
for (f!, f, pb, adj) in (
3142
(:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint),

src/pullbacks/eig.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,12 @@ function eig_pullback!(
4242
mul!(view(VᴴΔV, :, indV), V', ΔV)
4343

4444
mask = abs.(transpose(D) .- D) .< degeneracy_atol
45-
Δgauge = norm(view(VᴴΔV, mask), Inf)
46-
Δgauge gauge_atol ||
47-
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
45+
if isa(ΔA, Array)
46+
# not GPU friendly...
47+
Δgauge = norm(view(VᴴΔV, mask), Inf)
48+
Δgauge gauge_atol ||
49+
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
50+
end
4851

4952
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
5053

src/pullbacks/lq.jl

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,22 +36,24 @@ function lq_pullback!(
3636
ΔA1 = view(ΔA, 1:p, :)
3737
ΔA2 = view(ΔA, (p + 1):m, :)
3838

39-
if minmn > p # case where A is rank-deficient
40-
Δgauge = abs(zero(eltype(Q)))
41-
if !iszerotangent(ΔQ)
42-
# in this case the number Householder reflections will
43-
# change upon small variations, and all of the remaining
44-
# columns of ΔQ should be zero for a gauge-invariant
45-
# cost function
46-
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
47-
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
48-
end
49-
if !iszerotangent(ΔL)
50-
ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn)
51-
Δgauge = max(Δgauge, norm(ΔL22, Inf))
39+
if isa(ΔA, Array) # not GPU friendly
40+
if minmn > p # case where A is rank-deficient
41+
Δgauge = abs(zero(eltype(Q)))
42+
if !iszerotangent(ΔQ)
43+
# in this case the number Householder reflections will
44+
# change upon small variations, and all of the remaining
45+
# columns of ΔQ should be zero for a gauge-invariant
46+
# cost function
47+
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
48+
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
49+
end
50+
if !iszerotangent(ΔL)
51+
ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn)
52+
Δgauge = max(Δgauge, norm(ΔL22, Inf))
53+
end
54+
Δgauge gauge_atol ||
55+
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
5256
end
53-
Δgauge gauge_atol ||
54-
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
5557
end
5658

5759
ΔQ̃ = zero!(similar(Q, (p, n)))
@@ -69,9 +71,11 @@ function lq_pullback!(
6971
# how the full Q2 will change, but this we omit for now, and we consider
7072
# Q2' * ΔQ2 as a gauge dependent quantity.
7173
ΔQ2Q1ᴴ = ΔQ2 * Q1'
72-
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
73-
Δgauge gauge_atol ||
74-
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
74+
if isa(ΔA, Array) # not GPU friendly
75+
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
76+
Δgauge gauge_atol ||
77+
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
78+
end
7579
ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1)
7680
end
7781
end
@@ -95,8 +99,10 @@ function lq_pullback!(
9599
Md = diagview(M)
96100
Md .= real.(Md)
97101
end
98-
ldiv!(LowerTriangular(L11)', M)
99-
ldiv!(LowerTriangular(L11)', ΔQ̃)
102+
# not GPU friendly...
103+
L11arr = typeof(L)(L11)
104+
ldiv!(LowerTriangular(L11arr)', M)
105+
ldiv!(LowerTriangular(L11arr)', ΔQ̃)
100106
ΔA1 = mul!(ΔA1, M, Q1, +1, 1)
101107
ΔA1 .+= ΔQ̃
102108
return ΔA

src/pullbacks/qr.jl

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,24 @@ function qr_pullback!(
3737
ΔA1 = view(ΔA, :, 1:p)
3838
ΔA2 = view(ΔA, :, (p + 1):n)
3939

40-
if minmn > p # case where A is rank-deficient
41-
Δgauge = abs(zero(eltype(Q)))
42-
if !iszerotangent(ΔQ)
43-
# in this case the number Householder reflections will
44-
# change upon small variations, and all of the remaining
45-
# columns of ΔQ should be zero for a gauge-invariant
46-
# cost function
47-
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
48-
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
49-
end
50-
if !iszerotangent(ΔR)
51-
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):n)
52-
Δgauge = max(Δgauge, norm(ΔR22, Inf))
40+
if isa(ΔA, Array) # not GPU friendly
41+
if minmn > p # case where A is rank-deficient
42+
Δgauge = abs(zero(eltype(Q)))
43+
if !iszerotangent(ΔQ)
44+
# in this case the number Householder reflections will
45+
# change upon small variations, and all of the remaining
46+
# columns of ΔQ should be zero for a gauge-invariant
47+
# cost function
48+
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
49+
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
50+
end
51+
if !iszerotangent(ΔR)
52+
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):n)
53+
Δgauge = max(Δgauge, norm(ΔR22, Inf))
54+
end
55+
Δgauge gauge_atol ||
56+
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
5357
end
54-
Δgauge gauge_atol ||
55-
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
5658
end
5759

5860
ΔQ̃ = zero!(similar(Q, (m, p)))
@@ -69,9 +71,11 @@ function qr_pullback!(
6971
# how the full Q2 will change, but this we omit for now, and we consider
7072
# Q2' * ΔQ2 as a gauge dependent quantity.
7173
Q1dΔQ2 = Q1' * ΔQ2
72-
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
73-
Δgauge gauge_atol ||
74-
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
74+
if isa(ΔA, Array) # not GPU friendly
75+
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
76+
Δgauge gauge_atol ||
77+
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
78+
end
7579
ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1)
7680
end
7781
end
@@ -95,8 +99,10 @@ function qr_pullback!(
9599
Md = diagview(M)
96100
Md .= real.(Md)
97101
end
98-
rdiv!(M, UpperTriangular(R11)')
99-
rdiv!(ΔQ̃, UpperTriangular(R11)')
102+
# not GPU-friendly...
103+
R11arr = typeof(R)(R11)
104+
rdiv!(M, UpperTriangular(R11arr)')
105+
rdiv!(ΔQ̃, UpperTriangular(R11arr)')
100106
ΔA1 = mul!(ΔA1, Q1, M, +1, 1)
101107
ΔA1 .+= ΔQ̃
102108
return ΔA

src/pullbacks/svd.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ which `abs(S[i] - S[j]) < degeneracy_atol`, is not small compared to `gauge_atol
2222
"""
2323
function svd_pullback!(
2424
ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ, ind = Colon();
25-
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
26-
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
25+
rank_atol::Real = default_pullback_rank_atol(diagview(USVᴴ[2])),
26+
degeneracy_atol::Real = default_pullback_rank_atol(diagview(USVᴴ[2])),
2727
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
2828
)
2929
# Extract the SVD components
@@ -33,7 +33,7 @@ function svd_pullback!(
3333
minmn = min(m, n)
3434
S = diagview(Smat)
3535
length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)"))
36-
r = searchsortedlast(S, rank_atol; rev = true) # rank
36+
r = findlast(s -> s rank_atol, S) # rank
3737
Ur = view(U, :, 1:r)
3838
Vᴴr = view(Vᴴ, 1:r, :)
3939
Sr = view(S, 1:r)
@@ -71,9 +71,11 @@ function svd_pullback!(
7171

7272
# check whether cotangents arise from gauge-invariance objective function
7373
mask = abs.(Sr' .- Sr) .< degeneracy_atol
74-
Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf)
75-
Δgauge gauge_atol ||
76-
@warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
74+
if isa(ΔA, Array) # norm check not GPU friendly
75+
Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf)
76+
Δgauge gauge_atol ||
77+
@warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
78+
end
7779

7880
UdΔAV = (aUΔU .+ aVΔV) .* inv_safe.(Sr' .- Sr, degeneracy_atol) .+
7981
(aUΔU .- aVΔV) .* inv_safe.(Sr' .+ Sr, degeneracy_atol)

test/ad_utils.jl

Lines changed: 0 additions & 62 deletions
This file was deleted.

0 commit comments

Comments
 (0)