Skip to content

Commit 4e208e3

Browse files
committed
change driver defaults for AMD/CUDA
1 parent e879f40 commit 4e208e3

2 files changed

Lines changed: 8 additions & 25 deletions

File tree

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,7 @@ using LinearAlgebra: BlasFloat
1414

1515
include("yarocsolver.jl")
1616

17-
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
18-
return ROCSOLVER_HouseholderQR(; kwargs...)
19-
end
20-
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
21-
qr_alg = ROCSOLVER_HouseholderQR(; kwargs...)
22-
return LQViaTransposedQR(qr_alg)
23-
end
17+
MatrixAlgebraKit.default_householder_driver(::StridedROCMatrix{<:BlasFloat}) = ROCSOLVER()
2418
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
2519
return ROCSOLVER_QRIteration(; kwargs...)
2620
end

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,7 @@ using LinearAlgebra: BlasFloat
1515

1616
include("yacusolver.jl")
1717

18-
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
19-
return CUSOLVER_HouseholderQR(; kwargs...)
20-
end
21-
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
22-
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
23-
return LQViaTransposedQR(qr_alg)
24-
end
18+
MatrixAlgebraKit.default_householder_driver(::StridedCuMatrix{<:BlasFloat}) = CUSOLVER()
2519
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
2620
return CUSOLVER_QRIteration(; kwargs...)
2721
end
@@ -32,22 +26,17 @@ function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT
3226
return CUSOLVER_DivideAndConquer(; kwargs...)
3327
end
3428

35-
3629
# include for block sector support
37-
function MatrixAlgebraKit.default_qr_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
38-
return CUSOLVER_HouseholderQR(; kwargs...)
39-
end
40-
function MatrixAlgebraKit.default_lq_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
41-
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
42-
return LQViaTransposedQR(qr_alg)
43-
end
44-
function MatrixAlgebraKit.default_svd_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
30+
const BlockView{T, A} = Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}}
31+
32+
MatrixAlgebraKit.default_householder_driver(::BlockView{T, A}) where {T <: BlasFloat, A <: CuVecOrMat{T}} = CUSOLVER()
33+
function MatrixAlgebraKit.default_svd_algorithm(::Type{BlockView{T, A}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
4534
return CUSOLVER_Jacobi(; kwargs...)
4635
end
47-
function MatrixAlgebraKit.default_eig_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
36+
function MatrixAlgebraKit.default_eig_algorithm(::Type{BlockView{T, A}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
4837
return CUSOLVER_Simple(; kwargs...)
4938
end
50-
function MatrixAlgebraKit.default_eigh_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
39+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{BlockView{T, A}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
5140
return CUSOLVER_DivideAndConquer(; kwargs...)
5241
end
5342

0 commit comments

Comments
 (0)