Skip to content

Commit 2d40186

Browse files
authored
Separate Algorithm and Driver for LQ/QR decompositions (#178)
* add Driver types * add `Householder` algorithm * refactor `Housholder` QR implementations * rename HouseholderReflection to avoid name clash * reorganize QR * make QR select Householder by default * merge QR implementations * also implement LQ * update algorithm list * update GLA QR implementation * some cleanup * remove reference to undefined docstring * update algorithm tests * fix JET errors * slight reimplementation * change driver defaults for AMD/CUDA * fix oopsie * make some strong assumptions on type stability * rework default_householder_driver * fix rebase conflicts * rework type stability
1 parent 95dd693 commit 2d40186

11 files changed

Lines changed: 527 additions & 502 deletions

File tree

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,28 @@ using MatrixAlgebraKit
44
using MatrixAlgebraKit: @algdef, Algorithm, check_input
55
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
66
using MatrixAlgebraKit: diagview, sign_safe
7-
using MatrixAlgebraKit: LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm
7+
using MatrixAlgebraKit: ROCSOLVER, LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
9-
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
9+
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
1010
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester, svd_rank
1111
using AMDGPU
1212
using LinearAlgebra
1313
using LinearAlgebra: BlasFloat
1414

1515
include("yarocsolver.jl")
1616

17-
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
18-
return ROCSOLVER_HouseholderQR(; kwargs...)
19-
end
20-
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
21-
qr_alg = ROCSOLVER_HouseholderQR(; kwargs...)
22-
return LQViaTransposedQR(qr_alg)
23-
end
17+
MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedROCMatrix{<:BlasFloat}} = ROCSOLVER()
2418
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
2519
return ROCSOLVER_QRIteration(; kwargs...)
2620
end
2721
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
2822
return ROCSOLVER_DivideAndConquer(; kwargs...)
2923
end
3024

31-
_gpu_geqrf!(A::StridedROCMatrix) = YArocSOLVER.geqrf!(A)
32-
_gpu_ungqr!(A::StridedROCMatrix, τ::StridedROCVector) = YArocSOLVER.ungqr!(A, τ)
33-
_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::StridedROCMatrix, τ::StridedROCVector, C::StridedROCVecOrMat) =
34-
YArocSOLVER.unmqr!(side, trans, A, τ, C)
25+
for f in (:geqrf!, :ungqr!, :unmqr!)
26+
@eval $f(::ROCSOLVER, args...) = YArocSOLVER.$f(args...)
27+
end
28+
3529
_gpu_gesvd!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix) =
3630
YArocSOLVER.gesvd!(A, S, U, Vᴴ)
3731
# not yet supported

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ using MatrixAlgebraKit
44
using MatrixAlgebraKit: @algdef, Algorithm, check_input
55
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
66
using MatrixAlgebraKit: diagview, sign_safe
7-
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
7+
using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
9-
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
9+
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
1010
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _sylvester, svd_rank
1111
using CUDA, CUDA.CUBLAS
1212
using CUDA: i32
@@ -15,13 +15,7 @@ using LinearAlgebra: BlasFloat
1515

1616
include("yacusolver.jl")
1717

18-
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
19-
return CUSOLVER_HouseholderQR(; kwargs...)
20-
end
21-
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
22-
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
23-
return LQViaTransposedQR(qr_alg)
24-
end
18+
MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER()
2519
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
2620
return CUSOLVER_QRIteration(; kwargs...)
2721
end
@@ -32,14 +26,12 @@ function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT
3226
return CUSOLVER_DivideAndConquer(; kwargs...)
3327
end
3428

29+
for f in (:geqrf!, :ungqr!, :unmqr!)
30+
@eval $f(::CUSOLVER, args...) = YACUSOLVER.$f(args...)
31+
end
32+
3533
_gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) =
3634
YACUSOLVER.Xgeev!(A, D, V)
37-
_gpu_geqrf!(A::StridedCuMatrix) =
38-
YACUSOLVER.geqrf!(A)
39-
_gpu_ungqr!(A::StridedCuMatrix, τ::StridedCuVector) =
40-
YACUSOLVER.ungqr!(A, τ)
41-
_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::StridedCuMatrix, τ::StridedCuVector, C::StridedCuVecOrMat) =
42-
YACUSOLVER.unmqr!(side, trans, A, τ, C)
4335
_gpu_gesvd!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix) =
4436
YACUSOLVER.gesvd!(A, S, U, Vᴴ)
4537
_gpu_Xgesvdp!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =

ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl

Lines changed: 28 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
module MatrixAlgebraKitGenericLinearAlgebraExt
22

33
using MatrixAlgebraKit
4-
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, default_fixgauge
5-
using MatrixAlgebraKit: left_orth_alg
4+
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge
65
using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr!
76
using LinearAlgebra: I, Diagonal, lmul!
87

@@ -57,81 +56,65 @@ function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix, D, ::GLA_QRIteration)
5756
return eigvals!(Hermitian(A); sortby = real)
5857
end
5958

60-
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}}
61-
return GLA_HouseholderQR(; kwargs...)
62-
end
63-
64-
function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::GLA_HouseholderQR)
65-
check_input(qr_full!, A, QR, alg)
66-
Q, R = QR
67-
return _gla_householder_qr!(A, Q, R; alg.kwargs...)
68-
end
69-
70-
function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, alg::GLA_HouseholderQR)
71-
check_input(qr_compact!, A, QR, alg)
72-
Q, R = QR
73-
return _gla_householder_qr!(A, Q, R; alg.kwargs...)
74-
end
75-
76-
function MatrixAlgebraKit.qr_null!(A::AbstractMatrix, N, alg::GLA_HouseholderQR)
77-
check_input(qr_null!, A, N, alg)
78-
return _gla_householder_qr_null!(A, N; alg.kwargs...)
79-
end
80-
81-
function _gla_householder_qr!(A::AbstractMatrix, Q, R; positive = true, blocksize = 1, pivoted = false)
82-
pivoted && throw(ArgumentError("Only pivoted = false implemented for GLA_HouseholderQR."))
83-
(blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for GLA_HouseholderQR."))
59+
function MatrixAlgebraKit.householder_qr!(
60+
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
61+
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
62+
)
63+
blocksize == 1 ||
64+
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
65+
pivoted &&
66+
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))
8467

8568
m, n = size(A)
86-
k = min(m, n)
69+
minmn = min(m, n)
70+
computeR = length(R) > 0
71+
72+
# compute QR
8773
Q̃, R̃ = qr!(A)
8874
lmul!(Q̃, MatrixAlgebraKit.one!(Q))
8975

9076
if positive
91-
@inbounds for j in 1:k
77+
@inbounds for j in 1:minmn
9278
s = sign_safe(R̃[j, j])
9379
@simd for i in 1:m
9480
Q[i, j] *= s
9581
end
9682
end
9783
end
9884

99-
computeR = length(R) > 0
10085
if computeR
10186
if positive
10287
@inbounds for j in n:-1:1
103-
@simd for i in 1:min(k, j)
88+
@simd for i in 1:min(minmn, j)
10489
R[i, j] = R̃[i, j] * conj(sign_safe(R̃[i, i]))
10590
end
106-
@simd for i in (min(k, j) + 1):size(R, 1)
91+
@simd for i in (min(minmn, j) + 1):size(R, 1)
10792
R[i, j] = zero(eltype(R))
10893
end
10994
end
11095
else
111-
R[1:k, :] .=
112-
MatrixAlgebraKit.zero!(@view(R[(k + 1):end, :]))
96+
R[1:minmn, :] .=
97+
MatrixAlgebraKit.zero!(@view(R[(minmn + 1):end, :]))
11398
end
11499
end
115100
return Q, R
116101
end
117102

118-
function _gla_householder_qr_null!(
119-
A::AbstractMatrix, N::AbstractMatrix;
120-
positive = true, blocksize = 1, pivoted = false
103+
function MatrixAlgebraKit.householder_qr_null!(
104+
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, N::AbstractMatrix;
105+
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
121106
)
122-
pivoted && throw(ArgumentError("Only pivoted = false implemented for GLA_HouseholderQR."))
123-
(blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for GLA_HouseholderQR."))
107+
blocksize == 1 ||
108+
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
109+
pivoted &&
110+
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))
111+
124112
m, n = size(A)
125113
minmn = min(m, n)
126-
fill!(N, zero(eltype(N)))
114+
zero!(N)
127115
one!(view(N, (minmn + 1):m, 1:(m - minmn)))
128116
Q̃, = qr!(A)
129-
lmul!(Q̃, N)
130-
return N
131-
end
132-
133-
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}}
134-
return MatrixAlgebraKit.LQViaTransposedQR(GLA_HouseholderQR(; kwargs...))
117+
return lmul!(Q̃, N)
135118
end
136119

137120
MatrixAlgebraKit.left_orth_alg(alg::GLA_HouseholderQR) = MatrixAlgebraKit.LeftOrthViaQR(alg)

src/algorithms.jl

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ Finally, the same behavior is obtained when the keyword arguments are
8888
passed as the third positional argument in the form of a `NamedTuple`.
8989
""" select_algorithm
9090

91-
function select_algorithm(f::F, A, alg::Alg = nothing; kwargs...) where {F, Alg}
91+
@inline function select_algorithm(f::F, A, alg::Alg = nothing; kwargs...) where {F, Alg}
9292
if isnothing(alg)
9393
return default_algorithm(f, A; kwargs...)
9494
elseif alg isa Symbol
@@ -143,6 +143,59 @@ If this is not possible, for example when the output size is not known a priori
143143
this function may return `nothing`.
144144
""" initialize_output
145145

146+
147+
# Drivers
148+
# -------
149+
"""
150+
abstract type Driver
151+
152+
Supertype used for customizing various implementations of the same algorithm.
153+
"""
154+
abstract type Driver end
155+
156+
"""
157+
DefaultDriver <: Driver
158+
159+
Select a default driver at runtime, based on the input matrix.
160+
"""
161+
struct DefaultDriver <: Driver end
162+
163+
"""
164+
LAPACK <: Driver
165+
166+
Driver to select LAPACK as the implementation strategy.
167+
"""
168+
struct LAPACK <: Driver end
169+
170+
"""
171+
CUSOLVER <: Driver
172+
173+
Driver to select CUSOLVER as the implementation strategy.
174+
"""
175+
struct CUSOLVER <: Driver end
176+
177+
"""
178+
ROCSOLVER <: Driver
179+
180+
Driver to select ROCSOLVER as the implementation strategy.
181+
"""
182+
struct ROCSOLVER <: Driver end
183+
184+
"""
185+
GLA <: Driver
186+
187+
Driver to select GenericLinearAlgebra.jl as the implementation strategy.
188+
"""
189+
struct GLA <: Driver end
190+
191+
"""
192+
Native <: Driver
193+
194+
Driver to select a native implementation in MatrixAlgebraKit as the implementation strategy.
195+
"""
196+
struct Native <: Driver end
197+
198+
146199
# Truncation strategy
147200
# -------------------
148201
"""

src/common/householder.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,32 @@
11
const IndexRange{T <: Integer} = Base.AbstractRange{T}
22

33
# Elementary Householder reflection
4-
struct Householder{T, V <: AbstractVector, R <: IndexRange}
4+
struct HouseholderReflection{T, V <: AbstractVector, R <: IndexRange}
55
β::T
66
v::V
77
r::R
88
end
9-
Base.adjoint(H::Householder) = Householder(conj(H.β), H.v, H.r)
9+
Base.adjoint(H::HouseholderReflection) = HouseholderReflection(conj(H.β), H.v, H.r)
1010

1111
function householder(x::AbstractVector, r::IndexRange = axes(x, 1), k = first(r))
1212
i = findfirst(==(k), r)
1313
i == nothing && error("k = $k should be in the range r = $r")
1414
β, v, ν = _householder!(x[r], i)
15-
return Householder(β, v, r), ν
15+
return HouseholderReflection(β, v, r), ν
1616
end
1717
# Householder reflector h that zeros the elements A[r,col] (except for A[k,col]) upon lmul!(h,A)
1818
function householder(A::AbstractMatrix, r::IndexRange, col::Int, k = first(r))
1919
i = findfirst(==(k), r)
2020
i == nothing && error("k = $k should be in the range r = $r")
2121
β, v, ν = _householder!(A[r, col], i)
22-
return Householder(β, v, r), ν
22+
return HouseholderReflection(β, v, r), ν
2323
end
2424
# Householder reflector that zeros the elements A[row,r] (except for A[row,k]) upon rmul!(A,h')
2525
function householder(A::AbstractMatrix, row::Int, r::IndexRange, k = first(r))
2626
i = findfirst(==(k), r)
2727
i == nothing && error("k = $k should be in the range r = $r")
2828
β, v, ν = _householder!(conj!(A[row, r]), i)
29-
return Householder(β, v, r), ν
29+
return HouseholderReflection(β, v, r), ν
3030
end
3131

3232
# generate Householder vector based on vector v, such that applying the reflection
@@ -66,7 +66,7 @@ function _householder!(v::AbstractVector{T}, i::Int = 1) where {T}
6666
return β, v, ν
6767
end
6868

69-
function LinearAlgebra.lmul!(H::Householder, x::AbstractVector)
69+
function LinearAlgebra.lmul!(H::HouseholderReflection, x::AbstractVector)
7070
v = H.v
7171
r = H.r
7272
β = H.β
@@ -87,7 +87,7 @@ function LinearAlgebra.lmul!(H::Householder, x::AbstractVector)
8787
end
8888
return x
8989
end
90-
function LinearAlgebra.lmul!(H::Householder, A::AbstractMatrix; cols = axes(A, 2))
90+
function LinearAlgebra.lmul!(H::HouseholderReflection, A::AbstractMatrix; cols = axes(A, 2))
9191
v = H.v
9292
r = H.r
9393
β = H.β
@@ -110,7 +110,7 @@ function LinearAlgebra.lmul!(H::Householder, A::AbstractMatrix; cols = axes(A, 2
110110
end
111111
return A
112112
end
113-
function LinearAlgebra.rmul!(A::AbstractMatrix, H::Householder; rows = axes(A, 1))
113+
function LinearAlgebra.rmul!(A::AbstractMatrix, H::HouseholderReflection; rows = axes(A, 1))
114114
v = H.v
115115
r = H.r
116116
β = H.β

0 commit comments

Comments
 (0)