Skip to content

Commit f468648

Browse files
committed
Add some tests
1 parent 699946e commit f468648

4 files changed

Lines changed: 33 additions & 0 deletions

File tree

test/eig.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,18 @@ end
5252
D2, V2 = @constinferred eig_trunc(A; alg, trunc)
5353
@test length(diagview(D2)) == r
5454
@test A * V2 V2 * D2
55+
56+
trunc = truncerror(s * norm(@view(D₀[r+1:end])))
57+
D3, V3 = @constinferred eig_trunc(A; alg, trunc)
58+
@test length(diagview(D3)) == r
59+
@test A * V3 V3 * D3
5560

5661
# trunctol keeps order, truncrank might not
5762
# test for same subspace
5863
@test V1 * ((V1' * V1) \ (V1' * V2)) V2
5964
@test V2 * ((V2' * V2) \ (V2' * V1)) V1
65+
@test V1 * ((V1' * V1) \ (V1' * V3)) V3
66+
@test V3 * ((V3' * V3) \ (V3' * V1)) V1
6067
end
6168
end
6269

@@ -70,6 +77,10 @@ end
7077
D2, V2 = @constinferred eig_trunc(A; alg)
7178
@test diagview(D2) diagview(D)[1:2] rtol = sqrt(eps(real(T)))
7279
@test_throws ArgumentError eig_trunc(A; alg, trunc=(; maxrank=2))
80+
81+
alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(0.2))
82+
D3, V3 = @constinferred eig_trunc(A; alg)
83+
@test diagview(D3) diagview(D)[1:2] rtol = sqrt(eps(real(T)))
7384
end
7485

7586
@testset "eig for Diagonal{$T}" for T in BLASFloats

test/eigh.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,17 @@ end
5757
@test length(diagview(D2)) == r
5858
@test isisometry(V2)
5959
@test A * V2 V2 * D2
60+
61+
trunc = truncerror(s * norm(@view(D₀[r+1:end])))
62+
D3, V3 = @constinferred eigh_trunc(A; alg, trunc)
63+
@test length(diagview(D3)) == r
64+
@test A * V3 V3 * D3
6065

6166
# test for same subspace
6267
@test V1 * (V1' * V2) V2
6368
@test V2 * (V2' * V1) V1
69+
@test V1 * (V1' * V3) V3
70+
@test V3 * (V3' * V1) V1
6471
end
6572
end
6673

@@ -75,6 +82,10 @@ end
7582
D2, V2 = @constinferred eigh_trunc(A; alg)
7683
@test diagview(D2) diagview(D)[1:2] rtol = sqrt(eps(real(T)))
7784
@test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2))
85+
86+
alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(0.2))
87+
D3, V3 = @constinferred eigh_trunc(A; alg)
88+
@test diagview(D3) diagview(D)[1:2] rtol = sqrt(eps(real(T)))
7889
end
7990

8091
@testset "eigh for Diagonal{$T}" for T in BLASFloats

test/svd.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ end
120120
@test U1 U2
121121
@test S1 S2
122122
@test V1ᴴ V2ᴴ
123+
124+
trunc = truncerr(s * norm(@view(S₀[(r + 1):end])))
125+
U3, S3, V3ᴴ = @constinferred svd_trunc(A; alg, trunc)
126+
@test length(S3.diag) == r
127+
@test U1 U3
128+
@test S1 S3
129+
@test V1ᴴ V3ᴴ
123130
end
124131
end
125132
end

test/truncate.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,8 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbov
6969
TruncationKeepBelow(0.2, 0))
7070
@test @constinferred(findtruncated(values, strategy)) == [1]
7171
end
72+
for strategy in (truncerror(; atol=0.2, rtol=0),)
73+
@test issetequal(@constinferred(findtruncated(values, strategy)), 2:5)
74+
@test @constinferred(findtruncated_sorted(sort(values; by=abs, rev=true), strategy)) == 1:4
75+
end
7276
end

0 commit comments

Comments
 (0)