Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MatrixAlgebraKit"
uuid = "6c742aac-3347-4629-af66-fc926824e5e4"
authors = ["Jutho <jutho.haegeman@ugent.be> and contributors"]
version = "0.2.0"
version = "0.2.1"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down Expand Up @@ -36,4 +36,5 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote"]
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore",
"ChainRulesTestUtils", "StableRNGs", "Zygote"]
109 changes: 60 additions & 49 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
MatrixAlgebraKit.select_algorithm(f, A, (; kwargs...))

Decide on an algorithm to use for implementing the function `f` on inputs of type `A`.
This can be obtained both for values `A` or types `A`.

If `alg` is an `AbstractAlgorithm` instance, it will be returned as-is.

Expand All @@ -73,62 +74,62 @@
the keyword arguments in `kwargs` will be passed to the algorithm constructor.
Finally, the same behavior is obtained when the keyword arguments are
passed as the third positional argument in the form of a `NamedTuple`.
"""
function select_algorithm end
""" select_algorithm

function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg}
return _select_algorithm(f, A, alg; kwargs...)
return select_algorithm(f, typeof(A), alg; kwargs...)
end
function select_algorithm(f::F, ::Type{A}, alg::Alg=nothing; kwargs...) where {F,A,Alg}
if isnothing(alg)
return default_algorithm(f, A; kwargs...)
elseif alg isa Symbol
return Algorithm{alg}(; kwargs...)
elseif alg isa Type
return alg(; kwargs...)
elseif alg isa NamedTuple
isempty(kwargs) ||
throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified."))
return default_algorithm(f, A; alg...)
elseif alg isa AbstractAlgorithm
isempty(kwargs) ||
throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified."))
return alg
end

function _select_algorithm(f::F, A, alg::Nothing; kwargs...) where {F}
return default_algorithm(f, A; kwargs...)
end
function _select_algorithm(f::F, A, alg::Symbol; kwargs...) where {F}
return Algorithm{alg}(; kwargs...)
end
function _select_algorithm(f::F, A, ::Type{Alg}; kwargs...) where {F,Alg}
return Alg(; kwargs...)
end
function _select_algorithm(f::F, A, alg::NamedTuple; kwargs...) where {F}
isempty(kwargs) ||
throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified."))
return default_algorithm(f, A; alg...)
end
function _select_algorithm(f::F, A, alg::AbstractAlgorithm; kwargs...) where {F}
isempty(kwargs) ||
throw(ArgumentError("Additional keyword arguments are not allowed when an algorithm is specified."))
return alg
end
function _select_algorithm(f::F, A, alg; kwargs...) where {F}
return throw(ArgumentError("Unknown alg $alg"))
throw(ArgumentError("Unknown alg $alg"))

Check warning on line 99 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L99

Added line #L99 was not covered by tests
end


@doc """
MatrixAlgebraKit.default_algorithm(f, A; kwargs...)
MatrixAlgebraKit.default_algorithm(f, ::Type{TA}; kwargs...) where {TA}

Select the default algorithm for a given factorization function `f` and input `A`.
In general, this is called by [`select_algorithm`](@ref) if no algorithm is specified
explicitly.
"""
function default_algorithm end
New types should prefer to register their default algorithms in the type domain.
""" default_algorithm
default_algorithm(f::F, A; kwargs...) where {F} = default_algorithm(f, typeof(A); kwargs...)
# avoid infinite recursion:
function default_algorithm(f::F, ::Type{T}; kwargs...) where {F,T}
throw(MethodError(default_algorithm, (f, T)))

Check warning on line 115 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L114-L115

Added lines #L114 - L115 were not covered by tests
end

@doc """
copy_input(f, A)

Preprocess the input `A` for a given function, such that it may be handled correctly later.
This may include a copy whenever the implementation would destroy the original matrix,
or a change of element type to something that is supported.
"""
function copy_input end
""" copy_input

@doc """
initialize_output(f, A, alg)

Whenever possible, allocate the destination for applying a given algorithm in-place.
If this is not possible, for example when the output size is not known a priori or immutable,
this function may return `nothing`.
"""
function initialize_output end
""" initialize_output

# Utility macros
# --------------
Expand Down Expand Up @@ -176,25 +177,35 @@
f isa Symbol || throw(ArgumentError("Unsupported usage of `@functiondef`"))
f! = Symbol(f, :!)

return esc(quote
# out of place to inplace
$f(A; kwargs...) = $f!(copy_input($f, A); kwargs...)
$f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg)

# fill in arguments
function $f!(A; alg=nothing, kwargs...)
return $f!(A, select_algorithm($f!, A, alg; kwargs...))
end
function $f!(A, out; alg=nothing, kwargs...)
return $f!(A, out, select_algorithm($f!, A, alg; kwargs...))
end
function $f!(A, alg::AbstractAlgorithm)
return $f!(A, initialize_output($f!, A, alg), alg)
end

# copy documentation to both functions
Core.@__doc__ $f, $f!
end)
ex = quote
# out of place to inplace
$f(A; kwargs...) = $f!(copy_input($f, A); kwargs...)
$f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg)

# fill in arguments
function $f!(A; alg=nothing, kwargs...)
return $f!(A, select_algorithm($f!, A, alg; kwargs...))
end
function $f!(A, out; alg=nothing, kwargs...)
return $f!(A, out, select_algorithm($f!, A, alg; kwargs...))
end
function $f!(A, alg::AbstractAlgorithm)
return $f!(A, initialize_output($f!, A, alg), alg)
end

# define fallbacks for algorithm selection
@inline function select_algorithm(::typeof($f), ::Type{A}, alg::Alg;
kwargs...) where {Alg,A}
return select_algorithm($f!, A, alg; kwargs...)
end
@inline function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
return default_algorithm($f!, A; kwargs...)
end

# copy documentation to both functions
Core.@__doc__ $f, $f!
end
return esc(ex)
end

"""
Expand Down
29 changes: 10 additions & 19 deletions src/interface/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,27 +87,18 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc).

# Algorithm selection
# -------------------
for f in (:eig_full, :eig_vals)
f! = Symbol(f, :!)
@eval begin
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_eig_algorithm(A; kwargs...)
end
end
# Default to LAPACK for `YALAPACK.BlasMat`
function default_algorithm(::typeof(eig_full!), ::Type{A};
kwargs...) where {A<:YALAPACK.BlasMat}
return LAPACK_Expert(; kwargs...)
end

function select_algorithm(::typeof(eig_trunc), A, alg; kwargs...)
return select_algorithm(eig_trunc!, A, alg; kwargs...)
function default_algorithm(::typeof(eig_vals!), ::Type{A};
kwargs...) where {A<:YALAPACK.BlasMat}
return LAPACK_Expert(; kwargs...)
end
function select_algorithm(::typeof(eig_trunc!), A, alg; trunc=nothing, kwargs...)

function select_algorithm(::typeof(eig_trunc!), ::Type{A}, alg; trunc=nothing,
kwargs...) where {A<:YALAPACK.BlasMat}
alg_eig = select_algorithm(eig_full!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
end

# Default to LAPACK
function default_eig_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...)
return LAPACK_Expert(; kwargs...)
end
29 changes: 10 additions & 19 deletions src/interface/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,27 +86,18 @@ See also [`eigh_full(!)`](@ref eigh_full) and [`eigh_trunc(!)`](@ref eigh_trunc)

# Algorithm selection
# -------------------
for f in (:eigh_full, :eigh_vals)
f! = Symbol(f, :!)
@eval begin
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_eigh_algorithm(A; kwargs...)
end
end
# Default to LAPACK for `YALAPACK.BlasMat`
function default_algorithm(::typeof(eigh_full!), ::Type{A};
kwargs...) where {A<:YALAPACK.BlasMat}
return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...)
end

function select_algorithm(::typeof(eigh_trunc), A, alg; kwargs...)
return select_algorithm(eigh_trunc!, A, alg; kwargs...)
function default_algorithm(::typeof(eigh_vals!), ::Type{A};
kwargs...) where {A<:YALAPACK.BlasMat}
return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...)
end
function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc=nothing, kwargs...)

function select_algorithm(::typeof(eigh_trunc!), ::Type{A}, alg; trunc=nothing,
kwargs...) where {A<:YALAPACK.BlasMat}
alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_eigh, select_truncation(trunc))
end

# Default to LAPACK
function default_eigh_algorithm(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...)
end
16 changes: 4 additions & 12 deletions src/interface/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,11 @@ See also [`qr_full(!)`](@ref lq_full) and [`qr_compact(!)`](@ref lq_compact).

# Algorithm selection
# -------------------
for f in (:lq_full, :lq_compact, :lq_null)
f! = Symbol(f, :!)
for f in (:lq_full!, :lq_compact!, :lq_null!)
@eval begin
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_lq_algorithm(A; kwargs...)
function default_algorithm(::typeof($f), ::Type{A};
kwargs...) where {A<:YALAPACK.BlasMat}
return LAPACK_HouseholderLQ(; kwargs...)
end
end
end

# Default to LAPACK
function default_lq_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...)
return LAPACK_HouseholderLQ(; kwargs...)
end
20 changes: 6 additions & 14 deletions src/interface/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,11 @@ end

# Algorithm selection
# -------------------
for f in (:left_polar, :right_polar)
f! = Symbol(f, :!)
@eval begin
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
function default_algorithm(::typeof(left_polar!), ::Type{A};
kwargs...) where {A<:YALAPACK.BlasMat}
return PolarViaSVD(default_algorithm(svd_compact!, A; kwargs...))
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_polar_algorithm(A; kwargs...)
end
end
end

# Default to LAPACK SDD for `StridedMatrix{<:BlasFloat}`
function default_polar_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...)
return PolarViaSVD(default_svd_algorithm(A; kwargs...))
function default_algorithm(::typeof(right_polar!), ::Type{A};
kwargs...) where {A<:YALAPACK.BlasMat}
return PolarViaSVD(default_algorithm(svd_compact!, A; kwargs...))
end
16 changes: 4 additions & 12 deletions src/interface/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,11 @@ See also [`lq_full(!)`](@ref lq_full) and [`lq_compact(!)`](@ref lq_compact).

# Algorithm selection
# -------------------
for f in (:qr_full, :qr_compact, :qr_null)
f! = Symbol(f, :!)
for f in (:qr_full!, :qr_compact!, :qr_null!)
@eval begin
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_qr_algorithm(A; kwargs...)
function default_algorithm(::typeof($f), ::Type{A};
kwargs...) where {A<:YALAPACK.BlasMat}
return LAPACK_HouseholderQR(; kwargs...)
end
end
end

# Default to LAPACK
function default_qr_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...)
return LAPACK_HouseholderQR(; kwargs...)
end
17 changes: 7 additions & 10 deletions src/interface/schur.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,11 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc).

# Algorithm selection
# -------------------
for f in (:schur_full, :schur_vals)
f! = Symbol(f, :!)
@eval begin
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_eig_algorithm(A; kwargs...)
end
end
function default_algorithm(::typeof(schur_full!), ::Type{A};
kwargs...) where {A<:YALAPACK.BlasMat}
return default_algorithm(eig_full!, A; kwargs...)
end
function default_algorithm(::typeof(schur_vals!), ::Type{A};
kwargs...) where {A<:YALAPACK.BlasMat}
return default_algorithm(eig_vals!, A; kwargs...)
end
25 changes: 7 additions & 18 deletions src/interface/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,27 +90,16 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_compact(!)`](@ref svd_compact) an

# Algorithm selection
# -------------------
for f in (:svd_full, :svd_compact, :svd_vals)
f! = Symbol(f, :!)
@eval begin
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_svd_algorithm(A; kwargs...)
end
for f in (:svd_full!, :svd_compact!, :svd_vals!)
# Default to LAPACK SDD for `YALAPACK.BlasMat`
@eval function default_algorithm(::typeof($f), ::Type{A};
kwargs...) where {A<:YALAPACK.BlasMat}
return LAPACK_DivideAndConquer(; kwargs...)
end
end

function select_algorithm(::typeof(svd_trunc), A, alg; kwargs...)
return select_algorithm(svd_trunc!, A, alg; kwargs...)
end
function select_algorithm(::typeof(svd_trunc!), A, alg; trunc=nothing, kwargs...)
function select_algorithm(::typeof(svd_trunc!), ::Type{A}, alg; trunc=nothing,
kwargs...) where {A<:YALAPACK.BlasMat}
alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_svd, select_truncation(trunc))
end

# Default to LAPACK SDD for `StridedMatrix{<:BlasFloat}`
function default_svd_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...)
return LAPACK_DivideAndConquer(; kwargs...)
end
3 changes: 3 additions & 0 deletions src/yalapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt, Char, LAPACK,
using LinearAlgebra.BLAS: @blasfunc, libblastrampoline
using LinearAlgebra.LAPACK: chkfinite, chktrans, chkside, chkuplofinite, chklapackerror

# type alias for matrices that are definitely supported by YALAPACK
const BlasMat{T<:BlasFloat} = StridedMatrix{T}

# LU factorisation
for (getrf, getrs, elty) in ((:dgetrf_, :dgetrs_, :Float64),
(:sgetrf_, :sgetrs_, :Float32),
Expand Down
Loading
Loading