From 2e4451b8f3bcac0240809bf46288223605173fa3 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Thu, 17 Apr 2025 15:24:46 +0200 Subject: [PATCH 01/16] first cuda commit - qr support --- Project.toml | 2 + .../MatrixAlgebraKitCUDAExt.jl | 9 + .../implementations/qr.jl | 104 ++++ ext/MatrixAlgebraKitCUDAExt/yacusolver.jl | 585 ++++++++++++++++++ 4 files changed, 700 insertions(+) create mode 100644 ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl create mode 100644 ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl create mode 100644 ext/MatrixAlgebraKitCUDAExt/yacusolver.jl diff --git a/Project.toml b/Project.toml index ec99e515f..7dd5762ea 100644 --- a/Project.toml +++ b/Project.toml @@ -8,9 +8,11 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" [extensions] MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore" +MatrixAlgebraKitCUDAExt = "CUDA" [compat] Aqua = "0.6, 0.7, 0.8" diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl new file mode 100644 index 000000000..6b2cfdba6 --- /dev/null +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -0,0 +1,9 @@ +module MatrixAlgebraKitChainRulesCoreExt + +using MatrixAlgebraKit +using CUDA + +include("yacusolver.jl") +inculde("implementations/qr.jl") + +end \ No newline at end of file diff --git a/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl b/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl new file mode 100644 index 000000000..495b2981e --- /dev/null +++ b/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl @@ -0,0 +1,104 @@ +""" + CUSOLVER_HouseholderQR(; positive = false) + +Algorithm type to denote the standard CUSOLVER algorithm for computing the QR decomposition of +a matrix using Householder reflectors. The keyword `positive=true` can be used to ensure that +the diagonal elements of `R` are non-negative. +""" +@algdef CUSOLVER_HouseholderQR + +# Outputs +# ------- +function MatrixAlgebraKit.initialize_output(::typeof(qr_full!), A::AbstractMatrix, + ::CUSOLVER_HouseholderQR) + m, n = size(A) + Q = similar(A, (m, m)) + R = similar(A, (m, n)) + return (Q, R) +end +function MatrixAlgebraKit.initialize_output(::typeof(qr_compact!), A::AbstractMatrix, + ::CUSOLVER_HouseholderQR) + m, n = size(A) + minmn = min(m, n) + Q = similar(A, (m, minmn)) + R = similar(A, (minmn, n)) + return (Q, R) +end +function MatrixAlgebraKit.initialize_output(::typeof(qr_null!), A::AbstractMatrix, + ::CUSOLVER_HouseholderQR) + m, n = size(A) + minmn = min(m, n) + N = similar(A, (m, m - minmn)) + return N +end + +# Implementation +# -------------- +# actual implementation +function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::CUSOLVER_HouseholderQR) + check_input(qr_full!, A, QR) + Q, R = QR + _cusolver_qr!(A, Q, R; alg.kwargs...) + return Q, R +end +function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, alg::CUSOLVER_HouseholderQR) + check_input(qr_compact!, A, QR) + Q, R = QR + _cusolver_qr!(A, Q, R; alg.kwargs...) + return Q, R +end +function MatrixAlgebraKit.qr_null!(A::AbstractMatrix, N, alg::CUSOLVER_HouseholderQR) + check_input(qr_null!, A, N) + _cusolver_qr_null!(A, N; alg.kwargs...) + return N +end + +function _cusolver_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; + positive=false, blocksize=1) + blocksize > 1 && + throw(ArgumentError("CUSOLVER does not provide a blocked implementation for a QR decomposition")) + m, n = size(A) + minmn = min(m, n) + computeR = length(R) > 0 + inplaceQ = Q === A + if inplaceQ && (computeR || positive || m < n) + throw(ArgumentError("inplace Q only supported if matrix is tall (`m >= n`), R is not required and using `positive=false`")) + end + + A, τ = YACUSOLVER.geqrf!(A) + if inplaceQ + Q = YACUSOLVER.ungqr!(A, τ) + else + Q = YACUSOLVER.unmqr!('L', 'N', A, τ, one!(Q)) + end + # henceforth, τ is no longer needed and can be reused + + if positive # already fix Q even if we do not need R + τ .= sign_safe.(diagview(A)) + Q = rmul!(Q, Diagonal(τ)) + end + + if computeR + R̃ = triu!(view(A, axes(R)...)) + if positive + R̃ = lmul!(Diagonal(τ)', R̃) + end + copyto!(R, R̃) + end + return Q, R +end + +function _cusolver_qr_null!(A::AbstractMatrix, N::AbstractMatrix; + positive=false, + pivoted=false, + blocksize=1) + blocksize > 1 && + throw(ArgumentError("CUSOLVER does not provide a blocked implementation for a QR decomposition")) + m, n = size(A) + minmn = min(m, n) + fill!(N, zero(eltype(N))) + one!(view(N, (minmn + 1):m, 1:(m - minmn))) + A, τ = YACUSOLVER.geqrf!(A) + N = unmqr!('L', 'N', A, τ, N) + return N +end diff --git a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl new file mode 100644 index 000000000..ff1fc167e --- /dev/null +++ b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl @@ -0,0 +1,585 @@ +module YACUSOLVER + +using LinearAlgebra +using LinearAlgebra: BlasInt, checksquare, chkstride1, require_one_based_indexing +using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag, chkuplo + +using CUDA +using CUDA.CUSOLVER: cusolverDnCreate + +# QR methods are implemented with full access to allocated arrays, so we do not need to redo this: +using CUDA.CUSOLVER: geqrf!, ormqr!, orgqr! +const unmqr! = ormqr! +const ungqr! = orgqr! + +# # Wrapper for SVD via QR Iteration +# for (bname, fname, elty, relty) in +# ((:cusolverDnSgesvd_bufferSize, :cusolverDnSgesvd, :Float32, :Float32), +# (:cusolverDnDgesvd_bufferSize, :cusolverDnDgesvd, :Float64, :Float64), +# (:cusolverDnCgesvd_bufferSize, :cusolverDnCgesvd, :ComplexF32, :Float32), +# (:cusolverDnZgesvd_bufferSize, :cusolverDnZgesvd, :ComplexF64, :Float64)) +# @eval begin +# function gesvd!(A::StridedCuMatrix{$elty}, +# S::StridedCuVector{$relty}=similar(A, $relty, min(size(A)...)), +# U::StridedCuMatrix{$elty}=similar(A, $elty, size(A, 1), +# min(size(A)...)), +# Vᴴ::StridedCuMatrix{$elty}=similar(A, $elty, min(size(A)...), +# size(A, 2))) +# chkstride1(A, U, Vᴴ, S) +# m, n = size(A) +# (m < n) && throw(ArgumentError("CUSOLVER's gesvd requires m ≥ n")) +# minmn = min(m, n) +# lda = max(1, stride(A, 2)) + +# if length(U) == 0 +# jobu = 'N' +# else +# size(U, 1) == m || +# throw(DimensionMismatch("row size mismatch between A and U")) +# if size(U, 2) == minmn +# if U === A +# jobu = 'O' +# else +# jobu = 'S' +# end +# elseif size(U, 2) == m +# jobu = 'A' +# else +# throw(DimensionMismatch("invalid column size of U")) +# end +# end +# if length(Vᴴ) == 0 +# jobvt = 'N' +# else +# size(Vᴴ, 2) == n || +# throw(DimensionMismatch("column size mismatch between A and Vᴴ")) +# if size(Vᴴ, 1) == minmn +# if Vᴴ === A +# jobvt = 'O' +# else +# jobvt = 'S' +# end +# elseif size(Vᴴ, 1) == n +# jobvt = 'A' +# else +# throw(DimensionMismatch("invalid row size of Vᴴ")) +# end +# end +# length(S) == minmn || +# throw(DimensionMismatch("length mismatch between A and S")) + +# lda = max(1, stride(A, 2)) +# ldu = max(1, stride(U, 2)) +# ldv = max(1, stride(Vᴴ, 2)) + +# dh = dense_handle() +# function bufferSize() +# out = Ref{Cint}(0) +# $bname(dh, m, n, out) +# return out[] * sizeof($elty) +# end +# rwork = CuArray{$relty}(undef, min(m, n) - 1) +# with_workspace(dh.workspace_gpu, bufferSize) do buffer +# return $fname(dh, jobu, jobvt, m, n, A, lda, S, U, ldu, Vᴴ, ldv, +# buffer, sizeof(buffer) ÷ sizeof($elty), rwork, dh.info) +# end +# unsafe_free!(rwork) + +# info = @allowscalar dh.info[1] +# chkargsok(BlasInt(info)) + +# return (S, U, Vᴴ) +# end +# end +# end + +# # Wrapper for SVD via Jacobi +# for (bname, fname, elty, relty) in +# ((:cusolverDnSgesvdj_bufferSize, :cusolverDnSgesvdj, :Float32, :Float32), +# (:cusolverDnDgesvdj_bufferSize, :cusolverDnDgesvdj, :Float64, :Float64), +# (:cusolverDnCgesvdj_bufferSize, :cusolverDnCgesvdj, :ComplexF32, :Float32), +# (:cusolverDnZgesvdj_bufferSize, :cusolverDnZgesvdj, :ComplexF64, :Float64)) +# @eval begin +# function gesvdj!(A::StridedCuMatrix{$elty}, +# S::StridedCuVector{$relty}=similar(A, $relty, min(size(A)...)), +# U::StridedCuMatrix{$elty}=similar(A, $elty, size(A, 1), +# min(size(A)...)), +# Vᴴ::StridedCuMatrix{$elty}=similar(A, $elty, min(size(A)...), +# size(A, 2)); +# tol::$relty=eps($relty), +# max_sweeps::Int=100) +# chkstride1(A, U, Vᴴ, S) +# m, n = size(A) +# minmn = min(m, n) +# lda = max(1, stride(A, 2)) + +# if length(U) == 0 && length(Vᴴ) == 0 +# jobz = 'N' +# econ = 0 +# else +# jobz = 'V' +# size(U, 1) == m || +# throw(DimensionMismatch("row size mismatch between A and U")) +# size(Vᴴ, 2) == n || +# throw(DimensionMismatch("column size mismatch between A and Vᴴ")) +# if size(U, 2) == size(Vᴴ, 1) == minmn +# econ = 1 +# elseif size(U, 2) == m && size(Vᴴ, 1) == n +# econ = 0 +# else +# throw(DimensionMismatch("invalid column size of U or row size of Vᴴ")) +# end +# end +# length(S) == minmn || +# throw(DimensionMismatch("length mismatch between A and S")) + +# if jobz == 'N' # it seems we still need the memory for U and Vᴴ +# U = similar(A, $elty, m, minmn) +# V = similar(A, $elty, n, minmn) +# else +# V = similar(Vᴴ') +# end +# lda = max(1, stride(A, 2)) +# ldu = max(1, stride(U, 2)) +# ldv = max(1, stride(V, 2)) + +# params = Ref{gesvdjInfo_t}(C_NULL) +# cusolverDnCreateGesvdjInfo(params) +# cusolverDnXgesvdjSetTolerance(params[], tol) +# cusolverDnXgesvdjSetMaxSweeps(params[], max_sweeps) +# dh = dense_handle() + +# function bufferSize() +# out = Ref{Cint}(0) +# $bname(dh, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, +# out, params[]) +# return out[] * sizeof($elty) +# end + +# with_workspace(dh.workspace_gpu, bufferSize) do buffer +# return $fname(dh, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, +# buffer, sizeof(buffer) ÷ sizeof($elty), dh.info, params[]) +# end + +# info = @allowscalar dh.info[1] +# chkargsok(BlasInt(info)) + +# cusolverDnDestroyGesvdjInfo(params[]) + +# if jobz != 'N' +# adjoint!(Vᴴ, V) +# end +# return U, S, Vᴴ +# end +# end +# end + +# for (jname, bname, fname, elty, relty) in +# ((:sygvd!, :cusolverDnSsygvd_bufferSize, :cusolverDnSsygvd, :Float32, :Float32), +# (:sygvd!, :cusolverDnDsygvd_bufferSize, :cusolverDnDsygvd, :Float64, :Float64), +# (:hegvd!, :cusolverDnChegvd_bufferSize, :cusolverDnChegvd, :ComplexF32, :Float32), +# (:hegvd!, :cusolverDnZhegvd_bufferSize, :cusolverDnZhegvd, :ComplexF64, :Float64)) +# @eval begin +# function $jname(itype::Int, +# jobz::Char, +# uplo::Char, +# A::StridedCuMatrix{$elty}, +# B::StridedCuMatrix{$elty}) +# chkuplo(uplo) +# nA, nB = checksquare(A, B) +# if nB != nA +# throw(DimensionMismatch("Dimensions of A ($nA, $nA) and B ($nB, $nB) must match!")) +# end +# n = nA +# lda = max(1, stride(A, 2)) +# ldb = max(1, stride(B, 2)) +# W = CuArray{$relty}(undef, n) +# dh = dense_handle() + +# function bufferSize() +# out = Ref{Cint}(0) +# $bname(dh, itype, jobz, uplo, n, A, lda, B, ldb, W, out) +# return out[] * sizeof($elty) +# end + +# with_workspace(dh.workspace_gpu, bufferSize) do buffer +# return $fname(dh, itype, jobz, uplo, n, A, lda, B, ldb, W, +# buffer, sizeof(buffer) ÷ sizeof($elty), dh.info) +# end + +# info = @allowscalar dh.info[1] +# chkargsok(BlasInt(info)) + +# if jobz == 'N' +# return W +# elseif jobz == 'V' +# return W, A, B +# end +# end +# end +# end + +# for (jname, bname, fname, elty, relty) in +# ((:sygvj!, :cusolverDnSsygvj_bufferSize, :cusolverDnSsygvj, :Float32, :Float32), +# (:sygvj!, :cusolverDnDsygvj_bufferSize, :cusolverDnDsygvj, :Float64, :Float64), +# (:hegvj!, :cusolverDnChegvj_bufferSize, :cusolverDnChegvj, :ComplexF32, :Float32), +# (:hegvj!, :cusolverDnZhegvj_bufferSize, :cusolverDnZhegvj, :ComplexF64, :Float64)) +# @eval begin +# function $jname(itype::Int, +# jobz::Char, +# uplo::Char, +# A::StridedCuMatrix{$elty}, +# B::StridedCuMatrix{$elty}; +# tol::$relty=eps($relty), +# max_sweeps::Int=100) +# chkuplo(uplo) +# nA, nB = checksquare(A, B) +# if nB != nA +# throw(DimensionMismatch("Dimensions of A ($nA, $nA) and B ($nB, $nB) must match!")) +# end +# n = nA +# lda = max(1, stride(A, 2)) +# ldb = max(1, stride(B, 2)) +# W = CuArray{$relty}(undef, n) +# params = Ref{syevjInfo_t}(C_NULL) +# cusolverDnCreateSyevjInfo(params) +# cusolverDnXsyevjSetTolerance(params[], tol) +# cusolverDnXsyevjSetMaxSweeps(params[], max_sweeps) +# dh = dense_handle() + +# function bufferSize() +# out = Ref{Cint}(0) +# $bname(dh, itype, jobz, uplo, n, A, lda, B, ldb, W, +# out, params[]) +# return out[] * sizeof($elty) +# end + +# with_workspace(dh.workspace_gpu, bufferSize) do buffer +# return $fname(dh, itype, jobz, uplo, n, A, lda, B, ldb, W, +# buffer, sizeof(buffer) ÷ sizeof($elty), dh.info, params[]) +# end + +# info = @allowscalar dh.info[1] +# chkargsok(BlasInt(info)) + +# cusolverDnDestroySyevjInfo(params[]) + +# if jobz == 'N' +# return W +# elseif jobz == 'V' +# return W, A, B +# end +# end +# end +# end + +# for (jname, bname, fname, elty, relty) in +# ((:syevjBatched!, :cusolverDnSsyevjBatched_bufferSize, :cusolverDnSsyevjBatched, +# :Float32, :Float32), +# (:syevjBatched!, :cusolverDnDsyevjBatched_bufferSize, :cusolverDnDsyevjBatched, +# :Float64, :Float64), +# (:heevjBatched!, :cusolverDnCheevjBatched_bufferSize, :cusolverDnCheevjBatched, +# :ComplexF32, :Float32), +# (:heevjBatched!, :cusolverDnZheevjBatched_bufferSize, :cusolverDnZheevjBatched, +# :ComplexF64, :Float64)) +# @eval begin +# function $jname(jobz::Char, +# uplo::Char, +# A::StridedCuArray{$elty}; +# tol::$relty=eps($relty), +# max_sweeps::Int=100) + +# # Set up information for the solver arguments +# chkuplo(uplo) +# n = checksquare(A) +# lda = max(1, stride(A, 2)) +# batchSize = size(A, 3) +# W = CuArray{$relty}(undef, n, batchSize) +# params = Ref{syevjInfo_t}(C_NULL) + +# dh = dense_handle() +# resize!(dh.info, batchSize) + +# # Initialize the solver parameters +# cusolverDnCreateSyevjInfo(params) +# cusolverDnXsyevjSetTolerance(params[], tol) +# cusolverDnXsyevjSetMaxSweeps(params[], max_sweeps) + +# # Calculate the workspace size +# function bufferSize() +# out = Ref{Cint}(0) +# $bname(dh, jobz, uplo, n, A, lda, W, out, params[], batchSize) +# return out[] * sizeof($elty) +# end + +# # Run the solver +# with_workspace(dh.workspace_gpu, bufferSize) do buffer +# return $fname(dh, jobz, uplo, n, A, lda, W, buffer, +# sizeof(buffer) ÷ sizeof($elty), dh.info, params[], batchSize) +# end + +# # Copy the solver info and delete the device memory +# info = @allowscalar collect(dh.info) + +# # Double check the solver's exit status +# for i in 1:batchSize +# chkargsok(BlasInt(info[i])) +# end + +# cusolverDnDestroySyevjInfo(params[]) + +# # Return eigenvalues (in W) and possibly eigenvectors (in A) +# if jobz == 'N' +# return W +# elseif jobz == 'V' +# return W, A +# end +# end +# end +# end + +# for (fname, elty) in ((:cusolverDnSpotrsBatched, :Float32), +# (:cusolverDnDpotrsBatched, :Float64), +# (:cusolverDnCpotrsBatched, :ComplexF32), +# (:cusolverDnZpotrsBatched, :ComplexF64)) +# @eval begin +# function potrsBatched!(uplo::Char, +# A::Vector{<:StridedCuMatrix{$elty}}, +# B::Vector{<:StridedCuVecOrMat{$elty}}) +# if length(A) != length(B) +# throw(DimensionMismatch("")) +# end +# # Set up information for the solver arguments +# chkuplo(uplo) +# n = checksquare(A[1]) +# if size(B[1], 1) != n +# throw(DimensionMismatch("first dimension of B[i], $(size(B[1],1)), must match second dimension of A, $n")) +# end +# nrhs = size(B[1], 2) +# # cuSOLVER's Remark 1: only nrhs=1 is supported. +# if nrhs != 1 +# throw(ArgumentError("cuSOLVER only supports vectors for B")) +# end +# lda = max(1, stride(A[1], 2)) +# ldb = max(1, stride(B[1], 2)) +# batchSize = length(A) + +# Aptrs = unsafe_batch(A) +# Bptrs = unsafe_batch(B) + +# dh = dense_handle() + +# # Run the solver +# $fname(dh, uplo, n, nrhs, Aptrs, lda, Bptrs, ldb, dh.info, batchSize) + +# # Copy the solver info and delete the device memory +# info = @allowscalar dh.info[1] +# chklapackerror(BlasInt(info)) + +# return B +# end +# end +# end + +# for (fname, elty) in ((:cusolverDnSpotrfBatched, :Float32), +# (:cusolverDnDpotrfBatched, :Float64), +# (:cusolverDnCpotrfBatched, :ComplexF32), +# (:cusolverDnZpotrfBatched, :ComplexF64)) +# @eval begin +# function potrfBatched!(uplo::Char, A::Vector{<:StridedCuMatrix{$elty}}) + +# # Set up information for the solver arguments +# chkuplo(uplo) +# n = checksquare(A[1]) +# lda = max(1, stride(A[1], 2)) +# batchSize = length(A) + +# Aptrs = unsafe_batch(A) + +# dh = dense_handle() +# resize!(dh.info, batchSize) + +# # Run the solver +# $fname(dh, uplo, n, Aptrs, lda, dh.info, batchSize) + +# # Copy the solver info and delete the device memory +# info = @allowscalar collect(dh.info) + +# # Double check the solver's exit status +# for i in 1:batchSize +# chkargsok(BlasInt(info[i])) +# end + +# # info[i] > 0 means the leading minor of order info[i] is not positive definite +# # LinearAlgebra.LAPACK does not throw Exception here +# # to simplify calls to isposdef! and factorize +# return A, info +# end +# end +# end + +# # gesv +# function gesv!(X::CuVecOrMat{T}, A::CuMatrix{T}, B::CuVecOrMat{T}; fallback::Bool=true, +# residual_history::Bool=false, irs_precision::String="AUTO", +# refinement_solver::String="CLASSICAL", +# maxiters::Int=0, maxiters_inner::Int=0, tol::Float64=0.0, +# tol_inner=Float64 = 0.0) where {T<:BlasFloat} +# params = CuSolverIRSParameters() +# info = CuSolverIRSInformation() +# n = checksquare(A) +# nrhs = size(B, 2) +# lda = max(1, stride(A, 2)) +# ldb = max(1, stride(B, 2)) +# ldx = max(1, stride(X, 2)) +# niters = Ref{Cint}() +# dh = dense_handle() + +# if irs_precision == "AUTO" +# (T == Float32) && (irs_precision = "R_32F") +# (T == Float64) && (irs_precision = "R_64F") +# (T == ComplexF32) && (irs_precision = "C_32F") +# (T == ComplexF64) && (irs_precision = "C_64F") +# else +# (T == Float32) && (irs_precision ∈ ("R_32F", "R_16F", "R_16BF", "R_TF32") || +# error("$irs_precision is not supported.")) +# (T == Float64) && +# (irs_precision ∈ ("R_64F", "R_32F", "R_16F", "R_16BF", "R_TF32") || +# error("$irs_precision is not supported.")) +# (T == ComplexF32) && (irs_precision ∈ ("C_32F", "C_16F", "C_16BF", "C_TF32") || +# error("$irs_precision is not supported.")) +# (T == ComplexF64) && +# (irs_precision ∈ ("C_64F", "C_32F", "C_16F", "C_16BF", "C_TF32") || +# error("$irs_precision is not supported.")) +# end +# cusolverDnIRSParamsSetSolverMainPrecision(params, T) +# cusolverDnIRSParamsSetSolverLowestPrecision(params, irs_precision) +# cusolverDnIRSParamsSetRefinementSolver(params, refinement_solver) +# (tol != 0.0) && cusolverDnIRSParamsSetTol(params, tol) +# (tol_inner != 0.0) && cusolverDnIRSParamsSetTolInner(params, tol_inner) +# (maxiters != 0) && cusolverDnIRSParamsSetMaxIters(params, maxiters) +# (maxiters_inner != 0) && cusolverDnIRSParamsSetMaxItersInner(params, maxiters_inner) +# fallback ? cusolverDnIRSParamsEnableFallback(params) : +# cusolverDnIRSParamsDisableFallback(params) +# residual_history && cusolverDnIRSInfosRequestResidual(info) + +# function bufferSize() +# buffer_size = Ref{Csize_t}(0) +# cusolverDnIRSXgesv_bufferSize(dh, params, n, nrhs, buffer_size) +# return buffer_size[] +# end + +# with_workspace(dh.workspace_gpu, bufferSize) do buffer +# return cusolverDnIRSXgesv(dh, params, info, n, nrhs, A, lda, B, ldb, +# X, ldx, buffer, sizeof(buffer), niters, dh.info) +# end + +# # Copy the solver flag and delete the device memory +# flag = @allowscalar dh.info[1] +# chklapackerror(BlasInt(flag)) + +# return X, info +# end + +# for (jname, bname, fname, elty, relty) in +# ((:syevd!, :cusolverDnSsyevd_bufferSize, :cusolverDnSsyevd, :Float32, :Float32), +# (:syevd!, :cusolverDnDsyevd_bufferSize, :cusolverDnDsyevd, :Float64, :Float64), +# (:heevd!, :cusolverDnCheevd_bufferSize, :cusolverDnCheevd, :ComplexF32, :Float32), +# (:heevd!, :cusolverDnZheevd_bufferSize, :cusolverDnZheevd, :ComplexF64, :Float64)) +# @eval begin +# function $jname(jobz::Char, +# uplo::Char, +# A::StridedCuMatrix{$elty}) +# chkuplo(uplo) +# n = checksquare(A) +# lda = max(1, stride(A, 2)) +# W = CuArray{$relty}(undef, n) +# dh = dense_handle() + +# function bufferSize() +# out = Ref{Cint}(0) +# $bname(dh, jobz, uplo, n, A, lda, W, out) +# return out[] * sizeof($elty) +# end + +# with_workspace(dh.workspace_gpu, bufferSize) do buffer +# return $fname(dh, jobz, uplo, n, A, lda, W, +# buffer, sizeof(buffer) ÷ sizeof($elty), dh.info) +# end + +# info = @allowscalar dh.info[1] +# chkargsok(BlasInt(info)) + +# if jobz == 'N' +# return W +# elseif jobz == 'V' +# return W, A +# end +# end +# end +# end + +# Wrapper for Hermitian Eigenvalue Problem +function heevd!(jobz::Char, uplo::Char, A::StridedCuMatrix{T}, + W::StridedCuVector{T}) where {T<:BlasFloat} + chkuplo(uplo) + n = checksquare(A) + lda = max(1, stride(A, 2)) + dh = dense_handle() + + function bufferSize() + out = Ref{Cint}(0) + cusolverDnSsyevd_bufferSize(dh, jobz, uplo, n, A, lda, W, out) + return out[] * sizeof(T) + end + + with_workspace(dh.workspace_gpu, bufferSize) do buffer + return cusolverDnSsyevd(dh, jobz, uplo, n, A, lda, W, buffer, + sizeof(buffer) ÷ sizeof(T), dh.info) + end + + info = @allowscalar dh.info[1] + chkargsok(BlasInt(info)) + return W, A +end + +# Wrapper for Non-Hermitian Eigenvalue Problem +function geevd!(jobvl::Char, jobvr::Char, A::StridedCuMatrix{T}, W::StridedCuVector{T}, + VL::StridedCuMatrix{T}, VR::StridedCuMatrix{T}) where {T<:BlasFloat} + n = checksquare(A) + lda = max(1, stride(A, 2)) + ldvl = max(1, stride(VL, 2)) + ldvr = max(1, stride(VR, 2)) + dh = dense_handle() + + function bufferSize() + out = Ref{Cint}(0) + cusolverDnSgeev_bufferSize(dh, jobvl, jobvr, n, A, lda, W, VL, ldvl, VR, ldvr, out) + return out[] * sizeof(T) + end + + with_workspace(dh.workspace_gpu, bufferSize) do buffer + return cusolverDnSgeev(dh, jobvl, jobvr, n, A, lda, W, VL, ldvl, VR, ldvr, buffer, + sizeof(buffer) ÷ sizeof(T), dh.info) + end + + info = @allowscalar dh.info[1] + chkargsok(BlasInt(info)) + return W, VL, VR +end + +# Wrapper for Randomized SVD (example implementation) +function randomized_svd!(A::StridedCuMatrix{T}, S::StridedCuVector{T}, + U::StridedCuMatrix{T}, V::StridedCuMatrix{T}, + rank::Int) where {T<:BlasFloat} + # Example implementation for randomized SVD + # Generate random projection matrix + Omega = CuArray{T}(randn(size(A, 2), rank)) + Y = A * Omega + Q, _ = qr(Y) + B = Q' * A + gesvdqr!(B, S, U, V) + U = Q * U + return S, U, V +end + +end \ No newline at end of file From 4b738a08db3006e3bfd7b23ba8671d8b0df9f6ef Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Thu, 17 Apr 2025 16:33:21 +0200 Subject: [PATCH 02/16] first fixes --- ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl | 2 +- ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 6b2cfdba6..7ba462eb5 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -1,4 +1,4 @@ -module MatrixAlgebraKitChainRulesCoreExt +module MatrixAlgebraKitCUDAExt using MatrixAlgebraKit using CUDA diff --git a/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl b/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl index 495b2981e..a3d00f4c8 100644 --- a/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl +++ b/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl @@ -7,6 +7,10 @@ the diagonal elements of `R` are non-negative. """ @algdef CUSOLVER_HouseholderQR +function MatrixAlgebraKit.default_qr_algorithm(A::CuMatrix{<:BlasFloat}; kwargs...) + return CUSOLVER_HouseholderQR(; kwargs...) +end + # Outputs # ------- function MatrixAlgebraKit.initialize_output(::typeof(qr_full!), A::AbstractMatrix, From 056341e3c08fc5504f300e79a56349435db5005e Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Sun, 20 Apr 2025 23:33:01 +0200 Subject: [PATCH 03/16] working qr --- .../MatrixAlgebraKitCUDAExt.jl | 7 +- .../implementations/qr.jl | 9 +- ext/MatrixAlgebraKitCUDAExt/yacusolver.jl | 3 +- src/common/initialization.jl | 5 +- src/common/view.jl | 20 +++ src/implementations/lq.jl | 2 +- src/implementations/qr.jl | 2 +- test/cuda/qr.jl | 117 ++++++++++++++++++ 8 files changed, 155 insertions(+), 10 deletions(-) create mode 100644 test/cuda/qr.jl diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 7ba462eb5..57a74b2eb 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -1,9 +1,14 @@ module MatrixAlgebraKitCUDAExt using MatrixAlgebraKit +using MatrixAlgebraKit: @algdef, Algorithm, check_input +using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! +using MatrixAlgebraKit: diagview, sign_safe using CUDA +using LinearAlgebra +using LinearAlgebra: BlasFloat include("yacusolver.jl") -inculde("implementations/qr.jl") +include("implementations/qr.jl") end \ No newline at end of file diff --git a/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl b/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl index a3d00f4c8..ba354a48e 100644 --- a/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl +++ b/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl @@ -78,14 +78,15 @@ function _cusolver_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; # henceforth, τ is no longer needed and can be reused if positive # already fix Q even if we do not need R + # TODO: report that `lmul!` and `rmul!` with `Diagonal` don't work with CUDA τ .= sign_safe.(diagview(A)) - Q = rmul!(Q, Diagonal(τ)) + Q .= Q .* transpose(τ) end if computeR - R̃ = triu!(view(A, axes(R)...)) + R̃ = uppertriangular!(view(A, axes(R)...)) if positive - R̃ = lmul!(Diagonal(τ)', R̃) + R̃ .= conj.(τ) .* R̃ end copyto!(R, R̃) end @@ -103,6 +104,6 @@ function _cusolver_qr_null!(A::AbstractMatrix, N::AbstractMatrix; fill!(N, zero(eltype(N))) one!(view(N, (minmn + 1):m, 1:(m - minmn))) A, τ = YACUSOLVER.geqrf!(A) - N = unmqr!('L', 'N', A, τ, N) + N = YACUSOLVER.unmqr!('L', 'N', A, τ, N) return N end diff --git a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl index ff1fc167e..32da84e3c 100644 --- a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl +++ b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl @@ -1,10 +1,11 @@ module YACUSOLVER using LinearAlgebra -using LinearAlgebra: BlasInt, checksquare, chkstride1, require_one_based_indexing +using LinearAlgebra: BlasInt, BlasFloat, checksquare, chkstride1, require_one_based_indexing using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag, chkuplo using CUDA +using CUDA: @allowscalar using CUDA.CUSOLVER: cusolverDnCreate # QR methods are implemented with full access to allocated arrays, so we do not need to redo this: diff --git a/src/common/initialization.jl b/src/common/initialization.jl index 6a5cba786..65aa4b38e 100644 --- a/src/common/initialization.jl +++ b/src/common/initialization.jl @@ -1,11 +1,12 @@ # TODO: Consider using zerovector! if using VectorInterface.jl -function zero!(A::AbstractMatrix) +function zero!(A::AbstractArray) A .= zero(eltype(A)) return A end function one!(A::AbstractMatrix) length(A) > 0 || return A - copyto!(A, LinearAlgebra.I) + zero!(A) + diagview(A) .= one(eltype(A)) return A end diff --git a/src/common/view.jl b/src/common/view.jl index c8ae1aa5f..00fd901eb 100644 --- a/src/common/view.jl +++ b/src/common/view.jl @@ -28,3 +28,23 @@ function uppertriangularind(A::AbstractMatrix) end return I end + +function uppertriangular!(A::AbstractMatrix) + Base.require_one_based_indexing(A) + m, n = size(A) + for i in 1:n + r = (i + 1):m + zero!(view(A, r, i)) + end + return A +end + +function lowertriangular!(A::AbstractMatrix) + Base.require_one_based_indexing(A) + m, n = size(A) + for i in 2:n + r = 1:(i - 1) + zero!(view(A, r, i)) + end + return A +end \ No newline at end of file diff --git a/src/implementations/lq.jl b/src/implementations/lq.jl index 165c63b56..69e337c8d 100644 --- a/src/implementations/lq.jl +++ b/src/implementations/lq.jl @@ -126,7 +126,7 @@ function _lapack_lq!(A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix; end if computeL - L̃ = tril!(view(A, axes(L)...)) + L̃ = lowertriangular!(view(A, axes(L)...)) if positive @inbounds for j in 1:minmn s = conj(sign_safe(L̃[j, j])) diff --git a/src/implementations/qr.jl b/src/implementations/qr.jl index 7e2e13eab..66c42622c 100644 --- a/src/implementations/qr.jl +++ b/src/implementations/qr.jl @@ -130,7 +130,7 @@ function _lapack_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; end if computeR - R̃ = triu!(view(A, axes(R)...)) + R̃ = uppertriangular!(view(A, axes(R)...)) if positive @inbounds for j in n:-1:1 @simd for i in 1:min(minmn, j) diff --git a/test/cuda/qr.jl b/test/cuda/qr.jl new file mode 100644 index 000000000..ef16bd77b --- /dev/null +++ b/test/cuda/qr.jl @@ -0,0 +1,117 @@ +using MatrixAlgebraKit +using MatrixAlgebraKit: diagview +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: diag, I +using CUDA + +function isapproxone(A) + return (size(A, 1) == size(A, 2)) && (A ≈ MatrixAlgebraKit.one!(similar(A))) +end + +@testset "qr_compact! and qr_null! for T = $T" for T in (Float32, Float64, ComplexF32, + ComplexF64) + rng = StableRNG(123) + m = 54 + for n in (37, m, 63) + minmn = min(m, n) + A = CuArray(randn(rng, T, m, n)) + Q, R = @constinferred qr_compact(A) + @test Q isa CuMatrix{T} && size(Q) == (m, minmn) + @test R isa CuMatrix{T} && size(R) == (minmn, n) + @test Q * R ≈ A + N = @constinferred qr_null(A) + @test N isa CuMatrix{T} && size(N) == (m, m - minmn) + @test isapproxone(Q' * Q) + @test maximum(abs, A' * N) < eps(real(T))^(2 / 3) + @test isapproxone(N' * N) + + Ac = similar(A) + Q2, R2 = @constinferred qr_compact!(copy!(Ac, A), (Q, R)) + @test Q2 === Q + @test R2 === R + N2 = @constinferred qr_null!(copy!(Ac, A), N) + @test N2 === N + + # noR + Q2 = similar(Q) + noR = similar(A, minmn, 0) + qr_compact!(copy!(Ac, A), (Q2, noR)) + @test Q == Q2 + + # positive + qr_compact!(copy!(Ac, A), (Q, R); positive=true) + @test Q * R ≈ A + @test isapproxone(Q' * Q) + @test all(>=(zero(real(T))), real(diagview(R))) + qr_compact!(copy!(Ac, A), (Q2, noR); positive=true) + @test Q == Q2 + + # explicit blocksize + qr_compact!(copy!(Ac, A), (Q, R); blocksize=1) + @test Q * R ≈ A + @test isapproxone(Q' * Q) + qr_compact!(copy!(Ac, A), (Q2, noR); blocksize=1) + @test Q == Q2 + qr_compact!(copy!(Ac, A), (Q2, noR); blocksize=1) + qr_null!(copy!(Ac, A), N; blocksize=1) + @test maximum(abs, A' * N) < eps(real(T))^(2 / 3) + @test isapproxone(N' * N) + if n <= m + qr_compact!(copy!(Q2, A), (Q2, noR); blocksize=1) # in-place Q + @test Q ≈ Q2 + @test_throws ArgumentError qr_compact!(copy!(Q2, A), (Q2, R); blocksize=1) + @test_throws ArgumentError qr_compact!(copy!(Q2, A), (Q2, R); blocksize=8) + @test_throws ArgumentError qr_compact!(copy!(Q2, A), (Q2, noR); positive=true) + @test_throws ArgumentError qr_compact!(copy!(Q2, A), (Q2, noR); blocksize=8) + end + end +end + +@testset "qr_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + rng = StableRNG(123) + m = 63 + for n in m # (37, m, 63) + minmn = min(m, n) + A = CuArray(randn(rng, T, m, n)) + Q, R = qr_full(A) + @test Q isa CuMatrix{T} && size(Q) == (m, m) + @test R isa CuMatrix{T} && size(R) == (m, n) + @test Q * R ≈ A + @test Q' * Q ≈ I + + Ac = similar(A) + Q2 = similar(Q) + noR = similar(A, m, 0) + Q2, R2 = @constinferred qr_full!(copy!(Ac, A), (Q, R)) + @test Q2 === Q + @test R2 === R + @test Q * R ≈ A + @test Q' * Q ≈ I + qr_full!(copy!(Ac, A), (Q2, noR)) + @test Q == Q2 + + # unblocked algorithm + qr_full!(copy!(Ac, A), (Q, R); blocksize=1) + @test Q * R ≈ A + @test Q' * Q ≈ I + qr_full!(copy!(Ac, A), (Q2, noR); blocksize=1) + @test Q == Q2 + if n == m + qr_full!(copy!(Q2, A), (Q2, noR); blocksize=1) # in-place Q + @test Q ≈ Q2 + end + # other blocking + @test_throws ArgumentError qr_full!(copy!(Ac, A), (Q, R); blocksize=8) + @test_throws ArgumentError qr_full!(copy!(Q2, A), (Q2, noR); positive=true) + + # positive + qr_full!(copy!(Ac, A), (Q, R); positive=true) + @test Q * R ≈ A + @test Q' * Q ≈ I + @test all(>=(zero(real(T))), real(diagview(R))) + qr_full!(copy!(Ac, A), (Q2, noR); positive=true) + @test Q == Q2 + end +end From 6a1490e44ab1518c0e351cf2d025d64d9455ed89 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Sun, 20 Apr 2025 23:41:12 +0200 Subject: [PATCH 04/16] small triangular! fix --- src/common/initialization.jl | 20 ++++++++++++++++++++ src/common/view.jl | 20 -------------------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/common/initialization.jl b/src/common/initialization.jl index 65aa4b38e..b8bc5fafd 100644 --- a/src/common/initialization.jl +++ b/src/common/initialization.jl @@ -10,3 +10,23 @@ function one!(A::AbstractMatrix) diagview(A) .= one(eltype(A)) return A end + +function uppertriangular!(A::AbstractMatrix) + Base.require_one_based_indexing(A) + m, n = size(A) + for i in 1:n + r = (i + 1):m + zero!(view(A, r, i)) + end + return A +end + +function lowertriangular!(A::AbstractMatrix) + Base.require_one_based_indexing(A) + m, n = size(A) + for i in 2:n + r = 1:min(i - 1, m) + zero!(view(A, r, i)) + end + return A +end \ No newline at end of file diff --git a/src/common/view.jl b/src/common/view.jl index 00fd901eb..c8ae1aa5f 100644 --- a/src/common/view.jl +++ b/src/common/view.jl @@ -28,23 +28,3 @@ function uppertriangularind(A::AbstractMatrix) end return I end - -function uppertriangular!(A::AbstractMatrix) - Base.require_one_based_indexing(A) - m, n = size(A) - for i in 1:n - r = (i + 1):m - zero!(view(A, r, i)) - end - return A -end - -function lowertriangular!(A::AbstractMatrix) - Base.require_one_based_indexing(A) - m, n = size(A) - for i in 2:n - r = 1:(i - 1) - zero!(view(A, r, i)) - end - return A -end \ No newline at end of file From 71737d76a5216a9b7c46a3aa78b7944532a7cd39 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Tue, 22 Apr 2025 23:25:40 +0200 Subject: [PATCH 05/16] add LQViaTransposedQR, CUDA LQ and tests --- Project.toml | 1 + .../MatrixAlgebraKitCUDAExt.jl | 2 + .../implementations/lq.jl | 4 + .../implementations/qr.jl | 35 +----- src/MatrixAlgebraKit.jl | 4 +- src/algorithms.jl | 2 +- src/implementations/eig.jl | 4 +- src/implementations/eigh.jl | 4 +- src/implementations/lq.jl | 51 +++++++- src/implementations/polar.jl | 4 +- src/implementations/qr.jl | 6 +- src/implementations/schur.jl | 4 +- src/implementations/svd.jl | 6 +- src/interface/lq.jl | 10 ++ test/cuda/lq.jl | 119 ++++++++++++++++++ test/cuda/qr.jl | 45 ++++--- test/lq.jl | 74 ++++++++++- 17 files changed, 305 insertions(+), 70 deletions(-) create mode 100644 ext/MatrixAlgebraKitCUDAExt/implementations/lq.jl create mode 100644 test/cuda/lq.jl diff --git a/Project.toml b/Project.toml index 7dd5762ea..798076856 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ MatrixAlgebraKitCUDAExt = "CUDA" Aqua = "0.6, 0.7, 0.8" ChainRulesCore = "1" ChainRulesTestUtils = "1" +CUDA = "5" JET = "0.9" LinearAlgebra = "1" SafeTestsets = "0.1" diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 57a74b2eb..3552abc5a 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -4,11 +4,13 @@ using MatrixAlgebraKit using MatrixAlgebraKit: @algdef, Algorithm, check_input using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! using MatrixAlgebraKit: diagview, sign_safe +using MatrixAlgebraKit: LQViaTransposedQR using CUDA using LinearAlgebra using LinearAlgebra: BlasFloat include("yacusolver.jl") include("implementations/qr.jl") +include("implementations/lq.jl") end \ No newline at end of file diff --git a/ext/MatrixAlgebraKitCUDAExt/implementations/lq.jl b/ext/MatrixAlgebraKitCUDAExt/implementations/lq.jl new file mode 100644 index 000000000..29db1248d --- /dev/null +++ b/ext/MatrixAlgebraKitCUDAExt/implementations/lq.jl @@ -0,0 +1,4 @@ +function MatrixAlgebraKit.default_lq_algorithm(A::CuMatrix{<:BlasFloat}; kwargs...) + qr_alg = CUSOLVER_HouseholderQR(; kwargs...) + return LQViaTransposedQR(qr_alg) +end \ No newline at end of file diff --git a/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl b/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl index ba354a48e..9006361e6 100644 --- a/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl +++ b/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl @@ -11,31 +11,6 @@ function MatrixAlgebraKit.default_qr_algorithm(A::CuMatrix{<:BlasFloat}; kwargs. return CUSOLVER_HouseholderQR(; kwargs...) end -# Outputs -# ------- -function MatrixAlgebraKit.initialize_output(::typeof(qr_full!), A::AbstractMatrix, - ::CUSOLVER_HouseholderQR) - m, n = size(A) - Q = similar(A, (m, m)) - R = similar(A, (m, n)) - return (Q, R) -end -function MatrixAlgebraKit.initialize_output(::typeof(qr_compact!), A::AbstractMatrix, - ::CUSOLVER_HouseholderQR) - m, n = size(A) - minmn = min(m, n) - Q = similar(A, (m, minmn)) - R = similar(A, (minmn, n)) - return (Q, R) -end -function MatrixAlgebraKit.initialize_output(::typeof(qr_null!), A::AbstractMatrix, - ::CUSOLVER_HouseholderQR) - m, n = size(A) - minmn = min(m, n) - N = similar(A, (m, m - minmn)) - return N -end - # Implementation # -------------- # actual implementation @@ -80,13 +55,15 @@ function _cusolver_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; if positive # already fix Q even if we do not need R # TODO: report that `lmul!` and `rmul!` with `Diagonal` don't work with CUDA τ .= sign_safe.(diagview(A)) - Q .= Q .* transpose(τ) + Qf = view(Q, 1:m, 1:minmn) # first minmn columns of Q + Qf .= Qf .* transpose(τ) end if computeR R̃ = uppertriangular!(view(A, axes(R)...)) if positive - R̃ .= conj.(τ) .* R̃ + R̃f = view(R̃, 1:minmn, 1:n) # first minmn rows of R + R̃f .= conj.(τ) .* R̃f end copyto!(R, R̃) end @@ -94,9 +71,7 @@ function _cusolver_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; end function _cusolver_qr_null!(A::AbstractMatrix, N::AbstractMatrix; - positive=false, - pivoted=false, - blocksize=1) + positive=false, blocksize=1) blocksize > 1 && throw(ArgumentError("CUSOLVER does not provide a blocked implementation for a QR decomposition")) m, n = size(A) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 567c1510f..e68473868 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -2,12 +2,12 @@ module MatrixAlgebraKit using LinearAlgebra: LinearAlgebra using LinearAlgebra: norm # TODO: eleminate if we use VectorInterface.jl? -using LinearAlgebra: mul!, rmul!, lmul! +using LinearAlgebra: mul!, rmul!, lmul!, adjoint!, rdiv!, ldiv! using LinearAlgebra: sylvester using LinearAlgebra: isposdef, ishermitian using LinearAlgebra: Diagonal, diag, diagind using LinearAlgebra: UpperTriangular, LowerTriangular -using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt, triu!, tril!, rdiv!, ldiv! +using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt export isisometry, isunitary diff --git a/src/algorithms.jl b/src/algorithms.jl index af1831030..7dbfe46e1 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -146,7 +146,7 @@ macro algdef(name) return $name{typeof(kw)}(kw) end function Base.show(io::IO, alg::$name) - return _show_alg(io, alg) + return ($_show_alg)(io, alg) end Core.@__doc__ $name diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 6a48e1950..6d0efccad 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -30,14 +30,14 @@ end # Outputs # ------- -function initialize_output(::typeof(eig_full!), A::AbstractMatrix, ::LAPACK_EigAlgorithm) +function initialize_output(::typeof(eig_full!), A::AbstractMatrix, ::AbstractAlgorithm) n = size(A, 1) # square check will happen later Tc = complex(eltype(A)) D = Diagonal(similar(A, Tc, n)) V = similar(A, Tc, (n, n)) return (D, V) end -function initialize_output(::typeof(eig_vals!), A::AbstractMatrix, ::LAPACK_EigAlgorithm) +function initialize_output(::typeof(eig_vals!), A::AbstractMatrix, ::AbstractAlgorithm) n = size(A, 1) # square check will happen later Tc = complex(eltype(A)) D = similar(A, Tc, n) diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index a1c6c779d..5178748ec 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -29,13 +29,13 @@ end # Outputs # ------- -function initialize_output(::typeof(eigh_full!), A::AbstractMatrix, ::LAPACK_EighAlgorithm) +function initialize_output(::typeof(eigh_full!), A::AbstractMatrix, ::AbstractAlgorithm) n = size(A, 1) # square check will happen later D = Diagonal(similar(A, real(eltype(A)), n)) V = similar(A, (n, n)) return (D, V) end -function initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::LAPACK_EighAlgorithm) +function initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::AbstractAlgorithm) n = size(A, 1) # square check will happen later D = similar(A, real(eltype(A)), n) return D diff --git a/src/implementations/lq.jl b/src/implementations/lq.jl index 69e337c8d..afdbb3f0e 100644 --- a/src/implementations/lq.jl +++ b/src/implementations/lq.jl @@ -42,20 +42,20 @@ end # Outputs # ------- -function initialize_output(::typeof(lq_full!), A::AbstractMatrix, ::LAPACK_HouseholderLQ) +function initialize_output(::typeof(lq_full!), A::AbstractMatrix, ::AbstractAlgorithm) m, n = size(A) L = similar(A, (m, n)) Q = similar(A, (n, n)) return (L, Q) end -function initialize_output(::typeof(lq_compact!), A::AbstractMatrix, ::LAPACK_HouseholderLQ) +function initialize_output(::typeof(lq_compact!), A::AbstractMatrix, ::AbstractAlgorithm) m, n = size(A) minmn = min(m, n) L = similar(A, (m, minmn)) Q = similar(A, (minmn, n)) return (L, Q) end -function initialize_output(::typeof(lq_null!), A::AbstractMatrix, ::LAPACK_HouseholderLQ) +function initialize_output(::typeof(lq_null!), A::AbstractMatrix, ::AbstractAlgorithm) m, n = size(A) minmn = min(m, n) Nᴴ = similar(A, (n - minmn, n)) @@ -71,17 +71,34 @@ function lq_full!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ) _lapack_lq!(A, L, Q; alg.kwargs...) return L, Q end +function lq_full!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR) + check_input(lq_full!, A, LQ) + L, Q = LQ + lq_via_qr!(A, L, Q, alg.qr_alg) + return L, Q +end function lq_compact!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ) check_input(lq_compact!, A, LQ) L, Q = LQ _lapack_lq!(A, L, Q; alg.kwargs...) return L, Q end +function lq_compact!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR) + check_input(lq_compact!, A, LQ) + L, Q = LQ + lq_via_qr!(A, L, Q, alg.qr_alg) + return L, Q +end function lq_null!(A::AbstractMatrix, Nᴴ, alg::LAPACK_HouseholderLQ) check_input(lq_null!, A, Nᴴ) _lapack_lq_null!(A, Nᴴ; alg.kwargs...) return Nᴴ end +function lq_null!(A::AbstractMatrix, Nᴴ, alg::LQViaTransposedQR) + check_input(lq_null!, A, Nᴴ) + lq_null_via_qr!(A, Nᴴ, alg.qr_alg) + return Nᴴ +end function _lapack_lq!(A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix; positive=false, @@ -158,3 +175,31 @@ function _lapack_lq_null!(A::AbstractMatrix, Nᴴ::AbstractMatrix; end return Nᴴ end + +# LQ via transposition and QR +function lq_via_qr!(A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, + qr_alg::AbstractAlgorithm) + m, n = size(A) + minmn = min(m, n) + At = adjoint!(similar(A'), A) + Qt = (A === Q) ? At : similar(Q') + Lt = similar(L') + if size(Q) == (n, n) + Qt, Lt = qr_full!(At, (Qt, Lt), qr_alg) + else + Qt, Lt = qr_compact!(At, (Qt, Lt), qr_alg) + end + adjoint!(Q, Qt) + !isempty(L) && adjoint!(L, Lt) + return L, Q +end + +function lq_null_via_qr!(A::AbstractMatrix, N::AbstractMatrix, qr_alg::AbstractAlgorithm) + m, n = size(A) + minmn = min(m, n) + At = adjoint!(similar(A'), A) + Nt = similar(N') + Nt = qr_null!(At, Nt, qr_alg) + !isempty(N) && adjoint!(N, Nt) + return N +end \ No newline at end of file diff --git a/src/implementations/polar.jl b/src/implementations/polar.jl index 2604aab56..fabb261c2 100644 --- a/src/implementations/polar.jl +++ b/src/implementations/polar.jl @@ -30,13 +30,13 @@ end # Outputs # ------- -function initialize_output(::typeof(left_polar!), A::AbstractMatrix, ::PolarViaSVD) +function initialize_output(::typeof(left_polar!), A::AbstractMatrix, ::AbstractAlgorithm) m, n = size(A) W = similar(A) P = similar(A, (n, n)) return (W, P) end -function initialize_output(::typeof(right_polar!), A::AbstractMatrix, ::PolarViaSVD) +function initialize_output(::typeof(right_polar!), A::AbstractMatrix, ::AbstractAlgorithm) m, n = size(A) P = similar(A, (m, m)) Wᴴ = similar(A) diff --git a/src/implementations/qr.jl b/src/implementations/qr.jl index 66c42622c..1d30b4b1d 100644 --- a/src/implementations/qr.jl +++ b/src/implementations/qr.jl @@ -42,20 +42,20 @@ end # Outputs # ------- -function initialize_output(::typeof(qr_full!), A::AbstractMatrix, ::LAPACK_HouseholderQR) +function initialize_output(::typeof(qr_full!), A::AbstractMatrix, ::AbstractAlgorithm) m, n = size(A) Q = similar(A, (m, m)) R = similar(A, (m, n)) return (Q, R) end -function initialize_output(::typeof(qr_compact!), A::AbstractMatrix, ::LAPACK_HouseholderQR) +function initialize_output(::typeof(qr_compact!), A::AbstractMatrix, ::AbstractAlgorithm) m, n = size(A) minmn = min(m, n) Q = similar(A, (m, minmn)) R = similar(A, (minmn, n)) return (Q, R) end -function initialize_output(::typeof(qr_null!), A::AbstractMatrix, ::LAPACK_HouseholderQR) +function initialize_output(::typeof(qr_null!), A::AbstractMatrix, ::AbstractAlgorithm) m, n = size(A) minmn = min(m, n) N = similar(A, (m, m - minmn)) diff --git a/src/implementations/schur.jl b/src/implementations/schur.jl index 55a1bdfaa..541ae97fb 100644 --- a/src/implementations/schur.jl +++ b/src/implementations/schur.jl @@ -28,13 +28,13 @@ end # Outputs # ------- -function initialize_output(::typeof(schur_full!), A::AbstractMatrix, ::LAPACK_EigAlgorithm) +function initialize_output(::typeof(schur_full!), A::AbstractMatrix, ::AbstractAlgorithm) n = size(A, 1) # square check will happen later Z = similar(A, (n, n)) vals = similar(A, complex(eltype(A)), n) return (A, Z, vals) end -function initialize_output(::typeof(schur_vals!), A::AbstractMatrix, ::LAPACK_EigAlgorithm) +function initialize_output(::typeof(schur_vals!), A::AbstractMatrix, ::AbstractAlgorithm) n = size(A, 1) # square check will happen later vals = similar(A, complex(eltype(A)), n) return vals diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index a7d4c9e56..031494f83 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -44,14 +44,14 @@ end # Outputs # ------- -function initialize_output(::typeof(svd_full!), A::AbstractMatrix, ::LAPACK_SVDAlgorithm) +function initialize_output(::typeof(svd_full!), A::AbstractMatrix, ::AbstractAlgorithm) m, n = size(A) U = similar(A, (m, m)) S = similar(A, real(eltype(A)), (m, n)) # TODO: Rectangular diagonal type? Vᴴ = similar(A, (n, n)) return (U, S, Vᴴ) end -function initialize_output(::typeof(svd_compact!), A::AbstractMatrix, ::LAPACK_SVDAlgorithm) +function initialize_output(::typeof(svd_compact!), A::AbstractMatrix, ::AbstractAlgorithm) m, n = size(A) minmn = min(m, n) U = similar(A, (m, minmn)) @@ -59,7 +59,7 @@ function initialize_output(::typeof(svd_compact!), A::AbstractMatrix, ::LAPACK_S Vᴴ = similar(A, (minmn, n)) return (U, S, Vᴴ) end -function initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::LAPACK_SVDAlgorithm) +function initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::AbstractAlgorithm) return similar(A, real(eltype(A)), (min(size(A)...),)) end function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::TruncatedAlgorithm) diff --git a/src/interface/lq.jl b/src/interface/lq.jl index 6f1ed12f2..338302ef4 100644 --- a/src/interface/lq.jl +++ b/src/interface/lq.jl @@ -81,3 +81,13 @@ for f in (:lq_full!, :lq_compact!, :lq_null!) return default_lq_algorithm(A; kwargs...) end end + +# Alternative algorithm (necessary for CUDA) +struct LQViaTransposedQR{A<:AbstractAlgorithm} <: AbstractAlgorithm + qr_alg::A +end +function Base.show(io::IO, alg::LQViaTransposedQR) + print(io, "LQViaTransposedQR(") + _show_alg(io, alg.qr_alg) + return print(io, ")") +end diff --git a/test/cuda/lq.jl b/test/cuda/lq.jl new file mode 100644 index 000000000..02e8b5066 --- /dev/null +++ b/test/cuda/lq.jl @@ -0,0 +1,119 @@ +using MatrixAlgebraKit +using MatrixAlgebraKit: diagview +using Test +using TestExtras +using StableRNGs +using CUDA + +function isapproxone(A) + return (size(A, 1) == size(A, 2)) && (A ≈ MatrixAlgebraKit.one!(similar(A))) +end + +@testset "lq_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + rng = StableRNG(123) + m = 54 + for n in (37, m, 63) + minmn = min(m, n) + A = CuArray(randn(rng, T, m, n)) + L, Q = @constinferred lq_compact(A) + @test L isa CuMatrix{T} && size(L) == (m, minmn) + @test Q isa CuMatrix{T} && size(Q) == (minmn, n) + @test L * Q ≈ A + @test isapproxone(Q * Q') + Nᴴ = @constinferred lq_null(A) + @test Nᴴ isa CuMatrix{T} && size(Nᴴ) == (n - minmn, n) + @test maximum(abs, A * Nᴴ') < eps(real(T))^(2 / 3) + @test isapproxone(Nᴴ * Nᴴ') + + Ac = similar(A) + L2, Q2 = @constinferred lq_compact!(copy!(Ac, A), (L, Q)) + @test L2 === L + @test Q2 === Q + Nᴴ2 = @constinferred lq_null!(copy!(Ac, A), Nᴴ) + @test Nᴴ2 === Nᴴ + + # noL + noL = similar(A, 0, minmn) + Q2 = similar(Q) + lq_compact!(copy!(Ac, A), (noL, Q2)) + @test Q == Q2 + + # positive + lq_compact!(copy!(Ac, A), (L, Q); positive=true) + @test L * Q ≈ A + @test isapproxone(Q * Q') + @test all(>=(zero(real(T))), real(diagview(L))) + lq_compact!(copy!(Ac, A), (noL, Q2); positive=true) + @test Q == Q2 + + # explicit blocksize + lq_compact!(copy!(Ac, A), (L, Q); blocksize=1) + @test L * Q ≈ A + @test isapproxone(Q * Q') + lq_compact!(copy!(Ac, A), (noL, Q2); blocksize=1) + @test Q == Q2 + lq_null!(copy!(Ac, A), Nᴴ; blocksize=1) + @test maximum(abs, A * Nᴴ') < eps(real(T))^(2 / 3) + @test isapproxone(Nᴴ * Nᴴ') + if m <= n + lq_compact!(copy!(Q2, A), (noL, Q2); blocksize=1) # in-place Q + @test Q ≈ Q2 + # these do not work because of the in-place Q + @test_throws ArgumentError lq_compact!(copy!(Q2, A), (L, Q2); blocksize=1) + @test_throws ArgumentError lq_compact!(copy!(Q2, A), (noL, Q2); positive=true) + end + # no blocked CUDA + @test_throws ArgumentError lq_compact!(copy!(Q2, A), (L, Q2); blocksize=8) + end +end + +@testset "lq_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + rng = StableRNG(123) + m = 54 + for n in (37, m, 63) + minmn = min(m, n) + A = CuArray(randn(rng, T, m, n)) + L, Q = lq_full(A) + @test L isa CuMatrix{T} && size(L) == (m, n) + @test Q isa CuMatrix{T} && size(Q) == (n, n) + @test L * Q ≈ A + @test isapproxone(Q * Q') + + Ac = similar(A) + L2, Q2 = @constinferred lq_full!(copy!(Ac, A), (L, Q)) + @test L2 === L + @test Q2 === Q + @test L * Q ≈ A + @test isapproxone(Q * Q') + + # noL + noL = similar(A, 0, n) + Q2 = similar(Q) + lq_full!(copy!(Ac, A), (noL, Q2)) + @test Q == Q2 + + # positive + lq_full!(copy!(Ac, A), (L, Q); positive=true) + @test L * Q ≈ A + @test isapproxone(Q * Q') + @test all(>=(zero(real(T))), real(diagview(L))) + lq_full!(copy!(Ac, A), (noL, Q2); positive=true) + @test Q == Q2 + + # explicit blocksize + lq_full!(copy!(Ac, A), (L, Q); blocksize=1) + @test L * Q ≈ A + @test isapproxone(Q * Q') + lq_full!(copy!(Ac, A), (noL, Q2); blocksize=1) + @test Q == Q2 + if n == m + lq_full!(copy!(Q2, A), (noL, Q2); blocksize=1) # in-place Q + @test Q ≈ Q2 + # these do not work because of the in-place Q + @test_throws ArgumentError lq_full!(copy!(Q2, A), (L, Q2); blocksize=1) + @test_throws ArgumentError lq_full!(copy!(Q2, A), (noL, Q2); positive=true) + end + # no blocked CUDA + @test_throws ArgumentError lq_full!(copy!(Ac, A), (L, Q); blocksize=8) + end +end diff --git a/test/cuda/qr.jl b/test/cuda/qr.jl index ef16bd77b..374d07064 100644 --- a/test/cuda/qr.jl +++ b/test/cuda/qr.jl @@ -3,7 +3,6 @@ using MatrixAlgebraKit: diagview using Test using TestExtras using StableRNGs -using LinearAlgebra: diag, I using CUDA function isapproxone(A) @@ -61,25 +60,26 @@ end if n <= m qr_compact!(copy!(Q2, A), (Q2, noR); blocksize=1) # in-place Q @test Q ≈ Q2 - @test_throws ArgumentError qr_compact!(copy!(Q2, A), (Q2, R); blocksize=1) - @test_throws ArgumentError qr_compact!(copy!(Q2, A), (Q2, R); blocksize=8) + # these do not work because of the in-place Q + @test_throws ArgumentError qr_compact!(copy!(Q2, A), (Q2, R2)) @test_throws ArgumentError qr_compact!(copy!(Q2, A), (Q2, noR); positive=true) - @test_throws ArgumentError qr_compact!(copy!(Q2, A), (Q2, noR); blocksize=8) end + # no blocked CUDA + @test_throws ArgumentError qr_compact!(copy!(Ac, A), (Q2, R); blocksize=8) end end @testset "qr_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) m = 63 - for n in m # (37, m, 63) + for n in (37, m, 63) minmn = min(m, n) A = CuArray(randn(rng, T, m, n)) Q, R = qr_full(A) @test Q isa CuMatrix{T} && size(Q) == (m, m) @test R isa CuMatrix{T} && size(R) == (m, n) @test Q * R ≈ A - @test Q' * Q ≈ I + @test isapproxone(Q' * Q) Ac = similar(A) Q2 = similar(Q) @@ -88,30 +88,37 @@ end @test Q2 === Q @test R2 === R @test Q * R ≈ A - @test Q' * Q ≈ I + @test isapproxone(Q' * Q) qr_full!(copy!(Ac, A), (Q2, noR)) @test Q == Q2 - # unblocked algorithm + # noR + noR = similar(A, m, 0) + Q2 = similar(Q) + qr_full!(copy!(Ac, A), (Q2, noR)) + @test Q == Q2 + + # positive + qr_full!(copy!(Ac, A), (Q, R); positive=true) + @test Q * R ≈ A + @test isapproxone(Q' * Q) + @test all(>=(zero(real(T))), real(diagview(R))) + qr_full!(copy!(Ac, A), (Q2, noR); positive=true) + @test Q == Q2 + + # explicit blocksize qr_full!(copy!(Ac, A), (Q, R); blocksize=1) @test Q * R ≈ A - @test Q' * Q ≈ I + @test isapproxone(Q' * Q) qr_full!(copy!(Ac, A), (Q2, noR); blocksize=1) @test Q == Q2 if n == m qr_full!(copy!(Q2, A), (Q2, noR); blocksize=1) # in-place Q @test Q ≈ Q2 + @test_throws ArgumentError qr_full!(copy!(Q2, A), (Q2, R2)) + @test_throws ArgumentError qr_full!(copy!(Q2, A), (Q2, noR); positive=true) end - # other blocking + # no blocked CUDA @test_throws ArgumentError qr_full!(copy!(Ac, A), (Q, R); blocksize=8) - @test_throws ArgumentError qr_full!(copy!(Q2, A), (Q2, noR); positive=true) - - # positive - qr_full!(copy!(Ac, A), (Q, R); positive=true) - @test Q * R ≈ A - @test Q' * Q ≈ I - @test all(>=(zero(real(T))), real(diagview(R))) - qr_full!(copy!(Ac, A), (Q2, noR); positive=true) - @test Q == Q2 end end diff --git a/test/lq.jl b/test/lq.jl index f12ea99a9..581f1782b 100644 --- a/test/lq.jl +++ b/test/lq.jl @@ -3,6 +3,7 @@ using Test using TestExtras using StableRNGs using LinearAlgebra: diag, I +using MatrixAlgebraKit: LQViaTransposedQR, LAPACK_HouseholderQR @testset "lq_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) @@ -32,6 +33,19 @@ using LinearAlgebra: diag, I lq_compact!(copy!(Ac, A), (noL, Q2)) @test Q == Q2 + # Transposed QR algorithm + qr_alg = LAPACK_HouseholderQR() + lq_alg = LQViaTransposedQR(qr_alg) + L2, Q2 = @constinferred lq_compact!(copy!(Ac, A), (L, Q), lq_alg) + @test L2 === L + @test Q2 === Q + Nᴴ2 = @constinferred lq_null!(copy!(Ac, A), Nᴴ, lq_alg) + @test Nᴴ2 === Nᴴ + noL = similar(A, 0, minmn) + Q2 = similar(Q) + lq_compact!(copy!(Ac, A), (noL, Q2), lq_alg) + @test Q == Q2 + # unblocked algorithm lq_compact!(copy!(Ac, A), (L, Q); blocksize=1) @test L * Q ≈ A @@ -56,8 +70,22 @@ using LinearAlgebra: diag, I lq_null!(copy!(Ac, A), Nᴴ; blocksize=8) @test maximum(abs, A * Nᴴ') < eps(real(T))^(2 / 3) @test isisometry(Nᴴ; side=:right) + @test Nᴴ * Nᴴ' ≈ I + + qr_alg = LAPACK_HouseholderQR(; blocksize=1) + lq_alg = LQViaTransposedQR(qr_alg) + lq_compact!(copy!(Ac, A), (L, Q), lq_alg) + @test L * Q ≈ A + @test Q * Q' ≈ I + lq_compact!(copy!(Ac, A), (noL, Q2), lq_alg) + @test Q == Q2 + lq_null!(copy!(Ac, A), Nᴴ, lq_alg) + @test maximum(abs, A * Nᴴ') < eps(real(T))^(2 / 3) + @test Nᴴ * Nᴴ' ≈ I + # pivoted @test_throws ArgumentError lq_compact!(copy!(Ac, A), (L, Q); pivoted=true) + # positive lq_compact!(copy!(Ac, A), (L, Q); positive=true) @test L * Q ≈ A @@ -65,6 +93,7 @@ using LinearAlgebra: diag, I @test all(>=(zero(real(T))), real(diag(L))) lq_compact!(copy!(Ac, A), (noL, Q2); positive=true) @test Q == Q2 + # positive and blocksize 1 lq_compact!(copy!(Ac, A), (L, Q); positive=true, blocksize=1) @test L * Q ≈ A @@ -72,6 +101,14 @@ using LinearAlgebra: diag, I @test all(>=(zero(real(T))), real(diag(L))) lq_compact!(copy!(Ac, A), (noL, Q2); positive=true, blocksize=1) @test Q == Q2 + qr_alg = LAPACK_HouseholderQR(; positive=true, blocksize=1) + lq_alg = LQViaTransposedQR(qr_alg) + lq_compact!(copy!(Ac, A), (L, Q), lq_alg) + @test L * Q ≈ A + @test Q * Q' ≈ I + @test all(>=(zero(real(T))), real(diag(L))) + lq_compact!(copy!(Ac, A), (noL, Q2), lq_alg) + @test Q == Q2 end end @@ -99,6 +136,19 @@ end lq_full!(copy!(Ac, A), (noL, Q2)) @test Q == Q2 + # Transposed QR algorithm + qr_alg = LAPACK_HouseholderQR() + lq_alg = LQViaTransposedQR(qr_alg) + L2, Q2 = @constinferred lq_full!(copy!(Ac, A), (L, Q), lq_alg) + @test L2 === L + @test Q2 === Q + @test L * Q ≈ A + @test Q * Q' ≈ I + noL = similar(A, 0, n) + Q2 = similar(Q) + lq_full!(copy!(Ac, A), (noL, Q2), lq_alg) + @test Q == Q2 + # unblocked algorithm lq_full!(copy!(Ac, A), (L, Q); blocksize=1) @test L * Q ≈ A @@ -109,7 +159,19 @@ end lq_full!(copy!(Q2, A), (noL, Q2); blocksize=1) # in-place Q @test Q ≈ Q2 end - # # other blocking + qr_alg = LAPACK_HouseholderQR(; blocksize=1) + lq_alg = LQViaTransposedQR(qr_alg) + lq_full!(copy!(Ac, A), (L, Q), lq_alg) + @test L * Q ≈ A + @test Q * Q' ≈ I + lq_full!(copy!(Ac, A), (noL, Q2), lq_alg) + @test Q == Q2 + if n == m + lq_full!(copy!(Q2, A), (noL, Q2), lq_alg) # in-place Q + @test Q ≈ Q2 + end + + # other blocking lq_full!(copy!(Ac, A), (L, Q); blocksize=18) @test L * Q ≈ A @test isunitary(Q) @@ -124,6 +186,16 @@ end @test all(>=(zero(real(T))), real(diag(L))) lq_full!(copy!(Ac, A), (noL, Q2); positive=true) @test Q == Q2 + + qr_alg = LAPACK_HouseholderQR(; positive=true) + lq_alg = LQViaTransposedQR(qr_alg) + lq_full!(copy!(Ac, A), (L, Q), lq_alg) + @test L * Q ≈ A + @test Q * Q' ≈ I + @test all(>=(zero(real(T))), real(diag(L))) + lq_full!(copy!(Ac, A), (noL, Q2), lq_alg) + @test Q == Q2 + # positive and blocksize 1 lq_full!(copy!(Ac, A), (L, Q); positive=true, blocksize=1) @test L * Q ≈ A From 88678f37aad1f523d54a04bdf870ec3341fc7763 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Thu, 1 May 2025 12:18:21 +0200 Subject: [PATCH 06/16] first svd support --- Project.toml | 3 +- .../MatrixAlgebraKitCUDAExt.jl | 13 +- .../implementations/lq.jl | 4 - .../implementations/qr.jl | 17 +- .../implementations/svd.jl | 106 ++++++ ext/MatrixAlgebraKitCUDAExt/yacusolver.jl | 303 +++++++++--------- src/MatrixAlgebraKit.jl | 6 +- .../decompositions.jl | 44 ++- test/cuda/lq.jl | 4 +- test/cuda/qr.jl | 4 +- test/cuda/svd.jl | 120 +++++++ test/cuda/utilities.jl | 3 + test/runtests.jl | 13 + 13 files changed, 460 insertions(+), 180 deletions(-) delete mode 100644 ext/MatrixAlgebraKitCUDAExt/implementations/lq.jl create mode 100644 ext/MatrixAlgebraKitCUDAExt/implementations/svd.jl rename src/{implementations => interface}/decompositions.jl (74%) create mode 100644 test/cuda/svd.jl create mode 100644 test/cuda/utilities.jl diff --git a/Project.toml b/Project.toml index 798076856..efc996b4d 100644 --- a/Project.toml +++ b/Project.toml @@ -39,5 +39,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", - "ChainRulesTestUtils", "StableRNGs", "Zygote"] +test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA"] diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 3552abc5a..29a222319 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -11,6 +11,17 @@ using LinearAlgebra: BlasFloat include("yacusolver.jl") include("implementations/qr.jl") -include("implementations/lq.jl") +include("implementations/svd.jl") + +function MatrixAlgebraKit.default_qr_algorithm(A::CuMatrix{<:BlasFloat}; kwargs...) + return CUSOLVER_HouseholderQR(; kwargs...) +end +function MatrixAlgebraKit.default_lq_algorithm(A::CuMatrix{<:BlasFloat}; kwargs...) + qr_alg = CUSOLVER_HouseholderQR(; kwargs...) + return LQViaTransposedQR(qr_alg) +end +function MatrixAlgebraKit.default_svd_algorithm(A::CuMatrix{<:BlasFloat}; kwargs...) + return CUSOLVER_QRIteration(; kwargs...) +end end \ No newline at end of file diff --git a/ext/MatrixAlgebraKitCUDAExt/implementations/lq.jl b/ext/MatrixAlgebraKitCUDAExt/implementations/lq.jl deleted file mode 100644 index 29db1248d..000000000 --- a/ext/MatrixAlgebraKitCUDAExt/implementations/lq.jl +++ /dev/null @@ -1,4 +0,0 @@ -function MatrixAlgebraKit.default_lq_algorithm(A::CuMatrix{<:BlasFloat}; kwargs...) - qr_alg = CUSOLVER_HouseholderQR(; kwargs...) - return LQViaTransposedQR(qr_alg) -end \ No newline at end of file diff --git a/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl b/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl index 9006361e6..636cf82e3 100644 --- a/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl +++ b/ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl @@ -1,19 +1,4 @@ -""" - CUSOLVER_HouseholderQR(; positive = false) - -Algorithm type to denote the standard CUSOLVER algorithm for computing the QR decomposition of -a matrix using Householder reflectors. The keyword `positive=true` can be used to ensure that -the diagonal elements of `R` are non-negative. -""" -@algdef CUSOLVER_HouseholderQR - -function MatrixAlgebraKit.default_qr_algorithm(A::CuMatrix{<:BlasFloat}; kwargs...) - return CUSOLVER_HouseholderQR(; kwargs...) -end - -# Implementation -# -------------- -# actual implementation +# CUSOLVER QR implementation function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::CUSOLVER_HouseholderQR) check_input(qr_full!, A, QR) Q, R = QR diff --git a/ext/MatrixAlgebraKitCUDAExt/implementations/svd.jl b/ext/MatrixAlgebraKitCUDAExt/implementations/svd.jl new file mode 100644 index 000000000..46376a345 --- /dev/null +++ b/ext/MatrixAlgebraKitCUDAExt/implementations/svd.jl @@ -0,0 +1,106 @@ +const CUSOLVER_SVDAlgorithm = Union{CUSOLVER_QRIteration, + CUSOLVER_SVDPolar, + CUSOLVER_Jacobi} + +# CUSOLVER SVD implementation +function MatrixAlgebraKit.svd_full!(A::CuMatrix, USVᴴ, alg::CUSOLVER_SVDAlgorithm) + check_input(svd_full!, A, USVᴴ) + U, S, Vᴴ = USVᴴ + fill!(S, zero(eltype(S))) + m, n = size(A) + minmn = min(m, n) + if alg isa CUSOLVER_QRIteration + isempty(alg.kwargs) || + throw(ArgumentError("LAPACK_QRIteration does not accept any keyword arguments")) + YACUSOLVER.gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ) + elseif alg isa CUSOLVER_SVDPolar + YACUSOLVER.Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...) + # elseif alg isa LAPACK_Bisection + # throw(ArgumentError("LAPACK_Bisection is not supported for full SVD")) + # elseif alg isa LAPACK_Jacobi + # throw(ArgumentError("LAPACK_Bisection is not supported for full SVD")) + else + throw(ArgumentError("Unsupported SVD algorithm")) + end + diagview(S) .= view(S, 1:minmn, 1) + view(S, 2:minmn, 1) .= zero(eltype(S)) + # TODO: make this controllable using a `gaugefix` keyword argument + for j in 1:max(m, n) + if j <= minmn + u = view(U, :, j) + v = view(Vᴴ, j, :) + s = conj(sign(_argmaxabs(u))) + u .*= s + v .*= conj(s) + elseif j <= m + u = view(U, :, j) + s = conj(sign(_argmaxabs(u))) + u .*= s + else + v = view(Vᴴ, j, :) + s = conj(sign(_argmaxabs(v))) + v .*= s + end + end + return USVᴴ +end + +function MatrixAlgebraKit.svd_compact!(A::CuMatrix, USVᴴ, alg::CUSOLVER_SVDAlgorithm) + check_input(svd_compact!, A, USVᴴ) + U, S, Vᴴ = USVᴴ + if alg isa CUSOLVER_QRIteration + isempty(alg.kwargs) || + throw(ArgumentError("CUSOLVER_QRIteration does not accept any keyword arguments")) + YACUSOLVER.gesvd!(A, S.diag, U, Vᴴ) + elseif alg isa CUSOLVER_SVDPolar + YACUSOLVER.Xgesvdp!(A, S.diag, U, Vᴴ; alg.kwargs...) + # elseif alg isa LAPACK_DivideAndConquer + # isempty(alg.kwargs) || + # throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments")) + # YALAPACK.gesdd!(A, S.diag, U, Vᴴ) + # elseif alg isa LAPACK_Bisection + # YALAPACK.gesvdx!(A, S.diag, U, Vᴴ; alg.kwargs...) + # elseif alg isa LAPACK_Jacobi + # isempty(alg.kwargs) || + # throw(ArgumentError("LAPACK_Jacobi does not accept any keyword arguments")) + # YALAPACK.gesvj!(A, S.diag, U, Vᴴ) + else + throw(ArgumentError("Unsupported SVD algorithm")) + end + # TODO: make this controllable using a `gaugefix` keyword argument + for j in 1:size(U, 2) + u = view(U, :, j) + v = view(Vᴴ, j, :) + s = conj(sign(_argmaxabs(u))) + u .*= s + v .*= conj(s) + end + return USVᴴ +end +_argmaxabs(x) = reduce(_largest, x; init=zero(eltype(x))) +_largest(x, y) = abs(x) < abs(y) ? y : x + +function MatrixAlgebraKit.svd_vals!(A::CuMatrix, S, alg::CUSOLVER_SVDAlgorithm) + check_input(svd_vals!, A, S) + U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) + if alg isa CUSOLVER_QRIteration + isempty(alg.kwargs) || + throw(ArgumentError("CUSOLVER_QRIteration does not accept any keyword arguments")) + YACUSOLVER.gesvd!(A, S, U, Vᴴ) + elseif alg isa CUSOLVER_SVDPolar + YACUSOLVER.Xgesvdp!(A, S, U, Vᴴ; alg.kwargs...) + # elseif alg isa LAPACK_DivideAndConquer + # isempty(alg.kwargs) || + # throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments")) + # YALAPACK.gesdd!(A, S, U, Vᴴ) + # elseif alg isa LAPACK_Bisection + # YALAPACK.gesvdx!(A, S, U, Vᴴ; alg.kwargs...) + # elseif alg isa LAPACK_Jacobi + # isempty(alg.kwargs) || + # throw(ArgumentError("LAPACK_Jacobi does not accept any keyword arguments")) + # YALAPACK.gesvj!(A, S, U, Vᴴ) + else + throw(ArgumentError("Unsupported SVD algorithm")) + end + return S +end \ No newline at end of file diff --git a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl index 32da84e3c..117f3aeed 100644 --- a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl +++ b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl @@ -6,113 +6,188 @@ using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdia using CUDA using CUDA: @allowscalar -using CUDA.CUSOLVER: cusolverDnCreate +using CUDA.CUSOLVER # QR methods are implemented with full access to allocated arrays, so we do not need to redo this: using CUDA.CUSOLVER: geqrf!, ormqr!, orgqr! const unmqr! = ormqr! const ungqr! = orgqr! -# # Wrapper for SVD via QR Iteration -# for (bname, fname, elty, relty) in -# ((:cusolverDnSgesvd_bufferSize, :cusolverDnSgesvd, :Float32, :Float32), -# (:cusolverDnDgesvd_bufferSize, :cusolverDnDgesvd, :Float64, :Float64), -# (:cusolverDnCgesvd_bufferSize, :cusolverDnCgesvd, :ComplexF32, :Float32), -# (:cusolverDnZgesvd_bufferSize, :cusolverDnZgesvd, :ComplexF64, :Float64)) -# @eval begin -# function gesvd!(A::StridedCuMatrix{$elty}, -# S::StridedCuVector{$relty}=similar(A, $relty, min(size(A)...)), -# U::StridedCuMatrix{$elty}=similar(A, $elty, size(A, 1), -# min(size(A)...)), -# Vᴴ::StridedCuMatrix{$elty}=similar(A, $elty, min(size(A)...), -# size(A, 2))) -# chkstride1(A, U, Vᴴ, S) -# m, n = size(A) -# (m < n) && throw(ArgumentError("CUSOLVER's gesvd requires m ≥ n")) -# minmn = min(m, n) -# lda = max(1, stride(A, 2)) +# Wrapper for SVD via QR Iteration +for (bname, fname, elty, relty) in + ((:cusolverDnSgesvd_bufferSize, :cusolverDnSgesvd, :Float32, :Float32), + (:cusolverDnDgesvd_bufferSize, :cusolverDnDgesvd, :Float64, :Float64), + (:cusolverDnCgesvd_bufferSize, :cusolverDnCgesvd, :ComplexF32, :Float32), + (:cusolverDnZgesvd_bufferSize, :cusolverDnZgesvd, :ComplexF64, :Float64)) + @eval begin + #! format: off + function gesvd!(A::StridedCuMatrix{$elty}, + S::StridedCuVector{$relty}=similar(A, $relty, min(size(A)...)), + U::StridedCuMatrix{$elty}=similar(A, $elty, size(A, 1), min(size(A)...)), + Vᴴ::StridedCuMatrix{$elty}=similar(A, $elty, min(size(A)...), size(A, 2))) + #! format: on + chkstride1(A, U, Vᴴ, S) + m, n = size(A) + (m < n) && throw(ArgumentError("CUSOLVER's gesvd requires m ≥ n")) + minmn = min(m, n) + if length(U) == 0 + jobu = 'N' + else + size(U, 1) == m || + throw(DimensionMismatch("row size mismatch between A and U")) + if size(U, 2) == minmn + if U === A + jobu = 'O' + else + jobu = 'S' + end + elseif size(U, 2) == m + jobu = 'A' + else + throw(DimensionMismatch("invalid column size of U")) + end + end + if length(Vᴴ) == 0 + jobvt = 'N' + else + size(Vᴴ, 2) == n || + throw(DimensionMismatch("column size mismatch between A and Vᴴ")) + if size(Vᴴ, 1) == minmn + if Vᴴ === A + jobvt = 'O' + else + jobvt = 'S' + end + elseif size(Vᴴ, 1) == n + jobvt = 'A' + else + throw(DimensionMismatch("invalid row size of Vᴴ")) + end + end + length(S) == minmn || + throw(DimensionMismatch("length mismatch between A and S")) + + lda = max(1, stride(A, 2)) + ldu = max(1, stride(U, 2)) + ldv = max(1, stride(Vᴴ, 2)) + + dh = CUSOLVER.dense_handle() + function bufferSize() + out = Ref{Cint}(0) + CUSOLVER.$bname(dh, m, n, out) + return out[] * sizeof($elty) + end + rwork = CuArray{$relty}(undef, min(m, n) - 1) + CUDA.with_workspace(dh.workspace_gpu, bufferSize) do buffer + return CUSOLVER.$fname(dh, jobu, jobvt, m, n, + A, lda, S, U, ldu, Vᴴ, ldv, + buffer, sizeof(buffer) ÷ sizeof($elty), rwork, + dh.info) + end + CUDA.unsafe_free!(rwork) + + info = @allowscalar dh.info[1] + CUSOLVER.chkargsok(BlasInt(info)) + + return (S, U, Vᴴ) + end + end +end -# if length(U) == 0 -# jobu = 'N' -# else -# size(U, 1) == m || -# throw(DimensionMismatch("row size mismatch between A and U")) -# if size(U, 2) == minmn -# if U === A -# jobu = 'O' -# else -# jobu = 'S' -# end -# elseif size(U, 2) == m -# jobu = 'A' -# else -# throw(DimensionMismatch("invalid column size of U")) -# end -# end -# if length(Vᴴ) == 0 -# jobvt = 'N' -# else -# size(Vᴴ, 2) == n || -# throw(DimensionMismatch("column size mismatch between A and Vᴴ")) -# if size(Vᴴ, 1) == minmn -# if Vᴴ === A -# jobvt = 'O' -# else -# jobvt = 'S' -# end -# elseif size(Vᴴ, 1) == n -# jobvt = 'A' -# else -# throw(DimensionMismatch("invalid row size of Vᴴ")) -# end -# end -# length(S) == minmn || -# throw(DimensionMismatch("length mismatch between A and S")) +function Xgesvdp!(A::StridedCuMatrix{T}, + S::StridedCuVector=similar(A, real(T), min(size(A)...)), + U::StridedCuMatrix{T}=similar(A, T, size(A, 1), min(size(A)...)), + Vᴴ::StridedCuMatrix{T}=similar(A, T, min(size(A)...), size(A, 2)); + tol=norm(A) * eps(real(T))) where {T<:BlasFloat} + chkstride1(A, U, S, Vᴴ) + m, n = size(A) + minmn = min(m, n) + if length(U) == length(Vᴴ) == 0 + jobz = 'N' + econ = 1 + else + jobz = 'V' + size(U, 1) == m || + throw(DimensionMismatch("row size mismatch between A and U")) + size(Vᴴ, 2) == n || + throw(DimensionMismatch("column size mismatch between A and Vᴴ")) + if size(U, 2) == size(Vᴴ, 1) == minmn + econ = 1 + elseif size(U, 2) == m && size(Vᴴ, 1) == n + econ = 0 + else + throw(DimensionMismatch("invalid column size of U or row size of Vᴴ")) + end + end + R = eltype(S) + length(S) == minmn || + throw(DimensionMismatch("length mismatch between A and S")) + R == real(T) || + throw(ArgumentError("S does not have the matching real `eltype` of A")) + + Ṽ = similar(Vᴴ, (n, n)) + Ũ = (size(U) == (m, m)) ? U : similar(U, (m, m)) + lda = max(1, stride(A, 2)) + ldu = max(1, stride(Ũ, 2)) + ldv = max(1, stride(Ṽ, 2)) + h_err_sigma = Ref{Cdouble}(0) + params = CUSOLVER.CuSolverParameters() + dh = CUSOLVER.dense_handle() -# lda = max(1, stride(A, 2)) -# ldu = max(1, stride(U, 2)) -# ldv = max(1, stride(Vᴴ, 2)) + function bufferSize() + out_cpu = Ref{Csize_t}(0) + out_gpu = Ref{Csize_t}(0) + CUSOLVER.cusolverDnXgesvdp_bufferSize(dh, params, jobz, econ, m, n, + T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv, + T, out_gpu, out_cpu) -# dh = dense_handle() -# function bufferSize() -# out = Ref{Cint}(0) -# $bname(dh, m, n, out) -# return out[] * sizeof($elty) -# end -# rwork = CuArray{$relty}(undef, min(m, n) - 1) -# with_workspace(dh.workspace_gpu, bufferSize) do buffer -# return $fname(dh, jobu, jobvt, m, n, A, lda, S, U, ldu, Vᴴ, ldv, -# buffer, sizeof(buffer) ÷ sizeof($elty), rwork, dh.info) -# end -# unsafe_free!(rwork) + return out_gpu[], out_cpu[] + end + CUSOLVER.with_workspaces(dh.workspace_gpu, dh.workspace_cpu, + bufferSize()...) do buffer_gpu, buffer_cpu + return CUSOLVER.cusolverDnXgesvdp(dh, params, jobz, econ, m, n, + T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv, + T, buffer_gpu, sizeof(buffer_gpu), + buffer_cpu, sizeof(buffer_cpu), + dh.info, h_err_sigma) + end + err = h_err_sigma[] + if err > tol + warn("Xgesvdp! did not attained requested tolerance: error = $err > tolerance = $tol") + end -# info = @allowscalar dh.info[1] -# chkargsok(BlasInt(info)) + flag = @allowscalar dh.info[1] + CUSOLVER.chklapackerror(BlasInt(flag)) + if Ũ !== U && length(U) > 0 + U .= view(Ũ, 1:m, 1:size(U, 2)) + end + if length(Vᴴ) > 0 + Vᴴ .= view(Ṽ', 1:size(Vᴴ, 1), 1:n) + end + Ũ !== U && CUDA.unsafe_free!(Ũ) + CUDA.unsafe_free!(Ṽ) -# return (S, U, Vᴴ) -# end -# end -# end + return S, U, Vᴴ +end -# # Wrapper for SVD via Jacobi +# Wrapper for SVD via Jacobi # for (bname, fname, elty, relty) in # ((:cusolverDnSgesvdj_bufferSize, :cusolverDnSgesvdj, :Float32, :Float32), # (:cusolverDnDgesvdj_bufferSize, :cusolverDnDgesvdj, :Float64, :Float64), # (:cusolverDnCgesvdj_bufferSize, :cusolverDnCgesvdj, :ComplexF32, :Float32), # (:cusolverDnZgesvdj_bufferSize, :cusolverDnZgesvdj, :ComplexF64, :Float64)) # @eval begin +# #! format: off # function gesvdj!(A::StridedCuMatrix{$elty}, # S::StridedCuVector{$relty}=similar(A, $relty, min(size(A)...)), -# U::StridedCuMatrix{$elty}=similar(A, $elty, size(A, 1), -# min(size(A)...)), -# Vᴴ::StridedCuMatrix{$elty}=similar(A, $elty, min(size(A)...), -# size(A, 2)); +# U::StridedCuMatrix{$elty}=similar(A, $elty, size(A, 1), min(size(A)...)), +# Vᴴ::StridedCuMatrix{$elty}=similar(A, $elty, min(size(A)...), size(A, 2)); # tol::$relty=eps($relty), # max_sweeps::Int=100) +# #! format: on # chkstride1(A, U, Vᴴ, S) # m, n = size(A) # minmn = min(m, n) -# lda = max(1, stride(A, 2)) # if length(U) == 0 && length(Vᴴ) == 0 # jobz = 'N' @@ -519,68 +594,4 @@ const ungqr! = orgqr! # end # end -# Wrapper for Hermitian Eigenvalue Problem -function heevd!(jobz::Char, uplo::Char, A::StridedCuMatrix{T}, - W::StridedCuVector{T}) where {T<:BlasFloat} - chkuplo(uplo) - n = checksquare(A) - lda = max(1, stride(A, 2)) - dh = dense_handle() - - function bufferSize() - out = Ref{Cint}(0) - cusolverDnSsyevd_bufferSize(dh, jobz, uplo, n, A, lda, W, out) - return out[] * sizeof(T) - end - - with_workspace(dh.workspace_gpu, bufferSize) do buffer - return cusolverDnSsyevd(dh, jobz, uplo, n, A, lda, W, buffer, - sizeof(buffer) ÷ sizeof(T), dh.info) - end - - info = @allowscalar dh.info[1] - chkargsok(BlasInt(info)) - return W, A -end - -# Wrapper for Non-Hermitian Eigenvalue Problem -function geevd!(jobvl::Char, jobvr::Char, A::StridedCuMatrix{T}, W::StridedCuVector{T}, - VL::StridedCuMatrix{T}, VR::StridedCuMatrix{T}) where {T<:BlasFloat} - n = checksquare(A) - lda = max(1, stride(A, 2)) - ldvl = max(1, stride(VL, 2)) - ldvr = max(1, stride(VR, 2)) - dh = dense_handle() - - function bufferSize() - out = Ref{Cint}(0) - cusolverDnSgeev_bufferSize(dh, jobvl, jobvr, n, A, lda, W, VL, ldvl, VR, ldvr, out) - return out[] * sizeof(T) - end - - with_workspace(dh.workspace_gpu, bufferSize) do buffer - return cusolverDnSgeev(dh, jobvl, jobvr, n, A, lda, W, VL, ldvl, VR, ldvr, buffer, - sizeof(buffer) ÷ sizeof(T), dh.info) - end - - info = @allowscalar dh.info[1] - chkargsok(BlasInt(info)) - return W, VL, VR -end - -# Wrapper for Randomized SVD (example implementation) -function randomized_svd!(A::StridedCuMatrix{T}, S::StridedCuVector{T}, - U::StridedCuMatrix{T}, V::StridedCuMatrix{T}, - rank::Int) where {T<:BlasFloat} - # Example implementation for randomized SVD - # Generate random projection matrix - Omega = CuArray{T}(randn(size(A, 2), rank)) - Y = A * Omega - Q, _ = qr(Y) - B = Q' * A - gesvdqr!(B, S, U, V) - U = Q * U - return S, U, V -end - end \ No newline at end of file diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index e68473868..5f872bb55 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -29,7 +29,9 @@ export left_orth!, right_orth!, left_null!, right_null! export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations, - LAPACK_DivideAndConquer, LAPACK_Jacobi + LAPACK_DivideAndConquer, LAPACK_Jacobi, + LQViaTransposedQR, + CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered VERSION >= v"1.11.0-DEV.469" && @@ -46,6 +48,7 @@ include("common/matrixproperties.jl") include("yalapack.jl") include("algorithms.jl") +include("interface/decompositions.jl") include("interface/qr.jl") include("interface/lq.jl") include("interface/svd.jl") @@ -55,7 +58,6 @@ include("interface/schur.jl") include("interface/polar.jl") include("interface/orthnull.jl") -include("implementations/decompositions.jl") include("implementations/truncation.jl") include("implementations/qr.jl") include("implementations/lq.jl") diff --git a/src/implementations/decompositions.jl b/src/interface/decompositions.jl similarity index 74% rename from src/implementations/decompositions.jl rename to src/interface/decompositions.jl index d886588ad..bff490a9a 100644 --- a/src/implementations/decompositions.jl +++ b/src/interface/decompositions.jl @@ -1,8 +1,8 @@ # TODO: module Decompositions? -# ========== -# ALGORITHMS -# ========== +# ================= +# LAPACK ALGORITHMS +# ================= # reference for naming LAPACK algorithms: # https://www.netlib.org/lapack/explore-html/topics.html @@ -112,3 +112,41 @@ const LAPACK_SVDAlgorithm = Union{LAPACK_QRIteration, LAPACK_Bisection, LAPACK_DivideAndConquer, LAPACK_Jacobi} + +# ========================= +# CUSOLVER ALGORITHMS +# ========================= +""" + CUSOLVER_HouseholderQR(; positive = false) + +Algorithm type to denote the standard CUSOLVER algorithm for computing the QR decomposition of +a matrix using Householder reflectors. The keyword `positive=true` can be used to ensure that +the diagonal elements of `R` are non-negative. +""" +@algdef CUSOLVER_HouseholderQR + +""" + CUSOLVER_QRIteration() + +Algorithm type to denote the CUSOLVER driver for computing the eigenvalue decomposition of a +Hermitian matrix, or the singular value decomposition of a general matrix using the +QR Iteration algorithm. +""" +@algdef CUSOLVER_QRIteration + +""" + CUSOLVER_SVDPolar() + +Algorithm type to denote the CUSOLVER driver for computing the singular value decomposition of +a general matrix by using Halley's iterative algorithm to compute the polar decompositon, +followed by the hermitian eigenvalue decomposition of the positive definite factor. +""" +@algdef CUSOLVER_SVDPolar + +""" + CUSOLVER_Jacobi() + +Algorithm type to denote the CUSOLVER driver for computing the singular value decomposition of +a general matrix using the Jacobi algorithm. +""" +@algdef CUSOLVER_Jacobi diff --git a/test/cuda/lq.jl b/test/cuda/lq.jl index 02e8b5066..cc1def1ce 100644 --- a/test/cuda/lq.jl +++ b/test/cuda/lq.jl @@ -5,9 +5,7 @@ using TestExtras using StableRNGs using CUDA -function isapproxone(A) - return (size(A, 1) == size(A, 2)) && (A ≈ MatrixAlgebraKit.one!(similar(A))) -end +include("utilities.jl") @testset "lq_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) diff --git a/test/cuda/qr.jl b/test/cuda/qr.jl index 374d07064..4ada45652 100644 --- a/test/cuda/qr.jl +++ b/test/cuda/qr.jl @@ -5,9 +5,7 @@ using TestExtras using StableRNGs using CUDA -function isapproxone(A) - return (size(A, 1) == size(A, 2)) && (A ≈ MatrixAlgebraKit.one!(similar(A))) -end +include("utilities.jl") @testset "qr_compact! and qr_null! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) diff --git a/test/cuda/svd.jl b/test/cuda/svd.jl new file mode 100644 index 000000000..1c11cd5ae --- /dev/null +++ b/test/cuda/svd.jl @@ -0,0 +1,120 @@ +using MatrixAlgebraKit +using MatrixAlgebraKit: diagview +using LinearAlgebra: Diagonal, isposdef +using Test +using TestExtras +using StableRNGs +using CUDA + +include("utilities.jl") + +@testset "svd_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + rng = StableRNG(123) + m = 54 + @testset "size ($m, $n)" for n in (37, m, 63) + k = min(m, n) + algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar()) + @testset "algorithm $alg" for alg in algs + n > m && alg isa CUSOLVER_QRIteration && continue # not supported + minmn = min(m, n) + A = CuArray(randn(rng, T, m, n)) + + U, S, Vᴴ = svd_compact(A; alg) + @test U isa CuMatrix{T} && size(U) == (m, minmn) + @test S isa Diagonal{real(T),<:CuVector} && size(S) == (minmn, minmn) + @test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (minmn, n) + @test U * S * Vᴴ ≈ A + @test isapproxone(U' * U) + @test isapproxone(Vᴴ * Vᴴ') + @test isposdef(S) + + Ac = similar(A) + U2, S2, V2ᴴ = @constinferred svd_compact!(copy!(Ac, A), (U, S, Vᴴ), alg) + @test U2 === U + @test S2 === S + @test V2ᴴ === Vᴴ + @test U * S * Vᴴ ≈ A + @test isapproxone(U' * U) + @test isapproxone(Vᴴ * Vᴴ') + @test isposdef(S) + + Sd = svd_vals(A, alg) + @test CuArray(diagview(S)) ≈ Sd + # CuArray is necessary because norm of CuArray view with non-unit step is broken + end + end +end + +@testset "svd_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + rng = StableRNG(123) + m = 54 + @testset "size ($m, $n)" for n in (37, m, 63) + algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar()) + @testset "algorithm $alg" for alg in algs + n > m && alg isa CUSOLVER_QRIteration && continue # not supported + A = CuArray(randn(rng, T, m, n)) + U, S, Vᴴ = svd_full(A; alg) + @test U isa CuMatrix{T} && size(U) == (m, m) + @test S isa CuMatrix{real(T)} && size(S) == (m, n) + @test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (n, n) + @test U * S * Vᴴ ≈ A + @test isapproxone(U' * U) + @test isapproxone(U * U') + @test isapproxone(Vᴴ * Vᴴ') + @test isapproxone(Vᴴ' * Vᴴ) + @test all(isposdef, diagview(S)) + + Ac = similar(A) + U2, S2, V2ᴴ = @constinferred svd_full!(copy!(Ac, A), (U, S, Vᴴ), alg) + @test U2 === U + @test S2 === S + @test V2ᴴ === Vᴴ + @test U * S * Vᴴ ≈ A + @test isapproxone(U' * U) + @test isapproxone(U * U') + @test isapproxone(Vᴴ * Vᴴ') + @test isapproxone(Vᴴ' * Vᴴ) + @test all(isposdef, diagview(S)) + + Sc = similar(A, real(T), min(m, n)) + Sc2 = svd_vals!(copy!(Ac, A), Sc, alg) + @test Sc === Sc2 + @test CuArray(diagview(S)) ≈ Sc + # CuArray is necessary because norm of CuArray view with non-unit step is broken + end + end +end + +# @testset "svd_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) +# rng = StableRNG(123) +# m = 54 +# if LinearAlgebra.LAPACK.version() < v"3.12.0" +# algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection()) +# else +# algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), +# LAPACK_Jacobi()) +# end + +# @testset "size ($m, $n)" for n in (37, m, 63) +# @testset "algorithm $alg" for alg in algs +# n > m && alg isa LAPACK_Jacobi && continue # not supported +# A = randn(rng, T, m, n) +# S₀ = svd_vals(A) +# minmn = min(m, n) +# r = minmn - 2 + +# U1, S1, V1ᴴ = @constinferred svd_trunc(A; alg, trunc=truncrank(r)) +# @test length(S1.diag) == r +# @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] + +# s = 1 + sqrt(eps(real(T))) +# trunc2 = trunctol(s * S₀[r + 1]) + +# U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc=trunctol(s * S₀[r + 1])) +# @test length(S2.diag) == r +# @test U1 ≈ U2 +# @test S1 ≈ S2 +# @test V1ᴴ ≈ V2ᴴ +# end +# end +# end diff --git a/test/cuda/utilities.jl b/test/cuda/utilities.jl new file mode 100644 index 000000000..61518b556 --- /dev/null +++ b/test/cuda/utilities.jl @@ -0,0 +1,3 @@ +function isapproxone(A) + return (size(A, 1) == size(A, 2)) && (A ≈ MatrixAlgebraKit.one!(similar(A))) +end diff --git a/test/runtests.jl b/test/runtests.jl index 5f0d29913..15cc10132 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -32,6 +32,19 @@ end include("chainrules.jl") end +using CUDA +if CUDA.functional() + @safetestset "CUDA QR" begin + include("cuda/qr.jl") + end + @safetestset "CUDA LQ" begin + include("cuda/lq.jl") + end + @safetestset "CUDA SVD" begin + include("cuda/svd.jl") + end +end + @safetestset "MatrixAlgebraKit.jl" begin @safetestset "Code quality (Aqua.jl)" begin using MatrixAlgebraKit From 5df81c1c6f03e6953c0a36a4178a8a0e53dce07c Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Mon, 12 May 2025 00:53:37 +0200 Subject: [PATCH 07/16] Jacobi SVD algorithm --- .../implementations/svd.jl | 10 +- ext/MatrixAlgebraKitCUDAExt/yacusolver.jl | 141 +++++++++--------- test/cuda/svd.jl | 4 +- 3 files changed, 77 insertions(+), 78 deletions(-) diff --git a/ext/MatrixAlgebraKitCUDAExt/implementations/svd.jl b/ext/MatrixAlgebraKitCUDAExt/implementations/svd.jl index 46376a345..f99760314 100644 --- a/ext/MatrixAlgebraKitCUDAExt/implementations/svd.jl +++ b/ext/MatrixAlgebraKitCUDAExt/implementations/svd.jl @@ -15,6 +15,8 @@ function MatrixAlgebraKit.svd_full!(A::CuMatrix, USVᴴ, alg::CUSOLVER_SVDAlgori YACUSOLVER.gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ) elseif alg isa CUSOLVER_SVDPolar YACUSOLVER.Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...) + elseif alg isa CUSOLVER_Jacobi + YACUSOLVER.gesvdj!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...) # elseif alg isa LAPACK_Bisection # throw(ArgumentError("LAPACK_Bisection is not supported for full SVD")) # elseif alg isa LAPACK_Jacobi @@ -54,16 +56,14 @@ function MatrixAlgebraKit.svd_compact!(A::CuMatrix, USVᴴ, alg::CUSOLVER_SVDAlg YACUSOLVER.gesvd!(A, S.diag, U, Vᴴ) elseif alg isa CUSOLVER_SVDPolar YACUSOLVER.Xgesvdp!(A, S.diag, U, Vᴴ; alg.kwargs...) + elseif alg isa CUSOLVER_Jacobi + YACUSOLVER.gesvdj!(A, S.diag, U, Vᴴ; alg.kwargs...) # elseif alg isa LAPACK_DivideAndConquer # isempty(alg.kwargs) || # throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments")) # YALAPACK.gesdd!(A, S.diag, U, Vᴴ) # elseif alg isa LAPACK_Bisection # YALAPACK.gesvdx!(A, S.diag, U, Vᴴ; alg.kwargs...) - # elseif alg isa LAPACK_Jacobi - # isempty(alg.kwargs) || - # throw(ArgumentError("LAPACK_Jacobi does not accept any keyword arguments")) - # YALAPACK.gesvj!(A, S.diag, U, Vᴴ) else throw(ArgumentError("Unsupported SVD algorithm")) end @@ -89,6 +89,8 @@ function MatrixAlgebraKit.svd_vals!(A::CuMatrix, S, alg::CUSOLVER_SVDAlgorithm) YACUSOLVER.gesvd!(A, S, U, Vᴴ) elseif alg isa CUSOLVER_SVDPolar YACUSOLVER.Xgesvdp!(A, S, U, Vᴴ; alg.kwargs...) + elseif alg isa CUSOLVER_Jacobi + YACUSOLVER.gesvdj!(A, S, U, Vᴴ; alg.kwargs...) # elseif alg isa LAPACK_DivideAndConquer # isempty(alg.kwargs) || # throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments")) diff --git a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl index 117f3aeed..b020d7565 100644 --- a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl +++ b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl @@ -171,84 +171,81 @@ function Xgesvdp!(A::StridedCuMatrix{T}, end # Wrapper for SVD via Jacobi -# for (bname, fname, elty, relty) in -# ((:cusolverDnSgesvdj_bufferSize, :cusolverDnSgesvdj, :Float32, :Float32), -# (:cusolverDnDgesvdj_bufferSize, :cusolverDnDgesvdj, :Float64, :Float64), -# (:cusolverDnCgesvdj_bufferSize, :cusolverDnCgesvdj, :ComplexF32, :Float32), -# (:cusolverDnZgesvdj_bufferSize, :cusolverDnZgesvdj, :ComplexF64, :Float64)) -# @eval begin -# #! format: off -# function gesvdj!(A::StridedCuMatrix{$elty}, -# S::StridedCuVector{$relty}=similar(A, $relty, min(size(A)...)), -# U::StridedCuMatrix{$elty}=similar(A, $elty, size(A, 1), min(size(A)...)), -# Vᴴ::StridedCuMatrix{$elty}=similar(A, $elty, min(size(A)...), size(A, 2)); -# tol::$relty=eps($relty), -# max_sweeps::Int=100) -# #! format: on -# chkstride1(A, U, Vᴴ, S) -# m, n = size(A) -# minmn = min(m, n) - -# if length(U) == 0 && length(Vᴴ) == 0 -# jobz = 'N' -# econ = 0 -# else -# jobz = 'V' -# size(U, 1) == m || -# throw(DimensionMismatch("row size mismatch between A and U")) -# size(Vᴴ, 2) == n || -# throw(DimensionMismatch("column size mismatch between A and Vᴴ")) -# if size(U, 2) == size(Vᴴ, 1) == minmn -# econ = 1 -# elseif size(U, 2) == m && size(Vᴴ, 1) == n -# econ = 0 -# else -# throw(DimensionMismatch("invalid column size of U or row size of Vᴴ")) -# end -# end -# length(S) == minmn || -# throw(DimensionMismatch("length mismatch between A and S")) - -# if jobz == 'N' # it seems we still need the memory for U and Vᴴ -# U = similar(A, $elty, m, minmn) -# V = similar(A, $elty, n, minmn) -# else -# V = similar(Vᴴ') -# end -# lda = max(1, stride(A, 2)) -# ldu = max(1, stride(U, 2)) -# ldv = max(1, stride(V, 2)) +for (bname, fname, elty, relty) in + ((:cusolverDnSgesvdj_bufferSize, :cusolverDnSgesvdj, :Float32, :Float32), + (:cusolverDnDgesvdj_bufferSize, :cusolverDnDgesvdj, :Float64, :Float64), + (:cusolverDnCgesvdj_bufferSize, :cusolverDnCgesvdj, :ComplexF32, :Float32), + (:cusolverDnZgesvdj_bufferSize, :cusolverDnZgesvdj, :ComplexF64, :Float64)) + @eval begin + #! format: off + function gesvdj!(A::StridedCuMatrix{$elty}, + S::StridedCuVector{$relty}=similar(A, $relty, min(size(A)...)), + U::StridedCuMatrix{$elty}=similar(A, $elty, size(A, 1), min(size(A)...)), + Vᴴ::StridedCuMatrix{$elty}=similar(A, $elty, min(size(A)...), size(A, 2)); + tol::$relty=eps($relty), + max_sweeps::Int=100) + #! format: on + chkstride1(A, U, Vᴴ, S) + m, n = size(A) + minmn = min(m, n) -# params = Ref{gesvdjInfo_t}(C_NULL) -# cusolverDnCreateGesvdjInfo(params) -# cusolverDnXgesvdjSetTolerance(params[], tol) -# cusolverDnXgesvdjSetMaxSweeps(params[], max_sweeps) -# dh = dense_handle() + if length(U) == 0 && length(Vᴴ) == 0 + jobz = 'N' + econ = 0 + else + jobz = 'V' + size(U, 1) == m || + throw(DimensionMismatch("row size mismatch between A and U")) + size(Vᴴ, 2) == n || + throw(DimensionMismatch("column size mismatch between A and Vᴴ")) + if size(U, 2) == size(Vᴴ, 1) == minmn + econ = 1 + elseif size(U, 2) == m && size(Vᴴ, 1) == n + econ = 0 + else + throw(DimensionMismatch("invalid column size of U or row size of Vᴴ")) + end + end + length(S) == minmn || + throw(DimensionMismatch("length mismatch between A and S")) -# function bufferSize() -# out = Ref{Cint}(0) -# $bname(dh, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, -# out, params[]) -# return out[] * sizeof($elty) -# end + Ṽ = (jobz == 'V') ? similar(Vᴴ') : similar(Vᴴ, (n, minmn)) + Ũ = (jobz == 'V') ? U : similar(U, (m, minmn)) + lda = max(1, stride(A, 2)) + ldu = max(1, stride(Ũ, 2)) + ldv = max(1, stride(Ṽ, 2)) -# with_workspace(dh.workspace_gpu, bufferSize) do buffer -# return $fname(dh, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, -# buffer, sizeof(buffer) ÷ sizeof($elty), dh.info, params[]) -# end + params = Ref{CUSOLVER.gesvdjInfo_t}(C_NULL) + CUSOLVER.cusolverDnCreateGesvdjInfo(params) + CUSOLVER.cusolverDnXgesvdjSetTolerance(params[], tol) + CUSOLVER.cusolverDnXgesvdjSetMaxSweeps(params[], max_sweeps) + dh = CUSOLVER.dense_handle() -# info = @allowscalar dh.info[1] -# chkargsok(BlasInt(info)) + function bufferSize() + out = Ref{Cint}(0) + CUSOLVER.$bname(dh, jobz, econ, m, n, A, lda, S, Ũ, ldu, Ṽ, ldv, + out, params[]) + return out[] * sizeof($elty) + end -# cusolverDnDestroyGesvdjInfo(params[]) + CUSOLVER.with_workspace(dh.workspace_gpu, bufferSize) do buffer + return CUSOLVER.$fname(dh, jobz, econ, m, n, A, lda, S, Ũ, ldu, Ṽ, ldv, + buffer, sizeof(buffer) ÷ sizeof($elty), dh.info, + params[]) + end -# if jobz != 'N' -# adjoint!(Vᴴ, V) -# end -# return U, S, Vᴴ -# end -# end -# end + info = @allowscalar dh.info[1] + CUSOLVER.chkargsok(BlasInt(info)) + + CUSOLVER.cusolverDnDestroyGesvdjInfo(params[]) + + if jobz == 'V' + adjoint!(Vᴴ, Ṽ) + end + return U, S, Vᴴ + end + end +end # for (jname, bname, fname, elty, relty) in # ((:sygvd!, :cusolverDnSsygvd_bufferSize, :cusolverDnSsygvd, :Float32, :Float32), diff --git a/test/cuda/svd.jl b/test/cuda/svd.jl index 1c11cd5ae..66098c4ba 100644 --- a/test/cuda/svd.jl +++ b/test/cuda/svd.jl @@ -13,7 +13,7 @@ include("utilities.jl") m = 54 @testset "size ($m, $n)" for n in (37, m, 63) k = min(m, n) - algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar()) + algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi()) @testset "algorithm $alg" for alg in algs n > m && alg isa CUSOLVER_QRIteration && continue # not supported minmn = min(m, n) @@ -49,7 +49,7 @@ end rng = StableRNG(123) m = 54 @testset "size ($m, $n)" for n in (37, m, 63) - algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar()) + algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi()) @testset "algorithm $alg" for alg in algs n > m && alg isa CUSOLVER_QRIteration && continue # not supported A = CuArray(randn(rng, T, m, n)) From 76c232960f8e45d1d7f4c8765fc40091950ec7db Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sat, 9 Aug 2025 09:28:51 -0400 Subject: [PATCH 08/16] Fix default_algorithm for CUDA matrices --- ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 29a222319..01a721f72 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -5,6 +5,7 @@ using MatrixAlgebraKit: @algdef, Algorithm, check_input using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! using MatrixAlgebraKit: diagview, sign_safe using MatrixAlgebraKit: LQViaTransposedQR +using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm using CUDA using LinearAlgebra using LinearAlgebra: BlasFloat @@ -13,15 +14,15 @@ include("yacusolver.jl") include("implementations/qr.jl") include("implementations/svd.jl") -function MatrixAlgebraKit.default_qr_algorithm(A::CuMatrix{<:BlasFloat}; kwargs...) +function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix} return CUSOLVER_HouseholderQR(; kwargs...) end -function MatrixAlgebraKit.default_lq_algorithm(A::CuMatrix{<:BlasFloat}; kwargs...) +function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix} qr_alg = CUSOLVER_HouseholderQR(; kwargs...) return LQViaTransposedQR(qr_alg) end -function MatrixAlgebraKit.default_svd_algorithm(A::CuMatrix{<:BlasFloat}; kwargs...) +function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix} return CUSOLVER_QRIteration(; kwargs...) end -end \ No newline at end of file +end From 31adc075dc2e78328d6687fb3990271cc415a4e3 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 11 Aug 2025 06:47:45 -0400 Subject: [PATCH 09/16] Try running GPU tests with buildkite --- .buildkite/pipeline.yml | 18 +++++++++ test/runtests.jl | 89 ++++++++++++++++++++++------------------- 2 files changed, 65 insertions(+), 42 deletions(-) create mode 100644 .buildkite/pipeline.yml diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml new file mode 100644 index 000000000..27419d389 --- /dev/null +++ b/.buildkite/pipeline.yml @@ -0,0 +1,18 @@ +env: + SECRET_CODECOV_SECRET: "MH6hHjQi7vG2V1Yfotv5/z5Dkx1k5SdyGYlGTFXiQr22XksJgsXaBuvFKUrjC7JwcpBsOVU8103LuMKl3m7VJ35WzHZrOssYycVbdGcb2kloc6xvUOsN2R5BrhCQ4Pii0l6ZeVRjCnZVkcmb0Rf4glGFyfibCrqniry8RLhblsuFKFsijRK4OxiWYEs1IvUulN+ER8tEsEtw4+ZqC5nbLGMSnUG/saPkDQOVIBscvikbKEnBcCXBheGPktF+Y/cy/1Xa+FiBPoZcypwTeAjKG1g0MqyHXjaYekb/7fekaj+hukGaeJSCXxY8KEb2IZCh+Y36Tp6y6qsIp/AdtEnCpQ==;U2FsdGVkX18WQxvGLspPwzC4aDe+U7TXU+itebTbgh8LUkE6GukxxReHYiDZ6IrBiVvSGTVJMquW0c8KsOI1pw==" + +steps: + - label: "Julia v1" + plugins: + - JuliaCI/julia#v1: + version: "1" + - JuliaCI/julia-test#v1: ~ + - JuliaCI/julia-coverage#v1: + codecov: true + agents: + queue: "juliagpu" + cuda: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 30 + +# TODO add lts support diff --git a/test/runtests.jl b/test/runtests.jl index 15cc10132..0146a5bc0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,35 +1,52 @@ using SafeTestsets -@safetestset "Algorithms" begin - include("algorithms.jl") -end -@safetestset "Truncate" begin - include("truncate.jl") -end -@safetestset "QR / LQ Decomposition" begin - include("qr.jl") - include("lq.jl") -end -@safetestset "Singular Value Decomposition" begin - include("svd.jl") -end -@safetestset "Hermitian Eigenvalue Decomposition" begin - include("eigh.jl") -end -@safetestset "General Eigenvalue Decomposition" begin - include("eig.jl") -end -@safetestset "Schur Decomposition" begin - include("schur.jl") -end -@safetestset "Polar Decomposition" begin - include("polar.jl") -end -@safetestset "Image and Null Space" begin - include("orthnull.jl") -end -@safetestset "ChainRules" begin - include("chainrules.jl") +# don't run all tests on GPU, only the GPU +# specific ones +is_buildkite = get(ENV, "BUILDKITE", false) +if !isbuildkite + @safetestset "Algorithms" begin + include("algorithms.jl") + end + @safetestset "Truncate" begin + include("truncate.jl") + end + @safetestset "QR / LQ Decomposition" begin + include("qr.jl") + include("lq.jl") + end + @safetestset "Singular Value Decomposition" begin + include("svd.jl") + end + @safetestset "Hermitian Eigenvalue Decomposition" begin + include("eigh.jl") + end + @safetestset "General Eigenvalue Decomposition" begin + include("eig.jl") + end + @safetestset "Schur Decomposition" begin + include("schur.jl") + end + @safetestset "Polar Decomposition" begin + include("polar.jl") + end + @safetestset "Image and Null Space" begin + include("orthnull.jl") + end + @safetestset "ChainRules" begin + include("chainrules.jl") + end + @safetestset "MatrixAlgebraKit.jl" begin + @safetestset "Code quality (Aqua.jl)" begin + using MatrixAlgebraKit + using Aqua + Aqua.test_all(MatrixAlgebraKit) + end + @safetestset "Code linting (JET.jl)" begin + using MatrixAlgebraKit + using JET + JET.test_package(MatrixAlgebraKit; target_defined_modules=true) + end + end end using CUDA @@ -45,15 +62,3 @@ if CUDA.functional() end end -@safetestset "MatrixAlgebraKit.jl" begin - @safetestset "Code quality (Aqua.jl)" begin - using MatrixAlgebraKit - using Aqua - Aqua.test_all(MatrixAlgebraKit) - end - @safetestset "Code linting (JET.jl)" begin - using MatrixAlgebraKit - using JET - JET.test_package(MatrixAlgebraKit; target_defined_modules=true) - end -end From 9ca46a0dc9cd4863d0d1637557bb21f38d7111b0 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 11 Aug 2025 06:57:41 -0400 Subject: [PATCH 10/16] Dumb typo --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 0146a5bc0..2a26133d5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,7 +3,7 @@ using SafeTestsets # don't run all tests on GPU, only the GPU # specific ones is_buildkite = get(ENV, "BUILDKITE", false) -if !isbuildkite +if !is_buildkite @safetestset "Algorithms" begin include("algorithms.jl") end From d8ada043b31d44018f25152ee93ff38ef57e44d9 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 11 Aug 2025 13:07:47 +0200 Subject: [PATCH 11/16] Another silly ENV mistake --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 2a26133d5..13f33f097 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,7 @@ using SafeTestsets # don't run all tests on GPU, only the GPU # specific ones -is_buildkite = get(ENV, "BUILDKITE", false) +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" if !is_buildkite @safetestset "Algorithms" begin include("algorithms.jl") From 8327a5767239928e69033cd36185d57da7734976 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 11 Aug 2025 13:22:45 +0200 Subject: [PATCH 12/16] Cover extension directory also in coverage --- .buildkite/pipeline.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 27419d389..64a926bef 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -9,6 +9,9 @@ steps: - JuliaCI/julia-test#v1: ~ - JuliaCI/julia-coverage#v1: codecov: true + dirs: + - src + - ext agents: queue: "juliagpu" cuda: "*" From 2c2a2a6375be19c2924dc3a87fb5434d19b8b9d6 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 11 Aug 2025 13:41:29 +0200 Subject: [PATCH 13/16] Fix envvar name --- .buildkite/pipeline.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 64a926bef..f648b3abf 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,5 +1,5 @@ env: - SECRET_CODECOV_SECRET: "MH6hHjQi7vG2V1Yfotv5/z5Dkx1k5SdyGYlGTFXiQr22XksJgsXaBuvFKUrjC7JwcpBsOVU8103LuMKl3m7VJ35WzHZrOssYycVbdGcb2kloc6xvUOsN2R5BrhCQ4Pii0l6ZeVRjCnZVkcmb0Rf4glGFyfibCrqniry8RLhblsuFKFsijRK4OxiWYEs1IvUulN+ER8tEsEtw4+ZqC5nbLGMSnUG/saPkDQOVIBscvikbKEnBcCXBheGPktF+Y/cy/1Xa+FiBPoZcypwTeAjKG1g0MqyHXjaYekb/7fekaj+hukGaeJSCXxY8KEb2IZCh+Y36Tp6y6qsIp/AdtEnCpQ==;U2FsdGVkX18WQxvGLspPwzC4aDe+U7TXU+itebTbgh8LUkE6GukxxReHYiDZ6IrBiVvSGTVJMquW0c8KsOI1pw==" + SECRET_CODECOV_TOKEN: "MH6hHjQi7vG2V1Yfotv5/z5Dkx1k5SdyGYlGTFXiQr22XksJgsXaBuvFKUrjC7JwcpBsOVU8103LuMKl3m7VJ35WzHZrOssYycVbdGcb2kloc6xvUOsN2R5BrhCQ4Pii0l6ZeVRjCnZVkcmb0Rf4glGFyfibCrqniry8RLhblsuFKFsijRK4OxiWYEs1IvUulN+ER8tEsEtw4+ZqC5nbLGMSnUG/saPkDQOVIBscvikbKEnBcCXBheGPktF+Y/cy/1Xa+FiBPoZcypwTeAjKG1g0MqyHXjaYekb/7fekaj+hukGaeJSCXxY8KEb2IZCh+Y36Tp6y6qsIp/AdtEnCpQ==;U2FsdGVkX18WQxvGLspPwzC4aDe+U7TXU+itebTbgh8LUkE6GukxxReHYiDZ6IrBiVvSGTVJMquW0c8KsOI1pw==" steps: - label: "Julia v1" @@ -8,7 +8,6 @@ steps: version: "1" - JuliaCI/julia-test#v1: ~ - JuliaCI/julia-coverage#v1: - codecov: true dirs: - src - ext From da70492b0c1b4245916a95137e827f502a654a2b Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 11 Aug 2025 15:47:21 +0200 Subject: [PATCH 14/16] Soothe JET --- src/implementations/lq.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/implementations/lq.jl b/src/implementations/lq.jl index afdbb3f0e..6c81ca126 100644 --- a/src/implementations/lq.jl +++ b/src/implementations/lq.jl @@ -181,7 +181,7 @@ function lq_via_qr!(A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, qr_alg::AbstractAlgorithm) m, n = size(A) minmn = min(m, n) - At = adjoint!(similar(A'), A) + At = adjoint!(similar(A'), A)::AbstractMatrix Qt = (A === Q) ? At : similar(Q') Lt = similar(L') if size(Q) == (n, n) @@ -197,9 +197,9 @@ end function lq_null_via_qr!(A::AbstractMatrix, N::AbstractMatrix, qr_alg::AbstractAlgorithm) m, n = size(A) minmn = min(m, n) - At = adjoint!(similar(A'), A) + At = adjoint!(similar(A'), A)::AbstractMatrix Nt = similar(N') Nt = qr_null!(At, Nt, qr_alg) !isempty(N) && adjoint!(N, Nt) return N -end \ No newline at end of file +end From e732330a5dddba03f8d7c05142873b4ac4c8de6b Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 11 Aug 2025 16:15:59 +0200 Subject: [PATCH 15/16] Add GPU tests for LTS release --- .buildkite/pipeline.yml | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index f648b3abf..42eb56e02 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -17,4 +17,18 @@ steps: if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 30 -# TODO add lts support +steps: + - label: "Julia LTS" + plugins: + - JuliaCI/julia#v1: + version: "lts" + - JuliaCI/julia-test#v1: ~ + - JuliaCI/julia-coverage#v1: + dirs: + - src + - ext + agents: + queue: "juliagpu" + cuda: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 30 From 4fe347c900b72db80364d6fcf3ab7b77bcc86a4c Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 11 Aug 2025 16:29:25 +0200 Subject: [PATCH 16/16] Set LTS version manually --- .buildkite/pipeline.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 42eb56e02..970e7600a 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -21,7 +21,7 @@ steps: - label: "Julia LTS" plugins: - JuliaCI/julia#v1: - version: "lts" + version: "1.10" # "lts" isn't valid - JuliaCI/julia-test#v1: ~ - JuliaCI/julia-coverage#v1: dirs: