Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,54 @@
@doc """
select_algorithm(f, A; kwargs...)

Given some keyword arguments and an input `A`, decide on an algrithm to use for
Given some keyword arguments and an input `A`, decide on an algorithm to use for
implementing the function `f` on inputs of type `A`.

In general, if an algorithm is specified explicitly through the `alg` keyword argument
(either as an algorithm type, an algorithm name as a Symbol, or as an algorithm object),
that algorithm will be used instead of selecting it automatically. However, that
behavior may be modified for factorization functions and/or matrix types.

In general, if the algorithm is not specified, a default algorithm specified by
Comment thread
mtfishman marked this conversation as resolved.
Outdated
Comment thread
mtfishman marked this conversation as resolved.
Outdated
[`default_algorithm`](@ref) will be used.
"""
function select_algorithm end

function _select_algorithm(f, A::AbstractMatrix, alg::AbstractAlgorithm)
function select_algorithm(f, A; alg=nothing, kwargs...)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure using multiple dispatch via _select_algorithm to deal with the different cases of alg keywords is the right approach, as it might put more strain on the compiler. I don't think there is need to be able to extend this functionality, so something like

function select_algorithm(f, A; alg=nothing, kwargs...)
	if isnothing(alg)
		return default_algorithm(f, A; kwargs...)
	elseif alg isa AbstractAlgorithm
		isempty(kwargs) ||
			throw(ArgumentError("Additional keyword arguments are not allowed when an algorithm is specified."))
		return alg
	elseif
	...
	end
end

should work. Or does the _select_algorithm lowering actually help with the lowering?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm totally fine with that, using dispatch was a bit arbitrary. I generally like using dispatch rather than if-statements purely as a style thing, but if you and the compiler prefer if-statements I can switch to that.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying this out locally, an unfortunate side affect of this is that functions like left_orth/right_orth become type unstable, since before they were calling _select_algorithm which is type stable and select_algorithm isn't with the new design (I think what happens is that certain branches of select_algorithm aren't type stable, so then the entire thing becomes type unstable even for cases that used to be type stable).

I'll investigate that more. But maybe it would be better to just change select_algorithm to accept alg as a positional argument and go back to using dispatch. That should solve all type stability problems being discussed above as well.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the _select_algorithm lowering is better for type stability, I am certainly fine with that, and there is no need to waste time trying to trick the compiler to make the if elseif else construction inferable as well.

return _select_algorithm(f, A, alg; kwargs...)
end

function _select_algorithm(f, A, alg::Nothing; kwargs...)
return default_algorithm(f, A; kwargs...)
end
function _select_algorithm(f, A, alg::AbstractAlgorithm; kwargs...)
isempty(kwargs) ||
throw(ArgumentError("Additional keyword arguments are not allowed when an algorithm is specified."))
return alg
end
function _select_algorithm(f, A::AbstractMatrix, alg::NamedTuple)
function _select_algorithm(f, A, alg::Symbol; kwargs...)
return _select_algorithm(f, A, Algorithm{alg}; kwargs...)

Check warning on line 85 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L84-L85

Added lines #L84 - L85 were not covered by tests
Comment thread
mtfishman marked this conversation as resolved.
Outdated
end
function _select_algorithm(f, A, alg::Type; kwargs...)
return _select_algorithm(f, A, alg(; kwargs...))

Check warning on line 88 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L87-L88

Added lines #L87 - L88 were not covered by tests
end
function _select_algorithm(f, A::AbstractMatrix, alg::NamedTuple; kwargs...)
isempty(kwargs) ||
throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified."))
return select_algorithm(f, A; alg...)
end
function _select_algorithm(f, A, alg; kwargs...)
return throw(ArgumentError("Unknown alg $alg"))

Check warning on line 96 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L95-L96

Added lines #L95 - L96 were not covered by tests
end

@doc """
default_algorithm(f, A; kwargs...)

Select the default algorithm for a given factorization function `f` and input `A`.
In general, this is called by [`select_algorithm`](@ref) if no algorithm is specified
explicitly.
"""
function default_algorithm end

@doc """
copy_input(f, A)
Expand Down
18 changes: 16 additions & 2 deletions src/implementations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@
"""
struct NoTruncation <: TruncationStrategy end

function to_truncationstrategy(trunc::TruncationStrategy)
Comment thread
mtfishman marked this conversation as resolved.
Outdated
return trunc
end
function to_truncationstrategy(trunc::NamedTuple)
return TruncationStrategy(; trunc...)
end
function to_truncationstrategy(trunc::Nothing)
return NoTruncation()

Check warning on line 42 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L41-L42

Added lines #L41 - L42 were not covered by tests
end
function to_truncationstrategy(trunc)
return throw(ArgumentError("Unknown truncation strategy: $trunc"))

Check warning on line 45 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L44-L45

Added lines #L44 - L45 were not covered by tests
end

# TODO: how do we deal with sorting/filters that treat zeros differently
# since these are implicitly discarded by selecting compact/full

Expand Down Expand Up @@ -98,8 +111,9 @@
TruncationStrategy
components::T
end
TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...) =
TruncationIntersection((trunc, truncs...))
function TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...)
return TruncationIntersection((trunc, truncs...))

Check warning on line 115 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L114-L115

Added lines #L114 - L115 were not covered by tests
end

function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
return TruncationIntersection((trunc1, trunc2))
Expand Down
25 changes: 7 additions & 18 deletions src/interface/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,32 +90,21 @@
for f in (:eig_full, :eig_vals)
f! = Symbol(f, :!)
@eval begin
function select_algorithm(::typeof($f), A; kwargs...)
return select_algorithm($f!, A; kwargs...)
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)

Check warning on line 94 in src/interface/eig.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/eig.jl#L93-L94

Added lines #L93 - L94 were not covered by tests
end
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
if alg isa AbstractAlgorithm
return alg
elseif alg isa Symbol
return Algorithm{alg}(; kwargs...)
else
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
return default_eig_algorithm(A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_eig_algorithm(A; kwargs...)
end
end
end

function select_algorithm(::typeof(eig_trunc), A; kwargs...)
return select_algorithm(eig_trunc!, A; kwargs...)
end
function select_algorithm(::typeof(eig_trunc!), A; alg=nothing, trunc=nothing, kwargs...)
alg_eig = select_algorithm(eig_full!, A; alg, kwargs...)
alg_trunc = trunc isa TruncationStrategy ? trunc :
trunc isa NamedTuple ? TruncationStrategy(; trunc...) :
isnothing(trunc) ? NoTruncation() :
throw(ArgumentError("Unknown truncation strategy: $trunc"))
return TruncatedAlgorithm(alg_eig, alg_trunc)
function select_algorithm(::typeof(eig_trunc!), A; trunc=nothing, kwargs...)
alg_eig = select_algorithm(eig_full!, A; kwargs...)
return TruncatedAlgorithm(alg_eig, to_truncationstrategy(trunc))
end

# Default to LAPACK
Expand Down
25 changes: 7 additions & 18 deletions src/interface/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,32 +89,21 @@
for f in (:eigh_full, :eigh_vals)
f! = Symbol(f, :!)
@eval begin
function select_algorithm(::typeof($f), A; kwargs...)
return select_algorithm($f!, A; kwargs...)
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)

Check warning on line 93 in src/interface/eigh.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/eigh.jl#L92-L93

Added lines #L92 - L93 were not covered by tests
end
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
if alg isa AbstractAlgorithm
return alg
elseif alg isa Symbol
return Algorithm{alg}(; kwargs...)
else
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
return default_eigh_algorithm(A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_eigh_algorithm(A; kwargs...)
end
end
end

function select_algorithm(::typeof(eigh_trunc), A; kwargs...)
return select_algorithm(eigh_trunc!, A; kwargs...)
end
function select_algorithm(::typeof(eigh_trunc!), A; alg=nothing, trunc=nothing, kwargs...)
alg_eigh = select_algorithm(eigh_full!, A; alg, kwargs...)
alg_trunc = trunc isa TruncationStrategy ? trunc :
trunc isa NamedTuple ? TruncationStrategy(; trunc...) :
isnothing(trunc) ? NoTruncation() :
throw(ArgumentError("Unknown truncation strategy: $trunc"))
return TruncatedAlgorithm(alg_eigh, alg_trunc)
function select_algorithm(::typeof(eigh_trunc!), A; trunc=nothing, kwargs...)
alg_eigh = select_algorithm(eigh_full!, A; kwargs...)
return TruncatedAlgorithm(alg_eigh, to_truncationstrategy(trunc))
end

# Default to LAPACK
Expand Down
15 changes: 4 additions & 11 deletions src/interface/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,11 @@
for f in (:lq_full, :lq_compact, :lq_null)
f! = Symbol(f, :!)
@eval begin
function select_algorithm(::typeof($f), A; kwargs...)
return select_algorithm($f!, A; kwargs...)
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)

Check warning on line 75 in src/interface/lq.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/lq.jl#L74-L75

Added lines #L74 - L75 were not covered by tests
end
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
if alg isa AbstractAlgorithm
return alg
elseif alg isa Symbol
return Algorithm{alg}(; kwargs...)
else
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
return default_lq_algorithm(A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_lq_algorithm(A; kwargs...)
end
end
end
Expand Down
15 changes: 4 additions & 11 deletions src/interface/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,11 @@
for f in (:left_polar, :right_polar)
f! = Symbol(f, :!)
@eval begin
function select_algorithm(::typeof($f), A; kwargs...)
return select_algorithm($f!, A; kwargs...)
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)

Check warning on line 67 in src/interface/polar.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/polar.jl#L66-L67

Added lines #L66 - L67 were not covered by tests
end
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
if alg isa AbstractAlgorithm
return alg
elseif alg isa Symbol
return Algorithm{alg}(; kwargs...)
else
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
return default_polar_algorithm(A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_polar_algorithm(A; kwargs...)
end
end
end
Expand Down
15 changes: 4 additions & 11 deletions src/interface/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,11 @@
for f in (:qr_full, :qr_compact, :qr_null)
f! = Symbol(f, :!)
@eval begin
function select_algorithm(::typeof($f), A; kwargs...)
return select_algorithm($f!, A; kwargs...)
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)

Check warning on line 75 in src/interface/qr.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/qr.jl#L74-L75

Added lines #L74 - L75 were not covered by tests
end
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
if alg isa AbstractAlgorithm
return alg
elseif alg isa Symbol
return Algorithm{alg}(; kwargs...)
else
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
return default_qr_algorithm(A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_qr_algorithm(A; kwargs...)
end
end
end
Expand Down
15 changes: 4 additions & 11 deletions src/interface/schur.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,11 @@
for f in (:schur_full, :schur_vals)
f! = Symbol(f, :!)
@eval begin
function select_algorithm(::typeof($f), A; kwargs...)
return select_algorithm($f!, A; kwargs...)
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)

Check warning on line 58 in src/interface/schur.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/schur.jl#L57-L58

Added lines #L57 - L58 were not covered by tests
end
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
if alg isa AbstractAlgorithm
return alg
elseif alg isa Symbol
return Algorithm{alg}(; kwargs...)
else
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
return default_eig_algorithm(A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_eig_algorithm(A; kwargs...)

Check warning on line 61 in src/interface/schur.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/schur.jl#L60-L61

Added lines #L60 - L61 were not covered by tests
end
end
end
25 changes: 7 additions & 18 deletions src/interface/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,32 +93,21 @@
for f in (:svd_full, :svd_compact, :svd_vals)
f! = Symbol(f, :!)
@eval begin
function select_algorithm(::typeof($f), A; kwargs...)
return select_algorithm($f!, A; kwargs...)
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)

Check warning on line 97 in src/interface/svd.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/svd.jl#L96-L97

Added lines #L96 - L97 were not covered by tests
end
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
if alg isa AbstractAlgorithm
return alg
elseif alg isa Symbol
return Algorithm{alg}(; kwargs...)
else
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
return default_svd_algorithm(A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_svd_algorithm(A; kwargs...)
end
end
end

function select_algorithm(::typeof(svd_trunc), A; kwargs...)
return select_algorithm(svd_trunc!, A; kwargs...)
end
function select_algorithm(::typeof(svd_trunc!), A; alg=nothing, trunc=nothing, kwargs...)
alg_svd = select_algorithm(svd_compact!, A; alg, kwargs...)
alg_trunc = trunc isa TruncationStrategy ? trunc :
trunc isa NamedTuple ? TruncationStrategy(; trunc...) :
isnothing(trunc) ? NoTruncation() :
throw(ArgumentError("Unknown truncation strategy: $trunc"))
return TruncatedAlgorithm(alg_svd, alg_trunc)
function select_algorithm(::typeof(svd_trunc!), A; trunc=nothing, kwargs...)
alg_svd = select_algorithm(svd_compact!, A; kwargs...)
return TruncatedAlgorithm(alg_svd, to_truncationstrategy(trunc))
end

# Default to LAPACK SDD for `StridedMatrix{<:BlasFloat}`
Expand Down
Loading