Skip to content

Commit dc44f9c

Browse files
committed
default_driver
1 parent 27beef0 commit dc44f9c

7 files changed

Lines changed: 37 additions & 26 deletions

File tree

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@ using LinearAlgebra: BlasFloat
1414

1515
include("yarocsolver.jl")
1616

17-
MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedROCVecOrMat{<:BlasFloat}} = ROCSOLVER()
18-
MatrixAlgebraKit.default_qr_iteration_driver(::Type{<:StridedROCVecOrMat}) = ROCSOLVER()
19-
MatrixAlgebraKit.default_jacobi_driver(::Type{<:StridedROCVecOrMat}) = ROCSOLVER()
17+
MatrixAlgebraKit.default_driver(::Type{TA}) where {TA <: StridedROCVecOrMat{<:BlasFloat}} = ROCSOLVER()
18+
2019
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat{<:BlasFloat}}
2120
return QRIteration(; kwargs...)
2221
end

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@ using LinearAlgebra: BlasFloat
1515

1616
include("yacusolver.jl")
1717

18-
MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER()
19-
MatrixAlgebraKit.default_qr_iteration_driver(::Type{<:StridedCuVecOrMat{<:BlasFloat}}) = CUSOLVER()
20-
MatrixAlgebraKit.default_jacobi_driver(::Type{<:StridedCuVecOrMat{<:BlasFloat}}) = CUSOLVER()
21-
MatrixAlgebraKit.default_svd_polar_driver(::Type{<:StridedCuVecOrMat{<:BlasFloat}}) = CUSOLVER()
18+
MatrixAlgebraKit.default_driver(::Type{TA}) where {TA <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER()
19+
2220
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
2321
return QRIteration(; kwargs...)
2422
end

ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ import MatrixAlgebraKit: gesvd!
77
using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr!
88
using LinearAlgebra: I, Diagonal, lmul!
99

10-
MatrixAlgebraKit.default_qr_iteration_driver(::Type{<:StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}) = GLA()
10+
const GlaFloat = Union{BigFloat, Complex{BigFloat}}
11+
const GlaStridedVecOrMatrix{T <: GlaFloat} = Union{StridedVector{T}, StridedMatrix{T}}
12+
MatrixAlgebraKit.default_driver(::Type{<:QRIteration}, ::Type{TA}) where {TA <: GlaStridedVecOrMatrix} = GLA()
1113

12-
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
14+
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix}
1315
return QRIteration(; kwargs...)
1416
end
1517

@@ -30,7 +32,7 @@ function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix,
3032
return S, U, Vᴴ
3133
end
3234

33-
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
35+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix}
3436
return GLA_QRIteration(; kwargs...)
3537
end
3638

src/algorithms.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,23 @@ Driver to select a native implementation in MatrixAlgebraKit as the implementati
212212
"""
213213
struct Native <: Driver end
214214

215+
# In order to avoid amibiguities, this method is implemented in a tiered way
216+
# default_driver(alg, A) -> default_driver(typeof(alg), typeof(A))
217+
# default_driver(Talg, TA) -> default_driver(TA)
218+
# This is to try and minimize ambiguity while allowing overloading at multiple levels
219+
@inline default_driver(alg::AbstractAlgorithm, A) = default_driver(typeof(alg), A isa Type ? A : typeof(A))
220+
@inline default_driver(::Type{Alg}, A) where {Alg <: AbstractAlgorithm} = default_driver(Alg, typeof(A))
221+
@inline default_driver(::Type{Alg}, ::Type{TA}) where {Alg <: AbstractAlgorithm, TA} = default_driver(TA)
222+
223+
# defaults
224+
default_driver(::Type{TA}) where {TA <: AbstractArray} = Native() # default fallback
225+
default_driver(::Type{TA}) where {TA <: YALAPACK.MaybeBlasVecOrMat} = LAPACK()
226+
227+
# wrapper types
228+
@inline default_driver(::Type{Alg}, ::Type{<:SubArray{T, N, A}}) where {Alg <: AbstractAlgorithm, T, N, A} = default_driver(Alg, A)
229+
@inline default_driver(::Type{Alg}, ::Type{<:Base.ReshapedArray{T, N, A}}) where {Alg <: AbstractAlgorithm, T, N, A} = default_driver(Alg, A)
230+
@inline default_driver(::Type{<:SubArray{T, N, A}}) where {T, N, A} = default_driver(A)
231+
@inline default_driver(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} = default_driver(A)
215232

216233
# Truncation strategy
217234
# -------------------

src/implementations/lq.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ end
115115
@inline householder_lq!(A, L, Q; driver::Driver = DefaultDriver(), kwargs...) =
116116
householder_lq!(driver, A, L, Q; kwargs...)
117117
householder_lq!(::DefaultDriver, A, L, Q; kwargs...) =
118-
householder_lq!(default_householder_driver(A), A, L, Q; kwargs...)
118+
householder_lq!(default_driver(Householder, A), A, L, Q; kwargs...)
119119
householder_lq!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, L, Q; kwargs...) =
120120
lq_via_qr!(A, L, Q, Householder(; driver, kwargs...))
121121
function householder_lq!(
@@ -221,7 +221,7 @@ end
221221
@inline householder_lq_null!(A, Nᴴ; driver::Driver = DefaultDriver(), kwargs...) =
222222
householder_lq_null!(driver, A, Nᴴ; kwargs...)
223223
householder_lq_null!(::DefaultDriver, A, Nᴴ; kwargs...) =
224-
householder_lq_null!(default_householder_driver(A), A, Nᴴ; kwargs...)
224+
householder_lq_null!(default_driver(Householder, A), A, Nᴴ; kwargs...)
225225
householder_lq_null!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, Nᴴ; kwargs...) =
226226
lq_null_via_qr!(A, Nᴴ, Householder(; driver, kwargs...))
227227
function householder_lq_null!(

src/implementations/qr.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ end
117117
@inline householder_qr!(A, Q, R; driver::Driver = DefaultDriver(), kwargs...) =
118118
householder_qr!(driver, A, Q, R; kwargs...)
119119
householder_qr!(::DefaultDriver, A, Q, R; kwargs...) =
120-
householder_qr!(default_householder_driver(A), A, Q, R; kwargs...)
120+
householder_qr!(default_driver(Householder, A), A, Q, R; kwargs...)
121121
function householder_qr!(
122122
driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
123123
positive::Bool = true, pivoted::Bool = false,
@@ -248,7 +248,7 @@ end
248248
@inline householder_qr_null!(A, N; driver::Driver = DefaultDriver(), kwargs...) =
249249
householder_qr_null!(driver, A, N; kwargs...)
250250
householder_qr_null!(::DefaultDriver, A, N; kwargs...) =
251-
householder_qr_null!(default_householder_driver(A), A, N; kwargs...)
251+
householder_qr_null!(default_driver(Householder, A), A, N; kwargs...)
252252
function householder_qr_null!(
253253
driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, N::AbstractMatrix;
254254
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0

src/implementations/svd.jl

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -171,18 +171,13 @@ for (f, f_lapack!, Alg) in (
171171

172172
# driver
173173
@eval begin
174-
@inline $f_svd!(A, U, S, Vᴴ; driver::Driver = DefaultDriver(), kwargs...) =
175-
$f_svd!(driver, A, U, S, Vᴴ; kwargs...)
176-
@inline $f_svd_full!(A, U, S, Vᴴ; driver::Driver = DefaultDriver(), kwargs...) =
177-
$f_svd_full!(driver, A, U, S, Vᴴ; kwargs...)
178-
@inline $f_svd_vals!(A, S; driver::Driver = DefaultDriver(), kwargs...) =
179-
$f_svd_vals!(driver, A, S; kwargs...)
180-
@inline $f_svd!(::DefaultDriver, A, U, S, Vᴴ; kwargs...) =
181-
$f_svd!($(Symbol(:default_, f, :_driver))(A), A, U, S, Vᴴ; kwargs...)
182-
@inline $f_svd_full!(::DefaultDriver, A, U, S, Vᴴ; kwargs...) =
183-
$f_svd_full!($(Symbol(:default_, f, :_driver))(A), A, U, S, Vᴴ; kwargs...)
184-
@inline $f_svd_vals!(::DefaultDriver, A, S; kwargs...) =
185-
$f_svd_vals!($(Symbol(:default_, f, :_driver))(A), A, S; kwargs...)
174+
@inline $f_svd!(A, U, S, Vᴴ; driver::Driver = DefaultDriver(), kwargs...) = $f_svd!(driver, A, U, S, Vᴴ; kwargs...)
175+
@inline $f_svd_full!(A, U, S, Vᴴ; driver::Driver = DefaultDriver(), kwargs...) = $f_svd_full!(driver, A, U, S, Vᴴ; kwargs...)
176+
@inline $f_svd_vals!(A, S; driver::Driver = DefaultDriver(), kwargs...) = $f_svd_vals!(driver, A, S; kwargs...)
177+
178+
@inline $f_svd!(::DefaultDriver, A, U, S, Vᴴ; kwargs...) = $f_svd!(default_driver($Alg, A), A, U, S, Vᴴ; kwargs...)
179+
@inline $f_svd_full!(::DefaultDriver, A, U, S, Vᴴ; kwargs...) = $f_svd_full!(default_driver($Alg, A), A, U, S, Vᴴ; kwargs...)
180+
@inline $f_svd_vals!(::DefaultDriver, A, S; kwargs...) = $f_svd_vals!(default_driver($Alg, A), A, S; kwargs...)
186181
end
187182

188183
# Implementation

0 commit comments

Comments
 (0)