Skip to content

Commit 7caadab

Browse files
authored
Merge branch 'main' into mf/fillarraysext
2 parents ed8d2bd + d4099c5 commit 7caadab

13 files changed

Lines changed: 173 additions & 65 deletions

File tree

src/algorithms.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,6 @@ passed as the third positional argument in the form of a `NamedTuple`.
7777
""" select_algorithm
7878

7979
function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg}
80-
return select_algorithm(f, typeof(A), alg; kwargs...)
81-
end
82-
function select_algorithm(f::F, ::Type{A}, alg::Alg=nothing; kwargs...) where {F,A,Alg}
8380
if isnothing(alg)
8481
return default_algorithm(f, A; kwargs...)
8582
elseif alg isa Symbol
@@ -193,10 +190,24 @@ macro functiondef(f)
193190
end
194191

195192
# define fallbacks for algorithm selection
196-
@inline function select_algorithm(::typeof($f), ::Type{A}, alg::Alg;
197-
kwargs...) where {Alg,A}
193+
@inline function select_algorithm(::typeof($f), A, alg::Alg; kwargs...) where {Alg}
198194
return select_algorithm($f!, A, alg; kwargs...)
199195
end
196+
# define default algorithm fallbacks for out-of-place functions
197+
# in terms of the corresponding in-place function
198+
@inline function default_algorithm(::typeof($f), A; kwargs...)
199+
return default_algorithm($f!, A; kwargs...)
200+
end
201+
# define default algorithm fallbacks for out-of-place functions
202+
# in terms of the corresponding in-place function for types,
203+
# in principle this is covered by the definition above but
204+
# it is necessary to avoid ambiguity errors with the generic definitions:
205+
# ```julia
206+
# default_algorithm(f::F, A; kwargs...) where {F} = default_algorithm(f, typeof(A); kwargs...)
207+
# function default_algorithm(f::F, ::Type{T}; kwargs...) where {F,T}
208+
# throw(MethodError(default_algorithm, (f, T)))
209+
# end
210+
# ```
200211
@inline function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
201212
return default_algorithm($f!, A; kwargs...)
202213
end

src/implementations/truncation.jl

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -67,22 +67,30 @@ struct TruncationKeepFiltered{F} <: TruncationStrategy
6767
filter::F
6868
end
6969

70-
struct TruncationKeepAbove{T<:Real} <: TruncationStrategy
70+
struct TruncationKeepAbove{T<:Real,F} <: TruncationStrategy
7171
atol::T
7272
rtol::T
7373
p::Int
74+
by::F
75+
end
76+
function TruncationKeepAbove(; atol::Real, rtol::Real, p::Int=2, by=abs)
77+
return TruncationKeepAbove(atol, rtol, p, by)
7478
end
75-
function TruncationKeepAbove(atol::Real, rtol::Real, p::Int=2)
76-
return TruncationKeepAbove(promote(atol, rtol)..., p)
79+
function TruncationKeepAbove(atol::Real, rtol::Real, p::Int=2, by=abs)
80+
return TruncationKeepAbove(promote(atol, rtol)..., p, by)
7781
end
7882

79-
struct TruncationKeepBelow{T<:Real} <: TruncationStrategy
83+
struct TruncationKeepBelow{T<:Real,F} <: TruncationStrategy
8084
atol::T
8185
rtol::T
8286
p::Int
87+
by::F
88+
end
89+
function TruncationKeepBelow(; atol::Real, rtol::Real, p::Int=2, by=abs)
90+
return TruncationKeepBelow(atol, rtol, p, by)
8391
end
84-
function TruncationKeepBelow(atol::Real, rtol::Real, p::Int=2)
85-
return TruncationKeepBelow(promote(atol, rtol)..., p)
92+
function TruncationKeepBelow(atol::Real, rtol::Real, p::Int=2, by=abs)
93+
return TruncationKeepBelow(promote(atol, rtol)..., p, by)
8694
end
8795

8896
# TODO: better names for these functions of the above types
@@ -94,18 +102,18 @@ Truncation strategy to keep the first `howmany` values when sorted according to
94102
truncrank(howmany::Int; by=abs, rev=true) = TruncationKeepSorted(howmany, by, rev)
95103

96104
"""
97-
trunctol(atol::Real)
105+
trunctol(atol::Real; by=abs)
98106
99-
Truncation strategy to discard the values that are smaller than `atol` in absolute value.
107+
Truncation strategy to discard the values that are smaller than `atol` according to `by`.
100108
"""
101-
trunctol(atol) = TruncationKeepFiltered((atol) abs)
109+
trunctol(atol; by=abs) = TruncationKeepFiltered((atol) by)
102110

103111
"""
104-
truncabove(atol::Real)
112+
truncabove(atol::Real; by=abs)
105113
106-
Truncation strategy to discard the values that are larger than `atol` in absolute value.
114+
Truncation strategy to discard the values that are larger than `atol` according to `by`.
107115
"""
108-
truncabove(atol) = TruncationKeepFiltered((atol) abs)
116+
truncabove(atol; by=abs) = TruncationKeepFiltered((atol) by)
109117

110118
"""
111119
TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...)
@@ -177,17 +185,18 @@ Generic interface for finding truncated values of the spectrum of a decompositio
177185
based on the `strategy`. The output should be a collection of indices specifying
178186
which values to keep. `MatrixAlgebraKit.findtruncated` is used inside of the default
179187
implementation of [`truncate!`](@ref) to perform the truncation. It does not assume that the
180-
values are sorted. For a version that assumes the values are reverse sorted by
181-
absolute value (which is the standard case for SVD) see
182-
[`MatrixAlgebraKit.findtruncated_sorted`](@ref).
188+
values are sorted. For a version that assumes the values are reverse sorted (which is the
189+
standard case for SVD) see [`MatrixAlgebraKit.findtruncated_sorted`](@ref).
183190
""" findtruncated
184191

185192
@doc """
186193
MatrixAlgebraKit.findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)
187194
188-
Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are sorted in reverse order by
189-
absolute value. However, note that this assumption is not checked, so passing values that are not sorted
190-
in that way can silently give unexpected results. This is used in the default implementation of
195+
Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are sorted in reverse order.
196+
They are assumed to be sorted in a way that is consistent with the truncation strategy,
197+
which generally means they are sorted by absolute value but some truncation strategies allow
198+
customizing that. However, note that this assumption is not checked, so passing values that are not sorted
199+
in the correct way can silently give unexpected results. This is used in the default implementation of
191200
[`svd_trunc!`](@ref).
192201
""" findtruncated_sorted
193202

@@ -212,21 +221,21 @@ end
212221

213222
function findtruncated(values::AbstractVector, strategy::TruncationKeepBelow)
214223
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
215-
return findall((atol), values)
224+
return findall((atol) strategy.by, values)
216225
end
217226
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepBelow)
218227
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
219-
i = searchsortedfirst(values, atol; by=abs, rev=true)
228+
i = searchsortedfirst(values, atol; by=strategy.by, rev=true)
220229
return i:length(values)
221230
end
222231

223232
function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove)
224233
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
225-
return findall((atol), values)
234+
return findall((atol) strategy.by, values)
226235
end
227236
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepAbove)
228237
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
229-
i = searchsortedlast(values, atol; by=abs, rev=true)
238+
i = searchsortedlast(values, atol; by=strategy.by, rev=true)
230239
return 1:i
231240
end
232241

src/interface/eig.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,13 @@ for f in (:eig_full!, :eig_vals!)
9999
end
100100
end
101101

102-
function select_algorithm(::typeof(eig_trunc!), ::Type{A}, alg; trunc=nothing,
103-
kwargs...) where {A<:YALAPACK.BlasMat}
104-
alg_eig = select_algorithm(eig_full!, A, alg; kwargs...)
105-
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
102+
function select_algorithm(::typeof(eig_trunc!), A, alg; trunc=nothing, kwargs...)
103+
if alg isa TruncatedAlgorithm
104+
isnothing(trunc) ||
105+
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
106+
return alg
107+
else
108+
alg_eig = select_algorithm(eig_full!, A, alg; kwargs...)
109+
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
110+
end
106111
end

src/interface/eigh.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,13 @@ for f in (:eigh_full!, :eigh_vals!)
100100
end
101101
end
102102

103-
function select_algorithm(::typeof(eigh_trunc!), ::Type{A}, alg; trunc=nothing,
104-
kwargs...) where {A<:YALAPACK.BlasMat}
105-
alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...)
106-
return TruncatedAlgorithm(alg_eigh, select_truncation(trunc))
103+
function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc=nothing, kwargs...)
104+
if alg isa TruncatedAlgorithm
105+
isnothing(trunc) ||
106+
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
107+
return alg
108+
else
109+
alg_eig = select_algorithm(eigh_full!, A, alg; kwargs...)
110+
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
111+
end
107112
end

src/interface/lq.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,7 @@ function default_lq_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
7777
end
7878

7979
for f in (:lq_full!, :lq_compact!, :lq_null!)
80-
@eval begin
81-
function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
82-
return default_lq_algorithm(A; kwargs...)
83-
end
80+
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
81+
return default_lq_algorithm(A; kwargs...)
8482
end
8583
end

src/interface/polar.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,7 @@ end
6161
# Algorithm selection
6262
# -------------------
6363
default_polar_algorithm(A; kwargs...) = default_polar_algorithm(typeof(A); kwargs...)
64-
function default_polar_algorithm(T::Type; kwargs...)
65-
throw(MethodError(default_polar_algorithm, (T,)))
66-
end
67-
function default_polar_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
64+
function default_polar_algorithm(::Type{T}; kwargs...) where {T}
6865
return PolarViaSVD(default_algorithm(svd_compact!, T; kwargs...))
6966
end
7067

src/interface/qr.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,7 @@ function default_qr_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
7777
end
7878

7979
for f in (:qr_full!, :qr_compact!, :qr_null!)
80-
@eval begin
81-
function default_algorithm(::typeof($f), ::Type{A};
82-
kwargs...) where {A<:YALAPACK.BlasMat}
83-
return default_qr_algorithm(A; kwargs...)
84-
end
80+
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
81+
return default_qr_algorithm(A; kwargs...)
8582
end
8683
end

src/interface/svd.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,13 @@ for f in (:svd_full!, :svd_compact!, :svd_vals!)
104104
end
105105
end
106106

107-
function select_algorithm(::typeof(svd_trunc!), ::Type{A}, alg; trunc=nothing,
108-
kwargs...) where {A<:YALAPACK.BlasMat}
109-
alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...)
110-
return TruncatedAlgorithm(alg_svd, select_truncation(trunc))
107+
function select_algorithm(::typeof(svd_trunc!), A, alg; trunc=nothing, kwargs...)
108+
if alg isa TruncatedAlgorithm
109+
isnothing(trunc) ||
110+
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
111+
return alg
112+
else
113+
alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...)
114+
return TruncatedAlgorithm(alg_svd, select_truncation(trunc))
115+
end
111116
end

test/algorithms.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using MatrixAlgebraKit
22
using Test
33
using TestExtras
44
using MatrixAlgebraKit: LAPACK_SVDAlgorithm, NoTruncation, PolarViaSVD, TruncatedAlgorithm,
5-
default_algorithm, select_algorithm
5+
TruncationKeepBelow, default_algorithm, select_algorithm
66

77
@testset "default_algorithm" begin
88
A = randn(3, 3)
@@ -50,6 +50,12 @@ end
5050
NoTruncation())
5151
end
5252

53+
alg = TruncatedAlgorithm(LAPACK_Simple(), TruncationKeepBelow(0.1, 0.0))
54+
for f in (eig_trunc!, eigh_trunc!, svd_trunc!)
55+
@test @constinferred(select_algorithm(eig_trunc!, A, alg)) === alg
56+
@test_throws ArgumentError select_algorithm(eig_trunc!, A, alg; trunc=(; maxrank=2))
57+
end
58+
5359
@test @constinferred(select_algorithm(svd_compact!, A)) === LAPACK_DivideAndConquer()
5460
@test @constinferred(select_algorithm(svd_compact!, A, nothing)) ===
5561
LAPACK_DivideAndConquer()

test/eig.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33
using TestExtras
44
using StableRNGs
55
using LinearAlgebra: Diagonal
6-
using MatrixAlgebraKit: diagview
6+
using MatrixAlgebraKit: TruncatedAlgorithm, diagview
77

88
@testset "eig_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
99
rng = StableRNG(123)
@@ -57,3 +57,17 @@ end
5757
@test V2 * ((V2' * V2) \ (V2' * V1)) V1
5858
end
5959
end
60+
61+
@testset "eig_trunc! specify truncation algorithm T = $T" for T in
62+
(Float32, Float64, ComplexF32,
63+
ComplexF64)
64+
rng = StableRNG(123)
65+
m = 4
66+
V = randn(rng, T, m, m)
67+
D = Diagonal([0.9, 0.3, 0.1, 0.01])
68+
A = V * D * inv(V)
69+
alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2))
70+
D2, V2 = @constinferred eig_trunc(A; alg)
71+
@test diagview(D2) diagview(D)[1:2] rtol = sqrt(eps(real(T)))
72+
@test_throws ArgumentError eig_trunc(A; alg, trunc=(; maxrank=2))
73+
end

0 commit comments

Comments
 (0)