Skip to content

Commit 4d091cb

Browse files
JuthoKatharine Hyatt
andauthored
WIP CUDA support (#20)
* first cuda commit - qr support * first fixes * working qr * small triangular! fix * add LQViaTransposedQR, CUDA LQ and tests * first svd support * Jacobi SVD algorithm * Fix default_algorithm for CUDA matrices * Try running GPU tests with buildkite --------- Co-authored-by: Katharine Hyatt <katharine.s.hyatt@gmail.com>
1 parent e38e535 commit 4d091cb

24 files changed

Lines changed: 1474 additions & 71 deletions

File tree

.buildkite/pipeline.yml

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
env:
2+
SECRET_CODECOV_TOKEN: "MH6hHjQi7vG2V1Yfotv5/z5Dkx1k5SdyGYlGTFXiQr22XksJgsXaBuvFKUrjC7JwcpBsOVU8103LuMKl3m7VJ35WzHZrOssYycVbdGcb2kloc6xvUOsN2R5BrhCQ4Pii0l6ZeVRjCnZVkcmb0Rf4glGFyfibCrqniry8RLhblsuFKFsijRK4OxiWYEs1IvUulN+ER8tEsEtw4+ZqC5nbLGMSnUG/saPkDQOVIBscvikbKEnBcCXBheGPktF+Y/cy/1Xa+FiBPoZcypwTeAjKG1g0MqyHXjaYekb/7fekaj+hukGaeJSCXxY8KEb2IZCh+Y36Tp6y6qsIp/AdtEnCpQ==;U2FsdGVkX18WQxvGLspPwzC4aDe+U7TXU+itebTbgh8LUkE6GukxxReHYiDZ6IrBiVvSGTVJMquW0c8KsOI1pw=="
3+
4+
steps:
5+
- label: "Julia v1"
6+
plugins:
7+
- JuliaCI/julia#v1:
8+
version: "1"
9+
- JuliaCI/julia-test#v1: ~
10+
- JuliaCI/julia-coverage#v1:
11+
dirs:
12+
- src
13+
- ext
14+
agents:
15+
queue: "juliagpu"
16+
cuda: "*"
17+
if: build.message !~ /\[skip tests\]/
18+
timeout_in_minutes: 30
19+
20+
steps:
21+
- label: "Julia LTS"
22+
plugins:
23+
- JuliaCI/julia#v1:
24+
version: "1.10" # "lts" isn't valid
25+
- JuliaCI/julia-test#v1: ~
26+
- JuliaCI/julia-coverage#v1:
27+
dirs:
28+
- src
29+
- ext
30+
agents:
31+
queue: "juliagpu"
32+
cuda: "*"
33+
if: build.message !~ /\[skip tests\]/
34+
timeout_in_minutes: 30

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,17 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88

99
[weakdeps]
1010
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
11+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1112

1213
[extensions]
1314
MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore"
15+
MatrixAlgebraKitCUDAExt = "CUDA"
1416

1517
[compat]
1618
Aqua = "0.6, 0.7, 0.8"
1719
ChainRulesCore = "1"
1820
ChainRulesTestUtils = "1"
21+
CUDA = "5"
1922
JET = "0.9"
2023
LinearAlgebra = "1"
2124
SafeTestsets = "0.1"
@@ -36,5 +39,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
3639
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3740

3841
[targets]
39-
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore",
40-
"ChainRulesTestUtils", "StableRNGs", "Zygote"]
42+
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA"]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
module MatrixAlgebraKitCUDAExt
2+
3+
using MatrixAlgebraKit
4+
using MatrixAlgebraKit: @algdef, Algorithm, check_input
5+
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
6+
using MatrixAlgebraKit: diagview, sign_safe
7+
using MatrixAlgebraKit: LQViaTransposedQR
8+
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm
9+
using CUDA
10+
using LinearAlgebra
11+
using LinearAlgebra: BlasFloat
12+
13+
include("yacusolver.jl")
14+
include("implementations/qr.jl")
15+
include("implementations/svd.jl")
16+
17+
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
18+
return CUSOLVER_HouseholderQR(; kwargs...)
19+
end
20+
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
21+
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
22+
return LQViaTransposedQR(qr_alg)
23+
end
24+
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix}
25+
return CUSOLVER_QRIteration(; kwargs...)
26+
end
27+
28+
end
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# CUSOLVER QR implementation
2+
function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::CUSOLVER_HouseholderQR)
3+
check_input(qr_full!, A, QR)
4+
Q, R = QR
5+
_cusolver_qr!(A, Q, R; alg.kwargs...)
6+
return Q, R
7+
end
8+
function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, alg::CUSOLVER_HouseholderQR)
9+
check_input(qr_compact!, A, QR)
10+
Q, R = QR
11+
_cusolver_qr!(A, Q, R; alg.kwargs...)
12+
return Q, R
13+
end
14+
function MatrixAlgebraKit.qr_null!(A::AbstractMatrix, N, alg::CUSOLVER_HouseholderQR)
15+
check_input(qr_null!, A, N)
16+
_cusolver_qr_null!(A, N; alg.kwargs...)
17+
return N
18+
end
19+
20+
function _cusolver_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
21+
positive=false, blocksize=1)
22+
blocksize > 1 &&
23+
throw(ArgumentError("CUSOLVER does not provide a blocked implementation for a QR decomposition"))
24+
m, n = size(A)
25+
minmn = min(m, n)
26+
computeR = length(R) > 0
27+
inplaceQ = Q === A
28+
if inplaceQ && (computeR || positive || m < n)
29+
throw(ArgumentError("inplace Q only supported if matrix is tall (`m >= n`), R is not required and using `positive=false`"))
30+
end
31+
32+
A, τ = YACUSOLVER.geqrf!(A)
33+
if inplaceQ
34+
Q = YACUSOLVER.ungqr!(A, τ)
35+
else
36+
Q = YACUSOLVER.unmqr!('L', 'N', A, τ, one!(Q))
37+
end
38+
# henceforth, τ is no longer needed and can be reused
39+
40+
if positive # already fix Q even if we do not need R
41+
# TODO: report that `lmul!` and `rmul!` with `Diagonal` don't work with CUDA
42+
τ .= sign_safe.(diagview(A))
43+
Qf = view(Q, 1:m, 1:minmn) # first minmn columns of Q
44+
Qf .= Qf .* transpose(τ)
45+
end
46+
47+
if computeR
48+
= uppertriangular!(view(A, axes(R)...))
49+
if positive
50+
R̃f = view(R̃, 1:minmn, 1:n) # first minmn rows of R
51+
R̃f .= conj.(τ) .* R̃f
52+
end
53+
copyto!(R, R̃)
54+
end
55+
return Q, R
56+
end
57+
58+
function _cusolver_qr_null!(A::AbstractMatrix, N::AbstractMatrix;
59+
positive=false, blocksize=1)
60+
blocksize > 1 &&
61+
throw(ArgumentError("CUSOLVER does not provide a blocked implementation for a QR decomposition"))
62+
m, n = size(A)
63+
minmn = min(m, n)
64+
fill!(N, zero(eltype(N)))
65+
one!(view(N, (minmn + 1):m, 1:(m - minmn)))
66+
A, τ = YACUSOLVER.geqrf!(A)
67+
N = YACUSOLVER.unmqr!('L', 'N', A, τ, N)
68+
return N
69+
end
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
const CUSOLVER_SVDAlgorithm = Union{CUSOLVER_QRIteration,
2+
CUSOLVER_SVDPolar,
3+
CUSOLVER_Jacobi}
4+
5+
# CUSOLVER SVD implementation
6+
function MatrixAlgebraKit.svd_full!(A::CuMatrix, USVᴴ, alg::CUSOLVER_SVDAlgorithm)
7+
check_input(svd_full!, A, USVᴴ)
8+
U, S, Vᴴ = USVᴴ
9+
fill!(S, zero(eltype(S)))
10+
m, n = size(A)
11+
minmn = min(m, n)
12+
if alg isa CUSOLVER_QRIteration
13+
isempty(alg.kwargs) ||
14+
throw(ArgumentError("LAPACK_QRIteration does not accept any keyword arguments"))
15+
YACUSOLVER.gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ)
16+
elseif alg isa CUSOLVER_SVDPolar
17+
YACUSOLVER.Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...)
18+
elseif alg isa CUSOLVER_Jacobi
19+
YACUSOLVER.gesvdj!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...)
20+
# elseif alg isa LAPACK_Bisection
21+
# throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
22+
# elseif alg isa LAPACK_Jacobi
23+
# throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
24+
else
25+
throw(ArgumentError("Unsupported SVD algorithm"))
26+
end
27+
diagview(S) .= view(S, 1:minmn, 1)
28+
view(S, 2:minmn, 1) .= zero(eltype(S))
29+
# TODO: make this controllable using a `gaugefix` keyword argument
30+
for j in 1:max(m, n)
31+
if j <= minmn
32+
u = view(U, :, j)
33+
v = view(Vᴴ, j, :)
34+
s = conj(sign(_argmaxabs(u)))
35+
u .*= s
36+
v .*= conj(s)
37+
elseif j <= m
38+
u = view(U, :, j)
39+
s = conj(sign(_argmaxabs(u)))
40+
u .*= s
41+
else
42+
v = view(Vᴴ, j, :)
43+
s = conj(sign(_argmaxabs(v)))
44+
v .*= s
45+
end
46+
end
47+
return USVᴴ
48+
end
49+
50+
function MatrixAlgebraKit.svd_compact!(A::CuMatrix, USVᴴ, alg::CUSOLVER_SVDAlgorithm)
51+
check_input(svd_compact!, A, USVᴴ)
52+
U, S, Vᴴ = USVᴴ
53+
if alg isa CUSOLVER_QRIteration
54+
isempty(alg.kwargs) ||
55+
throw(ArgumentError("CUSOLVER_QRIteration does not accept any keyword arguments"))
56+
YACUSOLVER.gesvd!(A, S.diag, U, Vᴴ)
57+
elseif alg isa CUSOLVER_SVDPolar
58+
YACUSOLVER.Xgesvdp!(A, S.diag, U, Vᴴ; alg.kwargs...)
59+
elseif alg isa CUSOLVER_Jacobi
60+
YACUSOLVER.gesvdj!(A, S.diag, U, Vᴴ; alg.kwargs...)
61+
# elseif alg isa LAPACK_DivideAndConquer
62+
# isempty(alg.kwargs) ||
63+
# throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments"))
64+
# YALAPACK.gesdd!(A, S.diag, U, Vᴴ)
65+
# elseif alg isa LAPACK_Bisection
66+
# YALAPACK.gesvdx!(A, S.diag, U, Vᴴ; alg.kwargs...)
67+
else
68+
throw(ArgumentError("Unsupported SVD algorithm"))
69+
end
70+
# TODO: make this controllable using a `gaugefix` keyword argument
71+
for j in 1:size(U, 2)
72+
u = view(U, :, j)
73+
v = view(Vᴴ, j, :)
74+
s = conj(sign(_argmaxabs(u)))
75+
u .*= s
76+
v .*= conj(s)
77+
end
78+
return USVᴴ
79+
end
80+
_argmaxabs(x) = reduce(_largest, x; init=zero(eltype(x)))
81+
_largest(x, y) = abs(x) < abs(y) ? y : x
82+
83+
function MatrixAlgebraKit.svd_vals!(A::CuMatrix, S, alg::CUSOLVER_SVDAlgorithm)
84+
check_input(svd_vals!, A, S)
85+
U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0))
86+
if alg isa CUSOLVER_QRIteration
87+
isempty(alg.kwargs) ||
88+
throw(ArgumentError("CUSOLVER_QRIteration does not accept any keyword arguments"))
89+
YACUSOLVER.gesvd!(A, S, U, Vᴴ)
90+
elseif alg isa CUSOLVER_SVDPolar
91+
YACUSOLVER.Xgesvdp!(A, S, U, Vᴴ; alg.kwargs...)
92+
elseif alg isa CUSOLVER_Jacobi
93+
YACUSOLVER.gesvdj!(A, S, U, Vᴴ; alg.kwargs...)
94+
# elseif alg isa LAPACK_DivideAndConquer
95+
# isempty(alg.kwargs) ||
96+
# throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments"))
97+
# YALAPACK.gesdd!(A, S, U, Vᴴ)
98+
# elseif alg isa LAPACK_Bisection
99+
# YALAPACK.gesvdx!(A, S, U, Vᴴ; alg.kwargs...)
100+
# elseif alg isa LAPACK_Jacobi
101+
# isempty(alg.kwargs) ||
102+
# throw(ArgumentError("LAPACK_Jacobi does not accept any keyword arguments"))
103+
# YALAPACK.gesvj!(A, S, U, Vᴴ)
104+
else
105+
throw(ArgumentError("Unsupported SVD algorithm"))
106+
end
107+
return S
108+
end

0 commit comments

Comments
 (0)