Skip to content

Commit 86d7f0c

Browse files
committed
Truncation composition
1 parent ec40b48 commit 86d7f0c

3 files changed

Lines changed: 64 additions & 4 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MatrixAlgebraKit"
22
uuid = "6c742aac-3347-4629-af66-fc926824e5e4"
33
authors = ["Jutho <jutho.haegeman@ugent.be> and contributors"]
4-
version = "0.1.1"
4+
version = "0.1.2"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/implementations/truncation.jl

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,17 @@ function TruncationStrategy(; atol=nothing, rtol=nothing, maxrank=nothing)
1111
if isnothing(maxrank) && isnothing(atol) && isnothing(rtol)
1212
return NoTruncation()
1313
elseif isnothing(maxrank)
14-
@assert isnothing(rtol) "TODO: rtol"
15-
return trunctol(atol)
14+
atol = @something atol 0
15+
rtol = @something rtol 0
16+
return TruncationKeepAbove(atol, rtol)
1617
else
17-
return truncrank(maxrank)
18+
if isnothing(atol) && isnothing(rtol)
19+
return truncrank(maxrank)
20+
else
21+
atol = @something atol 0
22+
rtol = @something rtol 0
23+
return truncrank(maxrank) & TruncationKeepAbove(atol, rtol)
24+
end
1825
end
1926
end
2027

@@ -82,6 +89,27 @@ Truncation strategy to discard the values that are larger than `atol` in absolut
8289
"""
8390
truncabove(atol) = TruncationKeepFiltered((atol) abs)
8491

92+
"""
93+
TruncationComposition(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
94+
Compose two truncation strategies, keeping values common between the two strategies.
95+
"""
96+
struct TruncationComposition{T<:Tuple{Vararg{TruncationStrategy}}} <:
97+
TruncationStrategy
98+
components::T
99+
end
100+
function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
101+
return TruncationComposition((trunc1, trunc2))
102+
end
103+
function Base.:&(trunc1::TruncationComposition, trunc2::TruncationComposition)
104+
return TruncationComposition((trunc1.components..., trunc2.components...))
105+
end
106+
function Base.:&(trunc1::TruncationComposition, trunc2::TruncationStrategy)
107+
return TruncationComposition((trunc1.components..., trunc2))
108+
end
109+
function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationComposition)
110+
return TruncationComposition((trunc1, trunc2.components...))
111+
end
112+
85113
# truncate!
86114
# ---------
87115
# Generic implementation: `findtruncated` followed by indexing
@@ -147,6 +175,11 @@ function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove)
147175
return 1:i
148176
end
149177

178+
function findtruncated(values::AbstractVector, strategy::TruncationComposition)
179+
inds = map(Base.Fix1(findtruncated, values), strategy.components)
180+
return intersect(inds...)
181+
end
182+
150183
"""
151184
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)
152185

test/svd.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,30 @@ end
115115
end
116116
end
117117
end
118+
119+
@testset "svd_trunc! mix maxrank and tol for T = $T" for T in
120+
(Float32, Float64, ComplexF32,
121+
ComplexF64)
122+
rng = StableRNG(123)
123+
if LinearAlgebra.LAPACK.version() < v"3.12.0"
124+
algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection())
125+
else
126+
algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(),
127+
LAPACK_Jacobi())
128+
end
129+
m = 4
130+
@testset "algorithm $alg" for alg in algs
131+
U = qr_compact(randn(rng, T, m, m))[1]
132+
S = Diagonal([0.9, 0.3, 0.1, 0.01])
133+
Vᴴ = qr_compact(randn(rng, T, m, m))[1]
134+
A = U * S * Vᴴ
135+
136+
U1, S1, V1ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=1))
137+
@test length(S1.diag) == 1
138+
@test S1.diag S.diag[1:1] rtol = sqrt(eps(real(T)))
139+
140+
U2, S2, V2ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=3))
141+
@test length(S2.diag) == 2
142+
@test S2.diag S.diag[1:2] rtol = sqrt(eps(real(T)))
143+
end
144+
end

0 commit comments

Comments
 (0)