From 49cdee088d8985aaa77eca5527b612d224391eca Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 20 Aug 2025 06:40:37 -0400 Subject: [PATCH 01/10] Support eigh for CUDA --- .../MatrixAlgebraKitCUDAExt.jl | 9 +- ext/MatrixAlgebraKitCUDAExt/yacusolver.jl | 118 +++++++++++++----- src/MatrixAlgebraKit.jl | 2 +- src/implementations/eigh.jl | 48 ++++++- src/implementations/svd.jl | 2 - src/interface/decompositions.jl | 38 ++++++ test/cuda/eigh.jl | 78 ++++++++++++ test/runtests.jl | 3 + 8 files changed, 255 insertions(+), 43 deletions(-) create mode 100644 test/cuda/eigh.jl diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index c42e935b3..1fafb269d 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -5,8 +5,9 @@ using MatrixAlgebraKit: @algdef, Algorithm, check_input using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! using MatrixAlgebraKit: diagview, sign_safe using MatrixAlgebraKit: LQViaTransposedQR -using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm +using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev! +import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd! using CUDA using LinearAlgebra using LinearAlgebra: BlasFloat @@ -26,6 +27,9 @@ end function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix} return CUSOLVER_Simple(; kwargs...) end +function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix} + return CUSOLVER_DivideAndConquer(; kwargs...) +end _gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) = YACUSOLVER.Xgeev!(A, D, V) @@ -37,4 +41,7 @@ _gpu_Xgesvdp!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ:: _gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.Xgesvdr!(A, S, U, Vᴴ; kwargs...) _gpu_gesvdj!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...) +_gpu_heevj!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevj!(A, Dd, V; kwargs...) +_gpu_heevd!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevd!(A, Dd, V; kwargs...) + end diff --git a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl index 427eedb1a..a1a52620e 100644 --- a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl +++ b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl @@ -1,7 +1,7 @@ module YACUSOLVER using LinearAlgebra -using LinearAlgebra: BlasInt, BlasFloat, checksquare, chkstride1, require_one_based_indexing +using LinearAlgebra: BlasInt, BlasFloat, BlasReal, checksquare, chkstride1, require_one_based_indexing using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag, chkuplo using CUDA @@ -679,43 +679,93 @@ end # return X, info # end -# for (jname, bname, fname, elty, relty) in -# ((:syevd!, :cusolverDnSsyevd_bufferSize, :cusolverDnSsyevd, :Float32, :Float32), -# (:syevd!, :cusolverDnDsyevd_bufferSize, :cusolverDnDsyevd, :Float64, :Float64), -# (:heevd!, :cusolverDnCheevd_bufferSize, :cusolverDnCheevd, :ComplexF32, :Float32), -# (:heevd!, :cusolverDnZheevd_bufferSize, :cusolverDnZheevd, :ComplexF64, :Float64)) -# @eval begin -# function $jname(jobz::Char, -# uplo::Char, -# A::StridedCuMatrix{$elty}) -# chkuplo(uplo) -# n = checksquare(A) -# lda = max(1, stride(A, 2)) -# W = CuArray{$relty}(undef, n) -# dh = dense_handle() +for (bname, fname, elty, relty) in ((:(CUSOLVER.cusolverDnSsyevj_bufferSize), :(CUSOLVER.cusolverDnSsyevj), :Float32, :Float32), + (:(CUSOLVER.cusolverDnDsyevj_bufferSize), :(CUSOLVER.cusolverDnDsyevj), :Float64, :Float64), + (:(CUSOLVER.cusolverDnCheevj_bufferSize), :(CUSOLVER.cusolverDnCheevj), :ComplexF32, :Float32), + (:(CUSOLVER.cusolverDnZheevj_bufferSize), :(CUSOLVER.cusolverDnZheevj), :ComplexF64, :Float64)) + @eval begin + function heevj!(A::StridedCuMatrix{$elty}, + W::StridedCuVector{$relty}, + V::StridedCuMatrix{$elty}; + uplo::Char='U', + tol::$relty=eps($relty), + max_sweeps::Int=100 + ) + chkuplo(uplo) + n = checksquare(A) + lda = max(1, stride(A, 2)) + dh = CUSOLVER.dense_handle() + length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) + if length(V) == 0 + jobz = 'N' + else + size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) + jobz = 'V' + end + params = Ref{CUSOLVER.syevjInfo_t}(C_NULL) + CUSOLVER.cusolverDnCreateSyevjInfo(params) + CUSOLVER.cusolverDnXsyevjSetTolerance(params[], tol) + CUSOLVER.cusolverDnXsyevjSetMaxSweeps(params[], max_sweeps) + function bufferSize() + out = Ref{Cint}(0) + $bname(dh, jobz, uplo, n, A, lda, W, out, params[]) + return out[] * sizeof($elty) + end + CUDA.with_workspace(dh.workspace_gpu, bufferSize) do buffer + $fname(dh, jobz, uplo, n, A, lda, W, buffer, + sizeof(buffer) ÷ sizeof($elty), dh.info, params[]) + end -# function bufferSize() -# out = Ref{Cint}(0) -# $bname(dh, jobz, uplo, n, A, lda, W, out) -# return out[] * sizeof($elty) -# end + info = @allowscalar dh.info[1] + chkargsok(BlasInt(info)) -# with_workspace(dh.workspace_gpu, bufferSize) do buffer -# return $fname(dh, jobz, uplo, n, A, lda, W, -# buffer, sizeof(buffer) ÷ sizeof($elty), dh.info) -# end + if jobz == 'V' && V !== A + copy!(V, A) + end + return W, V + end + end +end -# info = @allowscalar dh.info[1] -# chkargsok(BlasInt(info)) +function heevd!(A::StridedCuMatrix{T}, + W::StridedCuVector{Tr}, + V::StridedCuMatrix{T}; + uplo::Char='U') where {T<:BlasFloat, Tr<:BlasReal} + chkuplo(uplo) + n = checksquare(A) + lda = max(1, stride(A, 2)) + dh = CUSOLVER.dense_handle() + length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) + if length(V) == 0 + jobz = 'N' + else + size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) + jobz = 'V' + end -# if jobz == 'N' -# return W -# elseif jobz == 'V' -# return W, A -# end -# end -# end -# end + params = CUSOLVER.CuSolverParameters() + function bufferSize() + out_cpu = Ref{Csize_t}(0) + out_gpu = Ref{Csize_t}(0) + CUSOLVER.cusolverDnXsyevd_bufferSize(dh, params, jobz, uplo, n, T, A, lda, Tr, W, T, out_gpu, out_cpu) + return out_gpu[], out_cpu[] + end + + CUSOLVER.with_workspaces(dh.workspace_gpu, dh.workspace_cpu, + bufferSize()...) do buffer_gpu, buffer_cpu + return CUSOLVER.cusolverDnXsyevd(dh, params, jobz, uplo, n, T, A, lda, Tr, W, + T, buffer_gpu, sizeof(buffer_gpu), buffer_cpu, + sizeof(buffer_cpu), dh.info) + end + + info = @allowscalar dh.info[1] + chkargsok(BlasInt(info)) + + if jobz == 'V' && V !== A + copy!(V, A) + end + return W, V +end # device code is unreachable by coverage right now # COV_EXCL_START diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 035caf11e..7a8f90d4f 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -34,7 +34,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_DivideAndConquer, LAPACK_Jacobi, LQViaTransposedQR, CUSOLVER_Simple, - CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, + CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, CUSOLVER_DivideAndConquer, ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 1e6a47f40..9b261d05f 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -61,11 +61,7 @@ function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm) YALAPACK.heevx!(A, Dd, V; alg.kwargs...) end # TODO: make this controllable using a `gaugefix` keyword argument - for j in 1:size(V, 2) - v = view(V, :, j) - s = conj(sign(argmax(abs, v))) - v .*= s - end + V = gaugefix!(V) return D, V end @@ -88,3 +84,45 @@ function eigh_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm) D, V = eigh_full!(A, DV, alg.alg) return truncate!(eigh_trunc!, (D, V), alg.trunc) end + +_gpu_heevj!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heevj!, (A, Dd, V))) +_gpu_heevd!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heevd!, (A, Dd, V))) +_gpu_heev!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heev!, (A, Dd, V))) +_gpu_heevx!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heevx!, (A, Dd, V))) + +function eigh_full!(A::AbstractMatrix, DV, alg::GPU_EighAlgorithm) + check_input(eigh_full!, A, DV, alg) + D, V = DV + Dd = D.diag + if alg isa GPU_Jacobi + _gpu_heevj!(A, Dd, V; alg.kwargs...) + elseif alg isa GPU_DivideAndConquer + _gpu_heevd!(A, Dd, V; alg.kwargs...) + elseif alg isa GPU_QRIteration # alg isa GPU_QRIteration == GPU_Simple + _gpu_heev!(A, Dd, V; alg.kwargs...) + elseif alg isa GPU_Bisection # alg isa GPU_Bisection == GPU_Expert + _gpu_heevx!(A, Dd, V; alg.kwargs...) + else + throw(ArgumentError("Unsupported eigh algorithm")) + end + # TODO: make this controllable using a `gaugefix` keyword argument + V = gaugefix!(V) + return D, V +end + +function eigh_vals!(A::AbstractMatrix, D, alg::GPU_EighAlgorithm) + check_input(eigh_vals!, A, D, alg) + V = similar(A, (size(A, 1), 0)) + if alg isa GPU_Jacobi + _gpu_heevj!(A, D, V; alg.kwargs...) + elseif alg isa GPU_DivideAndConquer + _gpu_heevd!(A, D, V; alg.kwargs...) + elseif alg isa GPU_QRIteration + _gpu_heev!(A, D, V; alg.kwargs...) + elseif alg isa GPU_Bisection + _gpu_heevx!(A, D, V; alg.kwargs...) + else + throw(ArgumentError("Unsupported eigh algorithm")) + end + return D +end diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 83e5f2206..15d9137e7 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -215,9 +215,7 @@ const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, ROCSOLVER_Jacobi} const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm} -const GPU_QRIteration = Union{CUSOLVER_QRIteration, ROCSOLVER_QRIteration} const GPU_SVDPolar = Union{CUSOLVER_SVDPolar} -const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi} const GPU_Randomized = Union{CUSOLVER_Randomized} function check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized) diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 1310944e0..722f90111 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -173,6 +173,16 @@ eigenvalue decomposition of a matrix. @algdef CUSOLVER_Simple const CUSOLVER_EigAlgorithm = Union{CUSOLVER_Simple} + +""" + CUSOLVER_DivideAndConquer() + +Algorithm type to denote the CUSOLVER driver for computing the eigenvalue decomposition of a +Hermitian matrix, or the singular value decomposition of a general matrix using the +Divide and Conquer algorithm. +""" +@algdef CUSOLVER_DivideAndConquer + # ========================= # ROCSOLVER ALGORITHMS # ========================= @@ -202,5 +212,33 @@ a general matrix using the Jacobi algorithm. """ @algdef ROCSOLVER_Jacobi +""" + ROCSOLVER_Bisection() + +Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a +Hermitian matrix, or the singular value decomposition of a general matrix using the +Bisection algorithm. +""" +@algdef ROCSOLVER_Bisection + +""" + ROCSOLVER_DivideAndConquer() + +Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a +Hermitian matrix, or the singular value decomposition of a general matrix using the +Divide and Conquer algorithm. +""" +@algdef ROCSOLVER_DivideAndConquer + + const GPU_Simple = Union{CUSOLVER_Simple} const GPU_EigAlgorithm = Union{GPU_Simple} +const GPU_QRIteration = Union{CUSOLVER_QRIteration, ROCSOLVER_QRIteration} +const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi} +const GPU_DivideAndConquer = Union{CUSOLVER_DivideAndConquer, ROCSOLVER_DivideAndConquer} +const GPU_Bisection = Union{ROCSOLVER_Bisection} +const GPU_EighAlgorithm = Union{GPU_QRIteration, + GPU_Jacobi, + GPU_DivideAndConquer, + GPU_Bisection} + diff --git a/test/cuda/eigh.jl b/test/cuda/eigh.jl new file mode 100644 index 000000000..c15bbb12e --- /dev/null +++ b/test/cuda/eigh.jl @@ -0,0 +1,78 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: LinearAlgebra, Diagonal, I +using MatrixAlgebraKit: TruncatedAlgorithm, diagview +using CUDA + +@testset "eigh_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + rng = StableRNG(123) + m = 54 + for alg in (CUSOLVER_DivideAndConquer(), + CUSOLVER_Jacobi(), + ) + A = CuArray(randn(rng, T, m, m)) + A = (A + A') / 2 + + D, V = @constinferred eigh_full(A; alg) + @test A * V ≈ V * D + @test isunitary(V) + @test all(isreal, D) + + D2, V2 = eigh_full!(copy(A), (D, V), alg) + @test D2 === D + @test V2 === V + + D3 = @constinferred eigh_vals(A, alg) + @test parent(D) ≈ D3 + end +end + +#=@testset "eigh_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + rng = StableRNG(123) + m = 54 + for alg in (CUSOLVER_QRIteration(), + CUSOLVER_DivideAndConquer(), + ) + A = CuArray(randn(rng, T, m, m)) + A = A * A' + A = (A + A') / 2 + Ac = similar(A) + D₀ = reverse(eigh_vals(A)) + r = m - 2 + s = 1 + sqrt(eps(real(T))) + + D1, V1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) + @test length(diagview(D1)) == r + @test isisometry(V1) + @test A * V1 ≈ V1 * D1 + @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] + + trunc = trunctol(s * D₀[r + 1]) + D2, V2 = @constinferred eigh_trunc(A; alg, trunc) + @test length(diagview(D2)) == r + @test isisometry(V2) + @test A * V2 ≈ V2 * D2 + + # test for same subspace + @test V1 * (V1' * V2) ≈ V2 + @test V2 * (V2' * V1) ≈ V1 + end +end + +@testset "eigh_trunc! specify truncation algorithm T = $T" for T in + (Float32, Float64, + ComplexF32, + ComplexF64) + rng = StableRNG(123) + m = 4 + V = qr_compact(CuArray(randn(rng, T, m, m)))[1] + D = Diagonal([0.9, 0.3, 0.1, 0.01]) + A = V * D * V' + A = (A + A') / 2 + alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncrank(2)) + D2, V2 = @constinferred eigh_trunc(A; alg) + @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) + @test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2)) +end=# diff --git a/test/runtests.jl b/test/runtests.jl index 9e6ed13ec..93f564b9e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -66,6 +66,9 @@ if CUDA.functional() @safetestset "CUDA General Eigenvalue Decomposition" begin include("cuda/eig.jl") end + @safetestset "CUDA Hermitian Eigenvalue Decomposition" begin + include("cuda/eigh.jl") + end end using AMDGPU From c1753b86e28effa047c40ffd59f6b1f93c0cb9fa Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 20 Aug 2025 08:28:28 -0400 Subject: [PATCH 02/10] Attempt at wrapping AMDGPU eigh --- .../MatrixAlgebraKitAMDGPUExt.jl | 12 +- ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl | 166 ++++++++++++++---- src/yalapack.jl | 4 +- test/amd/eigh.jl | 80 +++++++++ test/runtests.jl | 3 + 5 files changed, 227 insertions(+), 38 deletions(-) create mode 100644 test/amd/eigh.jl diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 0f8c35139..7f51bde93 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -5,8 +5,9 @@ using MatrixAlgebraKit: @algdef, Algorithm, check_input using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! using MatrixAlgebraKit: diagview, sign_safe using MatrixAlgebraKit: LQViaTransposedQR -using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm -import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj! +using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm +import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj! +import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx! using AMDGPU using LinearAlgebra using LinearAlgebra: BlasFloat @@ -23,6 +24,9 @@ end function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix} return ROCSOLVER_QRIteration(; kwargs...) end +function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix} + return ROCSOLVER_DivideAndConquer(; kwargs...) +end _gpu_geqrf!(A::StridedROCMatrix) = YArocSOLVER.geqrf!(A) _gpu_ungqr!(A::StridedROCMatrix, τ::StridedROCVector) = YArocSOLVER.ungqr!(A, τ) @@ -32,4 +36,8 @@ _gpu_gesvd!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ: #_gpu_Xgesvdp!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = YArocSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...) _gpu_gesvdj!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...) +_gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevj!(A, Dd, V; kwargs...) +_gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevd!(A, Dd, V; kwargs...) +_gpu_heev!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heev!(A, Dd, V; kwargs...) +_gpu_heevx!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevx!(A, Dd, V; kwargs...) end diff --git a/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl index 029bc0181..036d6c4f0 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl @@ -1,7 +1,7 @@ module YArocSOLVER using LinearAlgebra -using LinearAlgebra: BlasInt, BlasFloat, checksquare, chkstride1, require_one_based_indexing +using LinearAlgebra: BlasInt, BlasReal, BlasFloat, checksquare, chkstride1, require_one_based_indexing using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag, chkuplo using AMDGPU @@ -475,42 +475,140 @@ end # return X, info # end -# for (jname, bname, fname, elty, relty) in -# ((:syevd!, :rocsolverDnSsyevd_bufferSize, :rocsolverDnSsyevd, :Float32, :Float32), -# (:syevd!, :rocsolverDnDsyevd_bufferSize, :rocsolverDnDsyevd, :Float64, :Float64), -# (:heevd!, :rocsolverDnCheevd_bufferSize, :rocsolverDnCheevd, :ComplexF32, :Float32), -# (:heevd!, :rocsolverDnZheevd_bufferSize, :rocsolverDnZheevd, :ComplexF64, :Float64)) -# @eval begin -# function $jname(jobz::Char, -# uplo::Char, -# A::StridedROCMatrix{$elty}) -# chkuplo(uplo) -# n = checksquare(A) -# lda = max(1, stride(A, 2)) -# W = CuArray{$relty}(undef, n) -# dh = rocBLAS.handle() +for (heevd, heev, heevx, heevj, elty, relty) in + ((:(rocSOLVER.rocsolver_ssyevd), :(rocSOLVER.rocsolver_ssyev), :(rocSOLVER.rocsolver_ssyevx), :(rocSOLVER.rocsolver_ssyevj), :Float32, :Float32), + (:(rocSOVLER.rocsolver_dsyevd), :(rocSOLVER.rocsolver_dsyev), :(rocSOLVER.rocsolver_dsyevx), :(rocSOLVER.rocsolver_dsyevj), :Float64, :Float64), + (:(rocSOLVER.rocsolver_cheevd), :(rocSOLVER.rocsolver_cheev), :(rocSOLVER.rocsolver_cheevx), :(rocSOLVER.rocsolver_cheevj), :ComplexF32, :Float32), + (:(rocSOLVER.rocsolver_zheevd), :(rocSOLVER.rocsolver_zheev), :(rocSOLVER.rocsolver_zheevx), :(rocSOLVER.rocsolver_zheevj), :ComplexF64, :Float64)) + @eval begin + function heevd!(A::StridedROCMatrix{$elty}, + W::StridedROCVector{$relty}, + V::StridedROCMatrix{$elty}; + uplo::Char='U') + chkuplo(uplo) + n = checksquare(A) + lda = max(1, stride(A, 2)) + length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) + if length(V) == 0 + jobz = 'N' + else + size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) + jobz = 'O' + end + dh = rocBLAS.handle() + work = ROCVector{$relty}(undef, n) + dev_info = ROCVector{Cint}(undef, 1) + $heevd(dh, jobz, uplo, n, A, lda, W, work, dev_info) -# function bufferSize() -# out = Ref{Cint}(0) -# $bname(dh, jobz, uplo, n, A, lda, W, out) -# return out[] * sizeof($elty) -# end + info = @allowscalar dev_info[1] + chkargsok(BlasInt(info)) -# with_workspace(dh.workspace_gpu, bufferSize) do buffer -# return $fname(dh, jobz, uplo, n, A, lda, W, -# buffer, sizeof(buffer) ÷ sizeof($elty), dh.info) -# end + if jobz == 'O' && V !== A + copy!(V, A) + end + return W, V + end + function heev!(A::StridedROCMatrix{$elty}, + W::StridedROCVector{$relty}, + V::StridedROCMatrix{$elty}; + uplo::Char='U') + chkuplo(uplo) + n = checksquare(A) + lda = max(1, stride(A, 2)) + length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) + if length(V) == 0 + jobz = 'N' + else + size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) + jobz = 'O' + end + dh = rocBLAS.handle() + work = ROCVector{$relty}(undef, n) + dev_info = ROCVector{Cint}(undef, 1) + $heev(dh, jobz, uplo, n, A, lda, W, work, dev_info) -# info = @allowscalar dh.info[1] -# chkargsok(BlasInt(info)) + info = @allowscalar dev_info[1] + chkargsok(BlasInt(info)) -# if jobz == 'N' -# return W -# elseif jobz == 'V' -# return W, A -# end -# end -# end -# end + if jobz == 'O' && V !== A + copy!(V, A) + end + return W, V + end + function heevx!(A::StridedROCMatrix{$elty}, + W::StridedROCVector{$relty}, + V::StridedROCMatrix{$elty}; + uplo::Char='U', + kwargs...) + chkuplo(uplo) + n = checksquare(A) + lda = max(1, stride(A, 2)) + length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) + if haskey(kwargs, :irange) + il = first(kwargs[:irange]) + iu = last(kwargs[:irange]) + vl = vu = zero($relty) + range = 'I' + elseif haskey(kwargs, :vl) || haskey(kwargs, :vu) + vl = convert($relty, get(kwargs, :vl, -Inf)) + vu = convert($relty, get(kwargs, :vu, +Inf)) + il = iu = 0 + range = 'V' + else + il = iu = 0 + vl = vu = zero($relty) + range = 'A' + end + if length(V) == 0 + jobz = 'N' + else + size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) + jobz = 'O' + end + dh = rocBLAS.handle() + abstol = -one($relty) + m = Ref{BlasInt}() + ldv = max(1, stride(V, 2)) + work = ROCVector{$relty}(undef, n) + ifail = ROCVector{BlasInt}(undef, n) + dev_info = ROCVector{Cint}(undef, 1) + $heevx(dh, jobz, range, uplo, n, A, lda, vl, vu, il, iu, abstol, m, W, V, ldv, ifail, dev_info) + + info = @allowscalar dev_info[1] + chkargsok(BlasInt(info)) + return W, V, m[] + end + function heevj!(A::StridedROCMatrix{$elty}, + W::StridedROCVector{$relty}, + V::StridedROCMatrix{$elty}; + uplo::Char='U', + tol::$relty=eps($relty), + max_sweeps::Int=100) + chkuplo(uplo) + n = checksquare(A) + lda = max(1, stride(A, 2)) + length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) + if length(V) == 0 + jobz = 'N' + else + size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) + jobz = 'O' + end + dh = rocBLAS.handle() + dev_info = ROCVector{Cint}(undef, 1) + residual = ROCVector{$relty}(undef, 1) + n_sweeps = ROCVector{Cint}(undef, 1) + $heev(dh, jobz, uplo, n, A, lda, abstol, residual, max_sweeps, n_sweeps, W, dev_info) + + info = @allowscalar dev_info[1] + chkargsok(BlasInt(info)) + + if jobz == 'O' && V !== A + copy!(V, A) + end + return W, V + end + end +end end diff --git a/src/yalapack.jl b/src/yalapack.jl index ec37b6745..3d7b0369a 100644 --- a/src/yalapack.jl +++ b/src/yalapack.jl @@ -921,8 +921,8 @@ for (heev, heevx, heevr, heevd, hegvd, elty, relty) in end chkuplofinite(A, uplo) if haskey(kwargs, :irange) - il = first(irange) - iu = last(irange) + il = first(kwargs[:irange]) + iu = last(kwargs[:irange]) vl = vu = zero($relty) range = 'I' elseif haskey(kwargs, :vl) || haskey(kwargs, :vu) diff --git a/test/amd/eigh.jl b/test/amd/eigh.jl new file mode 100644 index 000000000..44be84952 --- /dev/null +++ b/test/amd/eigh.jl @@ -0,0 +1,80 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: LinearAlgebra, Diagonal, I +using MatrixAlgebraKit: TruncatedAlgorithm, diagview +using AMDGPU + +@testset "eigh_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + rng = StableRNG(123) + m = 54 + for alg in (ROCSOLVER_DivideAndConquer(), + ROCSOLVER_Jacobi(), + ROCSOLVER_Bisection(), + ROCSOLVER_QRIteration(), + ) + A = ROCArray(randn(rng, T, m, m)) + A = (A + A') / 2 + + D, V = @constinferred eigh_full(A; alg) + @test A * V ≈ V * D + @test isunitary(V) + @test all(isreal, D) + + D2, V2 = eigh_full!(copy(A), (D, V), alg) + @test D2 === D + @test V2 === V + + D3 = @constinferred eigh_vals(A, alg) + @test parent(D) ≈ D3 + end +end + +#=@testset "eigh_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + rng = StableRNG(123) + m = 54 + for alg in (CUSOLVER_QRIteration(), + CUSOLVER_DivideAndConquer(), + ) + A = ROCArray(randn(rng, T, m, m)) + A = A * A' + A = (A + A') / 2 + Ac = similar(A) + D₀ = reverse(eigh_vals(A)) + r = m - 2 + s = 1 + sqrt(eps(real(T))) + + D1, V1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) + @test length(diagview(D1)) == r + @test isisometry(V1) + @test A * V1 ≈ V1 * D1 + @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] + + trunc = trunctol(s * D₀[r + 1]) + D2, V2 = @constinferred eigh_trunc(A; alg, trunc) + @test length(diagview(D2)) == r + @test isisometry(V2) + @test A * V2 ≈ V2 * D2 + + # test for same subspace + @test V1 * (V1' * V2) ≈ V2 + @test V2 * (V2' * V1) ≈ V1 + end +end + +@testset "eigh_trunc! specify truncation algorithm T = $T" for T in + (Float32, Float64, + ComplexF32, + ComplexF64) + rng = StableRNG(123) + m = 4 + V = qr_compact(ROCArray(randn(rng, T, m, m)))[1] + D = Diagonal([0.9, 0.3, 0.1, 0.01]) + A = V * D * V' + A = (A + A') / 2 + alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncrank(2)) + D2, V2 = @constinferred eigh_trunc(A; alg) + @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) + @test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2)) +end=# diff --git a/test/runtests.jl b/test/runtests.jl index 93f564b9e..d85e8a32c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -82,4 +82,7 @@ if AMDGPU.functional() @safetestset "AMDGPU SVD" begin include("amd/svd.jl") end + @safetestset "AMDGPU Hermitian Eigenvalue Decomposition" begin + include("amd/eigh.jl") + end end From a854ce2dd45f734cf88a3ebc2e5f47f28475a3b7 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 20 Aug 2025 08:35:34 -0400 Subject: [PATCH 03/10] Export ROCSOLVER_DivideAndConquer --- src/MatrixAlgebraKit.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 7a8f90d4f..6e28cf969 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -35,7 +35,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LQViaTransposedQR, CUSOLVER_Simple, CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, CUSOLVER_DivideAndConquer, - ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi + ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi, ROCSOLVER_DivideAndConquer, ROCSOLVER_Bisection export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered VERSION >= v"1.11.0-DEV.469" && From 974f0499b3b7180ddbfe29cc1e3b8eff28deabc6 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 20 Aug 2025 08:50:54 -0400 Subject: [PATCH 04/10] Fix char types to rocblas --- ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl | 40 +++++++++++--------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl index 036d6c4f0..c93aaaecb 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl @@ -490,20 +490,21 @@ for (heevd, heev, heevx, heevj, elty, relty) in lda = max(1, stride(A, 2)) length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) if length(V) == 0 - jobz = 'N' + jobz = rocSOLVER.rocblas_evect_none else size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) - jobz = 'O' + jobz = rocSOLVER.rocblas_evect_original end dh = rocBLAS.handle() work = ROCVector{$relty}(undef, n) dev_info = ROCVector{Cint}(undef, 1) - $heevd(dh, jobz, uplo, n, A, lda, W, work, dev_info) + roc_uplo = convert(rocSOLVER.rocblas_fill, uplo) + $heevd(dh, jobz, roc_uplo, n, A, lda, W, work, dev_info) info = @allowscalar dev_info[1] chkargsok(BlasInt(info)) - if jobz == 'O' && V !== A + if jobz == rocSOLVER.rocblas_evect_original && V !== A copy!(V, A) end return W, V @@ -517,20 +518,21 @@ for (heevd, heev, heevx, heevj, elty, relty) in lda = max(1, stride(A, 2)) length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) if length(V) == 0 - jobz = 'N' + jobz = rocSOLVER.rocblas_evect_none else size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) - jobz = 'O' + jobz = rocSOLVER.rocblas_evect_original end dh = rocBLAS.handle() work = ROCVector{$relty}(undef, n) dev_info = ROCVector{Cint}(undef, 1) - $heev(dh, jobz, uplo, n, A, lda, W, work, dev_info) + roc_uplo = convert(rocSOLVER.rocblas_fill, uplo) + $heev(dh, jobz, roc_uplo, n, A, lda, W, work, dev_info) info = @allowscalar dev_info[1] chkargsok(BlasInt(info)) - if jobz == 'O' && V !== A + if jobz == rocSOLVER.rocblas_evect_original && V !== A copy!(V, A) end return W, V @@ -548,22 +550,22 @@ for (heevd, heev, heevx, heevj, elty, relty) in il = first(kwargs[:irange]) iu = last(kwargs[:irange]) vl = vu = zero($relty) - range = 'I' + range = rocSOLVER.rocblas_erange_index elseif haskey(kwargs, :vl) || haskey(kwargs, :vu) vl = convert($relty, get(kwargs, :vl, -Inf)) vu = convert($relty, get(kwargs, :vu, +Inf)) il = iu = 0 - range = 'V' + range = rocSOLVER.rocblas_erange_value else il = iu = 0 vl = vu = zero($relty) - range = 'A' + range = rocSOLVER.rocblas_erange_all end if length(V) == 0 - jobz = 'N' + jobz = rocSOLVER.rocblas_evect_none else size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) - jobz = 'O' + jobz = rocSOLVER.rocblas_evect_original end dh = rocBLAS.handle() abstol = -one($relty) @@ -572,7 +574,8 @@ for (heevd, heev, heevx, heevj, elty, relty) in work = ROCVector{$relty}(undef, n) ifail = ROCVector{BlasInt}(undef, n) dev_info = ROCVector{Cint}(undef, 1) - $heevx(dh, jobz, range, uplo, n, A, lda, vl, vu, il, iu, abstol, m, W, V, ldv, ifail, dev_info) + roc_uplo = convert(rocSOLVER.rocblas_fill, uplo) + $heevx(dh, jobz, range, roc_uplo, n, A, lda, vl, vu, il, iu, abstol, m, W, V, ldv, ifail, dev_info) info = @allowscalar dev_info[1] chkargsok(BlasInt(info)) @@ -589,21 +592,22 @@ for (heevd, heev, heevx, heevj, elty, relty) in lda = max(1, stride(A, 2)) length(W) == n || throw(DimensionMismatch("size mismatch between A and W")) if length(V) == 0 - jobz = 'N' + jobz = rocSOLVER.rocblas_evect_none else size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) - jobz = 'O' + jobz = rocSOLVER.rocblas_evect_original end dh = rocBLAS.handle() dev_info = ROCVector{Cint}(undef, 1) residual = ROCVector{$relty}(undef, 1) n_sweeps = ROCVector{Cint}(undef, 1) - $heev(dh, jobz, uplo, n, A, lda, abstol, residual, max_sweeps, n_sweeps, W, dev_info) + roc_uplo = convert(rocSOLVER.rocblas_fill, uplo) + $heev(dh, jobz, roc_uplo, n, A, lda, abstol, residual, max_sweeps, n_sweeps, W, dev_info) info = @allowscalar dev_info[1] chkargsok(BlasInt(info)) - if jobz == 'O' && V !== A + if jobz == rocSOLVER.rocblas_evect_original && V !== A copy!(V, A) end return W, V From f2c6fb63c30c1b0c01d394ab399803187a347cef Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 20 Aug 2025 08:56:03 -0400 Subject: [PATCH 05/10] Fix kwarg name --- ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl index c93aaaecb..9c73bfcfb 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl @@ -602,7 +602,7 @@ for (heevd, heev, heevx, heevj, elty, relty) in residual = ROCVector{$relty}(undef, 1) n_sweeps = ROCVector{Cint}(undef, 1) roc_uplo = convert(rocSOLVER.rocblas_fill, uplo) - $heev(dh, jobz, roc_uplo, n, A, lda, abstol, residual, max_sweeps, n_sweeps, W, dev_info) + $heev(dh, jobz, roc_uplo, n, A, lda, tol, residual, max_sweeps, n_sweeps, W, dev_info) info = @allowscalar dev_info[1] chkargsok(BlasInt(info)) From 4190b6694626dc1e265997dba84a01d08f650a3f Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 20 Aug 2025 08:59:01 -0400 Subject: [PATCH 06/10] Fix function call for heevj --- ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl index 9c73bfcfb..efda8f7c2 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl @@ -602,7 +602,7 @@ for (heevd, heev, heevx, heevj, elty, relty) in residual = ROCVector{$relty}(undef, 1) n_sweeps = ROCVector{Cint}(undef, 1) roc_uplo = convert(rocSOLVER.rocblas_fill, uplo) - $heev(dh, jobz, roc_uplo, n, A, lda, tol, residual, max_sweeps, n_sweeps, W, dev_info) + $heevj(dh, jobz, roc_uplo, n, A, lda, tol, residual, max_sweeps, n_sweeps, W, dev_info) info = @allowscalar dev_info[1] chkargsok(BlasInt(info)) From 013c518a99877b2beb82e1bd655c9eb4cc93efa2 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 20 Aug 2025 09:08:04 -0400 Subject: [PATCH 07/10] Add sort option for heevj --- ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl index efda8f7c2..b5dda1b10 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl @@ -586,7 +586,8 @@ for (heevd, heev, heevx, heevj, elty, relty) in V::StridedROCMatrix{$elty}; uplo::Char='U', tol::$relty=eps($relty), - max_sweeps::Int=100) + max_sweeps::Int=100, + sort::Char='N') chkuplo(uplo) n = checksquare(A) lda = max(1, stride(A, 2)) @@ -602,7 +603,8 @@ for (heevd, heev, heevx, heevj, elty, relty) in residual = ROCVector{$relty}(undef, 1) n_sweeps = ROCVector{Cint}(undef, 1) roc_uplo = convert(rocSOLVER.rocblas_fill, uplo) - $heevj(dh, jobz, roc_uplo, n, A, lda, tol, residual, max_sweeps, n_sweeps, W, dev_info) + roc_sort = sort == 'N' ? rocSOLVER.rocblas_esort_none : rocSOLVER.rocblas_esort_ascending + $heevj(dh, roc_sort, jobz, roc_uplo, n, A, lda, tol, residual, max_sweeps, n_sweeps, W, dev_info) info = @allowscalar dev_info[1] chkargsok(BlasInt(info)) From eae424bae93caa57c0bb8943e866d79ee5d1befb Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 20 Aug 2025 09:16:29 -0400 Subject: [PATCH 08/10] Fix int types for heevx --- ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl index b5dda1b10..3f3f02ab2 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl @@ -569,10 +569,10 @@ for (heevd, heev, heevx, heevj, elty, relty) in end dh = rocBLAS.handle() abstol = -one($relty) - m = Ref{BlasInt}() + m = Ref{Cint}() ldv = max(1, stride(V, 2)) work = ROCVector{$relty}(undef, n) - ifail = ROCVector{BlasInt}(undef, n) + ifail = ROCVector{Cint}(undef, n) dev_info = ROCVector{Cint}(undef, 1) roc_uplo = convert(rocSOLVER.rocblas_fill, uplo) $heevx(dh, jobz, range, roc_uplo, n, A, lda, vl, vu, il, iu, abstol, m, W, V, ldv, ifail, dev_info) From ad3c1a54236fdf9c6570b5bfa2f938b9a504bf07 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 20 Aug 2025 09:26:03 -0400 Subject: [PATCH 09/10] Fix nev for heevx --- ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl index 3f3f02ab2..2c0ab1d29 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl @@ -567,19 +567,19 @@ for (heevd, heev, heevx, heevj, elty, relty) in size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V")) jobz = rocSOLVER.rocblas_evect_original end - dh = rocBLAS.handle() - abstol = -one($relty) - m = Ref{Cint}() - ldv = max(1, stride(V, 2)) - work = ROCVector{$relty}(undef, n) - ifail = ROCVector{Cint}(undef, n) + dh = rocBLAS.handle() + abstol = -one($relty) + nev = ROCVector{Cint}(undef, 1) + ldv = max(1, stride(V, 2)) + ifail = ROCVector{Cint}(undef, n) dev_info = ROCVector{Cint}(undef, 1) roc_uplo = convert(rocSOLVER.rocblas_fill, uplo) - $heevx(dh, jobz, range, roc_uplo, n, A, lda, vl, vu, il, iu, abstol, m, W, V, ldv, ifail, dev_info) + $heevx(dh, jobz, range, roc_uplo, n, A, lda, vl, vu, il, iu, abstol, nev, W, V, ldv, ifail, dev_info) info = @allowscalar dev_info[1] chkargsok(BlasInt(info)) - return W, V, m[] + m = @allowscalar nev[1] + return W, V, m end function heevj!(A::StridedROCMatrix{$elty}, W::StridedROCVector{$relty}, From 956ac1e39a14c0981652e1e615cc4c20dc4ef939 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 20 Aug 2025 09:30:02 -0400 Subject: [PATCH 10/10] Correct function name --- ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl index 2c0ab1d29..3338308b1 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl @@ -477,7 +477,7 @@ end for (heevd, heev, heevx, heevj, elty, relty) in ((:(rocSOLVER.rocsolver_ssyevd), :(rocSOLVER.rocsolver_ssyev), :(rocSOLVER.rocsolver_ssyevx), :(rocSOLVER.rocsolver_ssyevj), :Float32, :Float32), - (:(rocSOVLER.rocsolver_dsyevd), :(rocSOLVER.rocsolver_dsyev), :(rocSOLVER.rocsolver_dsyevx), :(rocSOLVER.rocsolver_dsyevj), :Float64, :Float64), + (:(rocSOLVER.rocsolver_dsyevd), :(rocSOLVER.rocsolver_dsyev), :(rocSOLVER.rocsolver_dsyevx), :(rocSOLVER.rocsolver_dsyevj), :Float64, :Float64), (:(rocSOLVER.rocsolver_cheevd), :(rocSOLVER.rocsolver_cheev), :(rocSOLVER.rocsolver_cheevx), :(rocSOLVER.rocsolver_cheevj), :ComplexF32, :Float32), (:(rocSOLVER.rocsolver_zheevd), :(rocSOLVER.rocsolver_zheev), :(rocSOLVER.rocsolver_zheevx), :(rocSOLVER.rocsolver_zheevj), :ComplexF64, :Float64)) @eval begin