Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ Zygote = "0.7"
julia = "1.10"

[extras]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand Down
31 changes: 25 additions & 6 deletions ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,49 @@ using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!
using CUDA
using CUDA, CUDA.CUBLAS
using CUDA: i32
using LinearAlgebra
using LinearAlgebra: BlasFloat

using CUDA: i32

Comment thread
kshyatt marked this conversation as resolved.
Outdated
include("yacusolver.jl")

function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix}
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
return CUSOLVER_HouseholderQR(; kwargs...)
end
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix}
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
return LQViaTransposedQR(qr_alg)
end
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix}
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
return CUSOLVER_QRIteration(; kwargs...)
end
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix}
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
return CUSOLVER_Simple(; kwargs...)
end
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedCuMatrix}
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
return CUSOLVER_DivideAndConquer(; kwargs...)
end

# include for block sector support
function MatrixAlgebraKit.default_qr_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
return CUSOLVER_HouseholderQR(; kwargs...)
end
function MatrixAlgebraKit.default_lq_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
return LQViaTransposedQR(qr_alg)
end
function MatrixAlgebraKit.default_svd_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
return CUSOLVER_Jacobi(; kwargs...)
end
function MatrixAlgebraKit.default_eig_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
return CUSOLVER_Simple(; kwargs...)
end
function MatrixAlgebraKit.default_eigh_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
return CUSOLVER_DivideAndConquer(; kwargs...)
end

_gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) =
YACUSOLVER.Xgeev!(A, D, V)
Expand Down
7 changes: 5 additions & 2 deletions ext/MatrixAlgebraKitCUDAExt/yacusolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ for (bname, fname, elty, relty) in
)
chkstride1(A, U, Vᴴ, S)
m, n = size(A)
(m < n) && throw(ArgumentError("CUSOLVER's gesvd requires m ≥ n"))
(m < n) && throw(ArgumentError(lazy"CUSOLVER's gesvd requires m ($m) ≥ n ($n)"))
minmn = min(m, n)
if length(U) == 0
jobu = 'N'
Expand Down Expand Up @@ -191,14 +191,17 @@ for (bname, fname, elty, relty) in
(:cusolverDnZgesvdj_bufferSize, :cusolverDnZgesvdj, :ComplexF64, :Float64),
)
@eval begin
#! format: off
function gesvdj!(
A::StridedCuMatrix{$elty},
S::StridedCuVector{$relty} = similar(A, $relty, min(size(A)...)),
U::StridedCuMatrix{$elty} = similar(A, $elty, size(A, 1), min(size(A)...)),
Vᴴ::StridedCuMatrix{$elty} = similar(A, $elty, min(size(A)...), size(A, 2));
tol::$relty = eps($relty),
max_sweeps::Int = 100
max_sweeps::Int = 100,
kwargs...
)
#! format: on
chkstride1(A, U, Vᴴ, S)
m, n = size(A)
minmn = min(m, n)
Expand Down
5 changes: 3 additions & 2 deletions src/implementations/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,9 @@ function eig_vals!(A::AbstractMatrix, D, alg::GPU_EigAlgorithm)
check_input(eig_vals!, A, D, alg)
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
if alg isa GPU_Simple
isempty(alg.kwargs) ||
throw(ArgumentError("LAPACK_Simple (geev) does not accept any keyword arguments"))
# TODO filter out nothing kwargs
#isempty(alg.kwargs) ||
# throw(ArgumentError("GPU_Simple (geev) does not accept any keyword arguments"))
Comment thread
kshyatt marked this conversation as resolved.
Outdated
_gpu_geev!(A, D, V)
end
return D
Expand Down
1 change: 1 addition & 0 deletions src/implementations/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::DiagonalA
@check_scalar(V, A)
return nothing
end

function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, alg::DiagonalAlgorithm)
check_hermitian(A, alg)
@assert isdiag(A)
Expand Down
2 changes: 1 addition & 1 deletion src/implementations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ end
_argmaxabs(x) = reduce(_largest, x; init = zero(eltype(x)))
_largest(x, y) = abs(x) < abs(y) ? y : x

function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm)
function svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm)
check_input(svd_vals!, A, S, alg)
U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0))
if alg isa GPU_QRIteration
Expand Down
2 changes: 1 addition & 1 deletion test/amd/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64)
@test N isa ROCMatrix{T} && size(N) == (m, m - minmn)
@test V * C ≈ A
@test isisometric(V)
@test norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
@test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T)
@test isisometric(N)
hV = collect(V)
hN = collect(N)
Expand Down
2 changes: 0 additions & 2 deletions test/amd/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ using AMDGPU
k = min(m, n)
svd_algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
@testset "algorithm $svd_alg" for svd_alg in svd_algs
n < m && svd_alg isa ROCSOLVER_QRIteration && continue
A = ROCArray(randn(rng, T, m, n))
alg = PolarViaSVD(svd_alg)
W, P = left_polar(A; alg)
Expand Down Expand Up @@ -52,7 +51,6 @@ end
k = min(m, n)
svd_algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
@testset "algorithm $svd_alg" for svd_alg in svd_algs
n > m && svd_alg isa ROCSOLVER_QRIteration && continue
A = ROCArray(randn(rng, T, m, n))
alg = PolarViaSVD(svd_alg)
P, Wᴴ = right_polar(A; alg)
Expand Down
2 changes: 0 additions & 2 deletions test/cuda/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ using CUDA
k = min(m, n)
svd_algs = (CUSOLVER_QRIteration(), CUSOLVER_Jacobi())
@testset "algorithm $svd_alg" for svd_alg in svd_algs
n < m && svd_alg isa CUSOLVER_QRIteration && continue
A = CuArray(randn(rng, T, m, n))
alg = PolarViaSVD(svd_alg)
W, P = left_polar(A; alg)
Expand Down Expand Up @@ -52,7 +51,6 @@ end
k = min(m, n)
svd_algs = (CUSOLVER_QRIteration(), CUSOLVER_Jacobi())
@testset "algorithm $svd_alg" for svd_alg in svd_algs
n > m && svd_alg isa CUSOLVER_QRIteration && continue
A = CuArray(randn(rng, T, m, n))
alg = PolarViaSVD(svd_alg)
P, Wᴴ = right_polar(A; alg)
Expand Down
4 changes: 2 additions & 2 deletions test/cuda/projections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
m = 54
noisefactor = eps(real(T))^(3 / 4)
for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
A = CuArray(randn(rng, T, m, m))
A = CuArray(randn(rng, T, m, m))
Ah = (A + A') / 2
Aa = (A - A') / 2
Ac = copy(A)
Expand Down Expand Up @@ -69,7 +69,7 @@ end
# test that W is closer to A then any other isometry
for k in 1:10
δA = CuArray(randn(rng, T, m, n))
W = project_isometric(A, alg)
W = project_isometric(A, alg)
W2 = project_isometric(A + δA / 100, alg)
@test norm(A - W2) > norm(A - W)
end
Expand Down
Loading