Skip to content

Commit 1e86aea

Browse files
mtfishmanJutholkdvos
authored
Support unsorted spectra in TruncationKeepAbove/Below (#26)
Co-authored-by: Jutho <Jutho@users.noreply.github.com> Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent 253dc63 commit 1e86aea

4 files changed

Lines changed: 72 additions & 18 deletions

File tree

docs/src/dev_interface.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,6 @@ MatrixAlgebraKit.jl provides a developer interface for specifying custom algorit
1010
```@docs; canonical=false
1111
MatrixAlgebraKit.default_algorithm
1212
MatrixAlgebraKit.select_algorithm
13+
MatrixAlgebraKit.findtruncated
14+
MatrixAlgebraKit.findtruncated_sorted
1315
```

src/MatrixAlgebraKit.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
3131
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered
3232

3333
VERSION >= v"1.11.0-DEV.469" &&
34-
eval(Expr(:public, :default_algorithm, :select_algorithm))
34+
eval(Expr(:public, :default_algorithm, :findtruncated, :findtruncated_sorted,
35+
:select_algorithm))
3536

3637
include("common/defaults.jl")
3738
include("common/initialization.jl")

src/implementations/truncation.jl

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ end
4848
# since these are implicitly discarded by selecting compact/full
4949

5050
"""
51-
TruncationKeepSorted(howmany::Int, sortby::Function, rev::Bool)
51+
TruncationKeepSorted(howmany::Int, by::Function, rev::Bool)
5252
53-
Truncation strategy to keep the first `howmany` values when sorted according to `sortby` or the last `howmany` if `rev` is true.
53+
Truncation strategy to keep the first `howmany` values when sorted according to `by` in increasing (decreasing) order if `rev` is false (true).
5454
"""
5555
struct TruncationKeepSorted{F} <: TruncationStrategy
5656
howmany::Int
57-
sortby::F
57+
by::F
5858
rev::Bool
5959
end
6060

@@ -70,14 +70,20 @@ end
7070
struct TruncationKeepAbove{T<:Real} <: TruncationStrategy
7171
atol::T
7272
rtol::T
73+
p::Int
74+
end
75+
function TruncationKeepAbove(atol::Real, rtol::Real, p::Int=2)
76+
return TruncationKeepAbove(promote(atol, rtol)..., p)
7377
end
74-
TruncationKeepAbove(atol::Real, rtol::Real) = TruncationKeepAbove(promote(atol, rtol)...)
7578

7679
struct TruncationKeepBelow{T<:Real} <: TruncationStrategy
7780
atol::T
7881
rtol::T
82+
p::Int
83+
end
84+
function TruncationKeepBelow(atol::Real, rtol::Real, p::Int=2)
85+
return TruncationKeepBelow(promote(atol, rtol)..., p)
7986
end
80-
TruncationKeepBelow(atol::Real, rtol::Real) = TruncationKeepBelow(promote(atol, rtol)...)
8187

8288
# TODO: better names for these functions of the above types
8389
"""
@@ -137,7 +143,7 @@ Generic interface for post-truncating a decomposition, specified in `out`.
137143
""" truncate!
138144
# TODO: should we return a view?
139145
function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ), strategy::TruncationStrategy)
140-
ind = findtruncated(diagview(S), strategy)
146+
ind = findtruncated_sorted(diagview(S), strategy)
141147
return U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :]
142148
end
143149
function truncate!(::typeof(eig_trunc!), (D, V), strategy::TruncationStrategy)
@@ -164,15 +170,38 @@ end
164170
# findtruncated
165171
# -------------
166172
# specific implementations for finding truncated values
173+
@doc """
174+
MatrixAlgebraKit.findtruncated(values::AbstractVector, strategy::TruncationStrategy)
175+
176+
Generic interface for finding truncated values of the spectrum of a decomposition
177+
based on the `strategy`. The output should be a collection of indices specifying
178+
which values to keep. `MatrixAlgebraKit.findtruncated` is used inside of the default
179+
implementation of [`truncate!`](@ref) to perform the truncation. It does not assume that the
180+
values are sorted. For a version that assumes the values are reverse sorted by
181+
absolute value (which is the standard case for SVD) see
182+
[`MatrixAlgebraKit.findtruncated_sorted`](@ref).
183+
""" findtruncated
184+
185+
@doc """
186+
MatrixAlgebraKit.findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)
187+
188+
Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are sorted in reverse order by
189+
absolute value. However, note that this assumption is not checked, so passing values that are not sorted
190+
in that way can silently give unexpected results. This is used in the default implementation of
191+
[`svd_trunc!`](@ref).
192+
""" findtruncated_sorted
193+
167194
findtruncated(values::AbstractVector, ::NoTruncation) = Colon()
168195

169196
# TODO: this may also permute the eigenvalues, decide if we want to allow this or not
170197
# can be solved by going to simply sorting the resulting `ind`
171198
function findtruncated(values::AbstractVector, strategy::TruncationKeepSorted)
172-
sorted = sortperm(values; by=strategy.sortby, rev=strategy.rev)
173-
howmany = min(strategy.howmany, length(sorted))
174-
ind = sorted[1:howmany]
175-
return ind # TODO: consider sort!(ind)
199+
howmany = min(strategy.howmany, length(values))
200+
return partialsortperm(values, 1:howmany; by=strategy.by, rev=strategy.rev)
201+
end
202+
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepSorted)
203+
howmany = min(strategy.howmany, length(values))
204+
return 1:howmany
176205
end
177206

178207
# TODO: consider if worth using that values are sorted when filter is `<` or `>`.
@@ -182,13 +211,22 @@ function findtruncated(values::AbstractVector, strategy::TruncationKeepFiltered)
182211
end
183212

184213
function findtruncated(values::AbstractVector, strategy::TruncationKeepBelow)
185-
atol = max(strategy.atol, strategy.rtol * first(values))
186-
i = @something findfirst((atol), values) length(values) + 1
214+
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
215+
return findall((atol), values)
216+
end
217+
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepBelow)
218+
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
219+
i = searchsortedfirst(values, atol; by=abs, rev=true)
187220
return i:length(values)
188221
end
222+
189223
function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove)
190-
atol = max(strategy.atol, strategy.rtol * first(values))
191-
i = @something findlast((atol), values) 0
224+
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
225+
return findall((atol), values)
226+
end
227+
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepAbove)
228+
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
229+
i = searchsortedlast(values, atol; by=abs, rev=true)
192230
return 1:i
193231
end
194232

@@ -197,6 +235,11 @@ function findtruncated(values::AbstractVector, strategy::TruncationIntersection)
197235
return intersect(inds...)
198236
end
199237

238+
# Generic fallback.
239+
function findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)
240+
return findtruncated(values, strategy)
241+
end
242+
200243
"""
201244
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)
202245

test/truncate.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using MatrixAlgebraKit
22
using Test
33
using TestExtras
44
using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbove,
5-
TruncationStrategy, findtruncated
5+
TruncationKeepBelow, TruncationStrategy, findtruncated
66

77
@testset "truncate" begin
88
trunc = @constinferred TruncationStrategy()
@@ -18,7 +18,7 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbov
1818
@test trunc isa TruncationKeepSorted
1919
@test trunc == truncrank(10)
2020
@test trunc.howmany == 10
21-
@test trunc.sortby == abs
21+
@test trunc.by == abs
2222
@test trunc.rev == true
2323

2424
trunc = @constinferred TruncationStrategy(; atol=1e-2, rtol=1e-3, maxrank=10)
@@ -28,7 +28,15 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbov
2828
@test trunc.components[2] == TruncationKeepAbove(1e-2, 1e-3)
2929

3030
values = [1, 0.9, 0.5, 0.3, 0.01]
31-
@test @constinferred(findtruncated(values, truncrank(2))) == [1, 2]
31+
@test @constinferred(findtruncated(values, truncrank(2))) == 1:2
3232
@test @constinferred(findtruncated(values, truncrank(2; rev=false))) == [5, 4]
3333
@test @constinferred(findtruncated(values, truncrank(2; by=-))) == [5, 4]
34+
35+
values = [1, 0.9, 0.5, 0.3, 0.01]
36+
@test @constinferred(findtruncated(values, TruncationKeepAbove(0.4, 0.0))) == 1:3
37+
@test @constinferred(findtruncated(values, TruncationKeepBelow(0.4, 0.0))) == 4:5
38+
39+
values = [0.01, 1, 0.9, 0.3, 0.5]
40+
@test @constinferred(findtruncated(values, TruncationKeepAbove(0.4, 0.0))) == [2, 3, 5]
41+
@test @constinferred(findtruncated(values, TruncationKeepBelow(0.4, 0.0))) == [1, 4]
3442
end

0 commit comments

Comments
 (0)