Skip to content

Commit 7dff266

Browse files
committed
Improve type stability
1 parent 3073fa2 commit 7dff266

7 files changed

Lines changed: 88 additions & 81 deletions

File tree

src/algorithms.jl

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -54,45 +54,48 @@ function _show_alg(io::IO, alg::Algorithm)
5454
end
5555

5656
@doc """
57-
MatrixAlgebraKit.select_algorithm(f, A; kwargs...)
57+
MatrixAlgebraKit.select_algorithm(f, A, alg=nothing; kwargs...)
5858
59-
Given some keyword arguments and an input `A`, decide on an algorithm to use for
60-
implementing the function `f` on inputs of type `A`.
59+
Decide on an algorithm to use for implementing the function `f` on inputs of type `A`.
6160
62-
In general, if an algorithm is specified explicitly through the `alg` keyword argument
63-
(either as an algorithm type, an algorithm name as a Symbol, or as an algorithm object),
64-
that algorithm will be used instead of selecting it automatically. However, that
65-
behavior may be modified for factorization functions and/or matrix types.
61+
If `alg` is `nothing` (the default value), an algorithm will be selected automatically
62+
with [`MatrixAlgebra.default_algorithm`](@ref) and the keyword arguments will be passed
63+
to the algorithm constructor.
6664
67-
When the `alg` keyword argument is not provided, a default algorithm specified by
68-
[`default_algorithm`](@ref) will be used.
65+
If `alg` is a `NamedTuple`, an algorithm will be selected automatically
66+
with [`default_algorithm`](@ref) and `alg` will be passed to the algorithm
67+
as keyword arguments. In that case, keyword arguments can't be passed
68+
to `MatrixAlgebraKit.select_algorithm`
69+
70+
If `alg` is an `AbstractAlgorithm`, it will be returned as-is. In that case, keyword arguments
71+
can't be passed to `MatrixAlgebraKit.select_algorithm`.
6972
"""
7073
function select_algorithm end
7174

72-
function select_algorithm(f, A; alg=nothing, kwargs...)
75+
Base.@constprop :aggressive function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg}
7376
return _select_algorithm(f, A, alg; kwargs...)
7477
end
7578

76-
function _select_algorithm(f, A, alg::Nothing; kwargs...)
79+
function _select_algorithm(f::F, A, alg::Nothing; kwargs...) where {F}
7780
return default_algorithm(f, A; kwargs...)
7881
end
79-
function _select_algorithm(f, A, alg::AbstractAlgorithm; kwargs...)
80-
isempty(kwargs) ||
81-
throw(ArgumentError("Additional keyword arguments are not allowed when an algorithm is specified."))
82-
return alg
83-
end
84-
function _select_algorithm(f, A, alg::Symbol; kwargs...)
85-
return _select_algorithm(f, A, Algorithm{alg}(; kwargs...))
82+
Base.@constprop :aggressive function _select_algorithm(f::F, A, alg::Symbol; kwargs...) where {F}
83+
return Algorithm{alg}(; kwargs...)
8684
end
87-
function _select_algorithm(f, A, alg::Type; kwargs...)
88-
return _select_algorithm(f, A, alg(; kwargs...))
85+
Base.@constprop :aggressive function _select_algorithm(f::F, A, ::Type{Alg}; kwargs...) where {F,Alg}
86+
return Alg(; kwargs...)
8987
end
90-
function _select_algorithm(f, A::AbstractMatrix, alg::NamedTuple; kwargs...)
88+
function _select_algorithm(f::F, A, alg::NamedTuple; kwargs...) where {F}
9189
isempty(kwargs) ||
9290
throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified."))
93-
return select_algorithm(f, A; alg...)
91+
return default_algorithm(f, A; alg...)
9492
end
95-
function _select_algorithm(f, A, alg; kwargs...)
93+
function _select_algorithm(f::F, A, alg::AbstractAlgorithm; kwargs...) where {F}
94+
isempty(kwargs) ||
95+
throw(ArgumentError("Additional keyword arguments are not allowed when an algorithm is specified."))
96+
return alg
97+
end
98+
function _select_algorithm(f::F, A, alg; kwargs...) where {F}
9699
return throw(ArgumentError("Unknown alg $alg"))
97100
end
98101

@@ -171,13 +174,15 @@ macro functiondef(f)
171174

172175
return esc(quote
173176
# out of place to inplace
174-
$f(A; kwargs...) = $f!(copy_input($f, A); kwargs...)
177+
Base.@constprop :aggressive $f(A; kwargs...) = $f!(copy_input($f, A); kwargs...)
175178
$f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg)
176179

177180
# fill in arguments
178-
$f!(A; kwargs...) = $f!(A, select_algorithm($f!, A; kwargs...))
179-
function $f!(A, out; kwargs...)
180-
return $f!(A, out, select_algorithm($f!, A; kwargs...))
181+
Base.@constprop :aggressive function $f!(A; alg=nothing, kwargs...)
182+
return $f!(A, select_algorithm($f!, A, alg; kwargs...))
183+
end
184+
Base.@constprop :aggressive function $f!(A, out; alg=nothing, kwargs...)
185+
return $f!(A, out, select_algorithm($f!, A, alg; kwargs...))
181186
end
182187
function $f!(A, alg::AbstractAlgorithm)
183188
return $f!(A, initialize_output($f!, A, alg), alg)

src/implementations/orthnull.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -89,22 +89,22 @@ function left_orth!(A::AbstractMatrix, VC; trunc=nothing,
8989
throw(ArgumentError("truncation not supported for left_orth with kind=$kind"))
9090
end
9191
if kind == :qr
92-
alg_qr′ = _select_algorithm(qr_compact!, A, alg_qr)
92+
alg_qr′ = select_algorithm(qr_compact!, A, alg_qr)
9393
return qr_compact!(A, VC, alg_qr′)
9494
elseif kind == :polar
9595
size(A, 1) >= size(A, 2) ||
9696
throw(ArgumentError("`left_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m >= n`"))
97-
alg_polar′ = _select_algorithm(left_polar!, A, alg_polar)
97+
alg_polar′ = select_algorithm(left_polar!, A, alg_polar)
9898
return left_polar!(A, VC, alg_polar′)
9999
elseif kind == :svd && isnothing(trunc)
100-
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
100+
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
101101
V, C = VC
102102
S = Diagonal(initialize_output(svd_vals!, A, alg_svd′))
103103
U, S, Vᴴ = svd_compact!(A, (V, S, C), alg_svd′)
104104
return U, lmul!(S, Vᴴ)
105105
elseif kind == :svd
106-
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
107-
alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′)
106+
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
107+
alg_svd_trunc = select_algorithm(svd_trunc!, A, alg_svd′; trunc)
108108
V, C = VC
109109
S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg))
110110
U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_svd_trunc)
@@ -122,22 +122,22 @@ function right_orth!(A::AbstractMatrix, CVᴴ; trunc=nothing,
122122
throw(ArgumentError("truncation not supported for right_orth with kind=$kind"))
123123
end
124124
if kind == :lq
125-
alg_lq′ = _select_algorithm(lq_compact!, A, alg_lq)
125+
alg_lq′ = select_algorithm(lq_compact!, A, alg_lq)
126126
return lq_compact!(A, CVᴴ, alg_lq′)
127127
elseif kind == :polar
128128
size(A, 2) >= size(A, 1) ||
129129
throw(ArgumentError("`right_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m <= n`"))
130-
alg_polar′ = _select_algorithm(right_polar!, A, alg_polar)
130+
alg_polar′ = select_algorithm(right_polar!, A, alg_polar)
131131
return right_polar!(A, CVᴴ, alg_polar′)
132132
elseif kind == :svd && isnothing(trunc)
133-
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
133+
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
134134
C, Vᴴ = CVᴴ
135135
S = Diagonal(initialize_output(svd_vals!, A, alg_svd′))
136136
U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg_svd′)
137137
return rmul!(U, S), Vᴴ
138138
elseif kind == :svd
139-
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
140-
alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′)
139+
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
140+
alg_svd_trunc = select_algorithm(svd_trunc!, A, alg_svd′; trunc)
141141
C, Vᴴ = CVᴴ
142142
S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg))
143143
U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_svd_trunc)
@@ -167,15 +167,15 @@ function left_null!(A::AbstractMatrix, N; trunc=nothing,
167167
throw(ArgumentError("truncation not supported for left_null with kind=$kind"))
168168
end
169169
if kind == :qr
170-
alg_qr′ = _select_algorithm(qr_null!, A, alg_qr)
170+
alg_qr′ = select_algorithm(qr_null!, A, alg_qr)
171171
return qr_null!(A, N, alg_qr′)
172172
elseif kind == :svd && isnothing(trunc)
173-
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
173+
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
174174
U, _, _ = svd_full!(A, alg_svd′)
175175
(m, n) = size(A)
176176
return copy!(N, view(U, 1:m, (n + 1):m))
177177
elseif kind == :svd
178-
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
178+
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
179179
U, S, _ = svd_full!(A, alg_svd′)
180180
trunc′ = trunc isa TruncationStrategy ? trunc :
181181
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
@@ -194,15 +194,15 @@ function right_null!(A::AbstractMatrix, Nᴴ; trunc=nothing,
194194
throw(ArgumentError("truncation not supported for right_null with kind=$kind"))
195195
end
196196
if kind == :lq
197-
alg_lq′ = _select_algorithm(lq_null!, A, alg_lq)
197+
alg_lq′ = select_algorithm(lq_null!, A, alg_lq)
198198
return lq_null!(A, Nᴴ, alg_lq′)
199199
elseif kind == :svd && isnothing(trunc)
200-
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
200+
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
201201
_, _, Vᴴ = svd_full!(A, alg_svd′)
202202
(m, n) = size(A)
203203
return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n))
204204
elseif kind == :svd
205-
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
205+
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
206206
_, S, Vᴴ = svd_full!(A, alg_svd′)
207207
trunc′ = trunc isa TruncationStrategy ? trunc :
208208
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :

src/implementations/truncation.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,16 @@ Trivial truncation strategy that keeps all values, mostly for testing purposes.
3232
"""
3333
struct NoTruncation <: TruncationStrategy end
3434

35-
function select_truncation(trunc::TruncationStrategy)
36-
return trunc
37-
end
38-
function select_truncation(trunc::NamedTuple)
39-
return TruncationStrategy(; trunc...)
40-
end
41-
function select_truncation(trunc::Nothing)
42-
return NoTruncation()
43-
end
4435
function select_truncation(trunc)
45-
return throw(ArgumentError("Unknown truncation strategy: $trunc"))
36+
if isnothing(trunc)
37+
return NoTruncation()
38+
elseif trunc isa NamedTuple
39+
return TruncationStrategy(; trunc...)
40+
elseif trunc isa TruncationStrategy
41+
return trunc
42+
else
43+
return throw(ArgumentError("Unknown truncation strategy: $trunc"))
44+
end
4645
end
4746

4847
# TODO: how do we deal with sorting/filters that treat zeros differently

src/interface/eig.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,11 @@ for f in (:eig_full, :eig_vals)
9999
end
100100
end
101101

102-
function select_algorithm(::typeof(eig_trunc), A; kwargs...)
103-
return select_algorithm(eig_trunc!, A; kwargs...)
102+
function select_algorithm(::typeof(eig_trunc), A, alg; kwargs...)
103+
return select_algorithm(eig_trunc!, A, alg; kwargs...)
104104
end
105-
function select_algorithm(::typeof(eig_trunc!), A; trunc=nothing, kwargs...)
106-
alg_eig = select_algorithm(eig_full!, A; kwargs...)
105+
function select_algorithm(::typeof(eig_trunc!), A, alg; trunc=nothing, kwargs...)
106+
alg_eig = select_algorithm(eig_full!, A, alg; kwargs...)
107107
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
108108
end
109109

src/interface/eigh.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,11 @@ for f in (:eigh_full, :eigh_vals)
9898
end
9999
end
100100

101-
function select_algorithm(::typeof(eigh_trunc), A; kwargs...)
102-
return select_algorithm(eigh_trunc!, A; kwargs...)
101+
function select_algorithm(::typeof(eigh_trunc), A, alg; kwargs...)
102+
return select_algorithm(eigh_trunc!, A, alg; kwargs...)
103103
end
104-
function select_algorithm(::typeof(eigh_trunc!), A; trunc=nothing, kwargs...)
105-
alg_eigh = select_algorithm(eigh_full!, A; kwargs...)
104+
function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc=nothing, kwargs...)
105+
alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...)
106106
return TruncatedAlgorithm(alg_eigh, select_truncation(trunc))
107107
end
108108

src/interface/svd.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,11 @@ for f in (:svd_full, :svd_compact, :svd_vals)
102102
end
103103
end
104104

105-
function select_algorithm(::typeof(svd_trunc), A; kwargs...)
106-
return select_algorithm(svd_trunc!, A; kwargs...)
105+
function select_algorithm(::typeof(svd_trunc), A, alg; kwargs...)
106+
return select_algorithm(svd_trunc!, A, alg; kwargs...)
107107
end
108-
function select_algorithm(::typeof(svd_trunc!), A; trunc=nothing, kwargs...)
109-
alg_svd = select_algorithm(svd_compact!, A; kwargs...)
108+
function select_algorithm(::typeof(svd_trunc!), A, alg; trunc=nothing, kwargs...)
109+
alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...)
110110
return TruncatedAlgorithm(alg_svd, select_truncation(trunc))
111111
end
112112

test/eig.jl

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,16 @@ using MatrixAlgebraKit: diagview
88
@testset "eig_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
99
rng = StableRNG(123)
1010
m = 54
11-
for alg in
12-
(LAPACK_Simple(), LAPACK_Expert(), LAPACK_Simple, LAPACK_Expert, :LAPACK_Simple,
13-
:LAPACK_Expert)
11+
for alg in (LAPACK_Simple(), LAPACK_Expert())
1412
A = randn(rng, T, m, m)
1513
Tc = complex(T)
1614

17-
alg′ = if (alg isa Type) || (alg isa Symbol)
18-
# These cases aren't inferable right now.
19-
MatrixAlgebraKit.select_algorithm(eig_full!, A; alg)
20-
else
21-
@constinferred MatrixAlgebraKit.select_algorithm(eig_full!, A; alg)
22-
end
23-
24-
D, V = if (alg isa Type) || (alg isa Symbol)
25-
# These cases aren't inferable right now.
26-
eig_full(A; alg)
27-
else
28-
@constinferred eig_full(A; alg)
29-
end
15+
D, V = @constinferred eig_full(A; alg)
3016
@test eltype(D) == eltype(V) == Tc
3117
@test A * V V * D
3218

19+
alg′ = @constinferred MatrixAlgebraKit.select_algorithm(eig_full!, A, alg)
20+
3321
Ac = similar(A)
3422
D2, V2 = @constinferred eig_full!(copy!(Ac, A), (D, V), alg′)
3523
@test D2 === D
@@ -40,6 +28,21 @@ using MatrixAlgebraKit: diagview
4028
@test eltype(Dc) == Tc
4129
@test D Diagonal(Dc)
4230
end
31+
32+
# Test other alg inputs.
33+
A = randn(rng, T, m, m)
34+
Tc = complex(T)
35+
D, V = @constinferred eig_full(A; alg=:LAPACK_Simple)
36+
@test eltype(D) == eltype(V) == Tc
37+
@test A * V V * D
38+
39+
A = randn(rng, T, m, m)
40+
Tc = complex(T)
41+
## Inference is broken for this case for some reason.
42+
## D, V = @constinferred eig_full(A; alg=LAPACK_Simple)
43+
D, V = eig_full(A; alg=LAPACK_Simple)
44+
@test eltype(D) == eltype(V) == Tc
45+
@test A * V V * D
4346
end
4447

4548
@testset "eig_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)

0 commit comments

Comments
 (0)