Skip to content

Commit 0dfad35

Browse files
committed
Simplify implementation of findtruncated
1 parent 55c0a07 commit 0dfad35

2 files changed

Lines changed: 20 additions & 38 deletions

File tree

src/implementations/truncation.jl

Lines changed: 19 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ end
4848
# since these are implicitly discarded by selecting compact/full
4949

5050
"""
51-
TruncationKeepSorted(howmany::Int, sortby::Function, rev::Bool)
51+
TruncationKeepSorted(howmany::Int, by::Function, rev::Bool)
5252
53-
Truncation strategy to keep the first `howmany` values when sorted according to `sortby` or the last `howmany` if `rev` is true.
53+
Truncation strategy to keep the first `howmany` values when sorted according to `by` or the last `howmany` if `rev` is true.
5454
"""
5555
struct TruncationKeepSorted{F} <: TruncationStrategy
5656
howmany::Int
57-
sortby::F
57+
by::F
5858
rev::Bool
5959
end
6060

@@ -137,7 +137,7 @@ Generic interface for post-truncating a decomposition, specified in `out`.
137137
""" truncate!
138138
# TODO: should we return a view?
139139
function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ), strategy::TruncationStrategy)
140-
ind = findtruncated(diagview(S), strategy)
140+
ind = findtruncated_sorted(diagview(S), strategy)
141141
return U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :]
142142
end
143143
function truncate!(::typeof(eig_trunc!), (D, V), strategy::TruncationStrategy)
@@ -166,25 +166,16 @@ end
166166
# specific implementations for finding truncated values
167167
findtruncated(values::AbstractVector, ::NoTruncation) = Colon()
168168

169+
# TODO: this may also permute the eigenvalues, decide if we want to allow this or not
170+
# can be solved by going to simply sorting the resulting `ind`
169171
function findtruncated(values::AbstractVector, strategy::TruncationKeepSorted)
170-
if issorted(values; by=strategy.sortby, rev=strategy.rev)
171-
return convert(Vector{Int}, findtruncated_sorted(values, strategy))
172-
else
173-
return findtruncated_unsorted(values, strategy)
174-
end
172+
howmany = min(strategy.howmany, length(values))
173+
return partialsortperm(values, 1:howmany; by=strategy.by, rev=strategy.rev)
175174
end
176175
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepSorted)
177176
howmany = min(strategy.howmany, length(values))
178177
return 1:howmany
179178
end
180-
# TODO: this may also permute the eigenvalues, decide if we want to allow this or not
181-
# can be solved by going to simply sorting the resulting `ind`
182-
function findtruncated_unsorted(values::AbstractVector, strategy::TruncationKeepSorted)
183-
sorted = sortperm(values; by=strategy.sortby, rev=strategy.rev)
184-
howmany = min(strategy.howmany, length(sorted))
185-
ind = sorted[1:howmany]
186-
return ind # TODO: consider sort!(ind)
187-
end
188179

189180
# TODO: consider if worth using that values are sorted when filter is `<` or `>`.
190181
function findtruncated(values::AbstractVector, strategy::TruncationKeepFiltered)
@@ -193,44 +184,35 @@ function findtruncated(values::AbstractVector, strategy::TruncationKeepFiltered)
193184
end
194185

195186
function findtruncated(values::AbstractVector, strategy::TruncationKeepBelow)
196-
if issorted(values; by=abs, rev=true)
197-
return convert(Vector{Int}, findtruncated_sorted(values, strategy))
198-
else
199-
return findtruncated_unsorted(values, strategy)
200-
end
187+
atol = max(strategy.atol, strategy.rtol * maximum(values))
188+
return findall((atol), values)
201189
end
202190
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepBelow)
203191
atol = max(strategy.atol, strategy.rtol * first(values))
204-
i = @something findfirst((atol), values) length(values) + 1
192+
i = searchsortedfirst(values, atol; by=abs, rev=true)
205193
return i:length(values)
206194
end
207-
function findtruncated_unsorted(values::AbstractVector, strategy::TruncationKeepBelow)
208-
atol = max(strategy.atol, strategy.rtol * maximum(values))
209-
return findall((atol), values)
210-
end
211195

212196
function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove)
213-
if issorted(values; by=abs, rev=true)
214-
return convert(Vector{Int}, findtruncated_sorted(values, strategy))
215-
else
216-
return findtruncated_unsorted(values, strategy)
217-
end
197+
atol = max(strategy.atol, strategy.rtol * maximum(values))
198+
return findall((atol), values)
218199
end
219200
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepAbove)
220201
atol = max(strategy.atol, strategy.rtol * first(values))
221-
i = @something findlast((atol), values) 0
202+
i = searchsortedlast(values, atol; by=abs, rev=true)
222203
return 1:i
223204
end
224-
function findtruncated_unsorted(values::AbstractVector, strategy::TruncationKeepAbove)
225-
atol = max(strategy.atol, strategy.rtol * maximum(values))
226-
return findall((atol), values)
227-
end
228205

229206
function findtruncated(values::AbstractVector, strategy::TruncationIntersection)
230207
inds = map(Base.Fix1(findtruncated, values), strategy.components)
231208
return intersect(inds...)
232209
end
233210

211+
# Generic fallback.
212+
function findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)
213+
return findtruncated(values, strategy)
214+
end
215+
234216
"""
235217
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)
236218

test/truncate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbov
1818
@test trunc isa TruncationKeepSorted
1919
@test trunc == truncrank(10)
2020
@test trunc.howmany == 10
21-
@test trunc.sortby == abs
21+
@test trunc.by == abs
2222
@test trunc.rev == true
2323

2424
trunc = @constinferred TruncationStrategy(; atol=1e-2, rtol=1e-3, maxrank=10)

0 commit comments

Comments
 (0)