Skip to content

Commit 72d1c6d

Browse files
author
Katharine Hyatt
committed
Add projections tests for Enzyme
1 parent 3dc2eee commit 72d1c6d

3 files changed

Lines changed: 75 additions & 0 deletions

File tree

test/enzyme/projections.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...)
15+
TestSuite.seed_rng!(123)
16+
atol = rtol = m * m * TestSuite.precision(T)
17+
if !is_buildkite
18+
TestSuite.test_enzyme_projections(T, (m, m); atol, rtol)
19+
TestSuite.test_enzyme_projections(Diagonal{T, Vector{T}}, (m, m); atol, rtol)
20+
end
21+
end

test/testsuite/TestSuite.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,5 +130,6 @@ include("enzyme/lq.jl")
130130
include("enzyme/svd.jl")
131131
include("enzyme/polar.jl")
132132
include("enzyme/orthnull.jl")
133+
include("enzyme/projections.jl")
133134

134135
end
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""
2+
test_enzyme_projections(T, sz; kwargs...)
3+
4+
Run all Enzyme AD tests for hermitian and anti-hermitian projections of element type `T`
5+
and size `sz`.
6+
"""
7+
function test_enzyme_projections(T::Type, sz; kwargs...)
8+
summary_str = testargs_summary(T, sz)
9+
return @testset "Enzyme projection $summary_str" begin
10+
test_enzyme_project_hermitian(T, sz; kwargs...)
11+
test_enzyme_project_antihermitian(T, sz; kwargs...)
12+
end
13+
end
14+
15+
"""
16+
test_enzyme_project_hermitian(T, sz; rng, atol, rtol)
17+
18+
Test the Enzyme reverse-mode AD rule for `project_hermitian` and its in-place variant.
19+
"""
20+
function test_enzyme_project_hermitian(
21+
T, sz;
22+
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
23+
fdm = enzyme_fdm(T)
24+
)
25+
return @testset "project_hermitian" begin
26+
A = instantiate_matrix(T, sz)
27+
B = instantiate_matrix(T, sz)
28+
alg = MatrixAlgebraKit.select_algorithm(project_hermitian, A)
29+
test_reverse(project_hermitian, RT, (A, TA), (alg, Const); atol, rtol, fdm)
30+
test_reverse(project_hermitian!, RT, (A, TA), (A, TA), (alg, Const); atol, rtol, fdm)
31+
test_reverse(project_hermitian!, RT, (A, TA), (B, TA), (alg, Const); atol, rtol, fdm)
32+
end
33+
end
34+
35+
"""
36+
test_enzyme_project_antihermitian(T, sz; rng, atol, rtol)
37+
38+
Test the Enzyme reverse-mode AD rule for `project_antihermitian` and its in-place variant.
39+
"""
40+
function test_enzyme_project_antihermitian(
41+
T, sz;
42+
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
43+
fdm = enzyme_fdm(T)
44+
)
45+
return @testset "project_antihermitian" begin
46+
A = instantiate_matrix(T, sz)
47+
B = instantiate_matrix(T, sz)
48+
alg = MatrixAlgebraKit.select_algorithm(project_hermitian, A)
49+
test_reverse(project_antihermitian, RT, (A, TA), (alg, Const); atol, rtol, fdm)
50+
test_reverse(project_antihermitian!, RT, (A, TA), (A, TA), (alg, Const); atol, rtol, fdm)
51+
test_reverse(project_antihermitian!, RT, (A, TA), (B, TA), (alg, Const); atol, rtol, fdm)
52+
end
53+
end

0 commit comments

Comments
 (0)