@@ -4,127 +4,45 @@ using TestExtras
44using StableRNGs
55using LinearAlgebra: Diagonal
66using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm
7+ using CUDA, AMDGPU
78
89BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
9- GenericFloats = (Float16, BigFloat, Complex{BigFloat})
10-
11- @testset " eig_full! for T = $T " for T in BLASFloats
12- rng = StableRNG ( 123 )
13- m = 54
14- for alg in ( LAPACK_Simple (), LAPACK_Expert (), :LAPACK_Simple , LAPACK_Simple)
15- A = randn (rng, T, m, m)
16- Tc = complex (T)
17-
18- D, V = @constinferred eig_full (A; alg = ( $ alg) )
19- @test eltype (D) == eltype (V) == Tc
20- @test A * V ≈ V * D
21-
22- alg′ = @constinferred MatrixAlgebraKit . select_algorithm (eig_full!, A, $ alg )
23-
24- Ac = similar (A )
25- D2, V2 = @constinferred eig_full! ( copy! (Ac, A), (D, V), alg′)
26- @test D2 === D
27- @test V2 === V
28- @test A * V ≈ V * D
29-
30- Dc = @constinferred eig_vals (A, alg′ )
31- @test eltype (Dc) == Tc
32- @test D ≈ Diagonal (Dc)
10+ GenericFloats = (BigFloat, Complex{BigFloat})
11+
12+ @isdefined (TestSuite) || include ( " testsuite/TestSuite.jl " )
13+ using . TestSuite
14+
15+ is_buildkite = get ( ENV , " BUILDKITE " , " false " ) == " true "
16+
17+ m = 54
18+ for T in (BLASFloats ... , GenericFloats ... )
19+ TestSuite . seed_rng! ( 123 )
20+ if T ∈ BLASFloats
21+ if CUDA . functional ()
22+ TestSuite . test_eig (CuMatrix{T}, (m, m); test_trunc = false )
23+ TestSuite . test_eig_algs (CuMatrix{T}, (m, m), ( CUSOLVER_Simple (),) )
24+ TestSuite . test_eig (Diagonal{T, CuVector{T}}, m)
25+ TestSuite . test_eig_algs (Diagonal{T, CuVector{T}}, m, ( DiagonalAlgorithm (),) )
26+ end
27+ #= not yet supported
28+ if AMDGPU.functional()
29+ TestSuite.test_eig(ROCMatrix{T}, (m, m); test_blocksize = false)
30+ TestSuite.test_eig_algs(ROCMatrix{T}, (m, m), (ROCSOLVER_Simple(),))
31+ TestSuite.test_eig(Diagonal{T, ROCVector{T}}, m; test_blocksize = false )
32+ TestSuite.test_eig_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),))
33+ end =#
3334 end
34- end
35-
36- @testset " eig_trunc! for T = $T " for T in BLASFloats
37- rng = StableRNG (123 )
38- m = 54
39- for alg in (LAPACK_Simple (), LAPACK_Expert ())
40- A = randn (rng, T, m, m)
41- A *= A' # TODO : deal with eigenvalue ordering etc
42- # eigenvalues are sorted by ascending real component...
43- D₀ = sort! (eig_vals (A); by = abs, rev = true )
44- rmin = findfirst (i -> abs (D₀[end - i]) != abs (D₀[end - i - 1 ]), 1 : (m - 2 ))
45- r = length (D₀) - rmin
46- atol = sqrt (eps (real (T)))
47-
48- D1, V1, ϵ1 = @constinferred eig_trunc (A; alg, trunc = truncrank (r))
49- @test length (diagview (D1)) == r
50- @test A * V1 ≈ V1 * D1
51- @test ϵ1 ≈ norm (view (D₀, (r + 1 ): m)) atol = atol
52-
53- s = 1 + sqrt (eps (real (T)))
54- trunc = trunctol (; atol = s * abs (D₀[r + 1 ]))
55- D2, V2, ϵ2 = @constinferred eig_trunc (A; alg, trunc)
56- @test length (diagview (D2)) == r
57- @test A * V2 ≈ V2 * D2
58- @test ϵ2 ≈ norm (view (D₀, (r + 1 ): m)) atol = atol
59-
60- s = 1 - sqrt (eps (real (T)))
61- trunc = truncerror (; atol = s * norm (@view (D₀[r: end ]), 1 ), p = 1 )
62- D3, V3, ϵ3 = @constinferred eig_trunc (A; alg, trunc)
63- @test length (diagview (D3)) == r
64- @test A * V3 ≈ V3 * D3
65- @test ϵ3 ≈ norm (view (D₀, (r + 1 ): m)) atol = atol
66-
67- s = 1 - sqrt (eps (real (T)))
68- trunc = truncerror (; atol = s * norm (@view (D₀[r: end ]), 1 ), p = 1 )
69- D4, V4 = @constinferred eig_trunc_no_error (A; alg, trunc)
70- @test length (diagview (D4)) == r
71- @test A * V4 ≈ V4 * D4
72- # trunctol keeps order, truncrank might not
73- # test for same subspace
74- @test V1 * ((V1' * V1) \ (V1' * V2)) ≈ V2
75- @test V2 * ((V2' * V2) \ (V2' * V1)) ≈ V1
76- @test V1 * ((V1' * V1) \ (V1' * V3)) ≈ V3
77- @test V3 * ((V3' * V3) \ (V3' * V1)) ≈ V1
35+ if ! is_buildkite
36+ TestSuite. test_eig (T, (m, m))
37+ if T ∈ BLASFloats
38+ LAPACK_EIG_ALGS = (LAPACK_Simple (), LAPACK_Expert ())
39+ TestSuite. test_eig_algs (T, (m, m), LAPACK_EIG_ALGS)
40+ elseif T ∈ GenericFloats
41+ GS_EIG_ALGS = (GS_QRIteration (),)
42+ TestSuite. test_eig_algs (T, (m, m), GS_EIG_ALGS)
43+ end
44+ AT = Diagonal{T, Vector{T}}
45+ TestSuite. test_eig (AT, m)
46+ TestSuite. test_eig_algs (AT, m, (DiagonalAlgorithm (),))
7847 end
7948end
80-
81- @testset " eig_trunc! specify truncation algorithm T = $T " for T in BLASFloats
82- rng = StableRNG (123 )
83- m = 4
84- atol = sqrt (eps (real (T)))
85- V = randn (rng, T, m, m)
86- D = Diagonal (real (T)[0.9 , 0.3 , 0.1 , 0.01 ])
87- A = V * D * inv (V)
88- alg = TruncatedAlgorithm (LAPACK_Simple (), truncrank (2 ))
89- D2, V2, ϵ2 = @constinferred eig_trunc (A; alg)
90- @test diagview (D2) ≈ diagview (D)[1 : 2 ]
91- @test ϵ2 ≈ norm (diagview (D)[3 : 4 ]) atol = atol
92- @test_throws ArgumentError eig_trunc (A; alg, trunc = (; maxrank = 2 ))
93-
94- alg = TruncatedAlgorithm (LAPACK_Simple (), truncerror (; atol = 0.2 , p = 1 ))
95- D3, V3, ϵ3 = @constinferred eig_trunc (A; alg)
96- @test diagview (D3) ≈ diagview (D)[1 : 2 ]
97- @test ϵ3 ≈ norm (diagview (D)[3 : 4 ]) atol = atol
98-
99- alg = TruncatedAlgorithm (LAPACK_Simple (), truncerror (; atol = 0.2 , p = 1 ))
100- D4, V4 = @constinferred eig_trunc_no_error (A; alg)
101- @test diagview (D4) ≈ diagview (D)[1 : 2 ]
102- end
103-
104- @testset " eig for Diagonal{$T }" for T in (BLASFloats... , GenericFloats... )
105- rng = StableRNG (123 )
106- m = 54
107- Ad = randn (rng, T, m)
108- A = Diagonal (Ad)
109- atol = sqrt (eps (real (T)))
110-
111- D, V = @constinferred eig_full (A)
112- @test D isa Diagonal{T} && size (D) == size (A)
113- @test V isa Diagonal{T} && size (V) == size (A)
114- @test A * V ≈ V * D
115-
116- D2 = @constinferred eig_vals (A)
117- @test D2 isa AbstractVector{T} && length (D2) == m
118- @test diagview (D) ≈ D2
119-
120- A2 = Diagonal (T[0.9 , 0.3 , 0.1 , 0.01 ])
121- alg = TruncatedAlgorithm (DiagonalAlgorithm (), truncrank (2 ))
122- D2, V2, ϵ2 = @constinferred eig_trunc (A2; alg)
123- @test diagview (D2) ≈ diagview (A2)[1 : 2 ]
124- @test ϵ2 ≈ norm (diagview (A2)[3 : 4 ]) atol = atol
125-
126- A3 = Diagonal (T[0.9 , 0.3 , 0.1 , 0.01 ])
127- alg = TruncatedAlgorithm (DiagonalAlgorithm (), truncrank (2 ))
128- D3, V3 = @constinferred eig_trunc_no_error (A3; alg)
129- @test diagview (D3) ≈ diagview (A3)[1 : 2 ]
130- end
0 commit comments