Skip to content

Commit f878a6f

Browse files
committed
Refactor algorithm selection logic
1 parent 2364e25 commit f878a6f

9 files changed

Lines changed: 71 additions & 98 deletions

File tree

src/algorithms.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,27 @@ implementing the function `f` on inputs of type `A`.
6161
"""
6262
function select_algorithm end
6363

64+
function select_algorithm(f, A; alg=nothing, kwargs...)
65+
return _select_algorithm(f, A, alg; kwargs...)
66+
end
67+
68+
function _select_algorithm(f, A, alg::Nothing; kwargs...)
69+
return default_algorithm(f, A; kwargs...)
70+
end
71+
function _select_algorithm(f, A, alg::AbstractAlgorithm; kwargs...)
72+
isempty(kwargs) || throw(ArgumentError("Additional keyword arguments are not allowed when an algorithm is specified."))
73+
return alg
74+
end
75+
function _select_algorithm(f, A, alg::Symbol; kwargs...)
76+
return _select_algorithm(f, A, Algorithm{alg}; kwargs...)
77+
end
78+
function _select_algorithm(f, A, alg::Type; kwargs...)
79+
return _select_algorithm(f, A, alg(; kwargs...))
80+
end
81+
function _select_algorithm(f, A, alg; kwargs...)
82+
return throw(ArgumentError("Unknown alg $alg"))
83+
end
84+
6485
@doc """
6586
copy_input(f, A)
6687

src/implementations/truncation.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,19 @@ Trivial truncation strategy that keeps all values, mostly for testing purposes.
3232
"""
3333
struct NoTruncation <: TruncationStrategy end
3434

35+
function to_truncationstrategy(trunc::TruncationStrategy)
36+
return trunc
37+
end
38+
function to_truncationstrategy(trunc::NamedTuple)
39+
return TruncationStrategy(; trunc...)
40+
end
41+
function to_truncationstrategy(trunc::Nothing)
42+
return NoTruncation()
43+
end
44+
function to_truncationstrategy(trunc)
45+
return throw(ArgumentError("Unknown truncation strategy: $trunc"))
46+
end
47+
3548
# TODO: how do we deal with sorting/filters that treat zeros differently
3649
# since these are implicitly discarded by selecting compact/full
3750

src/interface/eig.jl

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -90,32 +90,21 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc).
9090
for f in (:eig_full, :eig_vals)
9191
f! = Symbol(f, :!)
9292
@eval begin
93-
function select_algorithm(::typeof($f), A; kwargs...)
94-
return select_algorithm($f!, A; kwargs...)
93+
function default_algorithm(::typeof($f), A; kwargs...)
94+
return default_algorithm($f!, A; kwargs...)
9595
end
96-
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
97-
if alg isa AbstractAlgorithm
98-
return alg
99-
elseif alg isa Symbol
100-
return Algorithm{alg}(; kwargs...)
101-
else
102-
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
103-
return default_eig_algorithm(A; kwargs...)
104-
end
96+
function default_algorithm(::typeof($f!), A; kwargs...)
97+
return default_eig_algorithm(A; kwargs...)
10598
end
10699
end
107100
end
108101

109102
function select_algorithm(::typeof(eig_trunc), A; kwargs...)
110103
return select_algorithm(eig_trunc!, A; kwargs...)
111104
end
112-
function select_algorithm(::typeof(eig_trunc!), A; alg=nothing, trunc=nothing, kwargs...)
113-
alg_eig = select_algorithm(eig_full!, A; alg, kwargs...)
114-
alg_trunc = trunc isa TruncationStrategy ? trunc :
115-
trunc isa NamedTuple ? TruncationStrategy(; trunc...) :
116-
isnothing(trunc) ? NoTruncation() :
117-
throw(ArgumentError("Unknown truncation strategy: $trunc"))
118-
return TruncatedAlgorithm(alg_eig, alg_trunc)
105+
function select_algorithm(::typeof(eig_trunc!), A; trunc=nothing, kwargs...)
106+
alg_eig = select_algorithm(eig_full!, A; kwargs...)
107+
return TruncatedAlgorithm(alg_eig, to_truncationstrategy(trunc))
119108
end
120109

121110
# Default to LAPACK

src/interface/eigh.jl

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -89,32 +89,21 @@ See also [`eigh_full(!)`](@ref eigh_full) and [`eigh_trunc(!)`](@ref eigh_trunc)
8989
for f in (:eigh_full, :eigh_vals)
9090
f! = Symbol(f, :!)
9191
@eval begin
92-
function select_algorithm(::typeof($f), A; kwargs...)
93-
return select_algorithm($f!, A; kwargs...)
92+
function default_algorithm(::typeof($f), A; kwargs...)
93+
return default_algorithm($f!, A; kwargs...)
9494
end
95-
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
96-
if alg isa AbstractAlgorithm
97-
return alg
98-
elseif alg isa Symbol
99-
return Algorithm{alg}(; kwargs...)
100-
else
101-
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
102-
return default_eigh_algorithm(A; kwargs...)
103-
end
95+
function default_algorithm(::typeof($f!), A; kwargs...)
96+
return default_eig_algorithm(A; kwargs...)
10497
end
10598
end
10699
end
107100

108101
function select_algorithm(::typeof(eigh_trunc), A; kwargs...)
109102
return select_algorithm(eigh_trunc!, A; kwargs...)
110103
end
111-
function select_algorithm(::typeof(eigh_trunc!), A; alg=nothing, trunc=nothing, kwargs...)
112-
alg_eigh = select_algorithm(eigh_full!, A; alg, kwargs...)
113-
alg_trunc = trunc isa TruncationStrategy ? trunc :
114-
trunc isa NamedTuple ? TruncationStrategy(; trunc...) :
115-
isnothing(trunc) ? NoTruncation() :
116-
throw(ArgumentError("Unknown truncation strategy: $trunc"))
117-
return TruncatedAlgorithm(alg_eigh, alg_trunc)
104+
function select_algorithm(::typeof(eigh_trunc!), A; trunc=nothing, kwargs...)
105+
alg_eigh = select_algorithm(eigh_full!, A; kwargs...)
106+
return TruncatedAlgorithm(alg_eigh, to_truncationstrategy(trunc))
118107
end
119108

120109
# Default to LAPACK

src/interface/lq.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,11 @@ See also [`qr_full(!)`](@ref lq_full) and [`qr_compact(!)`](@ref lq_compact).
7171
for f in (:lq_full, :lq_compact, :lq_null)
7272
f! = Symbol(f, :!)
7373
@eval begin
74-
function select_algorithm(::typeof($f), A; kwargs...)
75-
return select_algorithm($f!, A; kwargs...)
74+
function default_algorithm(::typeof($f), A; kwargs...)
75+
return default_algorithm($f!, A; kwargs...)
7676
end
77-
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
78-
if alg isa AbstractAlgorithm
79-
return alg
80-
elseif alg isa Symbol
81-
return Algorithm{alg}(; kwargs...)
82-
else
83-
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
84-
return default_lq_algorithm(A; kwargs...)
85-
end
77+
function default_algorithm(::typeof($f!), A; kwargs...)
78+
return default_lq_algorithm(A; kwargs...)
8679
end
8780
end
8881
end

src/interface/polar.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,11 @@ end
6363
for f in (:left_polar, :right_polar)
6464
f! = Symbol(f, :!)
6565
@eval begin
66-
function select_algorithm(::typeof($f), A; kwargs...)
67-
return select_algorithm($f!, A; kwargs...)
66+
function default_algorithm(::typeof($f), A; kwargs...)
67+
return default_algorithm($f!, A; kwargs...)
6868
end
69-
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
70-
if alg isa AbstractAlgorithm
71-
return alg
72-
elseif alg isa Symbol
73-
return Algorithm{alg}(; kwargs...)
74-
else
75-
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
76-
return default_polar_algorithm(A; kwargs...)
77-
end
69+
function default_algorithm(::typeof($f!), A; kwargs...)
70+
return default_polar_algorithm(A; kwargs...)
7871
end
7972
end
8073
end

src/interface/qr.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,11 @@ See also [`lq_full(!)`](@ref lq_full) and [`lq_compact(!)`](@ref lq_compact).
7171
for f in (:qr_full, :qr_compact, :qr_null)
7272
f! = Symbol(f, :!)
7373
@eval begin
74-
function select_algorithm(::typeof($f), A; kwargs...)
75-
return select_algorithm($f!, A; kwargs...)
74+
function default_algorithm(::typeof($f), A; kwargs...)
75+
return default_algorithm($f!, A; kwargs...)
7676
end
77-
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
78-
if alg isa AbstractAlgorithm
79-
return alg
80-
elseif alg isa Symbol
81-
return Algorithm{alg}(; kwargs...)
82-
else
83-
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
84-
return default_qr_algorithm(A; kwargs...)
85-
end
77+
function default_algorithm(::typeof($f!), A; kwargs...)
78+
return default_qr_algorithm(A; kwargs...)
8679
end
8780
end
8881
end

src/interface/schur.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,11 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc).
5454
for f in (:schur_full, :schur_vals)
5555
f! = Symbol(f, :!)
5656
@eval begin
57-
function select_algorithm(::typeof($f), A; kwargs...)
58-
return select_algorithm($f!, A; kwargs...)
57+
function default_algorithm(::typeof($f), A; kwargs...)
58+
return default_algorithm($f!, A; kwargs...)
5959
end
60-
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
61-
if alg isa AbstractAlgorithm
62-
return alg
63-
elseif alg isa Symbol
64-
return Algorithm{alg}(; kwargs...)
65-
else
66-
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
67-
return default_eig_algorithm(A; kwargs...)
68-
end
60+
function default_algorithm(::typeof($f!), A; kwargs...)
61+
return default_eig_algorithm(A; kwargs...)
6962
end
7063
end
7164
end

src/interface/svd.jl

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -93,32 +93,21 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_compact(!)`](@ref svd_compact) an
9393
for f in (:svd_full, :svd_compact, :svd_vals)
9494
f! = Symbol(f, :!)
9595
@eval begin
96-
function select_algorithm(::typeof($f), A; kwargs...)
97-
return select_algorithm($f!, A; kwargs...)
96+
function default_algorithm(::typeof($f), A; kwargs...)
97+
return default_algorithm($f!, A; kwargs...)
9898
end
99-
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
100-
if alg isa AbstractAlgorithm
101-
return alg
102-
elseif alg isa Symbol
103-
return Algorithm{alg}(; kwargs...)
104-
else
105-
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
106-
return default_svd_algorithm(A; kwargs...)
107-
end
99+
function default_algorithm(::typeof($f!), A; kwargs...)
100+
return default_svd_algorithm(A; kwargs...)
108101
end
109102
end
110103
end
111104

112105
function select_algorithm(::typeof(svd_trunc), A; kwargs...)
113106
return select_algorithm(svd_trunc!, A; kwargs...)
114107
end
115-
function select_algorithm(::typeof(svd_trunc!), A; alg=nothing, trunc=nothing, kwargs...)
116-
alg_svd = select_algorithm(svd_compact!, A; alg, kwargs...)
117-
alg_trunc = trunc isa TruncationStrategy ? trunc :
118-
trunc isa NamedTuple ? TruncationStrategy(; trunc...) :
119-
isnothing(trunc) ? NoTruncation() :
120-
throw(ArgumentError("Unknown truncation strategy: $trunc"))
121-
return TruncatedAlgorithm(alg_svd, alg_trunc)
108+
function select_algorithm(::typeof(svd_trunc!), A; trunc=nothing, kwargs...)
109+
alg_svd = select_algorithm(svd_compact!, A; kwargs...)
110+
return TruncatedAlgorithm(alg_svd, to_truncationstrategy(trunc))
122111
end
123112

124113
# Default to LAPACK SDD for `StridedMatrix{<:BlasFloat}`

0 commit comments

Comments
 (0)