Skip to content

Commit bf62657

Browse files
committed
change driver defaults for AMD/CUDA
1 parent c454434 commit bf62657

2 files changed

Lines changed: 2 additions & 14 deletions

File tree

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,7 @@ using LinearAlgebra: BlasFloat
1414

1515
include("yarocsolver.jl")
1616

17-
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
18-
return ROCSOLVER_HouseholderQR(; kwargs...)
19-
end
20-
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
21-
qr_alg = ROCSOLVER_HouseholderQR(; kwargs...)
22-
return LQViaTransposedQR(qr_alg)
23-
end
17+
MatrixAlgebraKit.default_householder_driver(::StridedROCMatrix{<:BlasFloat}) = ROCSOLVER()
2418
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
2519
return ROCSOLVER_QRIteration(; kwargs...)
2620
end

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,7 @@ using LinearAlgebra: BlasFloat
1515

1616
include("yacusolver.jl")
1717

18-
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
19-
return CUSOLVER_HouseholderQR(; kwargs...)
20-
end
21-
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
22-
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
23-
return LQViaTransposedQR(qr_alg)
24-
end
18+
MatrixAlgebraKit.default_householder_driver(::StridedCuMatrix{<:BlasFloat}) = CUSOLVER()
2519
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
2620
return CUSOLVER_QRIteration(; kwargs...)
2721
end

0 commit comments

Comments
 (0)