1- module YACUSOLVER
1+ module YAcuSOLVER
22
33using LinearAlgebra
44using LinearAlgebra: BlasInt, BlasFloat, BlasReal, checksquare, chkstride1, require_one_based_indexing
55using LinearAlgebra. LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag, chkuplo
66
77using CUDA
88using CUDA: @allowscalar , i32
9- using CUDA. CUSOLVER
9+ using CUDA. cuSOLVER
1010
1111# QR methods are implemented with full access to allocated arrays, so we do not need to redo this:
12- using CUDA. CUSOLVER : geqrf!, ormqr!, orgqr!
12+ using CUDA. cuSOLVER : geqrf!, ormqr!, orgqr!
1313const unmqr! = ormqr!
1414const ungqr! = orgqr!
1515
@@ -30,7 +30,7 @@ for (bname, fname, elty, relty) in
3030 )
3131 chkstride1 (A, U, Vᴴ, S)
3232 m, n = size (A)
33- (m < n) && throw (ArgumentError (lazy "CUSOLVER 's gesvd requires m ($m) ≥ n ($n)" ))
33+ (m < n) && throw (ArgumentError (lazy "cuSOLVER 's gesvd requires m ($m) ≥ n ($n)" ))
3434 minmn = min (m, n)
3535 if length (U) == 0
3636 jobu = ' N'
@@ -73,15 +73,15 @@ for (bname, fname, elty, relty) in
7373 ldu = max (1 , stride (U, 2 ))
7474 ldv = max (1 , stride (Vᴴ, 2 ))
7575
76- dh = CUSOLVER . dense_handle ()
76+ dh = cuSOLVER . dense_handle ()
7777 function bufferSize ()
7878 out = Ref {Cint} (0 )
79- CUSOLVER .$ bname (dh, m, n, out)
79+ cuSOLVER .$ bname (dh, m, n, out)
8080 return out[] * sizeof ($ elty)
8181 end
8282 rwork = CuArray {$relty} (undef, min (m, n) - 1 )
8383 CUDA. with_workspace (dh. workspace_gpu, bufferSize) do buffer
84- return CUSOLVER .$ fname (
84+ return cuSOLVER .$ fname (
8585 dh, jobu, jobvt, m, n,
8686 A, lda, S, U, ldu, Vᴴ, ldv,
8787 buffer, sizeof (buffer) ÷ sizeof ($ elty), rwork,
@@ -91,7 +91,7 @@ for (bname, fname, elty, relty) in
9191 CUDA. unsafe_free! (rwork)
9292
9393 info = @allowscalar dh. info[1 ]
94- CUSOLVER . chkargsok (BlasInt (info))
94+ cuSOLVER . chkargsok (BlasInt (info))
9595
9696 return (S, U, Vᴴ)
9797 end
@@ -137,25 +137,25 @@ function gesvdp!(
137137 ldu = max (1 , stride (Ũ, 2 ))
138138 ldv = max (1 , stride (Ṽ, 2 ))
139139 h_err_sigma = Ref {Cdouble} (0 )
140- params = CUSOLVER . CuSolverParameters ()
141- dh = CUSOLVER . dense_handle ()
140+ params = cuSOLVER . CuSolverParameters ()
141+ dh = cuSOLVER . dense_handle ()
142142
143143 function bufferSize ()
144144 out_cpu = Ref {Csize_t} (0 )
145145 out_gpu = Ref {Csize_t} (0 )
146- CUSOLVER . cusolverDnXgesvdp_bufferSize (
146+ cuSOLVER . cusolverDnXgesvdp_bufferSize (
147147 dh, params, jobz, econ, m, n,
148148 T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv,
149149 T, out_gpu, out_cpu
150150 )
151151
152152 return out_gpu[], out_cpu[]
153153 end
154- CUSOLVER . with_workspaces (
154+ cuSOLVER . with_workspaces (
155155 dh. workspace_gpu, dh. workspace_cpu,
156156 bufferSize ()...
157157 ) do buffer_gpu, buffer_cpu
158- return CUSOLVER . cusolverDnXgesvdp (
158+ return cuSOLVER . cusolverDnXgesvdp (
159159 dh, params, jobz, econ, m, n,
160160 T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv,
161161 T, buffer_gpu, sizeof (buffer_gpu),
@@ -167,7 +167,7 @@ function gesvdp!(
167167 err > tol && @warn " gesvdp! did not attain the requested tolerance: error = $err > tolerance = $tol "
168168
169169 flag = @allowscalar dh. info[1 ]
170- CUSOLVER . chklapackerror (BlasInt (flag))
170+ cuSOLVER . chklapackerror (BlasInt (flag))
171171 if Ũ != = U && length (U) > 0
172172 U .= view (Ũ, 1 : m, 1 : size (U, 2 ))
173173 end
@@ -230,33 +230,33 @@ for (bname, fname, elty, relty) in
230230 ldu = max (1 , stride (Ũ, 2 ))
231231 ldv = max (1 , stride (Ṽ, 2 ))
232232
233- params = Ref {CUSOLVER .gesvdjInfo_t} (C_NULL )
234- CUSOLVER . cusolverDnCreateGesvdjInfo (params)
235- CUSOLVER . cusolverDnXgesvdjSetTolerance (params[], tol)
236- CUSOLVER . cusolverDnXgesvdjSetMaxSweeps (params[], max_sweeps)
237- dh = CUSOLVER . dense_handle ()
233+ params = Ref {cuSOLVER .gesvdjInfo_t} (C_NULL )
234+ cuSOLVER . cusolverDnCreateGesvdjInfo (params)
235+ cuSOLVER . cusolverDnXgesvdjSetTolerance (params[], tol)
236+ cuSOLVER . cusolverDnXgesvdjSetMaxSweeps (params[], max_sweeps)
237+ dh = cuSOLVER . dense_handle ()
238238
239239 function bufferSize ()
240240 out = Ref {Cint} (0 )
241- CUSOLVER .$ bname (
241+ cuSOLVER .$ bname (
242242 dh, jobz, econ, m, n, A, lda, S, Ũ, ldu, Ṽ, ldv,
243243 out, params[]
244244 )
245245 return out[] * sizeof ($ elty)
246246 end
247247
248- CUSOLVER . with_workspace (dh. workspace_gpu, bufferSize) do buffer
249- return CUSOLVER .$ fname (
248+ cuSOLVER . with_workspace (dh. workspace_gpu, bufferSize) do buffer
249+ return cuSOLVER .$ fname (
250250 dh, jobz, econ, m, n, A, lda, S, Ũ, ldu, Ṽ, ldv,
251251 buffer, sizeof (buffer) ÷ sizeof ($ elty), dh. info,
252252 params[]
253253 )
254254 end
255255
256256 info = @allowscalar dh. info[1 ]
257- CUSOLVER . chkargsok (BlasInt (info))
257+ cuSOLVER . chkargsok (BlasInt (info))
258258
259- CUSOLVER . cusolverDnDestroyGesvdjInfo (params[])
259+ cuSOLVER . cusolverDnDestroyGesvdjInfo (params[])
260260
261261 if jobz == ' V'
262262 adjoint! (Vᴴ, Ṽ)
@@ -292,25 +292,25 @@ function gesvdr!(
292292 lda = max (1 , stride (A, 2 ))
293293 ldu = max (1 , stride (Ũ, 2 ))
294294 ldv = max (1 , stride (Ṽ, 2 ))
295- params = CUSOLVER . CuSolverParameters ()
296- dh = CUSOLVER . dense_handle ()
295+ params = cuSOLVER . CuSolverParameters ()
296+ dh = cuSOLVER . dense_handle ()
297297
298298 function bufferSize ()
299299 out_cpu = Ref {Csize_t} (0 )
300300 out_gpu = Ref {Csize_t} (0 )
301- CUSOLVER . cusolverDnXgesvdr_bufferSize (
301+ cuSOLVER . cusolverDnXgesvdr_bufferSize (
302302 dh, params, jobu, jobv, m, n, k, p, niters,
303303 T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv,
304304 T, out_gpu, out_cpu
305305 )
306306
307307 return out_gpu[], out_cpu[]
308308 end
309- CUSOLVER . with_workspaces (
309+ cuSOLVER . with_workspaces (
310310 dh. workspace_gpu, dh. workspace_cpu,
311311 bufferSize ()...
312312 ) do buffer_gpu, buffer_cpu
313- return CUSOLVER . cusolverDnXgesvdr (
313+ return cuSOLVER . cusolverDnXgesvdr (
314314 dh, params, jobu, jobv, m, n, k, p, niters,
315315 T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv,
316316 T, buffer_gpu, sizeof (buffer_gpu),
@@ -320,7 +320,7 @@ function gesvdr!(
320320 end
321321
322322 flag = @allowscalar dh. info[1 ]
323- CUSOLVER . chklapackerror (BlasInt (flag))
323+ cuSOLVER . chklapackerror (BlasInt (flag))
324324 if Ũ != = U && length (U) > 0
325325 U .= view (Ũ, 1 : m, 1 : size (U, 2 ))
326326 end
@@ -361,8 +361,8 @@ for (celty, elty) in ((:ComplexF32, :Float32), (:ComplexF64, :Float64), (:Comple
361361 VL = similar (A, n, 0 )
362362 lda = max (1 , stride (A, 2 ))
363363 ldvl = max (1 , stride (VL, 2 ))
364- params = CUSOLVER . CuSolverParameters ()
365- dh = CUSOLVER . dense_handle ()
364+ params = cuSOLVER . CuSolverParameters ()
365+ dh = cuSOLVER . dense_handle ()
366366
367367 if $ elty <: Real
368368 D2 = reinterpret ($ elty, D)
@@ -377,22 +377,22 @@ for (celty, elty) in ((:ComplexF32, :Float32), (:ComplexF64, :Float64), (:Comple
377377 function bufferSize ()
378378 out_cpu = Ref {Csize_t} (0 )
379379 out_gpu = Ref {Csize_t} (0 )
380- CUSOLVER . cusolverDnXgeev_bufferSize (
380+ cuSOLVER . cusolverDnXgeev_bufferSize (
381381 dh, params, jobvl, jobvr, n, $ elty, A,
382382 lda, $ elty, D2, $ elty, VL, ldvl, $ elty, VR, ldvr,
383383 $ elty, out_gpu, out_cpu
384384 )
385385 return out_gpu[], out_cpu[]
386386 end
387387 CUDA. with_workspaces (dh. workspace_gpu, dh. workspace_cpu, bufferSize ()... ) do buffer_gpu, buffer_cpu
388- CUSOLVER . cusolverDnXgeev (
388+ cuSOLVER . cusolverDnXgeev (
389389 dh, params, jobvl, jobvr, n, $ elty, A, lda, $ elty,
390390 D2, $ elty, VL, ldvl, $ elty, VR, ldvr, $ elty, buffer_gpu,
391391 sizeof (buffer_gpu), buffer_cpu, sizeof (buffer_cpu), dh. info
392392 )
393393 end
394394 flag = @allowscalar dh. info[1 ]
395- CUSOLVER . chkargsok (BlasInt (flag))
395+ cuSOLVER . chkargsok (BlasInt (flag))
396396 if eltype (A) <: Real
397397 work = CuVector {$elty} (undef, n)
398398 DR = view (D2, 1 : n)
@@ -711,10 +711,10 @@ end
711711# end
712712
713713for (bname, fname, elty, relty) in (
714- (:(CUSOLVER . cusolverDnSsyevj_bufferSize), :(CUSOLVER . cusolverDnSsyevj), :Float32 , :Float32 ),
715- (:(CUSOLVER . cusolverDnDsyevj_bufferSize), :(CUSOLVER . cusolverDnDsyevj), :Float64 , :Float64 ),
716- (:(CUSOLVER . cusolverDnCheevj_bufferSize), :(CUSOLVER . cusolverDnCheevj), :ComplexF32 , :Float32 ),
717- (:(CUSOLVER . cusolverDnZheevj_bufferSize), :(CUSOLVER . cusolverDnZheevj), :ComplexF64 , :Float64 ),
714+ (:(cuSOLVER . cusolverDnSsyevj_bufferSize), :(cuSOLVER . cusolverDnSsyevj), :Float32 , :Float32 ),
715+ (:(cuSOLVER . cusolverDnDsyevj_bufferSize), :(cuSOLVER . cusolverDnDsyevj), :Float64 , :Float64 ),
716+ (:(cuSOLVER . cusolverDnCheevj_bufferSize), :(cuSOLVER . cusolverDnCheevj), :ComplexF32 , :Float32 ),
717+ (:(cuSOLVER . cusolverDnZheevj_bufferSize), :(cuSOLVER . cusolverDnZheevj), :ComplexF64 , :Float64 ),
718718 )
719719 @eval begin
720720 function heevj! (
@@ -728,18 +728,18 @@ for (bname, fname, elty, relty) in (
728728 chkuplo (uplo)
729729 n = checksquare (A)
730730 lda = max (1 , stride (A, 2 ))
731- dh = CUSOLVER . dense_handle ()
731+ dh = cuSOLVER . dense_handle ()
732732 length (W) == n || throw (DimensionMismatch (" size mismatch between A and W" ))
733733 if length (V) == 0
734734 jobz = ' N'
735735 else
736736 size (V) == (n, n) || throw (DimensionMismatch (" size mismatch between A and V" ))
737737 jobz = ' V'
738738 end
739- params = Ref {CUSOLVER .syevjInfo_t} (C_NULL )
740- CUSOLVER . cusolverDnCreateSyevjInfo (params)
741- CUSOLVER . cusolverDnXsyevjSetTolerance (params[], tol)
742- CUSOLVER . cusolverDnXsyevjSetMaxSweeps (params[], max_sweeps)
739+ params = Ref {cuSOLVER .syevjInfo_t} (C_NULL )
740+ cuSOLVER . cusolverDnCreateSyevjInfo (params)
741+ cuSOLVER . cusolverDnXsyevjSetTolerance (params[], tol)
742+ cuSOLVER . cusolverDnXsyevjSetMaxSweeps (params[], max_sweeps)
743743 function bufferSize ()
744744 out = Ref {Cint} (0 )
745745 $ bname (dh, jobz, uplo, n, A, lda, W, out, params[])
@@ -772,7 +772,7 @@ function heevd!(
772772 chkuplo (uplo)
773773 n = checksquare (A)
774774 lda = max (1 , stride (A, 2 ))
775- dh = CUSOLVER . dense_handle ()
775+ dh = cuSOLVER . dense_handle ()
776776 length (W) == n || throw (DimensionMismatch (" size mismatch between A and W" ))
777777 if length (V) == 0
778778 jobz = ' N'
@@ -781,19 +781,19 @@ function heevd!(
781781 jobz = ' V'
782782 end
783783
784- params = CUSOLVER . CuSolverParameters ()
784+ params = cuSOLVER . CuSolverParameters ()
785785 function bufferSize ()
786786 out_cpu = Ref {Csize_t} (0 )
787787 out_gpu = Ref {Csize_t} (0 )
788- CUSOLVER . cusolverDnXsyevd_bufferSize (dh, params, jobz, uplo, n, T, A, lda, Tr, W, T, out_gpu, out_cpu)
788+ cuSOLVER . cusolverDnXsyevd_bufferSize (dh, params, jobz, uplo, n, T, A, lda, Tr, W, T, out_gpu, out_cpu)
789789 return out_gpu[], out_cpu[]
790790 end
791791
792- CUSOLVER . with_workspaces (
792+ cuSOLVER . with_workspaces (
793793 dh. workspace_gpu, dh. workspace_cpu,
794794 bufferSize ()...
795795 ) do buffer_gpu, buffer_cpu
796- return CUSOLVER . cusolverDnXsyevd (
796+ return cuSOLVER . cusolverDnXsyevd (
797797 dh, params, jobz, uplo, n, T, A, lda, Tr, W,
798798 T, buffer_gpu, sizeof (buffer_gpu), buffer_cpu,
799799 sizeof (buffer_cpu), dh. info
0 commit comments