Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MatrixAlgebraKit"
uuid = "6c742aac-3347-4629-af66-fc926824e5e4"
authors = ["Jutho <jutho.haegeman@ugent.be> and contributors"]
version = "0.3.2"
version = "0.4.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/dev_interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ MatrixAlgebraKit.jl provides a developer interface for specifying custom algorit
MatrixAlgebraKit.default_algorithm
MatrixAlgebraKit.select_algorithm
MatrixAlgebraKit.findtruncated
MatrixAlgebraKit.findtruncated_sorted
MatrixAlgebraKit.findtruncated_svd
```
6 changes: 3 additions & 3 deletions docs/src/user_interface/truncations.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Currently, truncations are supported through the following different methods:
notrunc
truncrank
trunctol
truncabove
truncfilter
truncerror
```

Expand All @@ -20,6 +20,6 @@ For example, truncating to a maximal dimension `10`, and discarding all values b

```julia
maxdim = 10
tol = 1e-6
combined_trunc = truncrank(maxdim) & trunctol(tol)
atol = 1e-6
combined_trunc = truncrank(maxdim) & trunctol(; atol)
```
7 changes: 6 additions & 1 deletion ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using MatrixAlgebraKit
using MatrixAlgebraKit: @algdef, Algorithm, check_input
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: LQViaTransposedQR
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!
Expand Down Expand Up @@ -40,4 +40,9 @@ _gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwar
_gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevd!(A, Dd, V; kwargs...)
_gpu_heev!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heev!(A, Dd, V; kwargs...)
_gpu_heevx!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevx!(A, Dd, V; kwargs...)

function MatrixAlgebraKit.findtruncated_svd(values::StridedROCVector, strategy::TruncationByValue)
return MatrixAlgebraKit.findtruncated(values, strategy)
end

end
6 changes: 5 additions & 1 deletion ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using MatrixAlgebraKit
using MatrixAlgebraKit: @algdef, Algorithm, check_input
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: LQViaTransposedQR
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!
Expand Down Expand Up @@ -44,4 +44,8 @@ _gpu_gesvdj!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::S
_gpu_heevj!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevj!(A, Dd, V; kwargs...)
_gpu_heevd!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevd!(A, Dd, V; kwargs...)

function MatrixAlgebraKit.findtruncated_svd(values::StridedCuVector, strategy::TruncationByValue)
return MatrixAlgebraKit.findtruncated(values, strategy)
end

end
25 changes: 13 additions & 12 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,21 @@ export left_polar!, right_polar!
export left_orth, right_orth, left_null, right_null
export left_orth!, right_orth!, left_null!, right_null!

export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
LAPACK_Simple, LAPACK_Expert,
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
LAPACK_DivideAndConquer, LAPACK_Jacobi,
LQViaTransposedQR,
CUSOLVER_Simple,
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, CUSOLVER_DivideAndConquer,
ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi, ROCSOLVER_DivideAndConquer, ROCSOLVER_Bisection,
DiagonalAlgorithm
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered, truncerror
export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations, LAPACK_DivideAndConquer, LAPACK_Jacobi
export LQViaTransposedQR
export DiagonalAlgorithm
export CUSOLVER_Simple, CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, CUSOLVER_DivideAndConquer
export ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi,
ROCSOLVER_DivideAndConquer, ROCSOLVER_Bisection

VERSION >= v"1.11.0-DEV.469" &&
eval(Expr(:public, :default_algorithm, :findtruncated, :findtruncated_sorted,
export notrunc, truncrank, trunctol, truncerror, truncfilter

@static if VERSION >= v"1.11.0-DEV.469"
eval(Expr(:public, :default_algorithm, :findtruncated, :findtruncated_svd,
:select_algorithm))
eval(Expr(:public, :TruncationByOrder, :TruncationByFilter, :TruncationByValue,
:TruncationByError, :TruncationIntersection))
end

include("common/defaults.jl")
include("common/initialization.jl")
Expand Down
6 changes: 3 additions & 3 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,16 @@ based on the `strategy`. The output should be a collection of indices specifying
which values to keep. `MatrixAlgebraKit.findtruncated` is used inside of the default
implementation of [`truncate!`](@ref) to perform the truncation. It does not assume that the
values are sorted. For a version that assumes the values are reverse sorted (which is the
standard case for SVD) see [`MatrixAlgebraKit.findtruncated_sorted`](@ref).
standard case for SVD) see [`MatrixAlgebraKit.findtruncated_svd`](@ref).
""" findtruncated

@doc """
MatrixAlgebraKit.findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)
MatrixAlgebraKit.findtruncated_svd(values::AbstractVector, strategy::TruncationStrategy)

Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are real and
sorted in descending order, as typically obtained by the SVD. This assumption is not
checked, and this is used in the default implementation of [`svd_trunc!`](@ref).
""" findtruncated_sorted
""" findtruncated_svd

"""
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)
Expand Down
4 changes: 2 additions & 2 deletions src/implementations/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,11 @@ end
# --------------------------------
function null_truncation_strategy(; atol=nothing, rtol=nothing, maxnullity=nothing)
if isnothing(maxnullity) && isnothing(atol) && isnothing(rtol)
return NoTruncation()
return notrunc()
end
atol = @something atol 0
rtol = @something rtol 0
trunc = TruncationKeepBelow(atol, rtol)
trunc = trunctol(; atol, rtol, keep_below=true)
return !isnothing(maxnullity) ? trunc & truncrank(maxnullity; rev=false) : trunc
end

Expand Down
105 changes: 56 additions & 49 deletions src/implementations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# ---------
# Generic implementation: `findtruncated` followed by indexing
function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ), strategy::TruncationStrategy)
ind = findtruncated_sorted(diagview(S), strategy)
ind = findtruncated_svd(diagview(S), strategy)
return U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :]
end
function truncate!(::typeof(eig_trunc!), (D, V), strategy::TruncationStrategy)
Expand All @@ -28,82 +28,89 @@ end

# findtruncated
# -------------
# Generic fallback
function findtruncated_svd(values::AbstractVector, strategy::TruncationStrategy)
return findtruncated(values, strategy)
end

# specific implementations for finding truncated values
findtruncated(values::AbstractVector, ::NoTruncation) = Colon()

function findtruncated(values::AbstractVector, strategy::TruncationKeepSorted)
function findtruncated(values::AbstractVector, strategy::TruncationByOrder)
howmany = min(strategy.howmany, length(values))
return partialsortperm(values, 1:howmany; by=strategy.by, rev=strategy.rev)
return partialsortperm(values, 1:howmany; strategy.by, strategy.rev)
end
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepSorted)
function findtruncated_svd(values::AbstractVector, strategy::TruncationByOrder)
strategy.by === abs || return findtruncated(values, strategy)
howmany = min(strategy.howmany, length(values))
return 1:howmany
return strategy.rev ? (1:howmany) : ((length(values) - howmany + 1):length(values))
end
Comment thread
lkdvos marked this conversation as resolved.

# TODO: consider if worth using that values are sorted when filter is `<` or `>`.
function findtruncated(values::AbstractVector, strategy::TruncationKeepFiltered)
ind = findall(strategy.filter, values)
return ind
function findtruncated(values::AbstractVector, strategy::TruncationByFilter)
return strategy.filter.(values)::AbstractVector{Bool}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really worth returning a bool vector here, instead of just indices? Is findall(strategy.filter, values) so much worse? Or is that for GPU compatibility?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, is the type annotation really useful? Is this a JET thing again?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really don't have any opinion here. @mtfishman, what do you think? I am fine with either way.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine either way, I think the main considerations were:

  1. It expresses the output better, since it encodes that the output indices don't repeat and don't have a specified ordering (i.e. they preserve the ordering).
  2. We've found it helpful to use bool vectors for slicing block sparse arrays since they make it easier to preserve the block structure, but since some strategies output bool vectors and some don't, we have to convert them all anyway.

end

function findtruncated(values::AbstractVector, strategy::TruncationKeepBelow)
function findtruncated(values::AbstractVector, strategy::TruncationByValue)
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
return findall(≤(atol) ∘ strategy.by, values)
filter = (strategy.keep_below ? ≤(atol) : ≥(atol)) ∘ strategy.by
return findtruncated(values, truncfilter(filter))
end
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepBelow)
function findtruncated_svd(values::AbstractVector, strategy::TruncationByValue)
strategy.by === abs || return findtruncated(values, strategy)
Comment thread
Jutho marked this conversation as resolved.
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
i = searchsortedfirst(values, atol; by=strategy.by, rev=true)
return i:length(values)
end

function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove)
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
return findall(≥(atol) ∘ strategy.by, values)
end
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepAbove)
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
i = searchsortedlast(values, atol; by=strategy.by, rev=true)
return 1:i
end

function findtruncated(values::AbstractVector, strategy::TruncationIntersection)
inds = map(Base.Fix1(findtruncated, values), strategy.components)
return intersect(inds...)
end
function findtruncated_sorted(values::AbstractVector, strategy::TruncationIntersection)
inds = map(Base.Fix1(findtruncated_sorted, values), strategy.components)
return intersect(inds...)
if strategy.keep_below
i = searchsortedfirst(values, atol; by=abs, rev=true)
return i:length(values)
else
i = searchsortedlast(values, atol; by=abs, rev=true)
return 1:i
end
end

function findtruncated(values::AbstractVector, strategy::TruncationError)
function findtruncated(values::AbstractVector, strategy::TruncationByError)
I = sortperm(values; by=abs, rev=true)
I′ = _truncerr_impl(values, I, strategy)
I′ = _truncerr_impl(values, I; strategy.atol, strategy.rtol, strategy.p)
return I[I′]
end
function findtruncated_sorted(values::AbstractVector, strategy::TruncationError)
function findtruncated_svd(values::AbstractVector, strategy::TruncationByError)
I = eachindex(values)
I′ = _truncerr_impl(values, I, strategy)
I′ = _truncerr_impl(values, I; strategy.atol, strategy.rtol, strategy.p)
return I[I′]
end
function _truncerr_impl(values::AbstractVector, I, strategy::TruncationError)
Nᵖ = sum(Base.Fix2(^, strategy.p) ∘ abs, values)
ϵᵖ = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * Nᵖ)
function _truncerr_impl(values::AbstractVector, I; atol::Real=0, rtol::Real=0, p::Real=2)
by = Base.Fix2(^, p) ∘ abs
Nᵖ = sum(by, values)
ϵᵖ = max(atol^p, rtol^p * Nᵖ)

# fast path to avoid checking all values
ϵᵖ ≥ Nᵖ && return Base.OneTo(0)

truncerrᵖ = zero(real(eltype(values)))
rank = length(values)
for i in reverse(I)
truncerrᵖ += abs(values[i])^strategy.p
if truncerrᵖ ≥ ϵᵖ
break
else
rank -= 1
end
truncerrᵖ += by(values[i])
truncerrᵖ ≥ ϵᵖ && break
rank -= 1
end

return Base.OneTo(rank)
end

# Generic fallback
function findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)
return findtruncated(values, strategy)
function findtruncated(values::AbstractVector, strategy::TruncationIntersection)
return mapreduce(Base.Fix1(findtruncated, values), _ind_intersect, strategy.components;
init=trues(length(values)))
end
function findtruncated_svd(values::AbstractVector, strategy::TruncationIntersection)
return mapreduce(Base.Fix1(findtruncated_svd, values), _ind_intersect,
strategy.components; init=trues(length(values)))
end

# when one of the ind selections is a bitvector, have to handle differently
function _ind_intersect(A::AbstractVector{Bool}, B::AbstractVector)
result = falses(length(A))
result[B] .= @view A[B]
return result
end
_ind_intersect(A::AbstractVector, B::AbstractVector{Bool}) = _ind_intersect(B, A)
_ind_intersect(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .& B
_ind_intersect(A, B) = intersect(A, B)
Loading
Loading