Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MatrixAlgebraKit"
uuid = "6c742aac-3347-4629-af66-fc926824e5e4"
authors = ["Jutho <jutho.haegeman@ugent.be> and contributors"]
version = "0.1.1"
version = "0.1.2"
Comment thread
lkdvos marked this conversation as resolved.

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
39 changes: 36 additions & 3 deletions src/implementations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,17 @@
if isnothing(maxrank) && isnothing(atol) && isnothing(rtol)
return NoTruncation()
elseif isnothing(maxrank)
@assert isnothing(rtol) "TODO: rtol"
return trunctol(atol)
atol = @something atol 0
rtol = @something rtol 0
return TruncationKeepAbove(atol, rtol)
else
return truncrank(maxrank)
if isnothing(atol) && isnothing(rtol)
return truncrank(maxrank)
else
atol = @something atol 0
rtol = @something rtol 0
return truncrank(maxrank) & TruncationKeepAbove(atol, rtol)
end
end
end

Expand Down Expand Up @@ -82,6 +89,27 @@
"""
truncabove(atol) = TruncationKeepFiltered(≤(atol) ∘ abs)

"""
TruncationComposition(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
Compose two truncation strategies, keeping values common between the two strategies.
"""
struct TruncationComposition{T<:Tuple{Vararg{TruncationStrategy}}} <:
TruncationStrategy
components::T
end
function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
return TruncationComposition((trunc1, trunc2))
end
function Base.:&(trunc1::TruncationComposition, trunc2::TruncationComposition)
return TruncationComposition((trunc1.components..., trunc2.components...))

Check warning on line 104 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L103-L104

Added lines #L103 - L104 were not covered by tests
end
function Base.:&(trunc1::TruncationComposition, trunc2::TruncationStrategy)
return TruncationComposition((trunc1.components..., trunc2))

Check warning on line 107 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L106-L107

Added lines #L106 - L107 were not covered by tests
end
function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationComposition)
return TruncationComposition((trunc1, trunc2.components...))

Check warning on line 110 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L109-L110

Added lines #L109 - L110 were not covered by tests
end

# truncate!
# ---------
# Generic implementation: `findtruncated` followed by indexing
Expand Down Expand Up @@ -147,6 +175,11 @@
return 1:i
end

function findtruncated(values::AbstractVector, strategy::TruncationComposition)
inds = map(Base.Fix1(findtruncated, values), strategy.components)
return intersect(inds...)
end

"""
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)

Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using SafeTestsets

@safetestset "Truncate" begin
include("truncate.jl")
end
@safetestset "QR / LQ Decomposition" begin
include("qr.jl")
include("lq.jl")
Expand Down
34 changes: 33 additions & 1 deletion test/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Test
using TestExtras
using StableRNGs
using LinearAlgebra: LinearAlgebra, Diagonal, I, isposdef
using MatrixAlgebraKit: diagview
using MatrixAlgebraKit: TruncationKeepAbove, diagview

@testset "svd_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
rng = StableRNG(123)
Expand Down Expand Up @@ -115,3 +115,35 @@ end
end
end
end

@testset "svd_trunc! mix maxrank and tol for T = $T" for T in
(Float32, Float64, ComplexF32,
ComplexF64)
rng = StableRNG(123)
if LinearAlgebra.LAPACK.version() < v"3.12.0"
algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection())
else
algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(),
LAPACK_Jacobi())
end
m = 4
@testset "algorithm $alg" for alg in algs
U = qr_compact(randn(rng, T, m, m))[1]
S = Diagonal([0.9, 0.3, 0.1, 0.01])
Vᴴ = qr_compact(randn(rng, T, m, m))[1]
A = U * S * Vᴴ

for (rtol, maxrank) in ((0.2, 1), (0.2, 3))
for trunc in ((; rtol, maxrank),
truncrank(maxrank) & TruncationKeepAbove(0, rtol))
U1, S1, V1ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=1))
Comment thread
mtfishman marked this conversation as resolved.
Outdated
@test length(S1.diag) == 1
@test S1.diag ≈ S.diag[1:1] rtol = sqrt(eps(real(T)))

U2, S2, V2ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=3))
Comment thread
mtfishman marked this conversation as resolved.
Outdated
@test length(S2.diag) == 2
@test S2.diag ≈ S.diag[1:2] rtol = sqrt(eps(real(T)))
end
end
end
end
29 changes: 29 additions & 0 deletions test/truncate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using MatrixAlgebraKit
using Test
using TestExtras
using MatrixAlgebraKit: NoTruncation, TruncationComposition, TruncationKeepAbove,
TruncationStrategy

@testset "truncate" begin
trunc = @constinferred TruncationStrategy()
@test trunc isa NoTruncation

trunc = @constinferred TruncationStrategy(; atol=1e-2, rtol=1e-3)
@test trunc isa TruncationKeepAbove
@test trunc == TruncationKeepAbove(1e-2, 1e-3)
@test trunc.atol == 1e-2
@test trunc.rtol == 1e-3

trunc = @constinferred TruncationStrategy(; maxrank=10)
@test trunc isa TruncationKeepSorted
@test trunc == truncrank(10)
@test trunc.howmany == 10
@test trunc.sortby == abs
@test trunc.rev == true

trunc = @constinferred TruncationStrategy(; atol=1e-2, rtol=1e-3, maxrank=10)
@test trunc isa TruncationComposition
@test trunc == truncrank(10) & TruncationKeepAbove(1e-2, 1e-3)
@test trunc.components[1] == truncrank(10)
@test trunc.components[2] == TruncationKeepAbove(1e-2, 1e-3)
end