Skip to content

Commit 6cc7924

Browse files
lkdvosJutho
andauthored
Separate Algorithm and Driver - part III (Schur-Eig) (#194)
* Add eigenvalue algorithm types * Add implementations * docstring updates * also implement extensions * Also implement Schur decomposition * slight reorganization of `supports_f` * format * small improvement * various renaming and cleanup * also remove `supports_svd` * docs update * one more round of eig changes * update docstring * Update decompositions.jl Co-authored-by: Jutho <Jutho@users.noreply.github.com> * special care for deprecated * fix ambiguity --------- Co-authored-by: Jutho <Jutho@users.noreply.github.com>
1 parent e95d486 commit 6cc7924

19 files changed

Lines changed: 430 additions & 310 deletions

docs/src/user_interface/decompositions.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ The following algorithms are available for the hermitian eigenvalue decompositio
8181

8282
```@autodocs; canonical=false
8383
Modules = [MatrixAlgebraKit]
84-
Filter = t -> t isa Type && t <: MatrixAlgebraKit.LAPACK_EighAlgorithm
84+
Filter = t -> t isa Type && t <: MatrixAlgebraKit.EighAlgorithms
8585
```
8686

8787
### Eigenvalue Decomposition
@@ -103,7 +103,7 @@ The following algorithms are available for the standard eigenvalue decomposition
103103

104104
```@autodocs; canonical=false
105105
Modules = [MatrixAlgebraKit]
106-
Filter = t -> t isa Type && t <: MatrixAlgebraKit.LAPACK_EigAlgorithm
106+
Filter = t -> t isa Type && t <: MatrixAlgebraKit.EigAlgorithms
107107
```
108108

109109
## Schur Decomposition
@@ -123,7 +123,7 @@ The following algorithms are available for the Schur decomposition:
123123

124124
```@autodocs; canonical=false
125125
Modules = [MatrixAlgebraKit]
126-
Filter = t -> t isa Type && t <: MatrixAlgebraKit.LAPACK_EigAlgorithm
126+
Filter = t -> t isa Type && t <: MatrixAlgebraKit.SchurAlgorithms
127127
```
128128

129129
## Singular Value Decomposition

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ using MatrixAlgebraKit: diagview, sign_safe
77
using MatrixAlgebraKit: ROCSOLVER, LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdj!
10-
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester, svd_rank
10+
import MatrixAlgebraKit: heevj!, heevd!, heev!, heevx!
11+
import MatrixAlgebraKit: _sylvester, svd_rank
1112
using AMDGPU
1213
using LinearAlgebra
1314
using LinearAlgebra: BlasFloat
@@ -20,14 +21,13 @@ function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <
2021
return QRIteration(; kwargs...)
2122
end
2223
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat{<:BlasFloat}}
23-
return ROCSOLVER_DivideAndConquer(; kwargs...)
24+
return DivideAndConquer(; kwargs...)
2425
end
2526

2627
for f in (:geqrf!, :ungqr!, :unmqr!)
2728
@eval $f(::ROCSOLVER, args...) = YArocSOLVER.$f(args...)
2829
end
2930

30-
MatrixAlgebraKit.supports_svd(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi)
3131
MatrixAlgebraKit.supports_svd_full(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi)
3232

3333
function gesvd!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...)
@@ -42,13 +42,13 @@ function gesvdj!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::Strid
4242
return MatrixAlgebraKit.svd_via_adjoint!(gesvdj!, ROCSOLVER(), A, S, U, Vᴴ; kwargs...)
4343
end
4444

45-
_gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
45+
heevj!(::ROCSOLVER, A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
4646
YArocSOLVER.heevj!(A, Dd, V; kwargs...)
47-
_gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
47+
heevd!(::ROCSOLVER, A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
4848
YArocSOLVER.heevd!(A, Dd, V; kwargs...)
49-
_gpu_heev!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
49+
heev!(::ROCSOLVER, A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
5050
YArocSOLVER.heev!(A, Dd, V; kwargs...)
51-
_gpu_heevx!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
51+
heevx!(::ROCSOLVER, A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
5252
YArocSOLVER.heevx!(A, Dd, V; kwargs...)
5353

5454
function MatrixAlgebraKit.findtruncated_svd(values::StridedROCVector, strategy::TruncationByValue)

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
66
using MatrixAlgebraKit: diagview, sign_safe
77
using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
9-
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!, _gpu_geev!
10-
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_Xgesvdr!, _sylvester, svd_rank
9+
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!
10+
import MatrixAlgebraKit: heevj!, heevd!, geev!
11+
import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank
1112
using CUDA, CUDA.CUBLAS
1213
using CUDA: i32
1314
using LinearAlgebra
@@ -21,18 +22,17 @@ function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <
2122
return QRIteration(; kwargs...)
2223
end
2324
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
24-
return CUSOLVER_Simple(; kwargs...)
25+
return QRIteration(; kwargs...)
2526
end
2627
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
27-
return CUSOLVER_DivideAndConquer(; kwargs...)
28+
return DivideAndConquer(; kwargs...)
2829
end
2930

3031

3132
for f in (:geqrf!, :ungqr!, :unmqr!)
3233
@eval $f(::CUSOLVER, args...) = YACUSOLVER.$f(args...)
3334
end
3435

35-
MatrixAlgebraKit.supports_svd(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar)
3636
MatrixAlgebraKit.supports_svd_full(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar)
3737

3838
function gesvd!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...)
@@ -53,12 +53,12 @@ gesvdp!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix,
5353
_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
5454
YACUSOLVER.gesvdr!(A, S, U, Vᴴ; kwargs...)
5555

56-
_gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) =
57-
YACUSOLVER.Xgeev!(A, D, V)
56+
geev!(::CUSOLVER, A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix) =
57+
YACUSOLVER.Xgeev!(A, Dd, V)
5858

59-
_gpu_heevj!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) =
59+
heevj!(::CUSOLVER, A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) =
6060
YACUSOLVER.heevj!(A, Dd, V; kwargs...)
61-
_gpu_heevd!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) =
61+
heevd!(::CUSOLVER, A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) =
6262
YACUSOLVER.heevd!(A, Dd, V; kwargs...)
6363

6464
function MatrixAlgebraKit.findtruncated_svd(values::StridedCuVector, strategy::TruncationByValue)

ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,22 @@ module MatrixAlgebraKitGenericLinearAlgebraExt
22

33
using MatrixAlgebraKit
44
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge
5-
using MatrixAlgebraKit: GLA
6-
import MatrixAlgebraKit: gesvd!
5+
using MatrixAlgebraKit: GLA, Driver
6+
import MatrixAlgebraKit: gesvd!, heev!
77
using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr!
88
using LinearAlgebra: I, Diagonal, lmul!
99

1010
const GlaFloat = Union{BigFloat, Complex{BigFloat}}
1111
const GlaStridedVecOrMatrix{T <: GlaFloat} = Union{StridedVector{T}, StridedMatrix{T}}
1212
MatrixAlgebraKit.default_driver(::Type{<:QRIteration}, ::Type{TA}) where {TA <: GlaStridedVecOrMatrix} = GLA()
1313

14-
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix}
15-
return QRIteration(; kwargs...)
14+
MatrixAlgebraKit.supports_svd_full(::GLA, f::Symbol) = f === :qr_iteration
15+
16+
function MatrixAlgebraKit.default_svd_algorithm(
17+
::Type{T};
18+
driver::Driver = GLA(), kwargs...
19+
) where {T <: GlaStridedVecOrMatrix}
20+
return QRIteration(; driver, kwargs...)
1621
end
1722

1823
function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...)
@@ -32,20 +37,20 @@ function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix,
3237
return S, U, Vᴴ
3338
end
3439

35-
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix}
36-
return GLA_QRIteration(; kwargs...)
37-
end
38-
39-
MatrixAlgebraKit.initialize_output(::typeof(eigh_full!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing)
40-
MatrixAlgebraKit.initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing
41-
42-
function MatrixAlgebraKit.eigh_full!(A::AbstractMatrix, DV, ::GLA_QRIteration)
43-
eigval, eigvec = eigen!(Hermitian(A); sortby = real)
44-
return Diagonal(eigval::AbstractVector{real(eltype(A))}), eigvec::AbstractMatrix{eltype(A)}
40+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; driver::Driver = GLA(), kwargs...) where {T <: GlaStridedVecOrMatrix}
41+
return QRIteration(; driver, kwargs...)
4542
end
4643

47-
function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix, D, ::GLA_QRIteration)
48-
return eigvals!(Hermitian(A); sortby = real)
44+
function heev!(::GLA, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...)
45+
if length(V) > 0
46+
eigval, eigvec = eigen!(Hermitian(A); sortby = real)
47+
copyto!(Dd, eigval)
48+
copyto!(V, eigvec)
49+
else
50+
eigval = eigvals!(Hermitian(A); sortby = real)
51+
copyto!(Dd, eigval)
52+
end
53+
return Dd, V
4954
end
5055

5156
function MatrixAlgebraKit.householder_qr!(
Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,50 @@
11
module MatrixAlgebraKitGenericSchurExt
22

33
using MatrixAlgebraKit
4-
using MatrixAlgebraKit: check_input
4+
using MatrixAlgebraKit: check_input, GS, Driver
5+
import MatrixAlgebraKit: geev!, geevx!, gees!, eig_full!, eig_vals!, schur_full!, schur_vals!
56
using LinearAlgebra: Diagonal, sorteig!
67
using GenericSchur
78

8-
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}}
9-
return GS_QRIteration(; kwargs...)
10-
end
11-
12-
MatrixAlgebraKit.initialize_output(::typeof(eig_full!), A::AbstractMatrix, ::GS_QRIteration) = (nothing, nothing)
13-
MatrixAlgebraKit.initialize_output(::typeof(eig_vals!), A::AbstractMatrix, ::GS_QRIteration) = nothing
14-
15-
function MatrixAlgebraKit.eig_full!(A::AbstractMatrix, DV, ::GS_QRIteration)
16-
D, V = GenericSchur.eigen!(A)
17-
return Diagonal(D), V
18-
end
9+
const GSFloat = Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}
1910

20-
function MatrixAlgebraKit.eig_vals!(A::AbstractMatrix, D, ::GS_QRIteration)
21-
return GenericSchur.eigvals!(A)
11+
function MatrixAlgebraKit.default_eig_algorithm(
12+
::Type{T}; driver::Driver = GS(), kwargs...
13+
) where {T <: StridedMatrix{<:GSFloat}}
14+
return QRIteration(; driver, kwargs...)
2215
end
2316

24-
function MatrixAlgebraKit.schur_full!(A::AbstractMatrix, TZv, alg::GS_QRIteration)
25-
check_input(schur_full!, A, TZv, alg)
26-
T, Z, vals = TZv
27-
S = GenericSchur.gschur(A)
28-
copyto!(T, S.T)
29-
copyto!(Z, S.Z)
30-
copyto!(vals, S.values)
31-
return T, Z, vals
17+
function geev!(::GS, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...)
18+
D, Vmat = GenericSchur.eigen!(A)
19+
copyto!(Dd, D)
20+
length(V) > 0 && copyto!(V, Vmat)
21+
return Dd, V
3222
end
3323

34-
function MatrixAlgebraKit.schur_vals!(A::AbstractMatrix, vals, alg::GS_QRIteration)
35-
check_input(schur_vals!, A, vals, alg)
24+
function gees!(driver::GS, A::AbstractMatrix, Z::AbstractMatrix, vals::AbstractVector)
3625
S = GenericSchur.gschur(A)
26+
copyto!(A, S.T)
27+
length(Z) > 0 && copyto!(Z, S.Z)
3728
copyto!(vals, sorteig!(S.values))
38-
return vals
29+
return A, Z, vals
3930
end
4031

32+
Base.@deprecate(
33+
eig_full!(A, DV, alg::GS_QRIteration),
34+
eig_full!(A, DV, QRIteration(; driver = GS(), alg.kwargs...))
35+
)
36+
Base.@deprecate(
37+
eig_vals!(A, D, alg::GS_QRIteration),
38+
eig_vals!(A, D, QRIteration(; driver = GS(), alg.kwargs...))
39+
)
40+
41+
Base.@deprecate(
42+
schur_full!(A, TZv, alg::GS_QRIteration),
43+
schur_full!(A, TZv, QRIteration(; driver = GS(), alg.kwargs...))
44+
)
45+
Base.@deprecate(
46+
schur_vals!(A, vals, alg::GS_QRIteration),
47+
schur_vals!(A, vals, QRIteration(; driver = GS(), alg.kwargs...))
48+
)
49+
4150
end

src/MatrixAlgebraKit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ export left_orth!, right_orth!, left_null!, right_null!
3333

3434
export Householder, Native_HouseholderQR, Native_HouseholderLQ
3535
export DivideAndConquer, SafeDivideAndConquer, QRIteration, Bisection, Jacobi, SVDViaPolar
36+
export RobustRepresentations
3637
export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
3738
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
3839
LAPACK_DivideAndConquer, LAPACK_Jacobi, LAPACK_SafeDivideAndConquer

src/algorithms.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,13 @@ Driver to select a native implementation in MatrixAlgebraKit as the implementati
212212
"""
213213
struct Native <: Driver end
214214

215+
"""
216+
GS <: Driver
217+
218+
Driver to select GenericSchur.jl as the implementation strategy.
219+
"""
220+
struct GS <: Driver end
221+
215222
# In order to avoid amibiguities, this method is implemented in a tiered way
216223
# default_driver(alg, A) -> default_driver(typeof(alg), typeof(A))
217224
# default_driver(Talg, TA) -> default_driver(TA)

0 commit comments

Comments
 (0)