@@ -171,84 +171,81 @@ function Xgesvdp!(A::StridedCuMatrix{T},
171171end
172172
173173# Wrapper for SVD via Jacobi
174- # for (bname, fname, elty, relty) in
175- # ((:cusolverDnSgesvdj_bufferSize, :cusolverDnSgesvdj, :Float32, :Float32),
176- # (:cusolverDnDgesvdj_bufferSize, :cusolverDnDgesvdj, :Float64, :Float64),
177- # (:cusolverDnCgesvdj_bufferSize, :cusolverDnCgesvdj, :ComplexF32, :Float32),
178- # (:cusolverDnZgesvdj_bufferSize, :cusolverDnZgesvdj, :ComplexF64, :Float64))
179- # @eval begin
180- # #! format: off
181- # function gesvdj!(A::StridedCuMatrix{$elty},
182- # S::StridedCuVector{$relty}=similar(A, $relty, min(size(A)...)),
183- # U::StridedCuMatrix{$elty}=similar(A, $elty, size(A, 1), min(size(A)...)),
184- # Vᴴ::StridedCuMatrix{$elty}=similar(A, $elty, min(size(A)...), size(A, 2));
185- # tol::$relty=eps($relty),
186- # max_sweeps::Int=100)
187- # #! format: on
188- # chkstride1(A, U, Vᴴ, S)
189- # m, n = size(A)
190- # minmn = min(m, n)
191-
192- # if length(U) == 0 && length(Vᴴ) == 0
193- # jobz = 'N'
194- # econ = 0
195- # else
196- # jobz = 'V'
197- # size(U, 1) == m ||
198- # throw(DimensionMismatch("row size mismatch between A and U"))
199- # size(Vᴴ, 2) == n ||
200- # throw(DimensionMismatch("column size mismatch between A and Vᴴ"))
201- # if size(U, 2) == size(Vᴴ, 1) == minmn
202- # econ = 1
203- # elseif size(U, 2) == m && size(Vᴴ, 1) == n
204- # econ = 0
205- # else
206- # throw(DimensionMismatch("invalid column size of U or row size of Vᴴ"))
207- # end
208- # end
209- # length(S) == minmn ||
210- # throw(DimensionMismatch("length mismatch between A and S"))
211-
212- # if jobz == 'N' # it seems we still need the memory for U and Vᴴ
213- # U = similar(A, $elty, m, minmn)
214- # V = similar(A, $elty, n, minmn)
215- # else
216- # V = similar(Vᴴ')
217- # end
218- # lda = max(1, stride(A, 2))
219- # ldu = max(1, stride(U, 2))
220- # ldv = max(1, stride(V, 2))
174+ for (bname, fname, elty, relty) in
175+ ((:cusolverDnSgesvdj_bufferSize , :cusolverDnSgesvdj , :Float32 , :Float32 ),
176+ (:cusolverDnDgesvdj_bufferSize , :cusolverDnDgesvdj , :Float64 , :Float64 ),
177+ (:cusolverDnCgesvdj_bufferSize , :cusolverDnCgesvdj , :ComplexF32 , :Float32 ),
178+ (:cusolverDnZgesvdj_bufferSize , :cusolverDnZgesvdj , :ComplexF64 , :Float64 ))
179+ @eval begin
180+ # ! format: off
181+ function gesvdj! (A:: StridedCuMatrix{$elty} ,
182+ S:: StridedCuVector{$relty} = similar (A, $ relty, min (size (A)... )),
183+ U:: StridedCuMatrix{$elty} = similar (A, $ elty, size (A, 1 ), min (size (A)... )),
184+ Vᴴ:: StridedCuMatrix{$elty} = similar (A, $ elty, min (size (A)... ), size (A, 2 ));
185+ tol:: $relty = eps ($ relty),
186+ max_sweeps:: Int = 100 )
187+ # ! format: on
188+ chkstride1 (A, U, Vᴴ, S)
189+ m, n = size (A)
190+ minmn = min (m, n)
221191
222- # params = Ref{gesvdjInfo_t}(C_NULL)
223- # cusolverDnCreateGesvdjInfo(params)
224- # cusolverDnXgesvdjSetTolerance(params[], tol)
225- # cusolverDnXgesvdjSetMaxSweeps(params[], max_sweeps)
226- # dh = dense_handle()
192+ if length (U) == 0 && length (Vᴴ) == 0
193+ jobz = ' N'
194+ econ = 0
195+ else
196+ jobz = ' V'
197+ size (U, 1 ) == m ||
198+ throw (DimensionMismatch (" row size mismatch between A and U" ))
199+ size (Vᴴ, 2 ) == n ||
200+ throw (DimensionMismatch (" column size mismatch between A and Vᴴ" ))
201+ if size (U, 2 ) == size (Vᴴ, 1 ) == minmn
202+ econ = 1
203+ elseif size (U, 2 ) == m && size (Vᴴ, 1 ) == n
204+ econ = 0
205+ else
206+ throw (DimensionMismatch (" invalid column size of U or row size of Vᴴ" ))
207+ end
208+ end
209+ length (S) == minmn ||
210+ throw (DimensionMismatch (" length mismatch between A and S" ))
227211
228- # function bufferSize()
229- # out = Ref{Cint}(0)
230- # $bname(dh, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv,
231- # out, params[])
232- # return out[] * sizeof($elty)
233- # end
212+ Ṽ = (jobz == ' V' ) ? similar (Vᴴ' ) : similar (Vᴴ, (n, minmn))
213+ Ũ = (jobz == ' V' ) ? U : similar (U, (m, minmn))
214+ lda = max (1 , stride (A, 2 ))
215+ ldu = max (1 , stride (Ũ, 2 ))
216+ ldv = max (1 , stride (Ṽ, 2 ))
234217
235- # with_workspace(dh.workspace_gpu, bufferSize) do buffer
236- # return $fname(dh, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv,
237- # buffer, sizeof(buffer) ÷ sizeof($elty), dh.info, params[])
238- # end
218+ params = Ref {CUSOLVER.gesvdjInfo_t} (C_NULL )
219+ CUSOLVER. cusolverDnCreateGesvdjInfo (params)
220+ CUSOLVER. cusolverDnXgesvdjSetTolerance (params[], tol)
221+ CUSOLVER. cusolverDnXgesvdjSetMaxSweeps (params[], max_sweeps)
222+ dh = CUSOLVER. dense_handle ()
239223
240- # info = @allowscalar dh.info[1]
241- # chkargsok(BlasInt(info))
224+ function bufferSize ()
225+ out = Ref {Cint} (0 )
226+ CUSOLVER.$ bname (dh, jobz, econ, m, n, A, lda, S, Ũ, ldu, Ṽ, ldv,
227+ out, params[])
228+ return out[] * sizeof ($ elty)
229+ end
242230
243- # cusolverDnDestroyGesvdjInfo(params[])
231+ CUSOLVER. with_workspace (dh. workspace_gpu, bufferSize) do buffer
232+ return CUSOLVER.$ fname (dh, jobz, econ, m, n, A, lda, S, Ũ, ldu, Ṽ, ldv,
233+ buffer, sizeof (buffer) ÷ sizeof ($ elty), dh. info,
234+ params[])
235+ end
244236
245- # if jobz != 'N'
246- # adjoint!(Vᴴ, V)
247- # end
248- # return U, S, Vᴴ
249- # end
250- # end
251- # end
237+ info = @allowscalar dh. info[1 ]
238+ CUSOLVER. chkargsok (BlasInt (info))
239+
240+ CUSOLVER. cusolverDnDestroyGesvdjInfo (params[])
241+
242+ if jobz == ' V'
243+ adjoint! (Vᴴ, Ṽ)
244+ end
245+ return U, S, Vᴴ
246+ end
247+ end
248+ end
252249
253250# for (jname, bname, fname, elty, relty) in
254251# ((:sygvd!, :cusolverDnSsygvd_bufferSize, :cusolverDnSsygvd, :Float32, :Float32),
0 commit comments