|
1 | 1 | module YACUSOLVER |
2 | 2 |
|
3 | 3 | using LinearAlgebra |
4 | | -using LinearAlgebra: BlasInt, BlasFloat, checksquare, chkstride1, require_one_based_indexing |
| 4 | +using LinearAlgebra: BlasInt, BlasFloat, BlasReal, checksquare, chkstride1, require_one_based_indexing |
5 | 5 | using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag, chkuplo |
6 | 6 |
|
7 | 7 | using CUDA |
@@ -679,43 +679,93 @@ end |
679 | 679 | # return X, info |
680 | 680 | # end |
681 | 681 |
|
682 | | -# for (jname, bname, fname, elty, relty) in |
683 | | -# ((:syevd!, :cusolverDnSsyevd_bufferSize, :cusolverDnSsyevd, :Float32, :Float32), |
684 | | -# (:syevd!, :cusolverDnDsyevd_bufferSize, :cusolverDnDsyevd, :Float64, :Float64), |
685 | | -# (:heevd!, :cusolverDnCheevd_bufferSize, :cusolverDnCheevd, :ComplexF32, :Float32), |
686 | | -# (:heevd!, :cusolverDnZheevd_bufferSize, :cusolverDnZheevd, :ComplexF64, :Float64)) |
687 | | -# @eval begin |
688 | | -# function $jname(jobz::Char, |
689 | | -# uplo::Char, |
690 | | -# A::StridedCuMatrix{$elty}) |
691 | | -# chkuplo(uplo) |
692 | | -# n = checksquare(A) |
693 | | -# lda = max(1, stride(A, 2)) |
694 | | -# W = CuArray{$relty}(undef, n) |
695 | | -# dh = dense_handle() |
| 682 | +for (bname, fname, elty, relty) in ((:(CUSOLVER.cusolverDnSsyevj_bufferSize), :(CUSOLVER.cusolverDnSsyevj), :Float32, :Float32), |
| 683 | + (:(CUSOLVER.cusolverDnDsyevj_bufferSize), :(CUSOLVER.cusolverDnDsyevj), :Float64, :Float64), |
| 684 | + (:(CUSOLVER.cusolverDnCheevj_bufferSize), :(CUSOLVER.cusolverDnCheevj), :ComplexF32, :Float32), |
| 685 | + (:(CUSOLVER.cusolverDnZheevj_bufferSize), :(CUSOLVER.cusolverDnZheevj), :ComplexF64, :Float64)) |
| 686 | + @eval begin |
| 687 | + function heevj!(A::StridedCuMatrix{$elty}, |
| 688 | + W::StridedCuVector{$relty}, |
| 689 | + V::StridedCuMatrix{$elty}; |
| 690 | + uplo::Char='U', |
| 691 | + tol::$relty=eps($relty), |
| 692 | + max_sweeps::Int=100 |
| 693 | + ) |
| 694 | + chkuplo(uplo) |
| 695 | + n = checksquare(A) |
| 696 | + lda = max(1, stride(A, 2)) |
| 697 | + dh = CUSOLVER.dense_handle() |
| 698 | + length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) |
| 699 | + if length(V) == 0 |
| 700 | + jobz = 'N' |
| 701 | + else |
| 702 | + size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) |
| 703 | + jobz = 'V' |
| 704 | + end |
| 705 | + params = Ref{CUSOLVER.syevjInfo_t}(C_NULL) |
| 706 | + CUSOLVER.cusolverDnCreateSyevjInfo(params) |
| 707 | + CUSOLVER.cusolverDnXsyevjSetTolerance(params[], tol) |
| 708 | + CUSOLVER.cusolverDnXsyevjSetMaxSweeps(params[], max_sweeps) |
| 709 | + function bufferSize() |
| 710 | + out = Ref{Cint}(0) |
| 711 | + $bname(dh, jobz, uplo, n, A, lda, W, out, params[]) |
| 712 | + return out[] * sizeof($elty) |
| 713 | + end |
| 714 | + CUDA.with_workspace(dh.workspace_gpu, bufferSize) do buffer |
| 715 | + $fname(dh, jobz, uplo, n, A, lda, W, buffer, |
| 716 | + sizeof(buffer) ÷ sizeof($elty), dh.info, params[]) |
| 717 | + end |
696 | 718 |
|
697 | | -# function bufferSize() |
698 | | -# out = Ref{Cint}(0) |
699 | | -# $bname(dh, jobz, uplo, n, A, lda, W, out) |
700 | | -# return out[] * sizeof($elty) |
701 | | -# end |
| 719 | + info = @allowscalar dh.info[1] |
| 720 | + chkargsok(BlasInt(info)) |
702 | 721 |
|
703 | | -# with_workspace(dh.workspace_gpu, bufferSize) do buffer |
704 | | -# return $fname(dh, jobz, uplo, n, A, lda, W, |
705 | | -# buffer, sizeof(buffer) ÷ sizeof($elty), dh.info) |
706 | | -# end |
| 722 | + if jobz == 'V' && V !== A |
| 723 | + copy!(V, A) |
| 724 | + end |
| 725 | + return W, V |
| 726 | + end |
| 727 | + end |
| 728 | +end |
707 | 729 |
|
708 | | -# info = @allowscalar dh.info[1] |
709 | | -# chkargsok(BlasInt(info)) |
| 730 | +function heevd!(A::StridedCuMatrix{T}, |
| 731 | + W::StridedCuVector{Tr}, |
| 732 | + V::StridedCuMatrix{T}; |
| 733 | + uplo::Char='U') where {T<:BlasFloat, Tr<:BlasReal} |
| 734 | + chkuplo(uplo) |
| 735 | + n = checksquare(A) |
| 736 | + lda = max(1, stride(A, 2)) |
| 737 | + dh = CUSOLVER.dense_handle() |
| 738 | + length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) |
| 739 | + if length(V) == 0 |
| 740 | + jobz = 'N' |
| 741 | + else |
| 742 | + size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) |
| 743 | + jobz = 'V' |
| 744 | + end |
710 | 745 |
|
711 | | -# if jobz == 'N' |
712 | | -# return W |
713 | | -# elseif jobz == 'V' |
714 | | -# return W, A |
715 | | -# end |
716 | | -# end |
717 | | -# end |
718 | | -# end |
| 746 | + params = CUSOLVER.CuSolverParameters() |
| 747 | + function bufferSize() |
| 748 | + out_cpu = Ref{Csize_t}(0) |
| 749 | + out_gpu = Ref{Csize_t}(0) |
| 750 | + CUSOLVER.cusolverDnXsyevd_bufferSize(dh, params, jobz, uplo, n, T, A, lda, Tr, W, T, out_gpu, out_cpu) |
| 751 | + return out_gpu[], out_cpu[] |
| 752 | + end |
| 753 | + |
| 754 | + CUSOLVER.with_workspaces(dh.workspace_gpu, dh.workspace_cpu, |
| 755 | + bufferSize()...) do buffer_gpu, buffer_cpu |
| 756 | + return CUSOLVER.cusolverDnXsyevd(dh, params, jobz, uplo, n, T, A, lda, Tr, W, |
| 757 | + T, buffer_gpu, sizeof(buffer_gpu), buffer_cpu, |
| 758 | + sizeof(buffer_cpu), dh.info) |
| 759 | + end |
| 760 | + |
| 761 | + info = @allowscalar dh.info[1] |
| 762 | + chkargsok(BlasInt(info)) |
| 763 | + |
| 764 | + if jobz == 'V' && V !== A |
| 765 | + copy!(V, A) |
| 766 | + end |
| 767 | + return W, V |
| 768 | +end |
719 | 769 |
|
720 | 770 | # device code is unreachable by coverage right now |
721 | 771 | # COV_EXCL_START |
|
0 commit comments