Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/src/user_interface/truncations.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,18 @@ CollapsedDocStrings = true
Currently, truncations are supported through the following different methods:

```@docs; canonical=false
notrunc
truncrank
trunctol
truncabove
truncerror
```

It is additionally possible to combine truncation strategies by making use of the `&` operator.
For example, truncating to a maximal dimension `10`, and discarding all values below `1e-6` would be achieved by:

```julia
maxdim = 10
tol = 1e-6
combined_trunc = truncrank(maxdim) & trunctol(tol)
```
3 changes: 2 additions & 1 deletion src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, CUSOLVER_DivideAndConquer,
ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi, ROCSOLVER_DivideAndConquer, ROCSOLVER_Bisection,
DiagonalAlgorithm
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered, truncerror

VERSION >= v"1.11.0-DEV.469" &&
eval(Expr(:public, :default_algorithm, :findtruncated, :findtruncated_sorted,
Expand All @@ -55,6 +55,7 @@ include("common/gauge.jl")
include("yalapack.jl")
include("algorithms.jl")
include("interface/decompositions.jl")
include("interface/truncation.jl")
include("interface/qr.jl")
include("interface/lq.jl")
include("interface/svd.jl")
Expand Down
68 changes: 68 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,74 @@ If this is not possible, for example when the output size is not known a priori
this function may return `nothing`.
""" initialize_output

# Truncation strategy
# -------------------
"""
abstract type TruncationStrategy end

Supertype to denote different strategies for truncated decompositions that are implemented via post-truncation.

See also [`truncate!`](@ref)
"""
abstract type TruncationStrategy end

@doc """
MatrixAlgebraKit.select_truncation(trunc)

Construct a [`TruncationStrategy`](@ref) from the given `NamedTuple` of keywords or input strategy.
""" select_truncation

function select_truncation(trunc)
if isnothing(trunc)
return NoTruncation()
elseif trunc isa NamedTuple
return TruncationStrategy(; trunc...)
elseif trunc isa TruncationStrategy
return trunc
else
return throw(ArgumentError("Unknown truncation strategy: $trunc"))
end
end

@doc """
MatrixAlgebraKit.findtruncated(values::AbstractVector, strategy::TruncationStrategy)

Generic interface for finding truncated values of the spectrum of a decomposition
based on the `strategy`. The output should be a collection of indices specifying
which values to keep. `MatrixAlgebraKit.findtruncated` is used inside of the default
implementation of [`truncate!`](@ref) to perform the truncation. It does not assume that the
values are sorted. For a version that assumes the values are reverse sorted (which is the
standard case for SVD) see [`MatrixAlgebraKit.findtruncated_sorted`](@ref).
""" findtruncated

@doc """
MatrixAlgebraKit.findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)

Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are sorted in reverse order.
They are assumed to be sorted in a way that is consistent with the truncation strategy,
which generally means they are sorted by absolute value but some truncation strategies allow
customizing that. However, note that this assumption is not checked, so passing values that are not sorted
Comment thread
lkdvos marked this conversation as resolved.
Outdated
in the correct way can silently give unexpected results. This is used in the default implementation of
[`svd_trunc!`](@ref).
""" findtruncated_sorted

"""
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)

Generic wrapper type for algorithms that consist of first using `alg`, followed by a
truncation through `trunc`.
"""
struct TruncatedAlgorithm{A,T} <: AbstractAlgorithm
alg::A
trunc::T
end

@doc """
truncate!(f, out, strategy::TruncationStrategy)

Generic interface for post-truncating a decomposition, specified in `out`.
""" truncate!

# Utility macros
# --------------

Expand Down
216 changes: 32 additions & 184 deletions src/implementations/truncation.jl
Original file line number Diff line number Diff line change
@@ -1,155 +1,6 @@
"""
abstract type TruncationStrategy end

Supertype to denote different strategies for truncated decompositions that are implemented via post-truncation.

See also [`truncate!`](@ref)
"""
abstract type TruncationStrategy end

function TruncationStrategy(; atol=nothing, rtol=nothing, maxrank=nothing)
if isnothing(maxrank) && isnothing(atol) && isnothing(rtol)
return NoTruncation()
elseif isnothing(maxrank)
atol = @something atol 0
rtol = @something rtol 0
return TruncationKeepAbove(atol, rtol)
else
if isnothing(atol) && isnothing(rtol)
return truncrank(maxrank)
else
atol = @something atol 0
rtol = @something rtol 0
return truncrank(maxrank) & TruncationKeepAbove(atol, rtol)
end
end
end

"""
NoTruncation()

Trivial truncation strategy that keeps all values, mostly for testing purposes.
"""
struct NoTruncation <: TruncationStrategy end

function select_truncation(trunc)
if isnothing(trunc)
return NoTruncation()
elseif trunc isa NamedTuple
return TruncationStrategy(; trunc...)
elseif trunc isa TruncationStrategy
return trunc
else
return throw(ArgumentError("Unknown truncation strategy: $trunc"))
end
end

# TODO: how do we deal with sorting/filters that treat zeros differently
# since these are implicitly discarded by selecting compact/full

"""
TruncationKeepSorted(howmany::Int, by::Function, rev::Bool)

Truncation strategy to keep the first `howmany` values when sorted according to `by` in increasing (decreasing) order if `rev` is false (true).
"""
struct TruncationKeepSorted{F} <: TruncationStrategy
howmany::Int
by::F
rev::Bool
end

"""
TruncationKeepFiltered(filter::Function)

Truncation strategy to keep the values for which `filter` returns true.
"""
struct TruncationKeepFiltered{F} <: TruncationStrategy
filter::F
end

struct TruncationKeepAbove{T<:Real,F} <: TruncationStrategy
atol::T
rtol::T
p::Int
by::F
end
function TruncationKeepAbove(; atol::Real, rtol::Real, p::Int=2, by=abs)
return TruncationKeepAbove(atol, rtol, p, by)
end
function TruncationKeepAbove(atol::Real, rtol::Real, p::Int=2, by=abs)
return TruncationKeepAbove(promote(atol, rtol)..., p, by)
end

struct TruncationKeepBelow{T<:Real,F} <: TruncationStrategy
atol::T
rtol::T
p::Int
by::F
end
function TruncationKeepBelow(; atol::Real, rtol::Real, p::Int=2, by=abs)
return TruncationKeepBelow(atol, rtol, p, by)
end
function TruncationKeepBelow(atol::Real, rtol::Real, p::Int=2, by=abs)
return TruncationKeepBelow(promote(atol, rtol)..., p, by)
end

# TODO: better names for these functions of the above types
"""
truncrank(howmany::Int; by=abs, rev=true)

Truncation strategy to keep the first `howmany` values when sorted according to `by` or the last `howmany` if `rev` is true.
"""
truncrank(howmany::Int; by=abs, rev=true) = TruncationKeepSorted(howmany, by, rev)

"""
trunctol(atol::Real; by=abs)

Truncation strategy to discard the values that are smaller than `atol` according to `by`.
"""
trunctol(atol; by=abs) = TruncationKeepFiltered(≥(atol) ∘ by)

"""
truncabove(atol::Real; by=abs)

Truncation strategy to discard the values that are larger than `atol` according to `by`.
"""
truncabove(atol; by=abs) = TruncationKeepFiltered(≤(atol) ∘ by)

"""
TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...)

Composition of multiple truncation strategies, keeping values common between them.
"""
struct TruncationIntersection{T<:Tuple{Vararg{TruncationStrategy}}} <:
TruncationStrategy
components::T
end
function TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...)
return TruncationIntersection((trunc, truncs...))
end

function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
return TruncationIntersection((trunc1, trunc2))
end
function Base.:&(trunc1::TruncationIntersection, trunc2::TruncationIntersection)
return TruncationIntersection((trunc1.components..., trunc2.components...))
end
function Base.:&(trunc1::TruncationIntersection, trunc2::TruncationStrategy)
return TruncationIntersection((trunc1.components..., trunc2))
end
function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationIntersection)
return TruncationIntersection((trunc1, trunc2.components...))
end

# truncate!
# ---------
# Generic implementation: `findtruncated` followed by indexing
@doc """
truncate!(f, out, strategy::TruncationStrategy)

Generic interface for post-truncating a decomposition, specified in `out`.
""" truncate!
# TODO: should we return a view?
function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ), strategy::TruncationStrategy)
ind = findtruncated_sorted(diagview(S), strategy)
return U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :]
Expand Down Expand Up @@ -178,32 +29,8 @@ end
# findtruncated
# -------------
# specific implementations for finding truncated values
@doc """
MatrixAlgebraKit.findtruncated(values::AbstractVector, strategy::TruncationStrategy)

Generic interface for finding truncated values of the spectrum of a decomposition
based on the `strategy`. The output should be a collection of indices specifying
which values to keep. `MatrixAlgebraKit.findtruncated` is used inside of the default
implementation of [`truncate!`](@ref) to perform the truncation. It does not assume that the
values are sorted. For a version that assumes the values are reverse sorted (which is the
standard case for SVD) see [`MatrixAlgebraKit.findtruncated_sorted`](@ref).
""" findtruncated

@doc """
MatrixAlgebraKit.findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)

Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are sorted in reverse order.
They are assumed to be sorted in a way that is consistent with the truncation strategy,
which generally means they are sorted by absolute value but some truncation strategies allow
customizing that. However, note that this assumption is not checked, so passing values that are not sorted
in the correct way can silently give unexpected results. This is used in the default implementation of
[`svd_trunc!`](@ref).
""" findtruncated_sorted

findtruncated(values::AbstractVector, ::NoTruncation) = Colon()

# TODO: this may also permute the eigenvalues, decide if we want to allow this or not
# can be solved by going to simply sorting the resulting `ind`
function findtruncated(values::AbstractVector, strategy::TruncationKeepSorted)
howmany = min(strategy.howmany, length(values))
return partialsortperm(values, 1:howmany; by=strategy.by, rev=strategy.rev)
Expand Down Expand Up @@ -243,19 +70,40 @@ function findtruncated(values::AbstractVector, strategy::TruncationIntersection)
inds = map(Base.Fix1(findtruncated, values), strategy.components)
return intersect(inds...)
end
function findtruncated_sorted(values::AbstractVector, strategy::TruncationIntersection)
inds = map(Base.Fix1(findtruncated_sorted, values), strategy.components)
return intersect(inds...)
end

# Generic fallback.
function findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)
return findtruncated(values, strategy)
function findtruncated(values::AbstractVector, strategy::TruncationError)
I = sortperm(values; by=abs, rev=true)
I′ = _truncerr_impl(values, I, strategy)
return I[I′]
end
Comment thread
lkdvos marked this conversation as resolved.
function findtruncated_sorted(values::AbstractVector, strategy::TruncationError)
I = eachindex(values)
I′ = _truncerr_impl(values, I, strategy)
return I[I′]
end
function _truncerr_impl(values::AbstractVector, I, strategy::TruncationError)
Nᵖ = sum(Base.Fix2(^, strategy.p) ∘ abs, values)
ϵᵖ = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * Nᵖ)
ϵᵖ ≥ Nᵖ && return Base.OneTo(0)

"""
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)
truncerrᵖ = zero(real(eltype(values)))
rank = length(values)
for i in reverse(I)
truncerrᵖ += abs(values[i])^strategy.p
if truncerrᵖ ≥ ϵᵖ
break
else
rank -= 1
end
end
return Base.OneTo(rank)
end

Generic wrapper type for algorithms that consist of first using `alg`, followed by a
truncation through `trunc`.
"""
struct TruncatedAlgorithm{A,T} <: AbstractAlgorithm
alg::A
trunc::T
# Generic fallback
function findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)
return findtruncated(values, strategy)
end
Loading
Loading