@@ -2,21 +2,21 @@ using TestExtras
22using GenericLinearAlgebra
33using LinearAlgebra: opnorm
44
5- function test_svd (T:: Type , sz; test_trunc = true , kwargs... )
5+ function test_svd (T:: Type , sz; kwargs... )
66 summary_str = testargs_summary (T, sz)
77 return @testset " svd $summary_str " begin
88 test_svd_compact (T, sz; kwargs... )
99 test_svd_full (T, sz; kwargs... )
10- test_trunc && test_svd_trunc (T, sz; kwargs... )
10+ test_svd_trunc (T, sz; kwargs... )
1111 end
1212end
1313
14- function test_svd_algs (T:: Type , sz, algs; test_trunc = true , kwargs... )
14+ function test_svd_algs (T:: Type , sz, algs; kwargs... )
1515 summary_str = testargs_summary (T, sz)
1616 return @testset " svd algorithms $summary_str " begin
1717 test_svd_compact_algs (T, sz, algs; kwargs... )
1818 test_svd_full_algs (T, sz, algs; kwargs... )
19- test_trunc && test_svd_trunc_algs (T, sz, algs; kwargs... )
19+ test_svd_trunc_algs (T, sz, algs; kwargs... )
2020 end
2121end
2222
@@ -160,14 +160,15 @@ function test_svd_trunc(
160160 Ac = deepcopy (A)
161161 m, n = size (A)
162162 minmn = min (m, n)
163- S₀ = svd_vals (A)
163+ S₀ = collect ( svd_vals (A) )
164164 r = minmn - 2
165165
166166 if m > 0 && n > 0
167167 U1, S1, V1ᴴ, ϵ1 = @testinferred svd_trunc (A; trunc = truncrank (r))
168168 @test length (diagview (S1)) == r
169- @test diagview (S1) ≈ S₀[1 : r]
170- @test opnorm (A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1 ]
169+ @test collect (diagview (S1)) ≈ S₀[1 : r]
170+ AUSV_vals = svd_vals (A - U1 * S1 * V1ᴴ) # bypass broken svdvals on AMDGPU
171+ @test mapreduce (sv -> opnorm (sv, 2 ), max, AUSV_vals) ≈ S₀[r + 1 ]
171172 # Test truncation error
172173 @test ϵ1 ≈ norm (view (S₀, (r + 1 ): minmn)) atol = atol
173174
@@ -241,14 +242,15 @@ function test_svd_trunc_algs(
241242 Ac = deepcopy (A)
242243 m, n = size (A)
243244 minmn = min (m, n)
244- S₀ = svd_vals (A)
245+ S₀ = collect ( svd_vals (A) )
245246 r = minmn - 2
246247
247248 if m > 0 && n > 0
248249 U1, S1, V1ᴴ, ϵ1 = @testinferred svd_trunc (A; trunc = truncrank (r), alg)
249250 @test length (diagview (S1)) == r
250- @test diagview (S1) ≈ S₀[1 : r]
251- @test opnorm (A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1 ]
251+ @test collect (diagview (S1)) ≈ S₀[1 : r]
252+ AUSV_vals = svd_vals (A - U1 * S1 * V1ᴴ) # bypass broken svdvals on AMDGPU
253+ @test mapreduce (sv -> opnorm (sv, 2 ), max, AUSV_vals) ≈ S₀[r + 1 ]
252254 # Test truncation error
253255 @test ϵ1 ≈ norm (view (S₀, (r + 1 ): minmn)) atol = atol
254256
@@ -285,11 +287,11 @@ function test_svd_trunc_algs(
285287 )
286288 U1, S1, V1ᴴ, ϵ1 = svd_trunc (A; trunc = trunc_fun (0.2 , 1 ), alg)
287289 @test length (diagview (S1)) == 1
288- @test diagview (S1) ≈ diagview (S)[1 : 1 ]
290+ @test collect ( diagview (S1)) ≈ collect ( diagview (S)[1 : 1 ])
289291
290292 U2, S2, V2ᴴ, ϵ2 = svd_trunc (A; trunc = trunc_fun (0.2 , 3 ), alg)
291293 @test length (diagview (S2)) == 2
292- @test diagview (S2) ≈ diagview (S)[1 : 2 ]
294+ @test collect ( diagview (S2)) ≈ collect ( diagview (S)[1 : 2 ])
293295 end
294296 end
295297 @testset " specify truncation algorithm" begin
@@ -303,7 +305,7 @@ function test_svd_trunc_algs(
303305 A = U * S * Vᴴ
304306 truncalg = TruncatedAlgorithm (alg, trunctol (; atol = 0.2 ))
305307 U2, S2, V2ᴴ, ϵ2 = @testinferred svd_trunc (A; alg = truncalg)
306- @test diagview (S2) ≈ diagview (S)[1 : 2 ]
308+ @test collect ( diagview (S2)) ≈ collect ( diagview (S)[1 : 2 ])
307309 @test ϵ2 ≈ norm (diagview (S)[3 : 4 ]) atol = atol
308310 @test_throws ArgumentError svd_trunc (A; alg = truncalg, trunc = (; maxrank = 2 ))
309311 @test_throws ArgumentError svd_trunc_no_error (A; alg = truncalg, trunc = (; maxrank = 2 ))
0 commit comments