11using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar
22using BlockSparseArrays: BlockSparseArray, BlockDiagonal, eachblockstoredindex
3- using MatrixAlgebraKit: svd_compact, svd_full, svd_trunc, truncrank
3+ using MatrixAlgebraKit: svd_compact, svd_full, svd_trunc, truncrank, trunctol
44using LinearAlgebra: LinearAlgebra
55using Random: Random
66using Test: @inferred , @testset , @test
8888# ----------
8989
9090@testset " svd_trunc ($m , $n ) BlockSparseMatri{$T }" for ((m, n), T) in test_params
91- (m, n), T = first (test_params)
9291 a = BlockSparseArray {T} (undef, m, n)
9392
9493 # test blockdiagonal
9998
10099 minmn = min (size (a)... )
101100 r = max (1 , minmn - 2 )
101+ trunc = truncrank (r)
102102
103- U1, S1, V1ᴴ = svd_trunc (a; trunc= truncrank (r) )
104- U2, S2, V2ᴴ = svd_trunc (Matrix (a); trunc= truncrank (r) )
103+ U1, S1, V1ᴴ = svd_trunc (a; trunc)
104+ U2, S2, V2ᴴ = svd_trunc (Matrix (a); trunc)
105105 @test size (U1) == size (U2)
106106 @test size (S1) == size (S2)
107107 @test size (V1ᴴ) == size (V2ᴴ)
@@ -110,11 +110,11 @@ end
110110 @test (U1' * U1 ≈ LinearAlgebra. I)
111111 @test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra. I)
112112
113- # test permuted blockdiagonal
114- perm = Random . randperm ( length (m) )
115- b = a[ Block .(perm), Block .( 1 : length (n))]
116- U1, S1, V1ᴴ = svd_trunc (b ; trunc= truncrank (r) )
117- U2, S2, V2ᴴ = svd_trunc (Matrix (b ); trunc= truncrank (r) )
113+ atol = minimum (LinearAlgebra . diag (S1)) + 10 * eps ( real (T))
114+ trunc = trunctol (atol )
115+
116+ U1, S1, V1ᴴ = svd_trunc (a ; trunc)
117+ U2, S2, V2ᴴ = svd_trunc (Matrix (a ); trunc)
118118 @test size (U1) == size (U2)
119119 @test size (S1) == size (S2)
120120 @test size (V1ᴴ) == size (V2ᴴ)
@@ -123,17 +123,34 @@ end
123123 @test (U1' * U1 ≈ LinearAlgebra. I)
124124 @test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra. I)
125125
126+ # test permuted blockdiagonal
127+ perm = Random. randperm (length (m))
128+ b = a[Block .(perm), Block .(1 : length (n))]
129+ for trunc in (truncrank (r), trunctol (atol))
130+ U1, S1, V1ᴴ = svd_trunc (b; trunc)
131+ U2, S2, V2ᴴ = svd_trunc (Matrix (b); trunc)
132+ @test size (U1) == size (U2)
133+ @test size (S1) == size (S2)
134+ @test size (V1ᴴ) == size (V2ᴴ)
135+ @test Matrix (U1 * S1 * V1ᴴ) ≈ U2 * S2 * V2ᴴ
136+
137+ @test (U1' * U1 ≈ LinearAlgebra. I)
138+ @test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra. I)
139+ end
140+
126141 # test permuted blockdiagonal with missing row/col
127142 I_removed = rand (eachblockstoredindex (b))
128143 c = copy (b)
129144 delete! (blocks (c). storage, CartesianIndex (Int .(Tuple (I_removed))))
130- U1, S1, V1ᴴ = svd_trunc (c; trunc= truncrank (r))
131- U2, S2, V2ᴴ = svd_trunc (Matrix (c); trunc= truncrank (r))
132- @test size (U1) == size (U2)
133- @test size (S1) == size (S2)
134- @test size (V1ᴴ) == size (V2ᴴ)
135- @test Matrix (U1 * S1 * V1ᴴ) ≈ U2 * S2 * V2ᴴ
136-
137- @test (U1' * U1 ≈ LinearAlgebra. I)
138- @test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra. I)
145+ for trunc in (truncrank (r), trunctol (atol))
146+ U1, S1, V1ᴴ = svd_trunc (c; trunc)
147+ U2, S2, V2ᴴ = svd_trunc (Matrix (c); trunc)
148+ @test size (U1) == size (U2)
149+ @test size (S1) == size (S2)
150+ @test size (V1ᴴ) == size (V2ᴴ)
151+ @test Matrix (U1 * S1 * V1ᴴ) ≈ U2 * S2 * V2ᴴ
152+
153+ @test (U1' * U1 ≈ LinearAlgebra. I)
154+ @test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra. I)
155+ end
139156end
0 commit comments