Skip to content

Commit 6e6ed94

Browse files
committed
Fixes for enzyme
1 parent ed67efa commit 6e6ed94

5 files changed

Lines changed: 41 additions & 38 deletions

File tree

src/pullbacks/polar.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...
4646
M = zero(P)
4747
!iszerotangent(ΔWᴴ) && mul!(M, ΔWᴴ, Wᴴ', 1, 1)
4848
!iszerotangent(ΔP) && mul!(M, P, ΔP, -1, 1)
49-
C = sylvester(P, P, M' - M)
49+
C = _sylvester(P, P, M' - M)
5050
C .+= ΔP
5151
ΔA = mul!(ΔA, C, Wᴴ, 1, 1)
5252
if !iszerotangent(ΔWᴴ)

test/enzyme.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33
using LinearAlgebra: Diagonal
44
using CUDA, AMDGPU
55

6-
BLASFloats = (ComplexF64,) # full suite is too expensive on CI
6+
BLASFloats = (Float64,) # full suite is too expensive on CI
77
GenericFloats = (BigFloat,)
88
@isdefined(TestSuite) || include("testsuite/TestSuite.jl")
99
using .TestSuite
@@ -13,7 +13,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
1313
m = 19
1414
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
1515
TestSuite.seed_rng!(123)
16-
if T <: BLASFloats
16+
if T BLASFloats
1717
if CUDA.functional()
1818
TestSuite.test_enzyme(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
1919
#n == m && TestSuite.test_enzyme(Diagonal{T, CuVector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T))

test/runtests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ParallelTestRunner
1+
#=using ParallelTestRunner
22
using MatrixAlgebraKit
33
44
# Start with autodiscovered tests
@@ -22,8 +22,6 @@ if filter_tests!(testsuite, args)
2222
delete!(testsuite, "algorithms")
2323
delete!(testsuite, "truncate")
2424
delete!(testsuite, "gen_eig")
25-
delete!(testsuite, "mooncake")
26-
delete!(testsuite, "enzyme")
2725
delete!(testsuite, "chainrules")
2826
delete!(testsuite, "codequality")
2927
else
@@ -32,4 +30,6 @@ if filter_tests!(testsuite, args)
3230
end
3331
end
3432
35-
runtests(MatrixAlgebraKit, args; testsuite)
33+
runtests(MatrixAlgebraKit, args; testsuite)=#
34+
include("enzyme.jl")
35+
include("mooncake.jl")

test/testsuite/TestSuite.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ include("eigh.jl")
9696
include("orthnull.jl")
9797
include("svd.jl")
9898
include("mooncake.jl")
99+
include("enzyme.jl")
99100
include("chainrules.jl")
100101

101102
end

test/testsuite/enzyme.jl

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function enz_copy_eigh_trunc_no_error!(A, DV, alg)
4646
end
4747

4848
function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; ȳ = copy.(Δargs), return_act = Duplicated)
49-
ΔA = randn(rng, eltype(A), size(A)...)
49+
ΔA = randn!(similar(A))
5050
A_ΔA() = Duplicated(copy(A), copy(ΔA))
5151
function args_Δargs()
5252
if isnothing(args)
@@ -143,8 +143,8 @@ function test_enzyme_qr(
143143
r = min(m, n) - 5
144144
Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n))
145145
QR, ΔQR = ad_qr_rank_deficient_compact_setup(Ard)
146-
eltype(T) <: BlasFloat && test_reverse(qr_compact, RT, (Ard, TA), (alg, Const); atol, rtol, output_tangent = (ΔQ, ΔR), fdm)
147-
test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, (Q, R), (ΔQ, ΔR), alg)
146+
eltype(T) <: BlasFloat && test_reverse(qr_compact, RT, (Ard, TA), (alg, Const); atol, rtol, output_tangent = ΔQR, fdm)
147+
test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, QR, ΔQR, alg)
148148
end
149149
end
150150
end
@@ -163,8 +163,8 @@ function test_enzyme_lq(
163163
@testset "lq_compact" begin
164164
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
165165
LQ, ΔLQ = ad_lq_compact_setup(A)
166-
eltype(T) <: BlasFloat && test_reverse(lq_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = (ΔL, ΔQ), fdm)
167-
test_pullbacks_match(rng, lq_compact!, lq_compact, A, (L, Q), (ΔL, ΔQ), alg)
166+
eltype(T) <: BlasFloat && test_reverse(lq_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm)
167+
test_pullbacks_match(rng, lq_compact!, lq_compact, A, LQ, ΔLQ, alg)
168168
end
169169
end
170170
@testset "lq_null" begin
@@ -177,8 +177,8 @@ function test_enzyme_lq(
177177
@testset "lq_full" begin
178178
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
179179
LQ, ΔLQ = ad_lq_full_setup(A)
180-
eltype(T) <: BlasFloat && test_reverse(lq_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = (ΔL, ΔQ), fdm)
181-
test_pullbacks_match(rng, lq_full!, lq_full, A, (L, Q), (ΔL, ΔQ), alg)
180+
eltype(T) <: BlasFloat && test_reverse(lq_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm)
181+
test_pullbacks_match(rng, lq_full!, lq_full, A, LQ, ΔLQ, alg)
182182
end
183183
end
184184
@testset "lq_compact -- rank-deficient A" begin
@@ -187,8 +187,8 @@ function test_enzyme_lq(
187187
r = min(m, n) - 5
188188
Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n))
189189
LQ, ΔLQ = ad_lq_rank_deficient_compact_setup(Ard)
190-
eltype(T) <: BlasFloat && test_reverse(lq_compact, RT, (Ard, TA), (alg, Const); atol, rtol, output_tangent = (ΔL, ΔQ), fdm)
191-
test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, (L, Q), (ΔL, ΔQ), alg)
190+
eltype(T) <: BlasFloat && test_reverse(lq_compact, RT, (Ard, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm)
191+
test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, LQ, ΔLQ, alg)
192192
end
193193
end
194194
end
@@ -209,8 +209,8 @@ function test_enzyme_eig(
209209
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
210210
DV, ΔDV, ΔD2V = ad_eig_full_setup(A)
211211
if eltype(T) <: BlasFloat
212-
test_reverse(eig_full, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm)
213-
test_pullbacks_match(rng, eig_full!, eig_full, A, (D, V), (ΔD2, ΔV), alg)
212+
test_reverse(eig_full, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔD2V, fdm)
213+
test_pullbacks_match(rng, eig_full!, eig_full, A, DV, ΔD2V, alg)
214214
else
215215
test_pullbacks_match(rng, eig_full!, eig_full, A, (nothing, nothing), (nothing, nothing), alg; ȳ = (ΔD2, ΔV))
216216
end
@@ -221,9 +221,9 @@ function test_enzyme_eig(
221221
D, ΔD = ad_eig_vals_setup(A)
222222
if eltype(T) <: BlasFloat
223223
test_reverse(eig_vals, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = copy(ΔD2.diag), fdm)
224-
test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD2.diag, alg)
224+
test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD.diag, alg)
225225
else
226-
test_pullbacks_match(rng, eig_vals!, eig_vals, A, nothing, nothing, alg; ȳ = ΔD2.diag)
226+
test_pullbacks_match(rng, eig_vals!, eig_vals, A, nothing, nothing, alg; ȳ = ΔD.diag)
227227
end
228228
end
229229
end
@@ -233,19 +233,19 @@ function test_enzyme_eig(
233233
truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(r; by = abs))
234234
DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg)
235235
if eltype(T) <: BlasFloat
236-
test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm)
237-
test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc))
236+
test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm)
237+
test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg, ȳ = ΔDVtrunc)
238238
else
239-
test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing, nothing), (nothing, nothing), truncalg, ȳ = (ΔDtrunc, ΔVtrunc))
239+
test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing, nothing), (nothing, nothing), truncalg, ȳ = ΔDVtrunc)
240240
end
241241
end
242242
truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(5; by = real))
243243
DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg)
244244
if eltype(T) <: BlasFloat
245-
test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm)
246-
test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc))
245+
test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm)
246+
test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg, ȳ = ΔDVtrunc)
247247
else
248-
test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing, nothing), (nothing, nothing), truncalg, ȳ = (ΔDtrunc, ΔVtrunc))
248+
test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing, nothing), (nothing, nothing), truncalg, ȳ = ΔDVtrunc)
249249
end
250250
end
251251
end
@@ -265,17 +265,19 @@ function test_enzyme_eigh(
265265
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
266266
@testset "eigh_full" begin
267267
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
268+
DV, ΔDV, ΔD2V = ad_eigh_full_setup(A)
268269
if eltype(T) <: BlasFloat
269-
test_reverse(copy_eigh_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm)
270-
test_reverse(copy_eigh_full!, RT, (copy(A), TA), ((D, V), TA), (alg, Const); atol, rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm)
270+
test_reverse(copy_eigh_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔD2V, fdm)
271+
test_reverse(copy_eigh_full!, RT, (A, TA), ((D, V), TA), (alg, Const); atol, rtol, output_tangent = ΔD2V, fdm)
271272
end
272-
test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, (D, V), (ΔD2, ΔV), alg)
273+
test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, DV, ΔD2V, alg)
273274
end
274275
end
275276
@testset "eigh_vals" begin
276277
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
277-
eltype(T) <: BlasFloat && test_reverse(copy_eigh_vals, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = copy(ΔD2.diag), fdm)
278-
test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D.diag, ΔD2.diag, alg)
278+
D, ΔD = ad_eigh_vals_setup(A)
279+
eltype(T) <: BlasFloat && test_reverse(copy_eigh_vals, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔD, fdm)
280+
test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D, ΔD, alg)
279281
end
280282
end
281283
@testset "eigh_trunc" begin
@@ -284,14 +286,14 @@ function test_enzyme_eigh(
284286
Ddiag = diagview(D)
285287
truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs))
286288
DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg)
287-
eltype(T) <: BlasFloat && test_reverse(copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm)
288-
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
289+
eltype(T) <: BlasFloat && test_reverse(copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm)
290+
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, DV, ΔD2V, truncalg, ȳ = ΔDVtrunc, return_act = RT)
289291
end
290292
D = eigh_vals(A / 2)
291293
truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, D) / 2))
292294
DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg)
293-
eltype(T) <: BlasFloat && test_reverse(copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm)
294-
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
295+
eltype(T) <: BlasFloat && test_reverse(copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm)
296+
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, DV, ΔD2V, truncalg, ȳ = ΔDVtrunc, return_act = RT)
295297
end
296298
end
297299
end
@@ -312,11 +314,11 @@ function test_enzyme_svd(
312314
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
313315
USVᴴ, _, ΔUSVᴴ = ad_svd_compact_setup(A)
314316
if eltype(T) <: BlasFloat
315-
test_reverse(svd_compact, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = (ΔU, ΔS, ΔVᴴ), fdm)
316-
test_pullbacks_match(rng, svd_compact!, svd_compact, A, (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), alg)
317+
test_reverse(svd_compact, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔUSVᴴ, fdm)
318+
test_pullbacks_match(rng, svd_compact!, svd_compact, A, USVᴴ, ΔUSVᴴ, alg)
317319
else
318320
USVᴴ = MatrixAlgebraKit.initialize_output(svd_compact!, A, alg)
319-
test_pullbacks_match(rng, svd_compact!, svd_compact, A, USVᴴ, (nothing, nothing, nothing), alg; ȳ = (ΔU, ΔS, ΔVᴴ))
321+
test_pullbacks_match(rng, svd_compact!, svd_compact, A, USVᴴ, (nothing, nothing, nothing), alg; ȳ = ΔUSVᴴ)
320322
end
321323
end
322324
end

0 commit comments

Comments
 (0)