|
1 | 1 | module YACUSOLVER |
2 | 2 |
|
3 | 3 | using LinearAlgebra |
4 | | -using LinearAlgebra: BlasInt, BlasFloat, checksquare, chkstride1, require_one_based_indexing |
| 4 | +using LinearAlgebra: BlasInt, BlasFloat, BlasReal, checksquare, chkstride1, require_one_based_indexing |
5 | 5 | using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag, chkuplo |
6 | 6 |
|
7 | 7 | using CUDA |
@@ -612,42 +612,92 @@ end |
612 | 612 | # return X, info |
613 | 613 | # end |
614 | 614 |
|
615 | | -# for (jname, bname, fname, elty, relty) in |
616 | | -# ((:syevd!, :cusolverDnSsyevd_bufferSize, :cusolverDnSsyevd, :Float32, :Float32), |
617 | | -# (:syevd!, :cusolverDnDsyevd_bufferSize, :cusolverDnDsyevd, :Float64, :Float64), |
618 | | -# (:heevd!, :cusolverDnCheevd_bufferSize, :cusolverDnCheevd, :ComplexF32, :Float32), |
619 | | -# (:heevd!, :cusolverDnZheevd_bufferSize, :cusolverDnZheevd, :ComplexF64, :Float64)) |
620 | | -# @eval begin |
621 | | -# function $jname(jobz::Char, |
622 | | -# uplo::Char, |
623 | | -# A::StridedCuMatrix{$elty}) |
624 | | -# chkuplo(uplo) |
625 | | -# n = checksquare(A) |
626 | | -# lda = max(1, stride(A, 2)) |
627 | | -# W = CuArray{$relty}(undef, n) |
628 | | -# dh = dense_handle() |
| 615 | +for (bname, fname, elty, relty) in ((:(CUSOLVER.cusolverDnSsyevj_bufferSize), :(CUSOLVER.cusolverDnSsyevj), :Float32, :Float32), |
| 616 | + (:(CUSOLVER.cusolverDnDsyevj_bufferSize), :(CUSOLVER.cusolverDnDsyevj), :Float64, :Float64), |
| 617 | + (:(CUSOLVER.cusolverDnCheevj_bufferSize), :(CUSOLVER.cusolverDnCheevj), :ComplexF32, :Float32), |
| 618 | + (:(CUSOLVER.cusolverDnZheevj_bufferSize), :(CUSOLVER.cusolverDnZheevj), :ComplexF64, :Float64)) |
| 619 | + @eval begin |
| 620 | + function heevj!(A::StridedCuMatrix{$elty}, |
| 621 | + W::StridedCuVector{$relty}, |
| 622 | + V::StridedCuMatrix{$elty}; |
| 623 | + uplo::Char='U', |
| 624 | + tol::$relty=eps($relty), |
| 625 | + max_sweeps::Int=100 |
| 626 | + ) |
| 627 | + chkuplo(uplo) |
| 628 | + n = checksquare(A) |
| 629 | + lda = max(1, stride(A, 2)) |
| 630 | + dh = CUSOLVER.dense_handle() |
| 631 | + length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) |
| 632 | + if length(V) == 0 |
| 633 | + jobz = 'N' |
| 634 | + else |
| 635 | + size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) |
| 636 | + jobz = 'V' |
| 637 | + end |
| 638 | + params = Ref{CUSOLVER.syevjInfo_t}(C_NULL) |
| 639 | + CUSOLVER.cusolverDnCreateSyevjInfo(params) |
| 640 | + CUSOLVER.cusolverDnXsyevjSetTolerance(params[], tol) |
| 641 | + CUSOLVER.cusolverDnXsyevjSetMaxSweeps(params[], max_sweeps) |
| 642 | + function bufferSize() |
| 643 | + out = Ref{Cint}(0) |
| 644 | + $bname(dh, jobz, uplo, n, A, lda, W, out, params[]) |
| 645 | + return out[] * sizeof($elty) |
| 646 | + end |
| 647 | + CUDA.with_workspace(dh.workspace_gpu, bufferSize) do buffer |
| 648 | + $fname(dh, jobz, uplo, n, A, lda, W, buffer, |
| 649 | + sizeof(buffer) ÷ sizeof($elty), dh.info, params[]) |
| 650 | + end |
629 | 651 |
|
630 | | -# function bufferSize() |
631 | | -# out = Ref{Cint}(0) |
632 | | -# $bname(dh, jobz, uplo, n, A, lda, W, out) |
633 | | -# return out[] * sizeof($elty) |
634 | | -# end |
| 652 | + info = @allowscalar dh.info[1] |
| 653 | + chkargsok(BlasInt(info)) |
635 | 654 |
|
636 | | -# with_workspace(dh.workspace_gpu, bufferSize) do buffer |
637 | | -# return $fname(dh, jobz, uplo, n, A, lda, W, |
638 | | -# buffer, sizeof(buffer) ÷ sizeof($elty), dh.info) |
639 | | -# end |
| 655 | + if jobz == 'V' && V !== A |
| 656 | + copy!(V, A) |
| 657 | + end |
| 658 | + return W, V |
| 659 | + end |
| 660 | + end |
| 661 | +end |
640 | 662 |
|
641 | | -# info = @allowscalar dh.info[1] |
642 | | -# chkargsok(BlasInt(info)) |
| 663 | +function heevd!(A::StridedCuMatrix{T}, |
| 664 | + W::StridedCuVector{Tr}, |
| 665 | + V::StridedCuMatrix{T}; |
| 666 | + uplo::Char='U') where {T<:BlasFloat, Tr<:BlasReal} |
| 667 | + chkuplo(uplo) |
| 668 | + n = checksquare(A) |
| 669 | + lda = max(1, stride(A, 2)) |
| 670 | + dh = CUSOLVER.dense_handle() |
| 671 | + length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) |
| 672 | + if length(V) == 0 |
| 673 | + jobz = 'N' |
| 674 | + else |
| 675 | + size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) |
| 676 | + jobz = 'V' |
| 677 | + end |
643 | 678 |
|
644 | | -# if jobz == 'N' |
645 | | -# return W |
646 | | -# elseif jobz == 'V' |
647 | | -# return W, A |
648 | | -# end |
649 | | -# end |
650 | | -# end |
651 | | -# end |
| 679 | + params = CUSOLVER.CuSolverParameters() |
| 680 | + function bufferSize() |
| 681 | + out_cpu = Ref{Csize_t}(0) |
| 682 | + out_gpu = Ref{Csize_t}(0) |
| 683 | + CUSOLVER.cusolverDnXsyevd_bufferSize(dh, params, jobz, uplo, n, T, A, lda, Tr, W, T, out_gpu, out_cpu) |
| 684 | + return out_gpu[], out_cpu[] |
| 685 | + end |
| 686 | + |
| 687 | + CUSOLVER.with_workspaces(dh.workspace_gpu, dh.workspace_cpu, |
| 688 | + bufferSize()...) do buffer_gpu, buffer_cpu |
| 689 | + return CUSOLVER.cusolverDnXsyevd(dh, params, jobz, uplo, n, T, A, lda, Tr, W, |
| 690 | + T, buffer_gpu, sizeof(buffer_gpu), buffer_cpu, |
| 691 | + sizeof(buffer_cpu), dh.info) |
| 692 | + end |
| 693 | + |
| 694 | + info = @allowscalar dh.info[1] |
| 695 | + chkargsok(BlasInt(info)) |
| 696 | + |
| 697 | + if jobz == 'V' && V !== A |
| 698 | + copy!(V, A) |
| 699 | + end |
| 700 | + return W, V |
| 701 | +end |
652 | 702 |
|
653 | 703 | end |
0 commit comments