4848# since these are implicitly discarded by selecting compact/full
4949
5050"""
51- TruncationKeepSorted(howmany::Int, sortby ::Function, rev::Bool)
51+ TruncationKeepSorted(howmany::Int, by ::Function, rev::Bool)
5252
53- Truncation strategy to keep the first `howmany` values when sorted according to `sortby` or the last `howmany` if `rev` is true.
53+ Truncation strategy to keep the first `howmany` values when sorted according to `by` in increasing (decreasing) order if `rev` is false ( true) .
5454"""
5555struct TruncationKeepSorted{F} <: TruncationStrategy
5656 howmany:: Int
57- sortby :: F
57+ by :: F
5858 rev:: Bool
5959end
6060
7070struct TruncationKeepAbove{T<: Real } <: TruncationStrategy
7171 atol:: T
7272 rtol:: T
73+ p:: Int
74+ end
75+ function TruncationKeepAbove (atol:: Real , rtol:: Real , p:: Int = 2 )
76+ return TruncationKeepAbove (promote (atol, rtol)... , p)
7377end
74- TruncationKeepAbove (atol:: Real , rtol:: Real ) = TruncationKeepAbove (promote (atol, rtol)... )
7578
7679struct TruncationKeepBelow{T<: Real } <: TruncationStrategy
7780 atol:: T
7881 rtol:: T
82+ p:: Int
83+ end
84+ function TruncationKeepBelow (atol:: Real , rtol:: Real , p:: Int = 2 )
85+ return TruncationKeepBelow (promote (atol, rtol)... , p)
7986end
80- TruncationKeepBelow (atol:: Real , rtol:: Real ) = TruncationKeepBelow (promote (atol, rtol)... )
8187
8288# TODO : better names for these functions of the above types
8389"""
@@ -137,7 +143,7 @@ Generic interface for post-truncating a decomposition, specified in `out`.
137143""" truncate!
138144# TODO : should we return a view?
139145function truncate! (:: typeof (svd_trunc!), (U, S, Vᴴ), strategy:: TruncationStrategy )
140- ind = findtruncated (diagview (S), strategy)
146+ ind = findtruncated_sorted (diagview (S), strategy)
141147 return U[:, ind], Diagonal (diagview (S)[ind]), Vᴴ[ind, :]
142148end
143149function truncate! (:: typeof (eig_trunc!), (D, V), strategy:: TruncationStrategy )
@@ -164,15 +170,38 @@ end
164170# findtruncated
165171# -------------
166172# specific implementations for finding truncated values
173+ @doc """
174+ MatrixAlgebraKit.findtruncated(values::AbstractVector, strategy::TruncationStrategy)
175+
176+ Generic interface for finding truncated values of the spectrum of a decomposition
177+ based on the `strategy`. The output should be a collection of indices specifying
178+ which values to keep. `MatrixAlgebraKit.findtruncated` is used inside of the default
179+ implementation of [`truncate!`](@ref) to perform the truncation. It does not assume that the
180+ values are sorted. For a version that assumes the values are reverse sorted by
181+ absolute value (which is the standard case for SVD) see
182+ [`MatrixAlgebraKit.findtruncated_sorted`](@ref).
183+ """ findtruncated
184+
185+ @doc """
186+ MatrixAlgebraKit.findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)
187+
188+ Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are sorted in reverse order by
189+ absolute value. However, note that this assumption is not checked, so passing values that are not sorted
190+ in that way can silently give unexpected results. This is used in the default implementation of
191+ [`svd_trunc!`](@ref).
192+ """ findtruncated_sorted
193+
167194findtruncated (values:: AbstractVector , :: NoTruncation ) = Colon ()
168195
169196# TODO : this may also permute the eigenvalues, decide if we want to allow this or not
170197# can be solved by going to simply sorting the resulting `ind`
171198function findtruncated (values:: AbstractVector , strategy:: TruncationKeepSorted )
172- sorted = sortperm (values; by= strategy. sortby, rev= strategy. rev)
173- howmany = min (strategy. howmany, length (sorted))
174- ind = sorted[1 : howmany]
175- return ind # TODO : consider sort!(ind)
199+ howmany = min (strategy. howmany, length (values))
200+ return partialsortperm (values, 1 : howmany; by= strategy. by, rev= strategy. rev)
201+ end
202+ function findtruncated_sorted (values:: AbstractVector , strategy:: TruncationKeepSorted )
203+ howmany = min (strategy. howmany, length (values))
204+ return 1 : howmany
176205end
177206
178207# TODO : consider if worth using that values are sorted when filter is `<` or `>`.
@@ -182,13 +211,22 @@ function findtruncated(values::AbstractVector, strategy::TruncationKeepFiltered)
182211end
183212
184213function findtruncated (values:: AbstractVector , strategy:: TruncationKeepBelow )
185- atol = max (strategy. atol, strategy. rtol * first (values))
186- i = @something findfirst (≤ (atol), values) length (values) + 1
214+ atol = max (strategy. atol, strategy. rtol * norm (values, strategy. p))
215+ return findall (≤ (atol), values)
216+ end
217+ function findtruncated_sorted (values:: AbstractVector , strategy:: TruncationKeepBelow )
218+ atol = max (strategy. atol, strategy. rtol * norm (values, strategy. p))
219+ i = searchsortedfirst (values, atol; by= abs, rev= true )
187220 return i: length (values)
188221end
222+
189223function findtruncated (values:: AbstractVector , strategy:: TruncationKeepAbove )
190- atol = max (strategy. atol, strategy. rtol * first (values))
191- i = @something findlast (≥ (atol), values) 0
224+ atol = max (strategy. atol, strategy. rtol * norm (values, strategy. p))
225+ return findall (≥ (atol), values)
226+ end
227+ function findtruncated_sorted (values:: AbstractVector , strategy:: TruncationKeepAbove )
228+ atol = max (strategy. atol, strategy. rtol * norm (values, strategy. p))
229+ i = searchsortedlast (values, atol; by= abs, rev= true )
192230 return 1 : i
193231end
194232
@@ -197,6 +235,11 @@ function findtruncated(values::AbstractVector, strategy::TruncationIntersection)
197235 return intersect (inds... )
198236end
199237
238+ # Generic fallback.
239+ function findtruncated_sorted (values:: AbstractVector , strategy:: TruncationStrategy )
240+ return findtruncated (values, strategy)
241+ end
242+
200243"""
201244 TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)
202245
0 commit comments