Skip to content

Commit b710816

Browse files
authored
Merge branch 'main' into mf/orthnull_customization
2 parents 4448b07 + c46119e commit b710816

15 files changed

Lines changed: 207 additions & 122 deletions

File tree

docs/src/dev_interface.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
```@meta
2+
CurrentModule = MatrixAlgebraKit
3+
CollapsedDocStrings = true
4+
```
5+
6+
# Developer Interface
7+
8+
MatrixAlgebraKit.jl provides a developer interface for specifying custom algorithm backends and selecting default algorithms.
9+
10+
```@docs; canonical=false
11+
MatrixAlgebraKit.default_algorithm
12+
MatrixAlgebraKit.select_algorithm
13+
```

src/MatrixAlgebraKit.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
3030
LAPACK_DivideAndConquer, LAPACK_Jacobi
3131
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered
3232

33+
VERSION >= v"1.11.0-DEV.469" &&
34+
eval(Expr(:public, :default_algorithm, :select_algorithm))
35+
3336
include("common/defaults.jl")
3437
include("common/initialization.jl")
3538
include("common/pullbacks.jl")

src/algorithms.jl

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,64 @@ function _show_alg(io::IO, alg::Algorithm)
5454
end
5555

5656
@doc """
57-
select_algorithm(f, A; kwargs...)
57+
MatrixAlgebraKit.select_algorithm(f, A, alg::AbstractAlgorithm)
58+
MatrixAlgebraKit.select_algorithm(f, A, alg::Symbol; kwargs...)
59+
MatrixAlgebraKit.select_algorithm(f, A, alg::Type; kwargs...)
60+
MatrixAlgebraKit.select_algorithm(f, A; kwargs...)
61+
MatrixAlgebraKit.select_algorithm(f, A, (; kwargs...))
5862
59-
Given some keyword arguments and an input `A`, decide on an algrithm to use for
60-
implementing the function `f` on inputs of type `A`.
63+
Decide on an algorithm to use for implementing the function `f` on inputs of type `A`.
64+
65+
If `alg` is an `AbstractAlgorithm` instance, it will be returned as-is.
66+
67+
If `alg` is a `Symbol` or a `Type` of algorithm, the return value is obtained
68+
by calling the corresponding algorithm constructor;
69+
keyword arguments in `kwargs` are passed along to this constructor.
70+
71+
If `alg` is not specified (or `nothing`), an algorithm will be selected
72+
automatically with [`MatrixAlgebraKit.default_algorithm`](@ref) and
73+
the keyword arguments in `kwargs` will be passed to the algorithm constructor.
74+
Finally, the same behavior is obtained when the keyword arguments are
75+
passed as the third positional argument in the form of a `NamedTuple`.
6176
"""
6277
function select_algorithm end
6378

64-
function _select_algorithm(f, A::AbstractMatrix, alg::AbstractAlgorithm)
79+
function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg}
80+
return _select_algorithm(f, A, alg; kwargs...)
81+
end
82+
83+
function _select_algorithm(f::F, A, alg::Nothing; kwargs...) where {F}
84+
return default_algorithm(f, A; kwargs...)
85+
end
86+
function _select_algorithm(f::F, A, alg::Symbol; kwargs...) where {F}
87+
return Algorithm{alg}(; kwargs...)
88+
end
89+
function _select_algorithm(f::F, A, ::Type{Alg}; kwargs...) where {F,Alg}
90+
return Alg(; kwargs...)
91+
end
92+
function _select_algorithm(f::F, A, alg::NamedTuple; kwargs...) where {F}
93+
isempty(kwargs) ||
94+
throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified."))
95+
return default_algorithm(f, A; alg...)
96+
end
97+
function _select_algorithm(f::F, A, alg::AbstractAlgorithm; kwargs...) where {F}
98+
isempty(kwargs) ||
99+
throw(ArgumentError("Additional keyword arguments are not allowed when an algorithm is specified."))
65100
return alg
66101
end
67-
function _select_algorithm(f, A::AbstractMatrix, alg::NamedTuple)
68-
return select_algorithm(f, A; alg...)
102+
function _select_algorithm(f::F, A, alg; kwargs...) where {F}
103+
return throw(ArgumentError("Unknown alg $alg"))
69104
end
70105

106+
@doc """
107+
MatrixAlgebraKit.default_algorithm(f, A; kwargs...)
108+
109+
Select the default algorithm for a given factorization function `f` and input `A`.
110+
In general, this is called by [`select_algorithm`](@ref) if no algorithm is specified
111+
explicitly.
112+
"""
113+
function default_algorithm end
114+
71115
@doc """
72116
copy_input(f, A)
73117
@@ -138,9 +182,11 @@ macro functiondef(f)
138182
$f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg)
139183

140184
# fill in arguments
141-
$f!(A; kwargs...) = $f!(A, select_algorithm($f!, A; kwargs...))
142-
function $f!(A, out; kwargs...)
143-
return $f!(A, out, select_algorithm($f!, A; kwargs...))
185+
function $f!(A; alg=nothing, kwargs...)
186+
return $f!(A, select_algorithm($f!, A, alg; kwargs...))
187+
end
188+
function $f!(A, out; alg=nothing, kwargs...)
189+
return $f!(A, out, select_algorithm($f!, A, alg; kwargs...))
144190
end
145191
function $f!(A, alg::AbstractAlgorithm)
146192
return $f!(A, initialize_output($f!, A, alg), alg)

src/implementations/truncation.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,18 @@ Trivial truncation strategy that keeps all values, mostly for testing purposes.
3232
"""
3333
struct NoTruncation <: TruncationStrategy end
3434

35+
function select_truncation(trunc)
36+
if isnothing(trunc)
37+
return NoTruncation()
38+
elseif trunc isa NamedTuple
39+
return TruncationStrategy(; trunc...)
40+
elseif trunc isa TruncationStrategy
41+
return trunc
42+
else
43+
return throw(ArgumentError("Unknown truncation strategy: $trunc"))
44+
end
45+
end
46+
3547
# TODO: how do we deal with sorting/filters that treat zeros differently
3648
# since these are implicitly discarded by selecting compact/full
3749

src/interface/eig.jl

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -90,32 +90,21 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc).
9090
for f in (:eig_full, :eig_vals)
9191
f! = Symbol(f, :!)
9292
@eval begin
93-
function select_algorithm(::typeof($f), A; kwargs...)
94-
return select_algorithm($f!, A; kwargs...)
93+
function default_algorithm(::typeof($f), A; kwargs...)
94+
return default_algorithm($f!, A; kwargs...)
9595
end
96-
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
97-
if alg isa AbstractAlgorithm
98-
return alg
99-
elseif alg isa Symbol
100-
return Algorithm{alg}(; kwargs...)
101-
else
102-
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
103-
return default_eig_algorithm(A; kwargs...)
104-
end
96+
function default_algorithm(::typeof($f!), A; kwargs...)
97+
return default_eig_algorithm(A; kwargs...)
10598
end
10699
end
107100
end
108101

109-
function select_algorithm(::typeof(eig_trunc), A; kwargs...)
110-
return select_algorithm(eig_trunc!, A; kwargs...)
102+
function select_algorithm(::typeof(eig_trunc), A, alg; kwargs...)
103+
return select_algorithm(eig_trunc!, A, alg; kwargs...)
111104
end
112-
function select_algorithm(::typeof(eig_trunc!), A; alg=nothing, trunc=nothing, kwargs...)
113-
alg_eig = select_algorithm(eig_full!, A; alg, kwargs...)
114-
alg_trunc = trunc isa TruncationStrategy ? trunc :
115-
trunc isa NamedTuple ? TruncationStrategy(; trunc...) :
116-
isnothing(trunc) ? NoTruncation() :
117-
throw(ArgumentError("Unknown truncation strategy: $trunc"))
118-
return TruncatedAlgorithm(alg_eig, alg_trunc)
105+
function select_algorithm(::typeof(eig_trunc!), A, alg; trunc=nothing, kwargs...)
106+
alg_eig = select_algorithm(eig_full!, A, alg; kwargs...)
107+
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
119108
end
120109

121110
# Default to LAPACK

src/interface/eigh.jl

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -89,32 +89,21 @@ See also [`eigh_full(!)`](@ref eigh_full) and [`eigh_trunc(!)`](@ref eigh_trunc)
8989
for f in (:eigh_full, :eigh_vals)
9090
f! = Symbol(f, :!)
9191
@eval begin
92-
function select_algorithm(::typeof($f), A; kwargs...)
93-
return select_algorithm($f!, A; kwargs...)
92+
function default_algorithm(::typeof($f), A; kwargs...)
93+
return default_algorithm($f!, A; kwargs...)
9494
end
95-
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
96-
if alg isa AbstractAlgorithm
97-
return alg
98-
elseif alg isa Symbol
99-
return Algorithm{alg}(; kwargs...)
100-
else
101-
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
102-
return default_eigh_algorithm(A; kwargs...)
103-
end
95+
function default_algorithm(::typeof($f!), A; kwargs...)
96+
return default_eigh_algorithm(A; kwargs...)
10497
end
10598
end
10699
end
107100

108-
function select_algorithm(::typeof(eigh_trunc), A; kwargs...)
109-
return select_algorithm(eigh_trunc!, A; kwargs...)
101+
function select_algorithm(::typeof(eigh_trunc), A, alg; kwargs...)
102+
return select_algorithm(eigh_trunc!, A, alg; kwargs...)
110103
end
111-
function select_algorithm(::typeof(eigh_trunc!), A; alg=nothing, trunc=nothing, kwargs...)
112-
alg_eigh = select_algorithm(eigh_full!, A; alg, kwargs...)
113-
alg_trunc = trunc isa TruncationStrategy ? trunc :
114-
trunc isa NamedTuple ? TruncationStrategy(; trunc...) :
115-
isnothing(trunc) ? NoTruncation() :
116-
throw(ArgumentError("Unknown truncation strategy: $trunc"))
117-
return TruncatedAlgorithm(alg_eigh, alg_trunc)
104+
function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc=nothing, kwargs...)
105+
alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...)
106+
return TruncatedAlgorithm(alg_eigh, select_truncation(trunc))
118107
end
119108

120109
# Default to LAPACK

src/interface/lq.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,11 @@ See also [`qr_full(!)`](@ref lq_full) and [`qr_compact(!)`](@ref lq_compact).
7171
for f in (:lq_full, :lq_compact, :lq_null)
7272
f! = Symbol(f, :!)
7373
@eval begin
74-
function select_algorithm(::typeof($f), A; kwargs...)
75-
return select_algorithm($f!, A; kwargs...)
74+
function default_algorithm(::typeof($f), A; kwargs...)
75+
return default_algorithm($f!, A; kwargs...)
7676
end
77-
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
78-
if alg isa AbstractAlgorithm
79-
return alg
80-
elseif alg isa Symbol
81-
return Algorithm{alg}(; kwargs...)
82-
else
83-
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
84-
return default_lq_algorithm(A; kwargs...)
85-
end
77+
function default_algorithm(::typeof($f!), A; kwargs...)
78+
return default_lq_algorithm(A; kwargs...)
8679
end
8780
end
8881
end

src/interface/polar.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,11 @@ end
6363
for f in (:left_polar, :right_polar)
6464
f! = Symbol(f, :!)
6565
@eval begin
66-
function select_algorithm(::typeof($f), A; kwargs...)
67-
return select_algorithm($f!, A; kwargs...)
66+
function default_algorithm(::typeof($f), A; kwargs...)
67+
return default_algorithm($f!, A; kwargs...)
6868
end
69-
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
70-
if alg isa AbstractAlgorithm
71-
return alg
72-
elseif alg isa Symbol
73-
return Algorithm{alg}(; kwargs...)
74-
else
75-
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
76-
return default_polar_algorithm(A; kwargs...)
77-
end
69+
function default_algorithm(::typeof($f!), A; kwargs...)
70+
return default_polar_algorithm(A; kwargs...)
7871
end
7972
end
8073
end

src/interface/qr.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,11 @@ See also [`lq_full(!)`](@ref lq_full) and [`lq_compact(!)`](@ref lq_compact).
7171
for f in (:qr_full, :qr_compact, :qr_null)
7272
f! = Symbol(f, :!)
7373
@eval begin
74-
function select_algorithm(::typeof($f), A; kwargs...)
75-
return select_algorithm($f!, A; kwargs...)
74+
function default_algorithm(::typeof($f), A; kwargs...)
75+
return default_algorithm($f!, A; kwargs...)
7676
end
77-
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
78-
if alg isa AbstractAlgorithm
79-
return alg
80-
elseif alg isa Symbol
81-
return Algorithm{alg}(; kwargs...)
82-
else
83-
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
84-
return default_qr_algorithm(A; kwargs...)
85-
end
77+
function default_algorithm(::typeof($f!), A; kwargs...)
78+
return default_qr_algorithm(A; kwargs...)
8679
end
8780
end
8881
end

src/interface/schur.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,11 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc).
5454
for f in (:schur_full, :schur_vals)
5555
f! = Symbol(f, :!)
5656
@eval begin
57-
function select_algorithm(::typeof($f), A; kwargs...)
58-
return select_algorithm($f!, A; kwargs...)
57+
function default_algorithm(::typeof($f), A; kwargs...)
58+
return default_algorithm($f!, A; kwargs...)
5959
end
60-
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
61-
if alg isa AbstractAlgorithm
62-
return alg
63-
elseif alg isa Symbol
64-
return Algorithm{alg}(; kwargs...)
65-
else
66-
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
67-
return default_eig_algorithm(A; kwargs...)
68-
end
60+
function default_algorithm(::typeof($f!), A; kwargs...)
61+
return default_eig_algorithm(A; kwargs...)
6962
end
7063
end
7164
end

0 commit comments

Comments
 (0)