Skip to content

Commit 1f5fea2

Browse files
committed
Updates for TensorKit compatibility
1 parent c9727e4 commit 1f5fea2

3 files changed

Lines changed: 28 additions & 9 deletions

File tree

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,42 @@ import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gp
1010
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!
1111
using CUDA
1212
using LinearAlgebra
13-
using LinearAlgebra: BlasFloat
13+
using LinearAlgebra: BlasFloat, eigvals!
1414

1515
include("yacusolver.jl")
1616

17-
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
17+
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT<:BlasFloat, T<:StridedCuMatrix{TT}}
1818
return CUSOLVER_HouseholderQR(; kwargs...)
1919
end
20-
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
20+
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT<:BlasFloat, T<:StridedCuMatrix{TT}}
2121
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
2222
return LQViaTransposedQR(qr_alg)
2323
end
24-
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
24+
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT<:BlasFloat, T<:StridedCuMatrix{TT}}
2525
return CUSOLVER_QRIteration(; kwargs...)
2626
end
27-
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
27+
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT<:BlasFloat, T<:StridedCuMatrix{TT}}
2828
return CUSOLVER_Simple(; kwargs...)
2929
end
30-
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
30+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT<:BlasFloat, T<:StridedCuMatrix{TT}}
31+
return CUSOLVER_DivideAndConquer(; kwargs...)
32+
end
33+
34+
# include for block sector support
35+
function MatrixAlgebraKit.default_qr_algorithm(::Type{Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}}; kwargs...) where {T<:BlasFloat, A<:CuVecOrMat{T}}
36+
return CUSOLVER_HouseholderQR(; kwargs...)
37+
end
38+
function MatrixAlgebraKit.default_lq_algorithm(::Type{Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}}; kwargs...) where {T<:BlasFloat, A<:CuVecOrMat{T}}
39+
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
40+
return LQViaTransposedQR(qr_alg)
41+
end
42+
function MatrixAlgebraKit.default_svd_algorithm(::Type{Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}}; kwargs...) where {T<:BlasFloat, A<:CuVecOrMat{T}}
43+
return CUSOLVER_QRIteration(; kwargs...)
44+
end
45+
function MatrixAlgebraKit.default_eig_algorithm(::Type{Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}}; kwargs...) where {T<:BlasFloat, A<:CuVecOrMat{T}}
46+
return CUSOLVER_Simple(; kwargs...)
47+
end
48+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}}; kwargs...) where {T<:BlasFloat, A<:CuVecOrMat{T}}
3149
return CUSOLVER_DivideAndConquer(; kwargs...)
3250
end
3351

ext/MatrixAlgebraKitCUDAExt/yacusolver.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ for (bname, fname, elty, relty) in
2828
#! format: on
2929
chkstride1(A, U, Vᴴ, S)
3030
m, n = size(A)
31-
(m < n) && throw(ArgumentError("CUSOLVER's gesvd requires m ≥ n"))
31+
(m < n) && throw(ArgumentError("CUSOLVER's gesvd requires m ($m) ≥ n ($n)"))
3232
minmn = min(m, n)
3333
if length(U) == 0
3434
jobu = 'N'

src/implementations/eig.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,9 @@ function eig_vals!(A::AbstractMatrix, D, alg::GPU_EigAlgorithm)
102102
check_input(eig_vals!, A, D, alg)
103103
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
104104
if alg isa GPU_Simple
105-
isempty(alg.kwargs) ||
106-
throw(ArgumentError("LAPACK_Simple (geev) does not accept any keyword arguments"))
105+
# TODO filter out nothing kwargs
106+
#isempty(alg.kwargs) ||
107+
# throw(ArgumentError("GPU_Simple (geev) does not accept any keyword arguments"))
107108
_gpu_geev!(A, D, V)
108109
end
109110
return D

0 commit comments

Comments
 (0)