Skip to content

Commit ed99312

Browse files
committed
various renaming and cleanup
1 parent b56d7f3 commit ed99312

15 files changed

Lines changed: 146 additions & 221 deletions

File tree

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ end
3030

3131
MatrixAlgebraKit.supports_svd(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi)
3232
MatrixAlgebraKit.supports_svd_full(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi)
33-
MatrixAlgebraKit.supports_eigh(::ROCSOLVER, f::Symbol) = f in (:jacobi, :divide_and_conquer, :qr_iteration, :bisection)
3433

3534
function gesvd!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...)
3635
m, n = size(A)

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <
2222
return QRIteration(; kwargs...)
2323
end
2424
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
25-
return Simple(; kwargs...)
25+
return QRIteration(; balanced = false, kwargs...)
2626
end
2727
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
2828
return DivideAndConquer(; kwargs...)
@@ -33,8 +33,6 @@ for f in (:geqrf!, :ungqr!, :unmqr!)
3333
@eval $f(::CUSOLVER, args...) = YACUSOLVER.$f(args...)
3434
end
3535

36-
MatrixAlgebraKit.supports_eig(::CUSOLVER, f::Symbol) = f === :simple
37-
MatrixAlgebraKit.supports_eigh(::CUSOLVER, f::Symbol) = f in (:jacobi, :divide_and_conquer)
3836
MatrixAlgebraKit.supports_svd(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar)
3937
MatrixAlgebraKit.supports_svd_full(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar)
4038

ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module MatrixAlgebraKitGenericLinearAlgebraExt
22

33
using MatrixAlgebraKit
44
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge
5-
using MatrixAlgebraKit: GLA
5+
using MatrixAlgebraKit: GLA, Driver
66
import MatrixAlgebraKit: gesvd!, heev!
77
using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr!
88
using LinearAlgebra: I, Diagonal, lmul!
@@ -11,12 +11,14 @@ const GlaFloat = Union{BigFloat, Complex{BigFloat}}
1111
const GlaStridedVecOrMatrix{T <: GlaFloat} = Union{StridedVector{T}, StridedMatrix{T}}
1212
MatrixAlgebraKit.default_driver(::Type{<:QRIteration}, ::Type{TA}) where {TA <: GlaStridedVecOrMatrix} = GLA()
1313

14-
MatrixAlgebraKit.supports_eigh(::GLA, f::Symbol) = f === :qr_iteration
1514
MatrixAlgebraKit.supports_svd(::GLA, f::Symbol) = f === :qr_iteration
1615
MatrixAlgebraKit.supports_svd_full(::GLA, f::Symbol) = f === :qr_iteration
1716

18-
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix}
19-
return QRIteration(; kwargs...)
17+
function MatrixAlgebraKit.default_svd_algorithm(
18+
::Type{T};
19+
driver::Driver = GLA(), kwargs...
20+
) where {T <: GlaStridedVecOrMatrix}
21+
return QRIteration(; driver, kwargs...)
2022
end
2123

2224
function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...)
@@ -36,8 +38,8 @@ function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix,
3638
return S, U, Vᴴ
3739
end
3840

39-
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix}
40-
return QRIteration(; kwargs...)
41+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; driver::Driver = GLA(), kwargs...) where {T <: GlaStridedVecOrMatrix}
42+
return QRIteration(; driver, kwargs...)
4143
end
4244

4345
function heev!(::GLA, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...)
Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,51 @@
11
module MatrixAlgebraKitGenericSchurExt
22

33
using MatrixAlgebraKit
4-
using MatrixAlgebraKit: check_input, GS
5-
import MatrixAlgebraKit: geev!, gees!, eig_full!, eig_vals!, schur_full!, schur_vals!
4+
using MatrixAlgebraKit: check_input, GS, Driver
5+
import MatrixAlgebraKit: geev!, geevx!, gees!, eig_full!, eig_vals!, schur_full!, schur_vals!
66
using LinearAlgebra: Diagonal, sorteig!
77
using GenericSchur
88

99
const GSFloat = Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}
1010

11-
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:GSFloat}}
12-
return Simple(; kwargs...)
11+
function MatrixAlgebraKit.default_eig_algorithm(
12+
::Type{T};
13+
balanced::Bool = false, driver::Driver = GS(), kwargs...
14+
) where {T <: StridedMatrix{<:GSFloat}}
15+
return QRIteration(; driver, balanced, kwargs...)
1316
end
1417

15-
MatrixAlgebraKit.default_driver(::Type{<:Simple}, ::Type{TA}) where {TA <: StridedMatrix{<:GSFloat}} = GS()
16-
17-
MatrixAlgebraKit.supports_schur(::GS, f::Symbol) = f === :simple
18-
MatrixAlgebraKit.supports_eig(::GS, f::Symbol) = f === :simple
19-
2018
function geev!(::GS, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...)
2119
D, Vmat = GenericSchur.eigen!(A)
2220
copyto!(Dd, D)
2321
length(V) > 0 && copyto!(V, Vmat)
2422
return Dd, V
2523
end
2624

27-
function gees!(::GS, A::AbstractMatrix, Z::AbstractMatrix, vals::AbstractVector)
25+
function gees!(driver::GS, A::AbstractMatrix, Z::AbstractMatrix, vals::AbstractVector)
2826
S = GenericSchur.gschur(A)
2927
copyto!(A, S.T)
30-
if length(Z) > 0
31-
copyto!(Z, S.Z)
32-
copyto!(vals, S.values)
33-
else
34-
copyto!(vals, sorteig!(S.values))
35-
end
28+
length(Z) > 0 && copyto!(Z, S.Z)
29+
copyto!(vals, sorteig!(S.values))
3630
return A, Z, vals
3731
end
3832

3933
Base.@deprecate(
4034
eig_full!(A, DV, alg::GS_QRIteration),
41-
eig_full!(A, DV, Simple(; driver = GS(), alg.kwargs...))
35+
eig_full!(A, DV, QRIteration(; driver = GS(), alg.kwargs...))
4236
)
4337
Base.@deprecate(
4438
eig_vals!(A, D, alg::GS_QRIteration),
45-
eig_vals!(A, D, Simple(; driver = GS(), alg.kwargs...))
39+
eig_vals!(A, D, QRIteration(; driver = GS(), alg.kwargs...))
4640
)
4741

4842
Base.@deprecate(
4943
schur_full!(A, TZv, alg::GS_QRIteration),
50-
schur_full!(A, TZv, Simple(; driver = GS(), alg.kwargs...))
44+
schur_full!(A, TZv, QRIteration(; driver = GS(), alg.kwargs...))
5145
)
5246
Base.@deprecate(
5347
schur_vals!(A, vals, alg::GS_QRIteration),
54-
schur_vals!(A, vals, Simple(; driver = GS(), alg.kwargs...))
48+
schur_vals!(A, vals, QRIteration(; driver = GS(), alg.kwargs...))
5549
)
5650

5751
end

src/MatrixAlgebraKit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ export left_orth!, right_orth!, left_null!, right_null!
3333

3434
export Householder, Native_HouseholderQR, Native_HouseholderLQ
3535
export DivideAndConquer, SafeDivideAndConquer, QRIteration, Bisection, Jacobi, SVDViaPolar
36-
export MultipleRelativelyRobustRepresentations, Simple, Expert
36+
export RobustRepresentations
3737
export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
3838
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
3939
LAPACK_DivideAndConquer, LAPACK_Jacobi, LAPACK_SafeDivideAndConquer

src/implementations/eig.jl

Lines changed: 49 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -90,84 +90,60 @@ end
9090
# ==========================
9191
# IMPLEMENTATIONS
9292
# ==========================
93-
for f! in (:geev!, :geevx!)
94-
@eval $f!(driver::Driver, args...) = throw(ArgumentError("$driver does not provide $f!"))
93+
94+
geev!(driver::Driver, args...; kwargs...) = throw(ArgumentError("$driver does not provide $f!"))
95+
function geevx!(driver::Driver, A, Dd, V; kwargs...)
96+
@warn "$driver does not provide `geevx!`, falling back to `geev!`" maxlog = 1
97+
return geev!(driver, A, Dd, V; kwargs...)
9598
end
99+
_has_geevx!(::Driver) = false
96100

97101
# LAPACK implementations
98102
for f! in (:geev!, :geevx!)
99103
@eval $f!(::LAPACK, args...; kwargs...) = YALAPACK.$f!(args...; kwargs...)
100104
end
105+
_has_geevx!(::LAPACK) = true
101106

102-
supports_eig(::Driver, ::Symbol) = false
103-
supports_eig(::LAPACK, f::Symbol) = f in (:simple, :expert)
104-
105-
for (f, f_lapack!, Alg) in (
106-
(:simple, :geev!, :Simple),
107-
(:expert, :geevx!, :Expert),
108-
)
109-
f_eig_full! = Symbol(f, :_eig_full!)
110-
f_eig_vals! = Symbol(f, :_eig_vals!)
107+
# driver dispatch
108+
@inline qr_iteration_eig_full!(A, Dd, V; driver::Driver = DefaultDriver(), kwargs...) =
109+
qr_iteration_eig_full!(driver, A, Dd, V; kwargs...)
110+
@inline qr_iteration_eig_vals!(A, D, V; driver::Driver = DefaultDriver(), kwargs...) =
111+
qr_iteration_eig_vals!(driver, A, D, V; kwargs...)
111112

112-
# MatrixAlgebraKit wrappers
113-
@eval begin
114-
function eig_full!(A::AbstractMatrix, DV, alg::$Alg)
115-
check_input(eig_full!, A, DV, alg)
116-
D, V = DV
117-
Dd, V = $f_eig_full!(A, diagview(D), V; alg.kwargs...)
118-
return D, V
119-
end
120-
function eig_vals!(A::AbstractMatrix, D, alg::$Alg)
121-
check_input(eig_vals!, A, D, alg)
122-
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
123-
$f_eig_vals!(A, D, V; alg.kwargs...)
124-
return D
125-
end
126-
end
113+
@inline qr_iteration_eig_full!(::DefaultDriver, A, Dd, V; kwargs...) =
114+
qr_iteration_eig_full!(default_driver(QRIteration, A), A, Dd, V; kwargs...)
115+
@inline qr_iteration_eig_vals!(::DefaultDriver, A, D, V; kwargs...) =
116+
qr_iteration_eig_vals!(default_driver(QRIteration, A), A, D, V; kwargs...)
127117

128-
# driver dispatch
129-
@eval begin
130-
@inline $f_eig_full!(A, Dd, V; driver::Driver = DefaultDriver(), kwargs...) =
131-
$f_eig_full!(driver, A, Dd, V; kwargs...)
132-
@inline $f_eig_vals!(A, D, V; driver::Driver = DefaultDriver(), kwargs...) =
133-
$f_eig_vals!(driver, A, D, V; kwargs...)
134-
135-
@inline $f_eig_full!(::DefaultDriver, A, Dd, V; kwargs...) =
136-
$f_eig_full!(default_driver($Alg, A), A, Dd, V; kwargs...)
137-
@inline $f_eig_vals!(::DefaultDriver, A, D, V; kwargs...) =
138-
$f_eig_vals!(default_driver($Alg, A), A, D, V; kwargs...)
139-
end
118+
# Implementation
119+
function qr_iteration_eig_full!(
120+
driver::Driver, A, Dd, V;
121+
fixgauge::Bool = default_fixgauge(), balanced::Bool = _has_geevx!(driver), kwargs...
122+
)
123+
(balanced ? geevx! : geev!)(driver, A, Dd, V; kwargs...)
124+
fixgauge && gaugefix!(eig_full!, V)
125+
return Dd, V
126+
end
127+
function qr_iteration_eig_vals!(
128+
driver::Driver, A, D, V;
129+
fixgauge::Bool = default_fixgauge(), balanced::Bool = _has_geevx!(driver), kwargs...
130+
)
131+
(balanced ? geevx! : geev!)(driver, A, D, V; kwargs...)
132+
return D
133+
end
140134

141-
# Implementation
142-
@eval begin
143-
function $f_eig_full!(driver::Driver, A, Dd, V; fixgauge::Bool = default_fixgauge(), kwargs...)
144-
supports_eig(driver, $(QuoteNode(f))) ||
145-
throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`")))
146-
$(
147-
if f == :simple
148-
:(isempty(kwargs) || throw(ArgumentError(LazyString("invalid keyword arguments for ", driver, " simple eig"))))
149-
else
150-
:nothing
151-
end
152-
)
153-
$f_lapack!(driver, A, Dd, V; kwargs...)
154-
fixgauge && gaugefix!(eig_full!, V)
155-
return Dd, V
156-
end
157-
function $f_eig_vals!(driver::Driver, A, D, V; fixgauge::Bool = default_fixgauge(), kwargs...)
158-
supports_eig(driver, $(QuoteNode(f))) ||
159-
throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`")))
160-
$(
161-
if f == :simple
162-
:(isempty(kwargs) || throw(ArgumentError(LazyString("invalid keyword arguments for ", driver, " simple eig"))))
163-
else
164-
:nothing
165-
end
166-
)
167-
$f_lapack!(driver, A, D, V; kwargs...)
168-
return D
169-
end
170-
end
135+
# Top-level QRIteration dispatch
136+
function eig_full!(A::AbstractMatrix, DV, alg::QRIteration)
137+
check_input(eig_full!, A, DV, alg)
138+
D, V = DV
139+
qr_iteration_eig_full!(A, diagview(D), V; alg.kwargs...)
140+
return D, V
141+
end
142+
function eig_vals!(A::AbstractMatrix, D, alg::QRIteration)
143+
check_input(eig_vals!, A, D, alg)
144+
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
145+
qr_iteration_eig_vals!(A, D, V; alg.kwargs...)
146+
return D
171147
end
172148

173149
function eig_trunc!(A, DV, alg::TruncatedAlgorithm)
@@ -212,24 +188,23 @@ end
212188

213189
# Deprecations
214190
# ------------
215-
for algtype in (:Simple, :Expert)
216-
lapack_algtype = Symbol(:LAPACK_, algtype)
191+
for (lapack_algtype, balanced_val) in ((:LAPACK_Simple, false), (:LAPACK_Expert, true))
217192
@eval begin
218193
Base.@deprecate(
219194
eig_full!(A, DV, alg::$lapack_algtype),
220-
eig_full!(A, DV, $algtype(; driver = LAPACK(), alg.kwargs...))
195+
eig_full!(A, DV, QRIteration(; balanced = $balanced_val, alg.kwargs...))
221196
)
222197
Base.@deprecate(
223198
eig_vals!(A, D, alg::$lapack_algtype),
224-
eig_vals!(A, D, $algtype(; driver = LAPACK(), alg.kwargs...))
199+
eig_vals!(A, D, QRIteration(; balanced = $balanced_val, alg.kwargs...))
225200
)
226201
end
227202
end
228203
Base.@deprecate(
229204
eig_full!(A, DV, alg::CUSOLVER_Simple),
230-
eig_full!(A, DV, Simple(; driver = CUSOLVER(), alg.kwargs...))
205+
eig_full!(A, DV, QRIteration(; driver = CUSOLVER(), alg.kwargs...))
231206
)
232207
Base.@deprecate(
233208
eig_vals!(A, D, alg::CUSOLVER_Simple),
234-
eig_vals!(A, D, Simple(; driver = CUSOLVER(), alg.kwargs...))
209+
eig_vals!(A, D, QRIteration(; driver = CUSOLVER(), alg.kwargs...))
235210
)

src/implementations/eigh.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,8 @@ for f! in (:heevr!, :heevd!, :heev!, :heevx!)
108108
@eval $f!(::LAPACK, args...; kwargs...) = YALAPACK.$f!(args...; kwargs...)
109109
end
110110

111-
supports_eigh(::Driver, ::Symbol) = false
112-
supports_eigh(::LAPACK, f::Symbol) = f in (:mrrr, :divide_and_conquer, :qr_iteration, :bisection)
113-
114111
for (f, f_lapack!, Alg) in (
115-
(:mrrr, :heevr!, :MultipleRelativelyRobustRepresentations),
112+
(:mrrr, :heevr!, :RobustRepresentations),
116113
(:divide_and_conquer, :heevd!, :DivideAndConquer),
117114
(:qr_iteration, :heev!, :QRIteration),
118115
(:bisection, :heevx!, :Bisection),
@@ -153,15 +150,11 @@ for (f, f_lapack!, Alg) in (
153150
# Implementation
154151
@eval begin
155152
function $f_eigh_full!(driver::Driver, A, Dd, V; fixgauge::Bool = default_fixgauge(), kwargs...)
156-
supports_eigh(driver, $(QuoteNode(f))) ||
157-
throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`")))
158153
$f_lapack!(driver, A, Dd, V; kwargs...)
159154
fixgauge && gaugefix!(eigh_full!, V)
160155
return Dd, V
161156
end
162157
function $f_eigh_vals!(driver::Driver, A, D, V; fixgauge::Bool = default_fixgauge(), kwargs...)
163-
supports_eigh(driver, $(QuoteNode(f))) ||
164-
throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`")))
165158
$f_lapack!(driver, A, D, V; kwargs...)
166159
return D
167160
end
@@ -213,7 +206,15 @@ end
213206

214207
# Deprecations
215208
# ------------
216-
for algtype in (:MultipleRelativelyRobustRepresentations, :DivideAndConquer, :QRIteration, :Bisection)
209+
Base.@deprecate(
210+
eigh_full!(A, DV, alg::LAPACK_MultipleRelativelyRobustRepresentations),
211+
eigh_full!(A, DV, RobustRepresentations(; driver = LAPACK(), alg.kwargs...))
212+
)
213+
Base.@deprecate(
214+
eigh_vals!(A, D, alg::LAPACK_MultipleRelativelyRobustRepresentations),
215+
eigh_vals!(A, D, RobustRepresentations(; driver = LAPACK(), alg.kwargs...))
216+
)
217+
for algtype in (:DivideAndConquer, :QRIteration, :Bisection)
217218
lapack_algtype = Symbol(:LAPACK_, algtype)
218219
@eval begin
219220
Base.@deprecate(

0 commit comments

Comments
 (0)