|
1 | 1 | module YArocSOLVER |
2 | 2 |
|
3 | 3 | using LinearAlgebra |
4 | | -using LinearAlgebra: BlasInt, BlasFloat, checksquare, chkstride1, require_one_based_indexing |
| 4 | +using LinearAlgebra: BlasInt, BlasReal, BlasFloat, checksquare, chkstride1, require_one_based_indexing |
5 | 5 | using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag, chkuplo |
6 | 6 |
|
7 | 7 | using AMDGPU |
@@ -475,42 +475,146 @@ end |
475 | 475 | # return X, info |
476 | 476 | # end |
477 | 477 |
|
478 | | -# for (jname, bname, fname, elty, relty) in |
479 | | -# ((:syevd!, :rocsolverDnSsyevd_bufferSize, :rocsolverDnSsyevd, :Float32, :Float32), |
480 | | -# (:syevd!, :rocsolverDnDsyevd_bufferSize, :rocsolverDnDsyevd, :Float64, :Float64), |
481 | | -# (:heevd!, :rocsolverDnCheevd_bufferSize, :rocsolverDnCheevd, :ComplexF32, :Float32), |
482 | | -# (:heevd!, :rocsolverDnZheevd_bufferSize, :rocsolverDnZheevd, :ComplexF64, :Float64)) |
483 | | -# @eval begin |
484 | | -# function $jname(jobz::Char, |
485 | | -# uplo::Char, |
486 | | -# A::StridedROCMatrix{$elty}) |
487 | | -# chkuplo(uplo) |
488 | | -# n = checksquare(A) |
489 | | -# lda = max(1, stride(A, 2)) |
490 | | -# W = CuArray{$relty}(undef, n) |
491 | | -# dh = rocBLAS.handle() |
| 478 | +for (heevd, heev, heevx, heevj, elty, relty) in |
| 479 | + ((:(rocSOLVER.rocsolver_ssyevd), :(rocSOLVER.rocsolver_ssyev), :(rocSOLVER.rocsolver_ssyevx), :(rocSOLVER.rocsolver_ssyevj), :Float32, :Float32), |
| 480 | + (:(rocSOLVER.rocsolver_dsyevd), :(rocSOLVER.rocsolver_dsyev), :(rocSOLVER.rocsolver_dsyevx), :(rocSOLVER.rocsolver_dsyevj), :Float64, :Float64), |
| 481 | + (:(rocSOLVER.rocsolver_cheevd), :(rocSOLVER.rocsolver_cheev), :(rocSOLVER.rocsolver_cheevx), :(rocSOLVER.rocsolver_cheevj), :ComplexF32, :Float32), |
| 482 | + (:(rocSOLVER.rocsolver_zheevd), :(rocSOLVER.rocsolver_zheev), :(rocSOLVER.rocsolver_zheevx), :(rocSOLVER.rocsolver_zheevj), :ComplexF64, :Float64)) |
| 483 | + @eval begin |
| 484 | + function heevd!(A::StridedROCMatrix{$elty}, |
| 485 | + W::StridedROCVector{$relty}, |
| 486 | + V::StridedROCMatrix{$elty}; |
| 487 | + uplo::Char='U') |
| 488 | + chkuplo(uplo) |
| 489 | + n = checksquare(A) |
| 490 | + lda = max(1, stride(A, 2)) |
| 491 | + length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) |
| 492 | + if length(V) == 0 |
| 493 | + jobz = rocSOLVER.rocblas_evect_none |
| 494 | + else |
| 495 | + size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) |
| 496 | + jobz = rocSOLVER.rocblas_evect_original |
| 497 | + end |
| 498 | + dh = rocBLAS.handle() |
| 499 | + work = ROCVector{$relty}(undef, n) |
| 500 | + dev_info = ROCVector{Cint}(undef, 1) |
| 501 | + roc_uplo = convert(rocSOLVER.rocblas_fill, uplo) |
| 502 | + $heevd(dh, jobz, roc_uplo, n, A, lda, W, work, dev_info) |
492 | 503 |
|
493 | | -# function bufferSize() |
494 | | -# out = Ref{Cint}(0) |
495 | | -# $bname(dh, jobz, uplo, n, A, lda, W, out) |
496 | | -# return out[] * sizeof($elty) |
497 | | -# end |
| 504 | + info = @allowscalar dev_info[1] |
| 505 | + chkargsok(BlasInt(info)) |
498 | 506 |
|
499 | | -# with_workspace(dh.workspace_gpu, bufferSize) do buffer |
500 | | -# return $fname(dh, jobz, uplo, n, A, lda, W, |
501 | | -# buffer, sizeof(buffer) ÷ sizeof($elty), dh.info) |
502 | | -# end |
| 507 | + if jobz == rocSOLVER.rocblas_evect_original && V !== A |
| 508 | + copy!(V, A) |
| 509 | + end |
| 510 | + return W, V |
| 511 | + end |
| 512 | + function heev!(A::StridedROCMatrix{$elty}, |
| 513 | + W::StridedROCVector{$relty}, |
| 514 | + V::StridedROCMatrix{$elty}; |
| 515 | + uplo::Char='U') |
| 516 | + chkuplo(uplo) |
| 517 | + n = checksquare(A) |
| 518 | + lda = max(1, stride(A, 2)) |
| 519 | + length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) |
| 520 | + if length(V) == 0 |
| 521 | + jobz = rocSOLVER.rocblas_evect_none |
| 522 | + else |
| 523 | + size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) |
| 524 | + jobz = rocSOLVER.rocblas_evect_original |
| 525 | + end |
| 526 | + dh = rocBLAS.handle() |
| 527 | + work = ROCVector{$relty}(undef, n) |
| 528 | + dev_info = ROCVector{Cint}(undef, 1) |
| 529 | + roc_uplo = convert(rocSOLVER.rocblas_fill, uplo) |
| 530 | + $heev(dh, jobz, roc_uplo, n, A, lda, W, work, dev_info) |
503 | 531 |
|
504 | | -# info = @allowscalar dh.info[1] |
505 | | -# chkargsok(BlasInt(info)) |
| 532 | + info = @allowscalar dev_info[1] |
| 533 | + chkargsok(BlasInt(info)) |
506 | 534 |
|
507 | | -# if jobz == 'N' |
508 | | -# return W |
509 | | -# elseif jobz == 'V' |
510 | | -# return W, A |
511 | | -# end |
512 | | -# end |
513 | | -# end |
514 | | -# end |
| 535 | + if jobz == rocSOLVER.rocblas_evect_original && V !== A |
| 536 | + copy!(V, A) |
| 537 | + end |
| 538 | + return W, V |
| 539 | + end |
| 540 | + function heevx!(A::StridedROCMatrix{$elty}, |
| 541 | + W::StridedROCVector{$relty}, |
| 542 | + V::StridedROCMatrix{$elty}; |
| 543 | + uplo::Char='U', |
| 544 | + kwargs...) |
| 545 | + chkuplo(uplo) |
| 546 | + n = checksquare(A) |
| 547 | + lda = max(1, stride(A, 2)) |
| 548 | + length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) |
| 549 | + if haskey(kwargs, :irange) |
| 550 | + il = first(kwargs[:irange]) |
| 551 | + iu = last(kwargs[:irange]) |
| 552 | + vl = vu = zero($relty) |
| 553 | + range = rocSOLVER.rocblas_erange_index |
| 554 | + elseif haskey(kwargs, :vl) || haskey(kwargs, :vu) |
| 555 | + vl = convert($relty, get(kwargs, :vl, -Inf)) |
| 556 | + vu = convert($relty, get(kwargs, :vu, +Inf)) |
| 557 | + il = iu = 0 |
| 558 | + range = rocSOLVER.rocblas_erange_value |
| 559 | + else |
| 560 | + il = iu = 0 |
| 561 | + vl = vu = zero($relty) |
| 562 | + range = rocSOLVER.rocblas_erange_all |
| 563 | + end |
| 564 | + if length(V) == 0 |
| 565 | + jobz = rocSOLVER.rocblas_evect_none |
| 566 | + else |
| 567 | + size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) |
| 568 | + jobz = rocSOLVER.rocblas_evect_original |
| 569 | + end |
| 570 | + dh = rocBLAS.handle() |
| 571 | + abstol = -one($relty) |
| 572 | + nev = ROCVector{Cint}(undef, 1) |
| 573 | + ldv = max(1, stride(V, 2)) |
| 574 | + ifail = ROCVector{Cint}(undef, n) |
| 575 | + dev_info = ROCVector{Cint}(undef, 1) |
| 576 | + roc_uplo = convert(rocSOLVER.rocblas_fill, uplo) |
| 577 | + $heevx(dh, jobz, range, roc_uplo, n, A, lda, vl, vu, il, iu, abstol, nev, W, V, ldv, ifail, dev_info) |
| 578 | + |
| 579 | + info = @allowscalar dev_info[1] |
| 580 | + chkargsok(BlasInt(info)) |
| 581 | + m = @allowscalar nev[1] |
| 582 | + return W, V, m |
| 583 | + end |
| 584 | + function heevj!(A::StridedROCMatrix{$elty}, |
| 585 | + W::StridedROCVector{$relty}, |
| 586 | + V::StridedROCMatrix{$elty}; |
| 587 | + uplo::Char='U', |
| 588 | + tol::$relty=eps($relty), |
| 589 | + max_sweeps::Int=100, |
| 590 | + sort::Char='N') |
| 591 | + chkuplo(uplo) |
| 592 | + n = checksquare(A) |
| 593 | + lda = max(1, stride(A, 2)) |
| 594 | + length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) |
| 595 | + if length(V) == 0 |
| 596 | + jobz = rocSOLVER.rocblas_evect_none |
| 597 | + else |
| 598 | + size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) |
| 599 | + jobz = rocSOLVER.rocblas_evect_original |
| 600 | + end |
| 601 | + dh = rocBLAS.handle() |
| 602 | + dev_info = ROCVector{Cint}(undef, 1) |
| 603 | + residual = ROCVector{$relty}(undef, 1) |
| 604 | + n_sweeps = ROCVector{Cint}(undef, 1) |
| 605 | + roc_uplo = convert(rocSOLVER.rocblas_fill, uplo) |
| 606 | + roc_sort = sort == 'N' ? rocSOLVER.rocblas_esort_none : rocSOLVER.rocblas_esort_ascending |
| 607 | + $heevj(dh, roc_sort, jobz, roc_uplo, n, A, lda, tol, residual, max_sweeps, n_sweeps, W, dev_info) |
| 608 | + |
| 609 | + info = @allowscalar dev_info[1] |
| 610 | + chkargsok(BlasInt(info)) |
| 611 | + |
| 612 | + if jobz == rocSOLVER.rocblas_evect_original && V !== A |
| 613 | + copy!(V, A) |
| 614 | + end |
| 615 | + return W, V |
| 616 | + end |
| 617 | + end |
| 618 | +end |
515 | 619 |
|
516 | 620 | end |
0 commit comments