Skip to content

Commit e679407

Browse files
committed
some more test updates
1 parent d111cd9 commit e679407

3 files changed

Lines changed: 13 additions & 6 deletions

File tree

test/decompositions/svd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ if !is_buildkite
4343
for T in GenericFloats, m in (0, 54), n in (0, 37, m, 63)
4444
TestSuite.seed_rng!(123)
4545
TestSuite.test_svd(T, (m, n))
46-
TestSuite.test_svd_algs(T, (m, n), (GLA_QRIteration(),))
46+
TestSuite.test_svd_algs(T, (m, n), (QRIteration(; driver = MatrixAlgebraKit.GLA()),))
4747
end
4848

4949
# Diagonal:

test/testsuite/TestSuite.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,10 @@ function instantiate_rank_deficient_matrix(::Type{T}, sz; trunc = truncrank(div(
101101
return Diagonal(diag(mul!(A, V, C)))
102102
end
103103

104-
function instantiate_almost_rank_deficient_matrix(T, sz; trunc = truncrank(div(min(sz...), 2)), atol::Real = 0, rtol::Real = precision(T))
104+
function instantiate_almost_rank_deficient_matrix(
105+
T, sz;
106+
trunc = truncrank(div(min(sz...), 2)), atol::Real = 0, rtol::Real = precision(T)
107+
)
105108
A = instantiate_rank_deficient_matrix(T, sz; trunc)
106109
noise = normalize(instantiate_matrix(T, sz))
107110
A .+= max(atol, rtol * norm(A)) * noise

test/testsuite/decompositions/svd.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -371,19 +371,23 @@ function test_sketched_svd(
371371
return @testset "sketched svd_trunc! algorithm $alg $summary_str" for alg in algs
372372
@assert alg isa SketchedAlgorithm "Invalid sketched algorithm type: $(typeof(alg))"
373373

374-
A = instantiate_rank_deficient_matrix(T, sz; alg.trunc)
375-
A += max(atol, rtol * norm(A)) * instantiate_matrix(T, sz)
374+
A = instantiate_almost_rank_deficient_matrix(T, sz; alg.trunc, atol, rtol)
376375
Ac = deepcopy(A)
377376

378377
alg2 = MatrixAlgebraKit.TruncatedAlgorithm(alg)
379378

380379
U, S, Vᴴ, ϵ = @testinferred svd_trunc(A, alg)
381380
@test Ac == A
381+
ϵ′ = norm(A - U * S * Vᴴ)
382+
@test ϵ′ ϵ atol = sqrt(rtol) * max(one(ϵ′), ϵ′) # comparison to 0 is hard, very imprecise calculation
382383

383-
U′, S′, Vᴴ′, ϵ′ = svd_trunc(A, alg2)
384+
U′, S′, Vᴴ′ = svd_trunc_no_error(A, alg2)
385+
386+
# Need gauge fixing for comparison
387+
U, Vᴴ = MatrixAlgebraKit.gaugefix!(svd_trunc!, U, Vᴴ)
388+
U′, Vᴴ′ = MatrixAlgebraKit.gaugefix!(svd_trunc!, U′, Vᴴ′)
384389
@test U U′ atol = atol rtol = rtol
385390
@test S S′ atol = atol rtol = rtol
386391
@test Vᴴ Vᴴ′ atol = atol rtol = rtol
387-
@test ϵ ϵ′ atol = atol rtol = rtol
388392
end
389393
end

0 commit comments

Comments
 (0)