Skip to content

Commit 7f09c66

Browse files
committed
Attempting to wrap randomized SVD
1 parent 1220c32 commit 7f09c66

5 files changed

Lines changed: 118 additions & 3 deletions

File tree

ext/MatrixAlgebraKitCUDAExt/yacusolver.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,66 @@ for (bname, fname, elty, relty) in
247247
end
248248
end
249249

250+
# Wrapper for randomized SVD
251+
function Xgesvdr!(A::StridedCuMatrix{T},
252+
S::StridedCuVector=similar(A, real(T), min(size(A)...)),
253+
U::StridedCuMatrix{T}=similar(A, T, size(A, 1), min(size(A)...)),
254+
Vᴴ::StridedCuMatrix{T}=similar(A, T, min(size(A)...), size(A, 2));
255+
k::Int=length(S),
256+
p::Int=min(size(A)...)-k-1,
257+
niters::Int=1) where {T<:BlasFloat}
258+
chkstride1(A, U, S, Vᴴ)
259+
m, n = size(A)
260+
minmn = min(m, n)
261+
jobu = length(U) == 0 ? 'N' : 'S'
262+
jobv = length(Vᴴ) == 0 ? 'N' : 'S'
263+
k = min(size(S)...)
264+
R = eltype(S)
265+
k < minmn || throw(DimensionMismatch("length of S ($k) must be less than the smaller dimension of A ($minmn)"))
266+
k + p < minmn || throw(DimensionMismatch("length of S ($k) plus oversampling ($p) must be less than the smaller dimension of A ($minmn)"))
267+
R == real(T) ||
268+
throw(ArgumentError("S does not have the matching real `eltype` of A"))
269+
270+
= similar(Vᴴ, (n, n))
271+
= (size(U) == (m, m)) ? U : similar(U, (m, m))
272+
lda = max(1, stride(A, 2))
273+
ldu = max(1, stride(Ũ, 2))
274+
ldv = max(1, stride(Ṽ, 2))
275+
params = CUSOLVER.CuSolverParameters()
276+
dh = CUSOLVER.dense_handle()
277+
278+
function bufferSize()
279+
out_cpu = Ref{Csize_t}(0)
280+
out_gpu = Ref{Csize_t}(0)
281+
CUSOLVER.cusolverDnXgesvdr_bufferSize(dh, params, jobu, jobv, m, n, k, p, niters,
282+
T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv,
283+
T, out_gpu, out_cpu)
284+
285+
return out_gpu[], out_cpu[]
286+
end
287+
CUSOLVER.with_workspaces(dh.workspace_gpu, dh.workspace_cpu,
288+
bufferSize()...) do buffer_gpu, buffer_cpu
289+
return CUSOLVER.cusolverDnXgesvdr(dh, params, jobu, jobv, m, n, k, p, niters,
290+
T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv,
291+
T, buffer_gpu, sizeof(buffer_gpu),
292+
buffer_cpu, sizeof(buffer_cpu),
293+
dh.info)
294+
end
295+
296+
flag = @allowscalar dh.info[1]
297+
CUSOLVER.chklapackerror(BlasInt(flag))
298+
if!== U && length(U) > 0
299+
U .= view(Ũ, 1:m, 1:size(U, 2))
300+
end
301+
if length(Vᴴ) > 0
302+
Vᴴ .= view(Ṽ', 1:size(Vᴴ, 1), 1:n)
303+
end
304+
!== U && CUDA.unsafe_free!(Ũ)
305+
CUDA.unsafe_free!(Ṽ)
306+
307+
return S, U, Vᴴ
308+
end
309+
250310
# for (jname, bname, fname, elty, relty) in
251311
# ((:sygvd!, :cusolverDnSsygvd_bufferSize, :cusolverDnSsygvd, :Float32, :Float32),
252312
# (:sygvd!, :cusolverDnDsygvd_bufferSize, :cusolverDnDsygvd, :Float64, :Float64),

src/MatrixAlgebraKit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
3333
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
3434
LAPACK_DivideAndConquer, LAPACK_Jacobi,
3535
LQViaTransposedQR,
36-
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi,
36+
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized,
3737
ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi
3838
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered
3939

src/implementations/svd.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::Truncat
6666
return initialize_output(svd_compact!, A, alg.alg)
6767
end
6868

69+
6970
# Implementation
7071
# --------------
7172
function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
@@ -174,6 +175,7 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
174175
return S
175176
end
176177

178+
#YACUSOLVER.Xgesvdr!(A, view(S, 1:k, 1), U, Vᴴ; alg.kwargs...)
177179
function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm)
178180
USVᴴ′ = svd_compact!(A, USVᴴ, alg.alg)
179181
return truncate!(svd_trunc!, USVᴴ′, alg.trunc)
@@ -185,17 +187,20 @@ end
185187
###
186188
const CUSOLVER_SVDAlgorithm = Union{CUSOLVER_QRIteration,
187189
CUSOLVER_SVDPolar,
188-
CUSOLVER_Jacobi}
190+
CUSOLVER_Jacobi,
191+
CUSOLVER_Randomized}
189192
const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration,
190193
ROCSOLVER_Jacobi}
191194
const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm}
192195

193196
const GPU_QRIteration = Union{CUSOLVER_QRIteration, ROCSOLVER_QRIteration}
194197
const GPU_SVDPolar = Union{CUSOLVER_SVDPolar}
195198
const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi}
199+
const GPU_Randomized = Union{CUSOLVER_Randomized}
196200

197201
_gpu_gesvd!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix) = throw(MethodError(_gpu_gesvd!, (A, S, U, Vᴴ)))
198202
_gpu_Xgesvdp!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_Xgesvdp!, (A, S, U, Vᴴ)))
203+
_gpu_Xgesvdr!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_Xgesvdr!, (A, S, U, Vᴴ)))
199204
_gpu_gesvdj!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_gesvdj!, (A, S, U, Vᴴ)))
200205

201206
# GPU SVD implementation

src/interface/decompositions.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,19 @@ a general matrix using the Jacobi algorithm.
151151
"""
152152
@algdef CUSOLVER_Jacobi
153153

154+
"""
155+
CUSOLVER_Randomized(; p, niters)
156+
157+
Algorithm type to denote the CUSOLVER driver for computing the singular value decomposition of
158+
a general matrix using the randomized SVD algorithm.
159+
160+
!!! note
161+
Randomized SVD cannot compute all singular values of the input matrix `A`, only the first `k` where
162+
`k < min(m, n)`. The remainder are used for oversampling. See the [CUSOLVER documentation](https://docs.nvidia.com/cuda/cusolver/index.html#cusolverdnxgesvdr)
163+
for more information.
164+
"""
165+
@algdef CUSOLVER_Randomized
166+
154167
# =========================
155168
# ROCSOLVER ALGORITHMS
156169
# =========================

test/cuda/svd.jl

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,49 @@ end
7676
@test isapproxone(Vᴴ' * Vᴴ)
7777
@test all(isposdef, diagview(S))
7878

79-
Sc = similar(A, real(T), min(m, n))
79+
minmn = min(m, n)
80+
Sc = similar(A, real(T), minmn)
8081
Sc2 = svd_vals!(copy!(Ac, A), Sc, alg)
8182
@test Sc === Sc2
8283
@test CuArray(diagview(S)) Sc
8384
# CuArray is necessary because norm of CuArray view with non-unit step is broken
8485
end
86+
k = min(m, n) - 20
87+
p = min(m, n) - k - 1
88+
algs = (CUSOLVER_Randomized(; k=k, p=p, niters=100),)
89+
@testset "algorithm $alg" for alg in algs
90+
A = CuArray(randn(rng, T, m, n))
91+
Uref, Sref, Vᴴref = svd_full(A, CUSOLVER_SVDPolar())
92+
U, S, Vᴴ = svd_full(A; alg)
93+
@test U isa CuMatrix{T} && size(U) == (m, m)
94+
@test S isa CuMatrix{real(T)} && size(S) == (m, n)
95+
@test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (n, n)
96+
for col in 1:k
97+
@test view(collect(U), :, col) view(collect(Uref), :, col)
98+
@test view(collect(Vᴴ), col, :) view(collect(Vᴴref), col, :)
99+
end
100+
@test all(isposdef, view(diagview(S), 1:k))
101+
@test view(CuArray(diagview(S)), 1:k) view(CuArray(diagview(Sref)), 1:k)
102+
103+
Ac = similar(A)
104+
U2, S2, V2ᴴ = @constinferred svd_full!(copy!(Ac, A), (U, S, Vᴴ), alg)
105+
@test U2 === U
106+
@test S2 === S
107+
@test V2ᴴ === Vᴴ
108+
for col in 1:k
109+
@test view(collect(U), :, col) view(collect(Uref), :, col)
110+
@test view(collect(Vᴴ), col, :) view(collect(Vᴴref), col, :)
111+
end
112+
@test all(isposdef, view(diagview(S), 1:k))
113+
@test view(CuArray(diagview(S2)), 1:k) view(CuArray(diagview(Sref)), 1:k)
114+
115+
Sc = similar(A, real(T), k)
116+
Sc2 = svd_vals!(copy!(Ac, A), Sc, alg)
117+
@test Sc === Sc2
118+
@test view(Sc, 1:k) view(CuArray(diagview(Sref)), 1:k)
119+
@test view(CuArray(diagview(S)), 1:k) Sc
120+
# CuArray is necessary because norm of CuArray view with non-unit step is broken
121+
end
85122
end
86123
end
87124

0 commit comments

Comments
 (0)