Skip to content

Commit 3b16fc0

Browse files
committed
Rewrite select and default_algorithm in the type domain
1 parent 7d77d39 commit 3b16fc0

1 file changed

Lines changed: 35 additions & 25 deletions

File tree

src/algorithms.jl

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ end
6161
MatrixAlgebraKit.select_algorithm(f, A, (; kwargs...))
6262
6363
Decide on an algorithm to use for implementing the function `f` on inputs of type `A`.
64+
This can be obtained both for values `A` or types `A`.
6465
6566
If `alg` is an `AbstractAlgorithm` instance, it will be returned as-is.
6667
@@ -76,39 +77,47 @@ passed as the third positional argument in the form of a `NamedTuple`.
7677
""" select_algorithm
7778

7879
function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg}
80+
return select_algorithm(f, typeof(A), alg; kwargs...)
81+
end
82+
function select_algorithm(f::F, ::Type{A}, alg::Alg=nothing; kwargs...) where {F,A,Alg}
7983
return _select_algorithm(f, A, alg; kwargs...)
8084
end
8185

82-
function _select_algorithm(f::F, A, alg::Nothing; kwargs...) where {F}
86+
function _select_algorithm(f::F, ::Type{A}, alg::Nothing; kwargs...) where {F,A}
8387
return default_algorithm(f, A; kwargs...)
8488
end
85-
function _select_algorithm(f::F, A, alg::Symbol; kwargs...) where {F}
89+
function _select_algorithm(f::F, ::Type{A}, alg::Symbol; kwargs...) where {F,A}
8690
return Algorithm{alg}(; kwargs...)
8791
end
88-
function _select_algorithm(f::F, A, ::Type{Alg}; kwargs...) where {F,Alg}
92+
function _select_algorithm(f::F, ::Type{A}, ::Type{Alg}; kwargs...) where {F,A,Alg}
8993
return Alg(; kwargs...)
9094
end
91-
function _select_algorithm(f::F, A, alg::NamedTuple; kwargs...) where {F}
95+
function _select_algorithm(f::F, ::Type{A}, alg::NamedTuple; kwargs...) where {F,A}
9296
isempty(kwargs) ||
9397
throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified."))
9498
return default_algorithm(f, A; alg...)
9599
end
96-
function _select_algorithm(f::F, A, alg::AbstractAlgorithm; kwargs...) where {F}
100+
function _select_algorithm(f::F, ::Type{A}, alg::AbstractAlgorithm; kwargs...) where {F,A}
97101
isempty(kwargs) ||
98102
throw(ArgumentError("Additional keyword arguments are not allowed when an algorithm is specified."))
99103
return alg
100104
end
101-
function _select_algorithm(f::F, A, alg; kwargs...) where {F}
105+
function _select_algorithm(f::F, ::Type{A}, alg; kwargs...) where {F,A}
102106
return throw(ArgumentError("Unknown alg $alg"))
103107
end
104108

105109
@doc """
106110
MatrixAlgebraKit.default_algorithm(f, A; kwargs...)
111+
MatrixAlgebraKit.default_algorithm(f, ::Type{TA}; kwargs...) where {TA}
107112
108113
Select the default algorithm for a given factorization function `f` and input `A`.
109114
In general, this is called by [`select_algorithm`](@ref) if no algorithm is specified
110115
explicitly.
116+
New types should prefer to register their default algorithms in the type domain.
111117
""" default_algorithm
118+
default_algorithm(f::F, A; kwargs...) where {F} = default_algorithm(f, typeof(A); kwargs...)
119+
# avoid infinite recursion:
120+
default_algorithm(f, T::Type; kwargs...) = throw(MethodError(default_algorithm, (f, T)))
112121

113122
@doc """
114123
copy_input(f, A)
@@ -172,25 +181,26 @@ macro functiondef(f)
172181
f isa Symbol || throw(ArgumentError("Unsupported usage of `@functiondef`"))
173182
f! = Symbol(f, :!)
174183

175-
return esc(quote
176-
# out of place to inplace
177-
$f(A; kwargs...) = $f!(copy_input($f, A); kwargs...)
178-
$f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg)
179-
180-
# fill in arguments
181-
function $f!(A; alg=nothing, kwargs...)
182-
return $f!(A, select_algorithm($f!, A, alg; kwargs...))
183-
end
184-
function $f!(A, out; alg=nothing, kwargs...)
185-
return $f!(A, out, select_algorithm($f!, A, alg; kwargs...))
186-
end
187-
function $f!(A, alg::AbstractAlgorithm)
188-
return $f!(A, initialize_output($f!, A, alg), alg)
189-
end
190-
191-
# copy documentation to both functions
192-
Core.@__doc__ $f, $f!
193-
end)
184+
ex = quote
185+
# out of place to inplace
186+
$f(A; kwargs...) = $f!(copy_input($f, A); kwargs...)
187+
$f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg)
188+
189+
# fill in arguments
190+
function $f!(A; alg=nothing, kwargs...)
191+
return $f!(A, select_algorithm($f!, A, alg; kwargs...))
192+
end
193+
function $f!(A, out; alg=nothing, kwargs...)
194+
return $f!(A, out, select_algorithm($f!, A, alg; kwargs...))
195+
end
196+
function $f!(A, alg::AbstractAlgorithm)
197+
return $f!(A, initialize_output($f!, A, alg), alg)
198+
end
199+
200+
# copy documentation to both functions
201+
Core.@__doc__ $f, $f!
202+
end
203+
return esc(ex)
194204
end
195205

196206
"""

0 commit comments

Comments
 (0)