Skip to content

Commit ca99981

Browse files
authored
Use TestSuite for eig (#130)
* Use TestSuite for eig * Don't test trunc for CUDA algs * No trunc for any CUDA * Comments
1 parent 3848f2b commit ca99981

5 files changed

Lines changed: 221 additions & 234 deletions

File tree

test/cuda/eig.jl

Lines changed: 0 additions & 108 deletions
This file was deleted.

test/eig.jl

Lines changed: 37 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -4,127 +4,45 @@ using TestExtras
44
using StableRNGs
55
using LinearAlgebra: Diagonal
66
using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm
7+
using CUDA, AMDGPU
78

89
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
9-
GenericFloats = (Float16, BigFloat, Complex{BigFloat})
10-
11-
@testset "eig_full! for T = $T" for T in BLASFloats
12-
rng = StableRNG(123)
13-
m = 54
14-
for alg in (LAPACK_Simple(), LAPACK_Expert(), :LAPACK_Simple, LAPACK_Simple)
15-
A = randn(rng, T, m, m)
16-
Tc = complex(T)
17-
18-
D, V = @constinferred eig_full(A; alg = ($alg))
19-
@test eltype(D) == eltype(V) == Tc
20-
@test A * V V * D
21-
22-
alg′ = @constinferred MatrixAlgebraKit.select_algorithm(eig_full!, A, $alg)
23-
24-
Ac = similar(A)
25-
D2, V2 = @constinferred eig_full!(copy!(Ac, A), (D, V), alg′)
26-
@test D2 === D
27-
@test V2 === V
28-
@test A * V V * D
29-
30-
Dc = @constinferred eig_vals(A, alg′)
31-
@test eltype(Dc) == Tc
32-
@test D Diagonal(Dc)
10+
GenericFloats = (BigFloat, Complex{BigFloat})
11+
12+
@isdefined(TestSuite) || include("testsuite/TestSuite.jl")
13+
using .TestSuite
14+
15+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
16+
17+
m = 54
18+
for T in (BLASFloats..., GenericFloats...)
19+
TestSuite.seed_rng!(123)
20+
if T BLASFloats
21+
if CUDA.functional()
22+
TestSuite.test_eig(CuMatrix{T}, (m, m); test_trunc = false)
23+
TestSuite.test_eig_algs(CuMatrix{T}, (m, m), (CUSOLVER_Simple(),); test_trunc = false)
24+
TestSuite.test_eig(Diagonal{T, CuVector{T}}, m; test_trunc = false)
25+
TestSuite.test_eig_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false)
26+
end
27+
#= not yet supported
28+
if AMDGPU.functional()
29+
TestSuite.test_eig(ROCMatrix{T}, (m, m); test_blocksize = false)
30+
TestSuite.test_eig_algs(ROCMatrix{T}, (m, m), (ROCSOLVER_Simple(),))
31+
TestSuite.test_eig(Diagonal{T, ROCVector{T}}, m; test_blocksize = false)
32+
TestSuite.test_eig_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),))
33+
end=#
3334
end
34-
end
35-
36-
@testset "eig_trunc! for T = $T" for T in BLASFloats
37-
rng = StableRNG(123)
38-
m = 54
39-
for alg in (LAPACK_Simple(), LAPACK_Expert())
40-
A = randn(rng, T, m, m)
41-
A *= A' # TODO: deal with eigenvalue ordering etc
42-
# eigenvalues are sorted by ascending real component...
43-
D₀ = sort!(eig_vals(A); by = abs, rev = true)
44-
rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2))
45-
r = length(D₀) - rmin
46-
atol = sqrt(eps(real(T)))
47-
48-
D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc = truncrank(r))
49-
@test length(diagview(D1)) == r
50-
@test A * V1 V1 * D1
51-
@test ϵ1 norm(view(D₀, (r + 1):m)) atol = atol
52-
53-
s = 1 + sqrt(eps(real(T)))
54-
trunc = trunctol(; atol = s * abs(D₀[r + 1]))
55-
D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc)
56-
@test length(diagview(D2)) == r
57-
@test A * V2 V2 * D2
58-
@test ϵ2 norm(view(D₀, (r + 1):m)) atol = atol
59-
60-
s = 1 - sqrt(eps(real(T)))
61-
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
62-
D3, V3, ϵ3 = @constinferred eig_trunc(A; alg, trunc)
63-
@test length(diagview(D3)) == r
64-
@test A * V3 V3 * D3
65-
@test ϵ3 norm(view(D₀, (r + 1):m)) atol = atol
66-
67-
s = 1 - sqrt(eps(real(T)))
68-
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
69-
D4, V4 = @constinferred eig_trunc_no_error(A; alg, trunc)
70-
@test length(diagview(D4)) == r
71-
@test A * V4 V4 * D4
72-
# trunctol keeps order, truncrank might not
73-
# test for same subspace
74-
@test V1 * ((V1' * V1) \ (V1' * V2)) V2
75-
@test V2 * ((V2' * V2) \ (V2' * V1)) V1
76-
@test V1 * ((V1' * V1) \ (V1' * V3)) V3
77-
@test V3 * ((V3' * V3) \ (V3' * V1)) V1
35+
if !is_buildkite
36+
TestSuite.test_eig(T, (m, m))
37+
if T BLASFloats
38+
LAPACK_EIG_ALGS = (LAPACK_Simple(), LAPACK_Expert())
39+
TestSuite.test_eig_algs(T, (m, m), LAPACK_EIG_ALGS)
40+
elseif T GenericFloats
41+
GS_EIG_ALGS = (GS_QRIteration(),)
42+
TestSuite.test_eig_algs(T, (m, m), GS_EIG_ALGS)
43+
end
44+
AT = Diagonal{T, Vector{T}}
45+
TestSuite.test_eig(AT, m)
46+
TestSuite.test_eig_algs(AT, m, (DiagonalAlgorithm(),))
7847
end
7948
end
80-
81-
@testset "eig_trunc! specify truncation algorithm T = $T" for T in BLASFloats
82-
rng = StableRNG(123)
83-
m = 4
84-
atol = sqrt(eps(real(T)))
85-
V = randn(rng, T, m, m)
86-
D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01])
87-
A = V * D * inv(V)
88-
alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2))
89-
D2, V2, ϵ2 = @constinferred eig_trunc(A; alg)
90-
@test diagview(D2) diagview(D)[1:2]
91-
@test ϵ2 norm(diagview(D)[3:4]) atol = atol
92-
@test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2))
93-
94-
alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(; atol = 0.2, p = 1))
95-
D3, V3, ϵ3 = @constinferred eig_trunc(A; alg)
96-
@test diagview(D3) diagview(D)[1:2]
97-
@test ϵ3 norm(diagview(D)[3:4]) atol = atol
98-
99-
alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(; atol = 0.2, p = 1))
100-
D4, V4 = @constinferred eig_trunc_no_error(A; alg)
101-
@test diagview(D4) diagview(D)[1:2]
102-
end
103-
104-
@testset "eig for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...)
105-
rng = StableRNG(123)
106-
m = 54
107-
Ad = randn(rng, T, m)
108-
A = Diagonal(Ad)
109-
atol = sqrt(eps(real(T)))
110-
111-
D, V = @constinferred eig_full(A)
112-
@test D isa Diagonal{T} && size(D) == size(A)
113-
@test V isa Diagonal{T} && size(V) == size(A)
114-
@test A * V V * D
115-
116-
D2 = @constinferred eig_vals(A)
117-
@test D2 isa AbstractVector{T} && length(D2) == m
118-
@test diagview(D) D2
119-
120-
A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01])
121-
alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2))
122-
D2, V2, ϵ2 = @constinferred eig_trunc(A2; alg)
123-
@test diagview(D2) diagview(A2)[1:2]
124-
@test ϵ2 norm(diagview(A2)[3:4]) atol = atol
125-
126-
A3 = Diagonal(T[0.9, 0.3, 0.1, 0.01])
127-
alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2))
128-
D3, V3 = @constinferred eig_trunc_no_error(A3; alg)
129-
@test diagview(D3) diagview(A3)[1:2]
130-
end

test/runtests.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@ if !is_buildkite
1616
@safetestset "Hermitian Eigenvalue Decomposition" begin
1717
include("eigh.jl")
1818
end
19-
@safetestset "General Eigenvalue Decomposition" begin
20-
include("eig.jl")
21-
end
2219
@safetestset "Generalized Eigenvalue Decomposition" begin
2320
include("gen_eig.jl")
2421
end
@@ -51,7 +48,6 @@ if !is_buildkite
5148
@safetestset "Hermitian Eigenvalue Decomposition" begin
5249
include("genericlinearalgebra/eigh.jl")
5350
end
54-
5551
end
5652

5753
@safetestset "QR / LQ Decomposition" begin
@@ -67,15 +63,15 @@ end
6763
@safetestset "Schur Decomposition" begin
6864
include("schur.jl")
6965
end
66+
@safetestset "General Eigenvalue Decomposition" begin
67+
include("eig.jl")
68+
end
7069

7170
using CUDA
7271
if CUDA.functional()
7372
@safetestset "CUDA SVD" begin
7473
include("cuda/svd.jl")
7574
end
76-
@safetestset "CUDA General Eigenvalue Decomposition" begin
77-
include("cuda/eig.jl")
78-
end
7975
@safetestset "CUDA Hermitian Eigenvalue Decomposition" begin
8076
include("cuda/eigh.jl")
8177
end

test/testsuite/TestSuite.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,5 +74,6 @@ include("lq.jl")
7474
include("polar.jl")
7575
include("projections.jl")
7676
include("schur.jl")
77+
include("eig.jl")
7778

7879
end

0 commit comments

Comments
 (0)