Skip to content

Commit e50d4bc

Browse files
committed
Refactor Enzyme testsuite
1 parent 295a354 commit e50d4bc

22 files changed

Lines changed: 786 additions & 505 deletions

test/enzyme.jl

Lines changed: 0 additions & 30 deletions
This file was deleted.

test/enzyme/eig.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_enzyme_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
18+
end
19+
end

test/enzyme/eigh.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_enzyme_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
18+
end
19+
end

test/enzyme/lq.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_enzyme_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18+
end
19+
end

test/enzyme/orthnull.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_enzyme_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18+
end
19+
end

test/enzyme/polar.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_enzyme_polar(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18+
end
19+
end

test/enzyme/qr.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_enzyme_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18+
end
19+
end

test/enzyme/svd.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
15+
TestSuite.seed_rng!(1234)
16+
if !is_buildkite
17+
TestSuite.test_enzyme_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18+
end
19+
end

test/runtests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,10 @@ if filter_tests!(testsuite, args)
2626
else
2727
is_apple_ci = Sys.isapple() && get(ENV, "CI", "false") == "true"
2828
if is_apple_ci
29-
delete!(testsuite, "enzyme")
3029
filter!(p -> !startswith(first(p), "mooncake/"), testsuite)
3130
delete!(testsuite, "chainrules")
3231
end
33-
Sys.iswindows() && delete!(testsuite, "enzyme")
32+
(Sys.iswindows() || is_apple_ci) && filter!(p -> !startswith(first(p), "enzyme/"), testsuite)
3433
end
3534
end
3635

test/testsuite/TestSuite.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ using LinearAlgebra: Diagonal, norm, istriu, istril, I
1515
using Random, StableRNGs
1616
using Mooncake
1717
using AMDGPU, CUDA
18+
using Enzyme, EnzymeTestUtils
1819

1920
const tests = Dict()
2021

@@ -118,7 +119,16 @@ include("mooncake/polar.jl")
118119
include("mooncake/orthnull.jl")
119120
include("mooncake/projections.jl")
120121

121-
include("enzyme.jl")
122122
include("chainrules.jl")
123123

124+
# Enzyme
125+
# ------
126+
include("enzyme/eig.jl")
127+
include("enzyme/eigh.jl")
128+
include("enzyme/qr.jl")
129+
include("enzyme/lq.jl")
130+
include("enzyme/svd.jl")
131+
include("enzyme/polar.jl")
132+
include("enzyme/orthnull.jl")
133+
124134
end

0 commit comments

Comments
 (0)