Skip to content

Commit ffbfa85

Browse files
committed
Some GPU fixes
1 parent 213fe3c commit ffbfa85

5 files changed

Lines changed: 89 additions & 57 deletions

File tree

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -52,30 +52,27 @@ end
5252
gesvdp!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
5353
YACUSOLVER.gesvdp!(A, S, U, Vᴴ; kwargs...)
5454

55-
# Sketched SVD via cuSOLVER's gesvdr kernel
55+
# Sketched SVD via cuSOLVER's gesvdr kernel.
56+
# The full m×m / n×n shapes of U / Vᴴ allow YACUSOLVER.gesvdr! to reuse them as cuSOLVER workspace.
57+
# `alg` is accepted but unused: cuSOLVER's gesvdr fuses the inner SVD itself.
5658
function gesvdr!(
5759
::CUSOLVER, A::StridedCuMatrix, S, U::StridedCuMatrix, Vᴴ::StridedCuMatrix;
5860
sketch::GaussianSketching, trunc::TruncationByOrder, alg::AbstractAlgorithm = DefaultAlgorithm()
5961
)
6062
isempty(A) && return U, S, Vᴴ
61-
m, n = size(A); minmn = min(m, n)
62-
k = trunc.howmany
63-
1 k minmn ||
64-
throw(ArgumentError("trunc.howmany=$k must satisfy 1 ≤ k ≤ min(size(A))=$minmn"))
65-
p = sketch.howmany - k
66-
p 0 || throw(
67-
ArgumentError(
68-
"sketch.howmany=$(sketch.howmany) must be ≥ trunc.howmany=$k"
69-
)
70-
)
71-
p = min(p, minmn - k - 1)
72-
niters = sketch.numiter - 1
73-
74-
Uk = view(U, :, 1:k)
75-
Vᴴk = view(Vᴴ, 1:k, :)
76-
Sk = view(diagview(S), 1:k)
77-
YACUSOLVER.gesvdr!(A, Sk, Uk, Vᴴk; k, p, niters)
78-
return Uk, S, Vᴴk
63+
m, n = size(A)
64+
sketch_amount = min(sketch.howmany, m, n)
65+
k = min(trunc.howmany, m, n)
66+
p = max(sketch_amount - k, 0)
67+
numiter = sketch.numiter
68+
69+
V = Vᴴ # gesvdr returns V, but this has to be the same size so we will use this as workspace
70+
71+
YACUSOLVER.gesvdr!(A, diagview(S), U, V; k, p, numiter)
72+
73+
# Truncate requires Vᴴ, so we adjoint here
74+
USVᴴtrunc, _ = MatrixAlgebraKit.truncate(MatrixAlgebraKit.svd_trunc!, (U, S, V'), trunc)
75+
return USVᴴtrunc
7976
end
8077

8178
geev!(::CUSOLVER, A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix) =

ext/MatrixAlgebraKitCUDAExt/yacusolver.jl

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -266,41 +266,53 @@ for (bname, fname, elty, relty) in
266266
end
267267
end
268268

269-
# Wrapper for randomized SVD
269+
# Wrapper for randomized SVD.
270+
# Caller must supply full-size buffers: U is (m, m) and Vᴴ is (n, n); both are reused
271+
# directly as cuSOLVER's workspace, and Vᴴ is converted in place from V to Vᴴ on the
272+
# leading k rows after cuSOLVER returns.
273+
# !!! Warning: this function takes in/returns V instead of Vᴴ
270274
function gesvdr!(
271275
A::StridedCuMatrix{T},
272276
S::StridedCuVector = similar(A, real(T), min(size(A)...)),
273-
U::StridedCuMatrix{T} = similar(A, T, size(A, 1), min(size(A)...)),
274-
Vᴴ::StridedCuMatrix{T} = similar(A, T, min(size(A)...), size(A, 2));
277+
U::StridedCuMatrix{T} = similar(A, T, size(A, 1), size(A, 1)),
278+
V::StridedCuMatrix{T} = similar(A, T, size(A, 2), size(A, 2));
275279
k::Int = length(S),
276280
p::Int = min(size(A)...) - k - 1,
277-
niters::Int = 1
281+
numiter::Int = 1,
278282
) where {T <: BlasFloat}
279-
chkstride1(A, U, S, Vᴴ)
283+
chkstride1(A, U, S, V)
280284
m, n = size(A)
281285
minmn = min(m, n)
282-
jobu = length(U) == 0 ? 'N' : 'S'
283-
jobv = length(Vᴴ) == 0 ? 'N' : 'S'
284286
R = eltype(S)
285-
k < minmn || throw(DimensionMismatch("length of S ($k) must be less than the smaller dimension of A ($minmn)"))
286-
k + p < minmn || throw(DimensionMismatch("length of S ($k) plus oversampling ($p) must be less than the smaller dimension of A ($minmn)"))
287287
R == real(T) ||
288288
throw(ArgumentError("S does not have the matching real `eltype` of A"))
289-
290-
= similar(Vᴴ, (n, n))
291-
= (size(U) == (m, m)) ? U : similar(U, (m, m))
289+
length(S) == minmn ||
290+
throw(DimensionMismatch("length of S ($(length(S))) must equal min(size(A)) = $minmn"))
291+
size(U) == (m, m) ||
292+
throw(DimensionMismatch("U must have shape (m, m) = ($m, $m); got $(size(U))"))
293+
size(V) == (n, n) ||
294+
throw(DimensionMismatch("V must have shape (n, n) = ($n, $n); got $(size(V))"))
295+
k < minmn ||
296+
throw(DimensionMismatch("rank k ($k) must be less than min(size(A)) = $minmn"))
297+
k + p < minmn ||
298+
throw(DimensionMismatch("k + p ($(k + p)) must be less than min(size(A)) = $minmn"))
299+
300+
isempty(A) && return S, U, V
301+
302+
jobu = 'S'
303+
jobv = 'S'
292304
lda = max(1, stride(A, 2))
293-
ldu = max(1, stride(, 2))
294-
ldv = max(1, stride(, 2))
305+
ldu = max(1, stride(U, 2))
306+
ldv = max(1, stride(V, 2))
295307
params = cuSOLVER.CuSolverParameters()
296308
dh = cuSOLVER.dense_handle()
297309

298310
function bufferSize()
299311
out_cpu = Ref{Csize_t}(0)
300312
out_gpu = Ref{Csize_t}(0)
301313
cuSOLVER.cusolverDnXgesvdr_bufferSize(
302-
dh, params, jobu, jobv, m, n, k, p, niters,
303-
T, A, lda, R, S, T, , ldu, T, , ldv,
314+
dh, params, jobu, jobv, m, n, k, p, numiter,
315+
T, A, lda, R, S, T, U, ldu, T, V, ldv,
304316
T, out_gpu, out_cpu
305317
)
306318

@@ -311,8 +323,8 @@ function gesvdr!(
311323
bufferSize()...
312324
) do buffer_gpu, buffer_cpu
313325
return cuSOLVER.cusolverDnXgesvdr(
314-
dh, params, jobu, jobv, m, n, k, p, niters,
315-
T, A, lda, R, S, T, , ldu, T, , ldv,
326+
dh, params, jobu, jobv, m, n, k, p, numiter,
327+
T, A, lda, R, S, T, U, ldu, T, V, ldv,
316328
T, buffer_gpu, sizeof(buffer_gpu),
317329
buffer_cpu, sizeof(buffer_cpu),
318330
dh.info
@@ -321,16 +333,8 @@ function gesvdr!(
321333

322334
flag = @allowscalar dh.info[1]
323335
cuSOLVER.chklapackerror(BlasInt(flag))
324-
if!== U && length(U) > 0
325-
U .= view(Ũ, 1:m, 1:size(U, 2))
326-
end
327-
if length(Vᴴ) > 0
328-
Vᴴ .= view(Ṽ', 1:size(Vᴴ, 1), 1:n)
329-
end
330-
!== U && CUDA.unsafe_free!(Ũ)
331-
CUDA.unsafe_free!(Ṽ)
332336

333-
return S, U, Vᴴ
337+
return S, U, V
334338
end
335339

336340
# Wrapper for general eigensolver

src/implementations/svd.jl

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,36 @@ function svd_trunc_no_error!(A::AbstractMatrix, (U, S, Vᴴ), alg::SketchedAlgor
311311
return gesvdr!(alg.driver, A, S, U, Vᴴ; alg.sketch, alg.alg, alg.trunc)
312312
end
313313

314+
# CUSOLVER's gesvdr kernel requires full U and Vᴴ
315+
function initialize_output(
316+
::typeof(svd_trunc_no_error!), A::AbstractMatrix,
317+
alg::SketchedAlgorithm{<:AbstractAlgorithm, <:SketchingStrategy, <:TruncationStrategy, CUSOLVER},
318+
)
319+
m, n = size(A)
320+
minmn = min(m, n)
321+
T = float(eltype(A))
322+
U = similar(A, T, (m, m))
323+
S = Diagonal(similar(A, real(T), (minmn,)))
324+
Vᴴ = similar(A, T, (n, n))
325+
return (U, S, Vᴴ)
326+
end
327+
328+
function check_input(
329+
::typeof(svd_trunc_no_error!), A::AbstractMatrix, (U, S, Vᴴ),
330+
alg::SketchedAlgorithm{<:AbstractAlgorithm, <:SketchingStrategy, <:TruncationStrategy, CUSOLVER},
331+
)
332+
m, n = size(A)
333+
minmn = min(m, n)
334+
@assert U isa AbstractMatrix && S isa Diagonal && Vᴴ isa AbstractMatrix
335+
@check_size(U, (m, m))
336+
@check_scalar(U, A)
337+
@check_size(S, (minmn, minmn))
338+
@check_scalar(S, A, real)
339+
@check_size(Vᴴ, (n, n))
340+
@check_scalar(Vᴴ, A)
341+
return nothing
342+
end
343+
314344
function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::SketchedAlgorithm)
315345
U, S, Vᴴ = svd_trunc_no_error!(A, USVᴴ, alg)
316346
Na = norm(A)
@@ -399,7 +429,7 @@ function _cusolver_randomized_to_sketched(alg::CUSOLVER_Randomized)
399429
niters = alg.kwargs.niters
400430
return SketchedAlgorithm(
401431
QRIteration(),
402-
GaussianSketching(k + p; numiter = niters + 1),
432+
GaussianSketching(k + p; numiter = niters),
403433
truncrank(k);
404434
driver = CUSOLVER(),
405435
)
@@ -415,7 +445,7 @@ end
415445
@inline function select_algorithm(::typeof(svd_trunc!), A, alg::CUSOLVER_Randomized; kwargs...)
416446
Base.depwarn(
417447
"`CUSOLVER_Randomized` is deprecated; use \
418-
`SketchedAlgorithm(QRIteration(), GaussianSketching(k+p; numiter=niters+1), truncrank(k); driver=CUSOLVER())` instead.",
448+
`SketchedAlgorithm(QRIteration(), GaussianSketching(k+p; numiter=niters), truncrank(k); driver=CUSOLVER())` instead.",
419449
:select_algorithm,
420450
)
421451
isempty(kwargs) ||

src/yalapack.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2351,7 +2351,7 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
23512351
throw(DimensionMismatch("length mismatch between A ($n) and S ($(length(S)))"))
23522352

23532353
lda = max(1, stride(A, 2))
2354-
mv = Ref{BlasInt}() # unused
2354+
mv = Ref{BlasInt}(0) # unused by LAPACK when JOBV='V', but must satisfy MV ≥ 0 input check
23552355
if jobv == 'V'
23562356
if U !== A
23572357
V = view(U, 1:n, 1:n) # use U as V storage

test/decompositions/svd.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ if !is_buildkite
3636
SketchedAlgorithm(; sketch = GaussianSketching(m ÷ 2, numiter = 4), trunc = truncrank(m ÷ 4)),
3737
]
3838
TestSuite.test_sketched_svd(T, (m, n), algs; rtol)
39-
TestSuite.test_sketched_svd(T, (n, m), algs)
39+
TestSuite.test_sketched_svd(T, (n, m), algs; rtol)
4040
end
4141

4242
# Generic floats:
@@ -57,7 +57,7 @@ end
5757

5858
# CUDA tests
5959
# ------------
60-
if false # CUDA.functional()
60+
if CUDA.functional()
6161
# LAPACK algorithms:
6262
for T in BLASFloats, m in (0, 23), n in (0, 17, m, 27)
6363
TestSuite.seed_rng!(123)
@@ -66,18 +66,19 @@ if false # CUDA.functional()
6666
TestSuite.test_svd_algs(CuMatrix{T}, (m, n), CUDA_SVD_ALGS)
6767
end
6868

69-
# Randomized SVD:
69+
# Sketched SVD:
7070
for T in BLASFloats, m in (0, 23), n in (0, 17, m, 27)
7171
TestSuite.seed_rng!(123)
7272
k = 5
7373
p = min(m, n) - k - 2
7474
p > 0 || continue
75-
cusolver_sketch = SketchedAlgorithm(
76-
GaussianSketching(k; oversampling = p, niters = 20),
77-
DefaultAlgorithm(),
78-
MatrixAlgebraKit.CUSOLVER(),
75+
rtol = sqrt(TestSuite.precision(T)) # extra square root
76+
cusolver_sketch = SketchedAlgorithm(;
77+
sketch = GaussianSketching(k + p; numiter = 20),
78+
trunc = truncrank(k),
79+
driver = MatrixAlgebraKit.CUSOLVER(),
7980
)
80-
TestSuite.test_randomized_svd(CuMatrix{T}, (m, n), (cusolver_sketch,))
81+
TestSuite.test_sketched_svd(CuMatrix{T}, (m, n), (cusolver_sketch,); rtol)
8182
end
8283

8384
# Diagonal:

0 commit comments

Comments
 (0)