@@ -242,11 +242,70 @@ for (bname, fname, elty, relty) in
242242 if jobz == ' V'
243243 adjoint! (Vᴴ, Ṽ)
244244 end
245- return U, S , Vᴴ
245+ return S, U , Vᴴ
246246 end
247247 end
248248end
249249
250+ # Wrapper for randomized SVD
251+ function Xgesvdr! (A:: StridedCuMatrix{T} ,
252+ S:: StridedCuVector = similar (A, real (T), min (size (A)... )),
253+ U:: StridedCuMatrix{T} = similar (A, T, size (A, 1 ), min (size (A)... )),
254+ Vᴴ:: StridedCuMatrix{T} = similar (A, T, min (size (A)... ), size (A, 2 ));
255+ k:: Int = length (S),
256+ p:: Int = min (size (A)... )- k- 1 ,
257+ niters:: Int = 1 ) where {T<: BlasFloat }
258+ chkstride1 (A, U, S, Vᴴ)
259+ m, n = size (A)
260+ minmn = min (m, n)
261+ jobu = length (U) == 0 ? ' N' : ' S'
262+ jobv = length (Vᴴ) == 0 ? ' N' : ' S'
263+ R = eltype (S)
264+ k < minmn || throw (DimensionMismatch (" length of S ($k ) must be less than the smaller dimension of A ($minmn )" ))
265+ k + p < minmn || throw (DimensionMismatch (" length of S ($k ) plus oversampling ($p ) must be less than the smaller dimension of A ($minmn )" ))
266+ R == real (T) ||
267+ throw (ArgumentError (" S does not have the matching real `eltype` of A" ))
268+
269+ Ṽ = similar (Vᴴ, (n, n))
270+ Ũ = (size (U) == (m, m)) ? U : similar (U, (m, m))
271+ lda = max (1 , stride (A, 2 ))
272+ ldu = max (1 , stride (Ũ, 2 ))
273+ ldv = max (1 , stride (Ṽ, 2 ))
274+ params = CUSOLVER. CuSolverParameters ()
275+ dh = CUSOLVER. dense_handle ()
276+
277+ function bufferSize ()
278+ out_cpu = Ref {Csize_t} (0 )
279+ out_gpu = Ref {Csize_t} (0 )
280+ CUSOLVER. cusolverDnXgesvdr_bufferSize (dh, params, jobu, jobv, m, n, k, p, niters,
281+ T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv,
282+ T, out_gpu, out_cpu)
283+
284+ return out_gpu[], out_cpu[]
285+ end
286+ CUSOLVER. with_workspaces (dh. workspace_gpu, dh. workspace_cpu,
287+ bufferSize ()... ) do buffer_gpu, buffer_cpu
288+ return CUSOLVER. cusolverDnXgesvdr (dh, params, jobu, jobv, m, n, k, p, niters,
289+ T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv,
290+ T, buffer_gpu, sizeof (buffer_gpu),
291+ buffer_cpu, sizeof (buffer_cpu),
292+ dh. info)
293+ end
294+
295+ flag = @allowscalar dh. info[1 ]
296+ CUSOLVER. chklapackerror (BlasInt (flag))
297+ if Ũ != = U && length (U) > 0
298+ U .= view (Ũ, 1 : m, 1 : size (U, 2 ))
299+ end
300+ if length (Vᴴ) > 0
301+ Vᴴ .= view (Ṽ' , 1 : size (Vᴴ, 1 ), 1 : n)
302+ end
303+ Ũ != = U && CUDA. unsafe_free! (Ũ)
304+ CUDA. unsafe_free! (Ṽ)
305+
306+ return S, U, Vᴴ
307+ end
308+
250309# for (jname, bname, fname, elty, relty) in
251310# ((:sygvd!, :cusolverDnSsygvd_bufferSize, :cusolverDnSsygvd, :Float32, :Float32),
252311# (:sygvd!, :cusolverDnDsygvd_bufferSize, :cusolverDnDsygvd, :Float64, :Float64),
0 commit comments