Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ for (fname, elty, relty) in

AMDGPU.unsafe_free!(dev_residual)
AMDGPU.unsafe_free!(dev_n_sweeps)
return U, S, Vᴴ
return (S, U, Vᴴ)
end
end
end
Expand Down
3 changes: 2 additions & 1 deletion ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: LQViaTransposedQR
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!
using CUDA
using LinearAlgebra
using LinearAlgebra: BlasFloat
Expand All @@ -30,6 +30,7 @@ _gpu_ungqr!(A::StridedCuMatrix, τ::StridedCuVector) = YACUSOLVER.ungqr!(A, τ)
_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::StridedCuMatrix, τ::StridedCuVector, C::StridedCuVecOrMat) = YACUSOLVER.unmqr!(side, trans, A, τ, C)
_gpu_gesvd!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix) = YACUSOLVER.gesvd!(A, S, U, Vᴴ)
_gpu_Xgesvdp!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...)
_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.Xgesvdr!(A, S, U, Vᴴ; kwargs...)
_gpu_gesvdj!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)

end
61 changes: 60 additions & 1 deletion ext/MatrixAlgebraKitCUDAExt/yacusolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,70 @@ for (bname, fname, elty, relty) in
if jobz == 'V'
adjoint!(Vᴴ, Ṽ)
end
return U, S, Vᴴ
return S, U, Vᴴ
end
end
end

# Wrapper for randomized SVD
function Xgesvdr!(A::StridedCuMatrix{T},
S::StridedCuVector=similar(A, real(T), min(size(A)...)),
U::StridedCuMatrix{T}=similar(A, T, size(A, 1), min(size(A)...)),
Vᴴ::StridedCuMatrix{T}=similar(A, T, min(size(A)...), size(A, 2));
k::Int=length(S),
Comment thread
lkdvos marked this conversation as resolved.
p::Int=min(size(A)...)-k-1,
niters::Int=1) where {T<:BlasFloat}
chkstride1(A, U, S, Vᴴ)
m, n = size(A)
minmn = min(m, n)
jobu = length(U) == 0 ? 'N' : 'S'
jobv = length(Vᴴ) == 0 ? 'N' : 'S'
R = eltype(S)
k < minmn || throw(DimensionMismatch("length of S ($k) must be less than the smaller dimension of A ($minmn)"))
k + p < minmn || throw(DimensionMismatch("length of S ($k) plus oversampling ($p) must be less than the smaller dimension of A ($minmn)"))
R == real(T) ||
throw(ArgumentError("S does not have the matching real `eltype` of A"))

Ṽ = similar(Vᴴ, (n, n))
Ũ = (size(U) == (m, m)) ? U : similar(U, (m, m))
lda = max(1, stride(A, 2))
ldu = max(1, stride(Ũ, 2))
ldv = max(1, stride(Ṽ, 2))
params = CUSOLVER.CuSolverParameters()
dh = CUSOLVER.dense_handle()

function bufferSize()
out_cpu = Ref{Csize_t}(0)
out_gpu = Ref{Csize_t}(0)
CUSOLVER.cusolverDnXgesvdr_bufferSize(dh, params, jobu, jobv, m, n, k, p, niters,
T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv,
T, out_gpu, out_cpu)

return out_gpu[], out_cpu[]
end
CUSOLVER.with_workspaces(dh.workspace_gpu, dh.workspace_cpu,
bufferSize()...) do buffer_gpu, buffer_cpu
return CUSOLVER.cusolverDnXgesvdr(dh, params, jobu, jobv, m, n, k, p, niters,
T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv,
T, buffer_gpu, sizeof(buffer_gpu),
buffer_cpu, sizeof(buffer_cpu),
dh.info)
end

flag = @allowscalar dh.info[1]
CUSOLVER.chklapackerror(BlasInt(flag))
if Ũ !== U && length(U) > 0
U .= view(Ũ, 1:m, 1:size(U, 2))
end
if length(Vᴴ) > 0
Vᴴ .= view(Ṽ', 1:size(Vᴴ, 1), 1:n)
end
Ũ !== U && CUDA.unsafe_free!(Ũ)
CUDA.unsafe_free!(Ṽ)

return S, U, Vᴴ
Comment thread
kshyatt marked this conversation as resolved.
end

# for (jname, bname, fname, elty, relty) in
# ((:sygvd!, :cusolverDnSsygvd_bufferSize, :cusolverDnSsygvd, :Float32, :Float32),
# (:sygvd!, :cusolverDnDsygvd_bufferSize, :cusolverDnDsygvd, :Float64, :Float64),
Expand Down
2 changes: 1 addition & 1 deletion src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
LAPACK_DivideAndConquer, LAPACK_Jacobi,
LQViaTransposedQR,
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi,
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized,
ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered

Expand Down
45 changes: 44 additions & 1 deletion src/common/gauge.jl
Comment thread
kshyatt marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -1,8 +1,51 @@
function gaugefix!(V::AbstractMatrix)
Comment thread
lkdvos marked this conversation as resolved.
for j in axes(V, 2)
v = view(V, :, j)
s = conj(sign(argmax(abs, v)))
s = conj(sign(_argmaxabs(v)))
@inbounds v .*= s
end
return V
end

function gaugefix!(::Val{:full}, U, S, Vᴴ, m::Int, n::Int)
for j in 1:max(m, n)
if j <= min(m, n)
u = view(U, :, j)
v = view(Vᴴ, j, :)
s = conj(sign(_argmaxabs(u)))
u .*= s
v .*= conj(s)
elseif j <= m
u = view(U, :, j)
s = conj(sign(_argmaxabs(u)))
u .*= s
else
v = view(Vᴴ, j, :)
s = conj(sign(_argmaxabs(v)))
v .*= s
end
end
return (U, S, Vᴴ)
end

function gaugefix!(::Val{:compact}, U, S, Vᴴ, m::Int, n::Int)
for j in 1:size(U, 2)
u = view(U, :, j)
v = view(Vᴴ, j, :)
s = conj(sign(_argmaxabs(u)))
u .*= s
v .*= conj(s)
end
return (U, S, Vᴴ)
end

function gaugefix!(::Val{:trunc}, U, S, Vᴴ, m::Int, n::Int)
Comment thread
kshyatt marked this conversation as resolved.
Outdated
for j in 1:min(m, n)
u = view(U, :, j)
v = view(Vᴴ, j, :)
s = conj(sign(_argmaxabs(u)))
u .*= s
v .*= conj(s)
end
return (U, S, Vᴴ)
end
8 changes: 4 additions & 4 deletions src/implementations/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function copy_input(::typeof(eig_vals), A::AbstractMatrix)
end
copy_input(::typeof(eig_trunc), A) = copy_input(eig_full, A)

function check_input(::typeof(eig_full!), A::AbstractMatrix, DV)
function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm)
m, n = size(A)
m == n || throw(DimensionMismatch("square input matrix expected"))
D, V = DV
Expand All @@ -19,7 +19,7 @@ function check_input(::typeof(eig_full!), A::AbstractMatrix, DV)
@check_scalar(V, A, complex)
return nothing
end
function check_input(::typeof(eig_vals!), A::AbstractMatrix, D)
function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::AbstractAlgorithm)
m, n = size(A)
m == n || throw(DimensionMismatch("square input matrix expected"))
@assert D isa AbstractVector
Expand Down Expand Up @@ -51,7 +51,7 @@ end
# --------------
# actual implementation
function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm)
check_input(eig_full!, A, DV)
check_input(eig_full!, A, DV, alg)
D, V = DV
if alg isa LAPACK_Simple
isempty(alg.kwargs) ||
Expand All @@ -66,7 +66,7 @@ function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm)
end

function eig_vals!(A::AbstractMatrix, D, alg::LAPACK_EigAlgorithm)
check_input(eig_vals!, A, D)
check_input(eig_vals!, A, D, alg)
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
if alg isa LAPACK_Simple
isempty(alg.kwargs) ||
Expand Down
8 changes: 4 additions & 4 deletions src/implementations/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function copy_input(::typeof(eigh_vals), A::AbstractMatrix)
end
copy_input(::typeof(eigh_trunc), A) = copy_input(eigh_full, A)

function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV)
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm)
m, n = size(A)
m == n || throw(DimensionMismatch("square input matrix expected"))
D, V = DV
Expand All @@ -19,7 +19,7 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV)
@check_scalar(V, A)
return nothing
end
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D)
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, ::AbstractAlgorithm)
m, n = size(A)
@assert D isa AbstractVector
@check_size(D, (n,))
Expand Down Expand Up @@ -48,7 +48,7 @@ end
# Implementation
# --------------
function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm)
check_input(eigh_full!, A, DV)
check_input(eigh_full!, A, DV, alg)
D, V = DV
Dd = D.diag
if alg isa LAPACK_MultipleRelativelyRobustRepresentations
Expand All @@ -70,7 +70,7 @@ function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm)
end

function eigh_vals!(A::AbstractMatrix, D, alg::LAPACK_EighAlgorithm)
check_input(eigh_vals!, A, D)
check_input(eigh_vals!, A, D, alg)
V = similar(A, (size(A, 1), 0))
if alg isa LAPACK_MultipleRelativelyRobustRepresentations
YALAPACK.heevr!(A, D, V; alg.kwargs...)
Expand Down
8 changes: 4 additions & 4 deletions src/implementations/gen_eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function copy_input(::typeof(gen_eig_vals), A::AbstractMatrix, B::AbstractMatrix
return copy_input(gen_eig_full, A, B)
end

function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, WV)
function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, WV, ::AbstractAlgorithm)
ma, na = size(A)
mb, nb = size(B)
ma == na || throw(DimensionMismatch("square input matrix A expected"))
Expand All @@ -24,7 +24,7 @@ function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatr
@check_scalar(V, B, complex)
return nothing
end
function check_input(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, W)
function check_input(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, W, ::AbstractAlgorithm)
ma, na = size(A)
mb, nb = size(B)
ma == na || throw(DimensionMismatch("square input matrix A expected"))
Expand Down Expand Up @@ -57,7 +57,7 @@ end
# --------------
# actual implementation
function gen_eig_full!(A::AbstractMatrix, B::AbstractMatrix, WV, alg::LAPACK_EigAlgorithm)
check_input(gen_eig_full!, A, B, WV)
check_input(gen_eig_full!, A, B, WV, alg)
W, V = WV
if alg isa LAPACK_Simple
isempty(alg.kwargs) ||
Expand All @@ -72,7 +72,7 @@ function gen_eig_full!(A::AbstractMatrix, B::AbstractMatrix, WV, alg::LAPACK_Eig
end

function gen_eig_vals!(A::AbstractMatrix, B::AbstractMatrix, W, alg::LAPACK_EigAlgorithm)
check_input(gen_eig_vals!, A, B, W)
check_input(gen_eig_vals!, A, B, W, alg)
V = similar(A, complex(eltype(A)), (size(A, 1), 0))
if alg isa LAPACK_Simple
isempty(alg.kwargs) ||
Expand Down
18 changes: 9 additions & 9 deletions src/implementations/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function copy_input(::typeof(lq_null), A::AbstractMatrix)
return copy!(similar(A, float(eltype(A))), A)
end

function check_input(::typeof(lq_full!), A::AbstractMatrix, LQ)
function check_input(::typeof(lq_full!), A::AbstractMatrix, LQ, ::AbstractAlgorithm)
m, n = size(A)
L, Q = LQ
@assert L isa AbstractMatrix && Q isa AbstractMatrix
Expand All @@ -20,7 +20,7 @@ function check_input(::typeof(lq_full!), A::AbstractMatrix, LQ)
@check_scalar(Q, A)
return nothing
end
function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ)
function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ, ::AbstractAlgorithm)
m, n = size(A)
minmn = min(m, n)
L, Q = LQ
Expand All @@ -31,7 +31,7 @@ function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ)
@check_scalar(Q, A)
return nothing
end
function check_input(::typeof(lq_null!), A::AbstractMatrix, Nᴴ)
function check_input(::typeof(lq_null!), A::AbstractMatrix, Nᴴ, ::AbstractAlgorithm)
m, n = size(A)
minmn = min(m, n)
@assert Nᴴ isa AbstractMatrix
Expand Down Expand Up @@ -66,36 +66,36 @@ end
# --------------
# actual implementation
function lq_full!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ)
check_input(lq_full!, A, LQ)
check_input(lq_full!, A, LQ, alg)
L, Q = LQ
_lapack_lq!(A, L, Q; alg.kwargs...)
return L, Q
end
function lq_full!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR)
check_input(lq_full!, A, LQ)
check_input(lq_full!, A, LQ, alg)
L, Q = LQ
lq_via_qr!(A, L, Q, alg.qr_alg)
return L, Q
end
function lq_compact!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ)
check_input(lq_compact!, A, LQ)
check_input(lq_compact!, A, LQ, alg)
L, Q = LQ
_lapack_lq!(A, L, Q; alg.kwargs...)
return L, Q
end
function lq_compact!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR)
check_input(lq_compact!, A, LQ)
check_input(lq_compact!, A, LQ, alg)
L, Q = LQ
lq_via_qr!(A, L, Q, alg.qr_alg)
return L, Q
end
function lq_null!(A::AbstractMatrix, Nᴴ, alg::LAPACK_HouseholderLQ)
check_input(lq_null!, A, Nᴴ)
check_input(lq_null!, A, Nᴴ, alg)
_lapack_lq_null!(A, Nᴴ; alg.kwargs...)
return Nᴴ
end
function lq_null!(A::AbstractMatrix, Nᴴ, alg::LQViaTransposedQR)
check_input(lq_null!, A, Nᴴ)
check_input(lq_null!, A, Nᴴ, alg)
lq_null_via_qr!(A, Nᴴ, alg.qr_alg)
return Nᴴ
end
Expand Down
Loading
Loading