@@ -266,41 +266,53 @@ for (bname, fname, elty, relty) in
266266 end
267267end
268268
269- # Wrapper for randomized SVD
269+ # Wrapper for randomized SVD.
270+ # Caller must supply full-size buffers: U is (m, m) and Vᴴ is (n, n); both are reused
271+ # directly as cuSOLVER's workspace, and Vᴴ is converted in place from V to Vᴴ on the
272+ # leading k rows after cuSOLVER returns.
273+ # !!! Warning: this function takes in/returns V instead of Vᴴ
270274function gesvdr! (
271275 A:: StridedCuMatrix{T} ,
272276 S:: StridedCuVector = similar (A, real (T), min (size (A)... )),
273- U:: StridedCuMatrix{T} = similar (A, T, size (A, 1 ), min ( size (A) ... )),
274- Vᴴ :: StridedCuMatrix{T} = similar (A, T, min ( size (A) ... ), size (A, 2 ));
277+ U:: StridedCuMatrix{T} = similar (A, T, size (A, 1 ), size (A, 1 )),
278+ V :: StridedCuMatrix{T} = similar (A, T, size (A, 2 ), size (A, 2 ));
275279 k:: Int = length (S),
276280 p:: Int = min (size (A)... ) - k - 1 ,
277- niters :: Int = 1
281+ numiter :: Int = 1 ,
278282 ) where {T <: BlasFloat }
279- chkstride1 (A, U, S, Vᴴ )
283+ chkstride1 (A, U, S, V )
280284 m, n = size (A)
281285 minmn = min (m, n)
282- jobu = length (U) == 0 ? ' N' : ' S'
283- jobv = length (Vᴴ) == 0 ? ' N' : ' S'
284286 R = eltype (S)
285- k < minmn || throw (DimensionMismatch (" length of S ($k ) must be less than the smaller dimension of A ($minmn )" ))
286- k + p < minmn || throw (DimensionMismatch (" length of S ($k ) plus oversampling ($p ) must be less than the smaller dimension of A ($minmn )" ))
287287 R == real (T) ||
288288 throw (ArgumentError (" S does not have the matching real `eltype` of A" ))
289-
290- Ṽ = similar (Vᴴ, (n, n))
291- Ũ = (size (U) == (m, m)) ? U : similar (U, (m, m))
289+ length (S) == minmn ||
290+ throw (DimensionMismatch (" length of S ($(length (S)) ) must equal min(size(A)) = $minmn " ))
291+ size (U) == (m, m) ||
292+ throw (DimensionMismatch (" U must have shape (m, m) = ($m , $m ); got $(size (U)) " ))
293+ size (V) == (n, n) ||
294+ throw (DimensionMismatch (" V must have shape (n, n) = ($n , $n ); got $(size (V)) " ))
295+ k < minmn ||
296+ throw (DimensionMismatch (" rank k ($k ) must be less than min(size(A)) = $minmn " ))
297+ k + p < minmn ||
298+ throw (DimensionMismatch (" k + p ($(k + p) ) must be less than min(size(A)) = $minmn " ))
299+
300+ isempty (A) && return S, U, V
301+
302+ jobu = ' S'
303+ jobv = ' S'
292304 lda = max (1 , stride (A, 2 ))
293- ldu = max (1 , stride (Ũ , 2 ))
294- ldv = max (1 , stride (Ṽ , 2 ))
305+ ldu = max (1 , stride (U , 2 ))
306+ ldv = max (1 , stride (V , 2 ))
295307 params = cuSOLVER. CuSolverParameters ()
296308 dh = cuSOLVER. dense_handle ()
297309
298310 function bufferSize ()
299311 out_cpu = Ref {Csize_t} (0 )
300312 out_gpu = Ref {Csize_t} (0 )
301313 cuSOLVER. cusolverDnXgesvdr_bufferSize (
302- dh, params, jobu, jobv, m, n, k, p, niters ,
303- T, A, lda, R, S, T, Ũ , ldu, T, Ṽ , ldv,
314+ dh, params, jobu, jobv, m, n, k, p, numiter ,
315+ T, A, lda, R, S, T, U , ldu, T, V , ldv,
304316 T, out_gpu, out_cpu
305317 )
306318
@@ -311,8 +323,8 @@ function gesvdr!(
311323 bufferSize ()...
312324 ) do buffer_gpu, buffer_cpu
313325 return cuSOLVER. cusolverDnXgesvdr (
314- dh, params, jobu, jobv, m, n, k, p, niters ,
315- T, A, lda, R, S, T, Ũ , ldu, T, Ṽ , ldv,
326+ dh, params, jobu, jobv, m, n, k, p, numiter ,
327+ T, A, lda, R, S, T, U , ldu, T, V , ldv,
316328 T, buffer_gpu, sizeof (buffer_gpu),
317329 buffer_cpu, sizeof (buffer_cpu),
318330 dh. info
@@ -321,16 +333,8 @@ function gesvdr!(
321333
322334 flag = @allowscalar dh. info[1 ]
323335 cuSOLVER. chklapackerror (BlasInt (flag))
324- if Ũ != = U && length (U) > 0
325- U .= view (Ũ, 1 : m, 1 : size (U, 2 ))
326- end
327- if length (Vᴴ) > 0
328- Vᴴ .= view (Ṽ' , 1 : size (Vᴴ, 1 ), 1 : n)
329- end
330- Ũ != = U && CUDA. unsafe_free! (Ũ)
331- CUDA. unsafe_free! (Ṽ)
332336
333- return S, U, Vᴴ
337+ return S, U, V
334338end
335339
336340# Wrapper for general eigensolver
0 commit comments