4848# since these are implicitly discarded by selecting compact/full
4949
5050"""
51- TruncationKeepSorted(howmany::Int, sortby ::Function, rev::Bool)
51+ TruncationKeepSorted(howmany::Int, by ::Function, rev::Bool)
5252
53- Truncation strategy to keep the first `howmany` values when sorted according to `sortby ` or the last `howmany` if `rev` is true.
53+ Truncation strategy to keep the first `howmany` values when sorted according to `by ` or the last `howmany` if `rev` is true.
5454"""
5555struct TruncationKeepSorted{F} <: TruncationStrategy
5656 howmany:: Int
57- sortby :: F
57+ by :: F
5858 rev:: Bool
5959end
6060
@@ -137,7 +137,7 @@ Generic interface for post-truncating a decomposition, specified in `out`.
137137""" truncate!
138138# TODO : should we return a view?
139139function truncate! (:: typeof (svd_trunc!), (U, S, Vᴴ), strategy:: TruncationStrategy )
140- ind = findtruncated (diagview (S), strategy)
140+ ind = findtruncated_sorted (diagview (S), strategy)
141141 return U[:, ind], Diagonal (diagview (S)[ind]), Vᴴ[ind, :]
142142end
143143function truncate! (:: typeof (eig_trunc!), (D, V), strategy:: TruncationStrategy )
@@ -166,25 +166,16 @@ end
166166# specific implementations for finding truncated values
167167findtruncated (values:: AbstractVector , :: NoTruncation ) = Colon ()
168168
169+ # TODO : this may also permute the eigenvalues, decide if we want to allow this or not
170+ # can be solved by going to simply sorting the resulting `ind`
169171function findtruncated (values:: AbstractVector , strategy:: TruncationKeepSorted )
170- if issorted (values; by= strategy. sortby, rev= strategy. rev)
171- return convert (Vector{Int}, findtruncated_sorted (values, strategy))
172- else
173- return findtruncated_unsorted (values, strategy)
174- end
172+ howmany = min (strategy. howmany, length (values))
173+ return partialsortperm (values, 1 : howmany; by= strategy. by, rev= strategy. rev)
175174end
176175function findtruncated_sorted (values:: AbstractVector , strategy:: TruncationKeepSorted )
177176 howmany = min (strategy. howmany, length (values))
178177 return 1 : howmany
179178end
180- # TODO : this may also permute the eigenvalues, decide if we want to allow this or not
181- # can be solved by going to simply sorting the resulting `ind`
182- function findtruncated_unsorted (values:: AbstractVector , strategy:: TruncationKeepSorted )
183- sorted = sortperm (values; by= strategy. sortby, rev= strategy. rev)
184- howmany = min (strategy. howmany, length (sorted))
185- ind = sorted[1 : howmany]
186- return ind # TODO : consider sort!(ind)
187- end
188179
189180# TODO : consider if worth using that values are sorted when filter is `<` or `>`.
190181function findtruncated (values:: AbstractVector , strategy:: TruncationKeepFiltered )
@@ -193,44 +184,35 @@ function findtruncated(values::AbstractVector, strategy::TruncationKeepFiltered)
193184end
194185
195186function findtruncated (values:: AbstractVector , strategy:: TruncationKeepBelow )
196- if issorted (values; by= abs, rev= true )
197- return convert (Vector{Int}, findtruncated_sorted (values, strategy))
198- else
199- return findtruncated_unsorted (values, strategy)
200- end
187+ atol = max (strategy. atol, strategy. rtol * maximum (values))
188+ return findall (≤ (atol), values)
201189end
202190function findtruncated_sorted (values:: AbstractVector , strategy:: TruncationKeepBelow )
203191 atol = max (strategy. atol, strategy. rtol * first (values))
204- i = @something findfirst ( ≤ (atol), values) length (values) + 1
192+ i = searchsortedfirst (values, atol; by = abs, rev = true )
205193 return i: length (values)
206194end
207- function findtruncated_unsorted (values:: AbstractVector , strategy:: TruncationKeepBelow )
208- atol = max (strategy. atol, strategy. rtol * maximum (values))
209- return findall (≤ (atol), values)
210- end
211195
212196function findtruncated (values:: AbstractVector , strategy:: TruncationKeepAbove )
213- if issorted (values; by= abs, rev= true )
214- return convert (Vector{Int}, findtruncated_sorted (values, strategy))
215- else
216- return findtruncated_unsorted (values, strategy)
217- end
197+ atol = max (strategy. atol, strategy. rtol * maximum (values))
198+ return findall (≥ (atol), values)
218199end
219200function findtruncated_sorted (values:: AbstractVector , strategy:: TruncationKeepAbove )
220201 atol = max (strategy. atol, strategy. rtol * first (values))
221- i = @something findlast ( ≥ ( atol), values) 0
202+ i = searchsortedlast (values, atol; by = abs, rev = true )
222203 return 1 : i
223204end
224- function findtruncated_unsorted (values:: AbstractVector , strategy:: TruncationKeepAbove )
225- atol = max (strategy. atol, strategy. rtol * maximum (values))
226- return findall (≥ (atol), values)
227- end
228205
229206function findtruncated (values:: AbstractVector , strategy:: TruncationIntersection )
230207 inds = map (Base. Fix1 (findtruncated, values), strategy. components)
231208 return intersect (inds... )
232209end
233210
211+ # Generic fallback.
212+ function findtruncated_sorted (values:: AbstractVector , strategy:: TruncationStrategy )
213+ return findtruncated (values, strategy)
214+ end
215+
234216"""
235217 TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)
236218
0 commit comments