@@ -16,17 +16,19 @@ const ungqr! = orgqr!
1616
1717# Wrapper for SVD via QR Iteration
1818for (fname, elty, relty) in
19- ((:rocsolver_sgesvd , :Float32 , :Float32 ),
20- (:rocsolver_dgesvd , :Float64 , :Float64 ),
21- (:rocsolver_cgesvd , :ComplexF32 , :Float32 ),
22- (:rocsolver_zgesvd , :ComplexF64 , :Float64 ))
19+ (
20+ (:rocsolver_sgesvd , :Float32 , :Float32 ),
21+ (:rocsolver_dgesvd , :Float64 , :Float64 ),
22+ (:rocsolver_cgesvd , :ComplexF32 , :Float32 ),
23+ (:rocsolver_zgesvd , :ComplexF64 , :Float64 ),
24+ )
2325 @eval begin
24- # ! format: off
25- function gesvd! ( A:: StridedROCMatrix{$elty} ,
26- S:: StridedROCVector{$relty} = similar (A, $ relty, min (size (A)... )),
27- U:: StridedROCMatrix{$elty} = similar (A, $ elty, size (A, 1 ), min (size (A)... )),
28- Vᴴ:: StridedROCMatrix{$elty} = similar (A, $ elty, min (size (A)... ), size (A, 2 ) ))
29- # ! format: on
26+ function gesvd! (
27+ A:: StridedROCMatrix{$elty} ,
28+ S:: StridedROCVector{$relty} = similar (A, $ relty, min (size (A)... )),
29+ U:: StridedROCMatrix{$elty} = similar (A, $ elty, size (A, 1 ), min (size (A)... )),
30+ Vᴴ:: StridedROCMatrix{$elty} = similar (A, $ elty, min (size (A)... ), size (A, 2 ))
31+ )
3032 chkstride1 (A, U, Vᴴ, S)
3133 m, n = size (A)
3234 (m < n) && throw (ArgumentError (" rocSOLVER's gesvd requires m ≥ n" ))
@@ -72,13 +74,15 @@ for (fname, elty, relty) in
7274 ldu = max (1 , stride (U, 2 ))
7375 ldv = max (1 , stride (Vᴴ, 2 ))
7476
75- rwork = ROCArray {$relty} (undef, minmn - 1 )
76- dh = rocBLAS. handle ()
77+ rwork = ROCArray {$relty} (undef, minmn - 1 )
78+ dh = rocBLAS. handle ()
7779 dev_info = ROCVector {Cint} (undef, 1 )
78- rocSOLVER.$ fname (dh, jobu, jobvt, m, n,
79- A, lda, S, U, ldu, Vᴴ, ldv,
80- rwork, convert (rocSOLVER. rocblas_workmode, ' I' ),
81- dev_info)
80+ rocSOLVER.$ fname (
81+ dh, jobu, jobvt, m, n,
82+ A, lda, S, U, ldu, Vᴴ, ldv,
83+ rwork, convert (rocSOLVER. rocblas_workmode, ' I' ),
84+ dev_info
85+ )
8286 AMDGPU. unsafe_free! (rwork)
8387
8488 info = @allowscalar dev_info[1 ]
9195
9296# Wrapper for SVD via Jacobi
9397for (fname, elty, relty) in
94- ((:rocsolver_sgesvdj , :Float32 , :Float32 ),
95- (:rocsolver_dgesvdj , :Float64 , :Float64 ),
96- (:rocsolver_cgesvdj , :ComplexF32 , :Float32 ),
97- (:rocsolver_zgesvdj , :ComplexF64 , :Float64 ))
98+ (
99+ (:rocsolver_sgesvdj , :Float32 , :Float32 ),
100+ (:rocsolver_dgesvdj , :Float64 , :Float64 ),
101+ (:rocsolver_cgesvdj , :ComplexF32 , :Float32 ),
102+ (:rocsolver_zgesvdj , :ComplexF64 , :Float64 ),
103+ )
98104 @eval begin
99- # ! format: off
100- function gesvdj! (A:: StridedROCMatrix{$elty} ,
101- S:: StridedROCVector{$relty} = similar (A, $ relty, min (size (A)... )),
102- U:: StridedROCMatrix{$elty} = similar (A, $ elty, size (A, 1 ), min (size (A)... )),
103- Vᴴ:: StridedROCMatrix{$elty} = similar (A, $ elty, min (size (A)... ), size (A, 2 ));
104- tol:: $relty = eps ($ relty),
105- max_sweeps:: Int = 100 ,
106- )
107- # ! format: on
105+ function gesvdj! (
106+ A:: StridedROCMatrix{$elty} ,
107+ S:: StridedROCVector{$relty} = similar (A, $ relty, min (size (A)... )),
108+ U:: StridedROCMatrix{$elty} = similar (A, $ elty, size (A, 1 ), min (size (A)... )),
109+ Vᴴ:: StridedROCMatrix{$elty} = similar (A, $ elty, min (size (A)... ), size (A, 2 ));
110+ tol:: $relty = eps ($ relty),
111+ max_sweeps:: Int = 100 ,
112+ )
108113 chkstride1 (A, U, Vᴴ, S)
109114 m, n = size (A)
110115 minmn = min (m, n)
@@ -149,21 +154,22 @@ for (fname, elty, relty) in
149154 lda = max (1 , stride (A, 2 ))
150155 ldu = max (1 , stride (U, 2 ))
151156 ldv = max (1 , stride (Vᴴ, 2 ))
152- dev_info = ROCVector {Cint} (undef, 1 )
157+ dev_info = ROCVector {Cint} (undef, 1 )
153158 dev_residual = ROCVector {$relty} (undef, 1 )
154159 dev_n_sweeps = ROCVector {Cint} (undef, 1 )
155160
156161 dh = rocBLAS. handle ()
157- rocSOLVER.$ fname (dh, jobu, jobvt, m, n, A, lda, tol,
158- dev_residual, max_sweeps, dev_n_sweeps,
159- S, U, ldu, Vᴴ, ldv, dev_info,
160- )
162+ rocSOLVER.$ fname (
163+ dh, jobu, jobvt, m, n, A, lda, tol,
164+ dev_residual, max_sweeps, dev_n_sweeps,
165+ S, U, ldu, Vᴴ, ldv, dev_info,
166+ )
161167
162168 info = @allowscalar dev_info[1 ]
163169 rocSOLVER. chkargsok (BlasInt (info))
164170
165- AMDGPU. unsafe_free! (dev_residual)
166- AMDGPU. unsafe_free! (dev_n_sweeps)
171+ AMDGPU. unsafe_free! (dev_residual)
172+ AMDGPU. unsafe_free! (dev_n_sweeps)
167173 return (S, U, Vᴴ)
168174 end
169175 end
@@ -476,15 +482,19 @@ end
476482# end
477483
478484for (heevd, heev, heevx, heevj, elty, relty) in
479- ((:(rocSOLVER. rocsolver_ssyevd), :(rocSOLVER. rocsolver_ssyev), :(rocSOLVER. rocsolver_ssyevx), :(rocSOLVER. rocsolver_ssyevj), :Float32 , :Float32 ),
480- (:(rocSOLVER. rocsolver_dsyevd), :(rocSOLVER. rocsolver_dsyev), :(rocSOLVER. rocsolver_dsyevx), :(rocSOLVER. rocsolver_dsyevj), :Float64 , :Float64 ),
481- (:(rocSOLVER. rocsolver_cheevd), :(rocSOLVER. rocsolver_cheev), :(rocSOLVER. rocsolver_cheevx), :(rocSOLVER. rocsolver_cheevj), :ComplexF32 , :Float32 ),
482- (:(rocSOLVER. rocsolver_zheevd), :(rocSOLVER. rocsolver_zheev), :(rocSOLVER. rocsolver_zheevx), :(rocSOLVER. rocsolver_zheevj), :ComplexF64 , :Float64 ))
485+ (
486+ (:(rocSOLVER. rocsolver_ssyevd), :(rocSOLVER. rocsolver_ssyev), :(rocSOLVER. rocsolver_ssyevx), :(rocSOLVER. rocsolver_ssyevj), :Float32 , :Float32 ),
487+ (:(rocSOLVER. rocsolver_dsyevd), :(rocSOLVER. rocsolver_dsyev), :(rocSOLVER. rocsolver_dsyevx), :(rocSOLVER. rocsolver_dsyevj), :Float64 , :Float64 ),
488+ (:(rocSOLVER. rocsolver_cheevd), :(rocSOLVER. rocsolver_cheev), :(rocSOLVER. rocsolver_cheevx), :(rocSOLVER. rocsolver_cheevj), :ComplexF32 , :Float32 ),
489+ (:(rocSOLVER. rocsolver_zheevd), :(rocSOLVER. rocsolver_zheev), :(rocSOLVER. rocsolver_zheevx), :(rocSOLVER. rocsolver_zheevj), :ComplexF64 , :Float64 ),
490+ )
483491 @eval begin
484- function heevd! (A:: StridedROCMatrix{$elty} ,
485- W:: StridedROCVector{$relty} ,
486- V:: StridedROCMatrix{$elty} ;
487- uplo:: Char = ' U' )
492+ function heevd! (
493+ A:: StridedROCMatrix{$elty} ,
494+ W:: StridedROCVector{$relty} ,
495+ V:: StridedROCMatrix{$elty} ;
496+ uplo:: Char = ' U'
497+ )
488498 chkuplo (uplo)
489499 n = checksquare (A)
490500 lda = max (1 , stride (A, 2 ))
@@ -509,10 +519,12 @@ for (heevd, heev, heevx, heevj, elty, relty) in
509519 end
510520 return W, V
511521 end
512- function heev! (A:: StridedROCMatrix{$elty} ,
513- W:: StridedROCVector{$relty} ,
514- V:: StridedROCMatrix{$elty} ;
515- uplo:: Char = ' U' )
522+ function heev! (
523+ A:: StridedROCMatrix{$elty} ,
524+ W:: StridedROCVector{$relty} ,
525+ V:: StridedROCMatrix{$elty} ;
526+ uplo:: Char = ' U'
527+ )
516528 chkuplo (uplo)
517529 n = checksquare (A)
518530 lda = max (1 , stride (A, 2 ))
@@ -537,11 +549,13 @@ for (heevd, heev, heevx, heevj, elty, relty) in
537549 end
538550 return W, V
539551 end
540- function heevx! (A:: StridedROCMatrix{$elty} ,
541- W:: StridedROCVector{$relty} ,
542- V:: StridedROCMatrix{$elty} ;
543- uplo:: Char = ' U' ,
544- kwargs... )
552+ function heevx! (
553+ A:: StridedROCMatrix{$elty} ,
554+ W:: StridedROCVector{$relty} ,
555+ V:: StridedROCMatrix{$elty} ;
556+ uplo:: Char = ' U' ,
557+ kwargs...
558+ )
545559 chkuplo (uplo)
546560 n = checksquare (A)
547561 lda = max (1 , stride (A, 2 ))
@@ -567,27 +581,29 @@ for (heevd, heev, heevx, heevj, elty, relty) in
567581 size (V) == (n, n) || throw (DimensionMismatch (" size mismatch between A and V" ))
568582 jobz = rocSOLVER. rocblas_evect_original
569583 end
570- dh = rocBLAS. handle ()
571- abstol = - one ($ relty)
572- nev = ROCVector {Cint} (undef, 1 )
573- ldv = max (1 , stride (V, 2 ))
574- ifail = ROCVector {Cint} (undef, n)
584+ dh = rocBLAS. handle ()
585+ abstol = - one ($ relty)
586+ nev = ROCVector {Cint} (undef, 1 )
587+ ldv = max (1 , stride (V, 2 ))
588+ ifail = ROCVector {Cint} (undef, n)
575589 dev_info = ROCVector {Cint} (undef, 1 )
576590 roc_uplo = convert (rocSOLVER. rocblas_fill, uplo)
577591 $ heevx (dh, jobz, range, roc_uplo, n, A, lda, vl, vu, il, iu, abstol, nev, W, V, ldv, ifail, dev_info)
578592
579593 info = @allowscalar dev_info[1 ]
580594 chkargsok (BlasInt (info))
581- m = @allowscalar nev[1 ]
595+ m = @allowscalar nev[1 ]
582596 return W, V, m
583597 end
584- function heevj! (A:: StridedROCMatrix{$elty} ,
585- W:: StridedROCVector{$relty} ,
586- V:: StridedROCMatrix{$elty} ;
587- uplo:: Char = ' U' ,
588- tol:: $relty = eps ($ relty),
589- max_sweeps:: Int = 100 ,
590- sort:: Char = ' N' )
598+ function heevj! (
599+ A:: StridedROCMatrix{$elty} ,
600+ W:: StridedROCVector{$relty} ,
601+ V:: StridedROCMatrix{$elty} ;
602+ uplo:: Char = ' U' ,
603+ tol:: $relty = eps ($ relty),
604+ max_sweeps:: Int = 100 ,
605+ sort:: Char = ' N'
606+ )
591607 chkuplo (uplo)
592608 n = checksquare (A)
593609 lda = max (1 , stride (A, 2 ))
0 commit comments