From 80795e08d6d3ca55ba576de01a64621f3f724a59 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 1 Sep 2025 12:15:58 +0200 Subject: [PATCH 01/11] Add DiagonalAlgorithm --- src/MatrixAlgebraKit.jl | 3 ++- src/interface/decompositions.jl | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 6e28cf969..9bf19dd94 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -35,7 +35,8 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LQViaTransposedQR, CUSOLVER_Simple, CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, CUSOLVER_DivideAndConquer, - ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi, ROCSOLVER_DivideAndConquer, ROCSOLVER_Bisection + ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi, ROCSOLVER_DivideAndConquer, ROCSOLVER_Bisection, + DiagonalAlgorithm export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered VERSION >= v"1.11.0-DEV.469" && diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 722f90111..62a064017 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -113,6 +113,17 @@ const LAPACK_SVDAlgorithm = Union{LAPACK_QRIteration, LAPACK_DivideAndConquer, LAPACK_Jacobi} +# ========================= +# DIAGONAL ALGORITHMS +# ========================= +""" + DiagonalAlgorithm(; kwargs...) + +Algorithm type to denote a native Julia implementation of the decompositions making use of +the diagonal structure of the input and outputs. +""" +@algdef DiagonalAlgorithm + # ========================= # CUSOLVER ALGORITHMS # ========================= From a6456f8fae61df53173b160ce0b57229ff450dcd Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 1 Sep 2025 13:15:35 +0200 Subject: [PATCH 02/11] add Diagonal QR implementation and tests --- src/MatrixAlgebraKit.jl | 2 +- src/implementations/qr.jl | 103 ++++++++++++++++++++++++++++++++------ src/interface/qr.jl | 3 ++ test/qr.jl | 50 +++++++++++++++++- 4 files changed, 141 insertions(+), 17 deletions(-) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 9bf19dd94..4d97b8284 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -5,7 +5,7 @@ using LinearAlgebra: norm # TODO: eleminate if we use VectorInterface.jl? using LinearAlgebra: mul!, rmul!, lmul!, adjoint!, rdiv!, ldiv! using LinearAlgebra: sylvester using LinearAlgebra: isposdef, ishermitian -using LinearAlgebra: Diagonal, diag, diagind +using LinearAlgebra: Diagonal, diag, diagind, isdiag using LinearAlgebra: UpperTriangular, LowerTriangular using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt diff --git a/src/implementations/qr.jl b/src/implementations/qr.jl index 638d65d8f..2c6843c1f 100644 --- a/src/implementations/qr.jl +++ b/src/implementations/qr.jl @@ -1,13 +1,10 @@ # Inputs # ------ -function copy_input(::typeof(qr_full), A::AbstractMatrix) - return copy!(similar(A, float(eltype(A))), A) -end -function copy_input(::typeof(qr_compact), A::AbstractMatrix) - return copy!(similar(A, float(eltype(A))), A) -end -function copy_input(::typeof(qr_null), A::AbstractMatrix) - return copy!(similar(A, float(eltype(A))), A) +for f in (:qr_full, :qr_compact, :qr_null) + @eval function copy_input(::typeof($f), A::AbstractMatrix) + return copy!(similar(A, float(eltype(A))), A) + end + @eval copy_input(::typeof($f), A::Diagonal) = copy(A) end function check_input(::typeof(qr_full!), A::AbstractMatrix, QR, ::AbstractAlgorithm) @@ -40,6 +37,28 @@ function check_input(::typeof(qr_null!), A::AbstractMatrix, N, ::AbstractAlgorit return nothing end +function check_input(::typeof(qr_full!), A::AbstractMatrix, (Q, R), alg::DiagonalAlgorithm) + m, n = size(A) + @assert m == n && isdiag(A) + @assert Q isa Diagonal && R isa Diagonal + @check_size(Q, (m, n)) + @check_scalar(Q, A) + isempty(R) || @check_size(R, (m, n)) + @check_scalar(R, A) + return nothing +end +function check_input(::typeof(qr_compact!), A::AbstractMatrix, QR, alg::DiagonalAlgorithm) + return check_input(qr_full!, A, QR, alg) +end +function check_input(::typeof(qr_null!), A::AbstractMatrix, N, ::DiagonalAlgorithm) + m, n = size(A) + @assert m == n && isdiag(A) + @assert N isa AbstractMatrix + @check_size(N, (m, 0)) + @check_scalar(N, A) + return nothing +end + # Outputs # ------- function initialize_output(::typeof(qr_full!), A::AbstractMatrix, ::AbstractAlgorithm) @@ -62,6 +81,12 @@ function initialize_output(::typeof(qr_null!), A::AbstractMatrix, ::AbstractAlgo return N end +for f! in (:qr_full!, :qr_compact!) + @eval function initialize_output(::typeof($f!), A::AbstractMatrix, ::DiagonalAlgorithm) + return similar(A), A + end +end + # Implementation # -------------- # actual implementation @@ -83,6 +108,26 @@ function qr_null!(A::AbstractMatrix, N, alg::LAPACK_HouseholderQR) return N end +function qr_full!(A::AbstractMatrix, QR, alg::DiagonalAlgorithm) + check_input(qr_full!, A, QR, alg) + Q, R = QR + _diagonal_qr!(A, Q, R; alg.kwargs...) + return Q, R +end +function qr_compact!(A::AbstractMatrix, QR, alg::DiagonalAlgorithm) + check_input(qr_compact!, A, QR, alg) + Q, R = QR + _diagonal_qr!(A, Q, R; alg.kwargs...) + return Q, R +end +function qr_null!(A::AbstractMatrix, N, alg::DiagonalAlgorithm) + check_input(qr_null!, A, N, alg) + _diagonal_qr_null!(A, N; alg.kwargs...) + return N +end + +# LAPACK logic +# ------------ function _lapack_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; positive=false, pivoted=false, @@ -167,23 +212,51 @@ function _lapack_qr_null!(A::AbstractMatrix, N::AbstractMatrix; return N end +# Diagonal logic +# -------------- +function _diagonal_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; + positive::Bool=false) + Ad = diagview(A) + Qd = diagview(Q) + Rd = diagview(R) + if positive + @inbounds @simd for i in eachindex(Ad) + s = sign_safe(Ad[i]) + Qd[i] = s + Rd[i] = conj(s) * Ad[i] + end + else + A === R || copy!(Rd, Ad) + one!(Q) + end + return Q, R +end + +_diagonal_qr_null!(A::AbstractMatrix, N::AbstractMatrix) = N + ### GPU logic # placed here to avoid code duplication since much of the logic is replicable across # CUDA and AMDGPU ### -function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::Union{CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR}) +function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, + alg::Union{CUSOLVER_HouseholderQR, + ROCSOLVER_HouseholderQR}) check_input(qr_full!, A, QR, alg) Q, R = QR _gpu_qr!(A, Q, R; alg.kwargs...) return Q, R end -function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, alg::Union{CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR}) +function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, + alg::Union{CUSOLVER_HouseholderQR, + ROCSOLVER_HouseholderQR}) check_input(qr_compact!, A, QR, alg) Q, R = QR _gpu_qr!(A, Q, R; alg.kwargs...) return Q, R end -function MatrixAlgebraKit.qr_null!(A::AbstractMatrix, N, alg::Union{CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR}) +function MatrixAlgebraKit.qr_null!(A::AbstractMatrix, N, + alg::Union{CUSOLVER_HouseholderQR, + ROCSOLVER_HouseholderQR}) check_input(qr_null!, A, N, alg) _gpu_qr_null!(A, N; alg.kwargs...) return N @@ -191,11 +264,13 @@ end _gpu_geqrf!(A::AbstractMatrix) = throw(MethodError(_gpu_geqrf!, (A,))) _gpu_ungqr!(A::AbstractMatrix, τ::AbstractVector) = throw(MethodError(_gpu_ungqr!, (A, τ))) -_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::AbstractMatrix, τ::AbstractVector, C) = throw(MethodError(_gpu_unmqr!, (side, trans, A, τ, C))) - +function _gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::AbstractMatrix, + τ::AbstractVector, C) + throw(MethodError(_gpu_unmqr!, (side, trans, A, τ, C))) +end function _gpu_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; - positive=false, blocksize=1) + positive=false, blocksize=1) blocksize > 1 && throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a blocked implementation for a QR decomposition")) m, n = size(A) diff --git a/src/interface/qr.jl b/src/interface/qr.jl index 62d870808..b6a38d126 100644 --- a/src/interface/qr.jl +++ b/src/interface/qr.jl @@ -75,6 +75,9 @@ end function default_qr_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} return LAPACK_HouseholderQR(; kwargs...) end +function default_qr_algorithm(::Type{T}; kwargs...) where {T<:Diagonal} + return DiagonalAlgorithm(; kwargs...) +end for f in (:qr_full!, :qr_compact!, :qr_null!) @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} diff --git a/test/qr.jl b/test/qr.jl index 83f3ec6dd..630ea26cb 100644 --- a/test/qr.jl +++ b/test/qr.jl @@ -2,7 +2,7 @@ using MatrixAlgebraKit using Test using TestExtras using StableRNGs -using LinearAlgebra: diag, I +using LinearAlgebra: diag, I, Diagonal @testset "qr_compact! and qr_null! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) @@ -108,7 +108,7 @@ end @test Q isa Matrix{T} && size(Q) == (m, m) @test R isa Matrix{T} && size(R) == (m, n) @test Q * R ≈ A - @test Q' * Q ≈ I + @test isunitary(Q) Ac = similar(A) Q2 = similar(Q) @@ -174,3 +174,49 @@ end @test Q == Q2 end end + +@testset "qr_compact, qr_full and qr_null for Diagonal{$T}" for T in (Float32, Float64, + ComplexF32, ComplexF64) + rng = StableRNG(123) + atol = eps(real(T))^(3 / 4) + for m in (54, 0) + Ad = randn(rng, T, m) + A = Diagonal(Ad) + + # compact + Q, R = @constinferred qr_compact(A) + @test Q isa Diagonal{T} && size(Q) == (m, m) + @test R isa Diagonal{T} && size(R) == (m, m) + @test Q * R ≈ A + @test isunitary(Q) + + # compact and positive + Qp, Rp = @constinferred qr_compact(A; positive=true) + @test Qp isa Diagonal{T} && size(Qp) == (m, m) + @test Rp isa Diagonal{T} && size(Rp) == (m, m) + @test Qp * Rp ≈ A + @test isunitary(Qp) + @test all(≥(zero(real(T))), real(diag(Rp))) && + all(≈(zero(real(T)); atol), imag(diag(Rp))) + + # full + Q, R = @constinferred qr_full(A) + @test Q isa Diagonal{T} && size(Q) == (m, m) + @test R isa Diagonal{T} && size(R) == (m, m) + @test Q * R ≈ A + @test isunitary(Q) + + # full and positive + Qp, Rp = @constinferred qr_full(A; positive=true) + @test Qp isa Diagonal{T} && size(Qp) == (m, m) + @test Rp isa Diagonal{T} && size(Rp) == (m, m) + @test Qp * Rp ≈ A + @test isunitary(Qp) + @test all(≥(zero(real(T))), real(diag(Rp))) && + all(≈(zero(real(T)); atol), imag(diag(Rp))) + + # null + N = @constinferred qr_null(A) + @test N isa AbstractMatrix{T} && size(N) == (m, 0) + end +end From 3fd4f52f0d228fe7d00d5dd01b9ce022ac569278 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 1 Sep 2025 13:45:20 +0200 Subject: [PATCH 03/11] Add Diagonal LQ implementation and tests --- src/implementations/lq.jl | 107 +++++++++++++++++++++++++++++++------- src/interface/lq.jl | 3 ++ test/lq.jl | 48 ++++++++++++++++- 3 files changed, 137 insertions(+), 21 deletions(-) diff --git a/src/implementations/lq.jl b/src/implementations/lq.jl index 98617fbf5..9c2cdcb62 100644 --- a/src/implementations/lq.jl +++ b/src/implementations/lq.jl @@ -1,13 +1,10 @@ # Inputs # ------ -function copy_input(::typeof(lq_full), A::AbstractMatrix) - return copy!(similar(A, float(eltype(A))), A) -end -function copy_input(::typeof(lq_compact), A::AbstractMatrix) - return copy!(similar(A, float(eltype(A))), A) -end -function copy_input(::typeof(lq_null), A::AbstractMatrix) - return copy!(similar(A, float(eltype(A))), A) +for f in (:lq_full, :lq_compact, :lq_null) + @eval function copy_input(::typeof($f), A::AbstractMatrix) + return copy!(similar(A, float(eltype(A))), A) + end + @eval copy_input(::typeof($f), A::Diagonal) = copy(A) end function check_input(::typeof(lq_full!), A::AbstractMatrix, LQ, ::AbstractAlgorithm) @@ -40,6 +37,28 @@ function check_input(::typeof(lq_null!), A::AbstractMatrix, Nᴴ, ::AbstractAlgo return nothing end +function check_input(::typeof(lq_full!), A::AbstractMatrix, (L, Q), ::DiagonalAlgorithm) + m, n = size(A) + @assert m == n && isdiag(A) + @assert Q isa Diagonal && L isa Diagonal + isempty(L) || @check_size(L, (m, n)) + @check_scalar(L, A) + @check_size(Q, (n, n)) + @check_scalar(Q, A) + return nothing +end +function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ, alg::DiagonalAlgorithm) + return check_input(lq_full!, A, LQ, alg) +end +function check_input(::typeof(lq_null!), A::AbstractMatrix, N, ::DiagonalAlgorithm) + m, n = size(A) + @assert m == n && isdiag(A) + @assert N isa AbstractMatrix + @check_size(N, (0, m)) + @check_scalar(N, A) + return nothing +end + # Outputs # ------- function initialize_output(::typeof(lq_full!), A::AbstractMatrix, ::AbstractAlgorithm) @@ -62,44 +81,69 @@ function initialize_output(::typeof(lq_null!), A::AbstractMatrix, ::AbstractAlgo return Nᴴ end +for f! in (:lq_full!, :lq_compact!) + @eval function initialize_output(::typeof($f!), A::AbstractMatrix, ::DiagonalAlgorithm) + return A, similar(A) + end +end + # Implementation # -------------- -# actual implementation function lq_full!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ) check_input(lq_full!, A, LQ, alg) L, Q = LQ _lapack_lq!(A, L, Q; alg.kwargs...) return L, Q end -function lq_full!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR) - check_input(lq_full!, A, LQ, alg) - L, Q = LQ - lq_via_qr!(A, L, Q, alg.qr_alg) - return L, Q -end function lq_compact!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ) check_input(lq_compact!, A, LQ, alg) L, Q = LQ _lapack_lq!(A, L, Q; alg.kwargs...) return L, Q end +function lq_null!(A::AbstractMatrix, Nᴴ, alg::LAPACK_HouseholderLQ) + check_input(lq_null!, A, Nᴴ, alg) + _lapack_lq_null!(A, Nᴴ; alg.kwargs...) + return Nᴴ +end + +function lq_full!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR) + check_input(lq_full!, A, LQ, alg) + L, Q = LQ + lq_via_qr!(A, L, Q, alg.qr_alg) + return L, Q +end function lq_compact!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR) check_input(lq_compact!, A, LQ, alg) L, Q = LQ lq_via_qr!(A, L, Q, alg.qr_alg) return L, Q end -function lq_null!(A::AbstractMatrix, Nᴴ, alg::LAPACK_HouseholderLQ) - check_input(lq_null!, A, Nᴴ, alg) - _lapack_lq_null!(A, Nᴴ; alg.kwargs...) - return Nᴴ -end function lq_null!(A::AbstractMatrix, Nᴴ, alg::LQViaTransposedQR) check_input(lq_null!, A, Nᴴ, alg) lq_null_via_qr!(A, Nᴴ, alg.qr_alg) return Nᴴ end +function lq_full!(A::AbstractMatrix, LQ, alg::DiagonalAlgorithm) + check_input(lq_full!, A, LQ, alg) + L, Q = LQ + _diagonal_lq!(A, L, Q; alg.kwargs...) + return L, Q +end +function lq_compact!(A::AbstractMatrix, LQ, alg::DiagonalAlgorithm) + check_input(lq_compact!, A, LQ, alg) + L, Q = LQ + _diagonal_lq!(A, L, Q; alg.kwargs...) + return L, Q +end +function lq_null!(A::AbstractMatrix, N, alg::DiagonalAlgorithm) + check_input(lq_null!, A, N, alg) + return _diagonal_lq_null!(A, N; alg.kwargs...) +end + +# LAPACK logic +# ------------ function _lapack_lq!(A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix; positive=false, pivoted=false, @@ -177,6 +221,7 @@ function _lapack_lq_null!(A::AbstractMatrix, Nᴴ::AbstractMatrix; end # LQ via transposition and QR +# --------------------------- function lq_via_qr!(A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, qr_alg::AbstractAlgorithm) m, n = size(A) @@ -203,3 +248,25 @@ function lq_null_via_qr!(A::AbstractMatrix, N::AbstractMatrix, qr_alg::AbstractA !isempty(N) && adjoint!(N, Nt) return N end + +# Diagonal logic +# -------------- +function _diagonal_lq!(A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix; + positive::Bool=false) + Ad = diagview(A) + Ld = diagview(L) + Qd = diagview(Q) + if positive + @inbounds @simd for i in eachindex(Ad) + s = sign_safe(Ad[i]) + Qd[i] = s + Ld[i] = conj(s) * Ad[i] + end + else + A === L || copy!(Ld, Ad) + one!(Q) + end + return L, Q +end + +_diagonal_lq_null!(A::AbstractMatrix, N::AbstractMatrix) = N diff --git a/src/interface/lq.jl b/src/interface/lq.jl index 338302ef4..9ebfd6bf5 100644 --- a/src/interface/lq.jl +++ b/src/interface/lq.jl @@ -75,6 +75,9 @@ end function default_lq_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} return LAPACK_HouseholderLQ(; kwargs...) end +function default_lq_algorithm(::Type{T}; kwargs...) where {T<:Diagonal} + return DiagonalAlgorithm(; kwargs...) +end for f in (:lq_full!, :lq_compact!, :lq_null!) @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} diff --git a/test/lq.jl b/test/lq.jl index 581f1782b..1a4281371 100644 --- a/test/lq.jl +++ b/test/lq.jl @@ -2,7 +2,7 @@ using MatrixAlgebraKit using Test using TestExtras using StableRNGs -using LinearAlgebra: diag, I +using LinearAlgebra: diag, I, Diagonal using MatrixAlgebraKit: LQViaTransposedQR, LAPACK_HouseholderQR @testset "lq_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) @@ -205,3 +205,49 @@ end @test Q == Q2 end end + +@testset "lq_compact, lq_full and lq_null for Diagonal{$T}" for T in (Float32, Float64, + ComplexF32, ComplexF64) + rng = StableRNG(123) + atol = eps(real(T))^(3 / 4) + for m in (54, 0) + Ad = randn(rng, T, m) + A = Diagonal(Ad) + + # compact + L, Q = @constinferred lq_compact(A) + @test Q isa Diagonal{T} && size(Q) == (m, m) + @test L isa Diagonal{T} && size(L) == (m, m) + @test L * Q ≈ A + @test isunitary(Q) + + # compact and positive + Lp, Qp = @constinferred lq_compact(A; positive=true) + @test Qp isa Diagonal{T} && size(Qp) == (m, m) + @test Lp isa Diagonal{T} && size(Lp) == (m, m) + @test Lp * Qp ≈ A + @test isunitary(Qp) + @test all(≥(zero(real(T))), real(diag(Lp))) && + all(≈(zero(real(T)); atol), imag(diag(Lp))) + + # full + L, Q = @constinferred lq_full(A) + @test Q isa Diagonal{T} && size(Q) == (m, m) + @test L isa Diagonal{T} && size(L) == (m, m) + @test L * Q ≈ A + @test isunitary(Q) + + # full and positive + Lp, Qp = @constinferred lq_full(A; positive=true) + @test Qp isa Diagonal{T} && size(Qp) == (m, m) + @test Lp isa Diagonal{T} && size(Lp) == (m, m) + @test Lp * Qp ≈ A + @test isunitary(Qp) + @test all(≥(zero(real(T))), real(diag(Lp))) && + all(≈(zero(real(T)); atol), imag(diag(Lp))) + + # null + N = @constinferred lq_null(A) + @test N isa AbstractMatrix{T} && size(N) == (0, m) + end +end From 32f6978ed0a95fd6796e87a9faf5db013dfdb295 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 2 Sep 2025 09:49:34 -0400 Subject: [PATCH 04/11] Add Diagonal eig implementation and tests --- src/implementations/eig.jl | 52 ++++++++++++++++++++++++++++++++++++-- src/interface/eig.jl | 3 +++ test/eig.jl | 31 +++++++++++++++++++---- 3 files changed, 79 insertions(+), 7 deletions(-) diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 515e73867..2fd1c53e8 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -3,11 +3,13 @@ function copy_input(::typeof(eig_full), A::AbstractMatrix) return copy!(similar(A, float(eltype(A))), A) end -function copy_input(::typeof(eig_vals), A::AbstractMatrix) +function copy_input(::typeof(eig_vals), A) return copy_input(eig_full, A) end copy_input(::typeof(eig_trunc), A) = copy_input(eig_full, A) +copy_input(::typeof(eig_full), A::Diagonal) = copy(A) + function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm) m, n = size(A) m == n || throw(DimensionMismatch("square input matrix expected")) @@ -28,6 +30,28 @@ function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::AbstractAlgori return nothing end +function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::DiagonalAlgorithm) + m, n = size(A) + @assert m == n && isdiag(A) + D, V = DV + @assert D isa Diagonal && V isa Diagonal + @check_size(D, (m, m)) + @check_size(V, (m, m)) + # Diagonal doesn't need to promote to complex scalartype since we know it is diagonalizable + @check_scalar(D, A) + @check_scalar(V, A) + return nothing +end +function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::DiagonalAlgorithm) + m, n = size(A) + @assert m == n && isdiag(A) + @assert D isa AbstractVector + @check_size(D, (n,)) + # Diagonal doesn't need to promote to complex scalartype since we know it is diagonalizable + @check_scalar(D, A) + return nothing +end + # Outputs # ------- function initialize_output(::typeof(eig_full!), A::AbstractMatrix, ::AbstractAlgorithm) @@ -47,9 +71,15 @@ function initialize_output(::typeof(eig_trunc!), A::AbstractMatrix, alg::Truncat return initialize_output(eig_full!, A, alg.alg) end +function initialize_output(::typeof(eig_full!), A::Diagonal, ::DiagonalAlgorithm) + return A, similar(A) +end +function initialize_output(::typeof(eig_vals!), A::Diagonal, ::DiagonalAlgorithm) + return diagview(A) +end + # Implementation # -------------- -# actual implementation function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm) check_input(eig_full!, A, DV, alg) D, V = DV @@ -83,6 +113,24 @@ function eig_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm) return truncate!(eig_trunc!, (D, V), alg.trunc) end +# Diagonal logic +# -------------- +function eig_full!(A::Diagonal, (D, V)::Tuple{Diagonal,Diagonal}, alg::DiagonalAlgorithm) + check_input(eig_full!, A, (D, V), alg) + D === A || copy!(D, A) + one!(V) + return D, V +end + +function eig_vals!(A::Diagonal, D::AbstractVector, alg::DiagonalAlgorithm) + check_input(eig_vals!, A, D, alg) + Ad = diagview(A) + D === Ad || copy!(D, Ad) + return D +end + +# GPU logic +# --------- _gpu_geev!(A::AbstractMatrix, D, V) = throw(MethodError(_gpu_geev!, (A, D, V))) function eig_full!(A::AbstractMatrix, DV, alg::GPU_EigAlgorithm) diff --git a/src/interface/eig.jl b/src/interface/eig.jl index 90fde0e01..2796f39ce 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -82,6 +82,9 @@ default_eig_algorithm(T::Type; kwargs...) = throw(MethodError(default_eig_algori function default_eig_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} return LAPACK_Expert(; kwargs...) end +function default_eig_algorithm(::Type{T}; kwargs...) where {T<:Diagonal} + return DiagonalAlgorithm(; kwargs...) +end for f in (:eig_full!, :eig_vals!) @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} diff --git a/test/eig.jl b/test/eig.jl index cdaec9dce..d4d8dcf27 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -5,7 +5,9 @@ using StableRNGs using LinearAlgebra: Diagonal using MatrixAlgebraKit: TruncatedAlgorithm, diagview -@testset "eig_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) +const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) + +@testset "eig_full! for T = $T" for T in BLASFloats rng = StableRNG(123) m = 54 for alg in (LAPACK_Simple(), LAPACK_Expert(), :LAPACK_Simple, LAPACK_Simple) @@ -30,7 +32,7 @@ using MatrixAlgebraKit: TruncatedAlgorithm, diagview end end -@testset "eig_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) +@testset "eig_trunc! for T = $T" for T in BLASFloats rng = StableRNG(123) m = 54 for alg in (LAPACK_Simple(), LAPACK_Expert()) @@ -58,9 +60,7 @@ end end end -@testset "eig_trunc! specify truncation algorithm T = $T" for T in - (Float32, Float64, ComplexF32, - ComplexF64) +@testset "eig_trunc! specify truncation algorithm T = $T" for T in BLASFloats rng = StableRNG(123) m = 4 V = randn(rng, T, m, m) @@ -71,3 +71,24 @@ end @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) @test_throws ArgumentError eig_trunc(A; alg, trunc=(; maxrank=2)) end + +@testset "eig for Diagonal{$T}" for T in BLASFloats + rng = StableRNG(123) + m = 54 + Ad = randn(rng, T, m) + A = Diagonal(Ad) + + D, V = @constinferred eig_full(A) + @test D isa Diagonal{T} && size(D) == size(A) + @test V isa Diagonal{T} && size(V) == size(A) + @test A * V ≈ V * D + + D2 = @constinferred eig_vals(A) + @test D2 isa AbstractVector{T} && length(D2) == m + @test diagview(D) ≈ D2 + + A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01]) + alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) + D2, V2 = @constinferred eig_trunc(A2; alg) + @test diagview(D2) ≈ diagview(A2)[1:2] +end From e35c55e6493e1a223918f352ca9c7a1e1fd022ec Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 2 Sep 2025 14:35:01 -0400 Subject: [PATCH 05/11] Add Diagonal eigh implementation and tests --- src/implementations/eigh.jl | 49 +++++++++++++++++++++++++++++++++++++ src/interface/eigh.jl | 3 +++ test/eigh.jl | 33 ++++++++++++++++++++----- 3 files changed, 79 insertions(+), 6 deletions(-) diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 9b261d05f..860fa485a 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -8,6 +8,8 @@ function copy_input(::typeof(eigh_vals), A::AbstractMatrix) end copy_input(::typeof(eigh_trunc), A) = copy_input(eigh_full, A) +copy_input(::typeof(eigh_full), A::Diagonal) = copy(A) + function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm) m, n = size(A) m == n || throw(DimensionMismatch("square input matrix expected")) @@ -21,6 +23,27 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::AbstractAlgo end function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, ::AbstractAlgorithm) m, n = size(A) + m == n || throw(DimensionMismatch("square input matrix expected")) + @assert D isa AbstractVector + @check_size(D, (n,)) + @check_scalar(D, A, real) + return nothing +end + +function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::DiagonalAlgorithm) + m, n = size(A) + @assert m == n && isdiag(A) + D, V = DV + @assert D isa Diagonal && V isa Diagonal + @check_size(D, (m, m)) + @check_scalar(D, A, real) + @check_size(V, (m, m)) + @check_scalar(V, A) + return nothing +end +function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, ::DiagonalAlgorithm) + m, n = size(A) + @assert m == n && isdiag(A) @assert D isa AbstractVector @check_size(D, (n,)) @check_scalar(D, A, real) @@ -45,6 +68,13 @@ function initialize_output(::typeof(eigh_trunc!), A::AbstractMatrix, return initialize_output(eigh_full!, A, alg.alg) end +function initialize_output(::typeof(eigh_full!), A::Diagonal, ::DiagonalAlgorithm) + return eltype(A) <: Real ? A : similar(A, real(eltype(A))), similar(A) +end +function initialize_output(::typeof(eigh_vals!), A::Diagonal, ::DiagonalAlgorithm) + return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1)) +end + # Implementation # -------------- function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm) @@ -85,6 +115,25 @@ function eigh_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm) return truncate!(eigh_trunc!, (D, V), alg.trunc) end +# Diagonal logic +# -------------- +function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm) + check_input(eigh_full!, A, DV, alg) + D, V = DV + D === A || (diagview(D) .= real.(diagview(A))) + one!(V) + return D, V +end + +function eigh_vals!(A::Diagonal, D, alg::DiagonalAlgorithm) + check_input(eigh_vals!, A, D, alg) + Ad = diagview(A) + D === Ad || (D .= real.(Ad)) + return D +end + +# GPU logic +# --------- _gpu_heevj!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heevj!, (A, Dd, V))) _gpu_heevd!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heevd!, (A, Dd, V))) _gpu_heev!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heev!, (A, Dd, V))) diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index a650ca448..155680ea7 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -93,6 +93,9 @@ end function default_eigh_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...) end +function default_eigh_algorithm(::Type{T}; kwargs...) where {T<:Diagonal} + return DiagonalAlgorithm(; kwargs...) +end for f in (:eigh_full!, :eigh_vals!) @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} diff --git a/test/eigh.jl b/test/eigh.jl index 26e6ce8a8..59f4050ee 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -5,7 +5,9 @@ using StableRNGs using LinearAlgebra: LinearAlgebra, Diagonal, I using MatrixAlgebraKit: TruncatedAlgorithm, diagview -@testset "eigh_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) +const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) + +@testset "eigh_full! for T = $T" for T in BLASFloats rng = StableRNG(123) m = 54 for alg in (LAPACK_MultipleRelativelyRobustRepresentations(), @@ -29,7 +31,7 @@ using MatrixAlgebraKit: TruncatedAlgorithm, diagview end end -@testset "eigh_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) +@testset "eigh_trunc! for T = $T" for T in BLASFloats rng = StableRNG(123) m = 54 for alg in (LAPACK_QRIteration(), @@ -62,10 +64,7 @@ end end end -@testset "eigh_trunc! specify truncation algorithm T = $T" for T in - (Float32, Float64, - ComplexF32, - ComplexF64) +@testset "eigh_trunc! specify truncation algorithm T = $T" for T in BLASFloats rng = StableRNG(123) m = 4 V = qr_compact(randn(rng, T, m, m))[1] @@ -77,3 +76,25 @@ end @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) @test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2)) end + +@testset "eigh for Diagonal{$T}" for T in BLASFloats + rng = StableRNG(123) + m = 54 + Ad = randn(rng, T, m) + Ad .+= conj.(Ad) + A = Diagonal(Ad) + + D, V = @constinferred eigh_full(A) + @test D isa Diagonal{real(T)} && size(D) == size(A) + @test V isa Diagonal{T} && size(V) == size(A) + @test A * V ≈ V * D + + D2 = @constinferred eigh_vals(A) + @test D2 isa AbstractVector{real(T)} && length(D2) == m + @test diagview(D) ≈ D2 + + A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01]) + alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) + D2, V2 = @constinferred eigh_trunc(A2; alg) + @test diagview(D2) ≈ diagview(A2)[1:2] +end From 3264f47c94c524eec29d81431c3c0cbd72be386f Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 2 Sep 2025 15:27:51 -0400 Subject: [PATCH 06/11] Add Diagonal svd implementation and tests --- src/implementations/svd.jl | 75 +++++++++++++++++++++++++++++++++++++- src/interface/svd.jl | 3 ++ test/svd.jl | 50 ++++++++++++++++++++----- 3 files changed, 117 insertions(+), 11 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 15d9137e7..1874d6e2e 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -7,6 +7,8 @@ copy_input(::typeof(svd_compact), A) = copy_input(svd_full, A) copy_input(::typeof(svd_vals), A) = copy_input(svd_full, A) copy_input(::typeof(svd_trunc), A) = copy_input(svd_compact, A) +copy_input(::typeof(svd_full), A::Diagonal) = copy(A) + # TODO: many of these checks are happening again in the LAPACK routines function check_input(::typeof(svd_full!), A::AbstractMatrix, USVᴴ, ::AbstractAlgorithm) m, n = size(A) @@ -42,6 +44,32 @@ function check_input(::typeof(svd_vals!), A::AbstractMatrix, S, ::AbstractAlgori return nothing end +function check_input(::typeof(svd_full!), A::AbstractMatrix, USVᴴ, ::DiagonalAlgorithm) + m, n = size(A) + @assert m == n && isdiag(A) + U, S, Vᴴ = USVᴴ + @assert U isa AbstractMatrix && S isa Diagonal && Vᴴ isa AbstractMatrix + @check_size(U, (m, m)) + @check_scalar(U, A) + @check_size(S, (m, n)) + @check_scalar(S, A, real) + @check_size(Vᴴ, (n, n)) + @check_scalar(Vᴴ, A) + return nothing +end +function check_input(::typeof(svd_compact!), A::AbstractMatrix, USVᴴ, + alg::DiagonalAlgorithm) + return check_input(svd_full!, A, USVᴴ, alg) +end +function check_input(::typeof(svd_vals!), A::AbstractMatrix, S, ::DiagonalAlgorithm) + m, n = size(A) + @assert m == n && isdiag(A) + @assert S isa AbstractVector + @check_size(S, (m,)) + @check_scalar(S, A, real) + return nothing +end + # Outputs # ------- function initialize_output(::typeof(svd_full!), A::AbstractMatrix, ::AbstractAlgorithm) @@ -66,6 +94,18 @@ function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::Truncat return initialize_output(svd_compact!, A, alg.alg) end +function initialize_output(::typeof(svd_full!), A::Diagonal, ::DiagonalAlgorithm) + TA = eltype(A) + TUV = Base.promote_op(sign_safe, TA) + return similar(A, TUV, size(A)), similar(A, real(TA)), similar(A, TUV, size(A)) +end +function initialize_output(::typeof(svd_compact!), A::Diagonal, alg::DiagonalAlgorithm) + return initialize_output(svd_full!, A, alg) +end +function initialize_output(::typeof(svd_vals!), A::Diagonal, ::DiagonalAlgorithm) + return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1)) +end + function gaugefix!(::typeof(svd_full!), U, S, Vᴴ, m::Int, n::Int) for j in 1:max(m, n) if j <= min(m, n) @@ -111,7 +151,6 @@ function gaugefix!(::typeof(svd_trunc!), U, S, Vᴴ, m::Int, n::Int) return (U, S, Vᴴ) end - # Implementation # -------------- function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) @@ -203,7 +242,39 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm) return truncate!(svd_trunc!, USVᴴ′, alg.trunc) end -### GPU logic +# Diagonal logic +# -------------- +function svd_full!(A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithm) + check_input(svd_full!, A, USVᴴ, alg) + Ad = diagview(A) + U, S, Vᴴ = USVᴴ + Sd = diagview(S) + Sd .= abs.(Ad) + p = sortperm(Sd; rev=true) + permute!(Sd, p) + T = eltype(Vᴴ) + zero!(U) + zero!(Vᴴ) + @inbounds for (i, pi) in enumerate(p) + s = Ad[pi] + U[pi, i] = sign_safe(s) + Vᴴ[i, pi] = one(T) + end + return U, S, Vᴴ +end +function svd_compact!(A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithm) + return svd_full!(A, USVᴴ, alg) +end +function svd_vals!(A::AbstractMatrix, S, alg::DiagonalAlgorithm) + check_input(svd_vals!, A, S, alg) + Ad = diagview(A) + S .= abs.(Ad) + sort!(S; rev=true) + return S +end + +# GPU logic +# --------- # placed here to avoid code duplication since much of the logic is replicable across # CUDA and AMDGPU ### diff --git a/src/interface/svd.jl b/src/interface/svd.jl index fd4eb5a5c..659c0dd57 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -97,6 +97,9 @@ end function default_svd_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} return LAPACK_DivideAndConquer(; kwargs...) end +function default_svd_algorithm(::Type{T}; kwargs...) where {T<:Diagonal} + return DiagonalAlgorithm(; kwargs...) +end for f in (:svd_full!, :svd_compact!, :svd_vals!) @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} diff --git a/test/svd.jl b/test/svd.jl index 99dd940f1..5e7d65ecf 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -5,7 +5,9 @@ using StableRNGs using LinearAlgebra: LinearAlgebra, Diagonal, I, isposdef using MatrixAlgebraKit: TruncatedAlgorithm, TruncationKeepAbove, diagview, isisometry -@testset "svd_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) +const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) + +@testset "svd_compact! for T = $T" for T in BLASFloats rng = StableRNG(123) m = 54 @testset "size ($m, $n)" for n in (37, m, 63, 0) @@ -54,7 +56,7 @@ using MatrixAlgebraKit: TruncatedAlgorithm, TruncationKeepAbove, diagview, isiso end end -@testset "svd_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) +@testset "svd_full! for T = $T" for T in BLASFloats rng = StableRNG(123) m = 54 @testset "size ($m, $n)" for n in (37, m, 63, 0) @@ -88,7 +90,7 @@ end end end -@testset "svd_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) +@testset "svd_trunc! for T = $T" for T in BLASFloats rng = StableRNG(123) m = 54 if LinearAlgebra.LAPACK.version() < v"3.12.0" @@ -122,9 +124,7 @@ end end end -@testset "svd_trunc! mix maxrank and tol for T = $T" for T in - (Float32, Float64, ComplexF32, - ComplexF64) +@testset "svd_trunc! mix maxrank and tol for T = $T" for T in BLASFloats rng = StableRNG(123) if LinearAlgebra.LAPACK.version() < v"3.12.0" algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection()) @@ -152,9 +152,7 @@ end end end -@testset "svd_trunc! specify truncation algorithm T = $T" for T in - (Float32, Float64, ComplexF32, - ComplexF64) +@testset "svd_trunc! specify truncation algorithm T = $T" for T in BLASFloats rng = StableRNG(123) m = 4 U = qr_compact(randn(rng, T, m, m))[1] @@ -166,3 +164,37 @@ end @test diagview(S2) ≈ diagview(S)[1:2] rtol = sqrt(eps(real(T))) @test_throws ArgumentError svd_trunc(A; alg, trunc=(; maxrank=2)) end + +@testset "svd for Diagonal{$T}" for T in BLASFloats + rng = StableRNG(123) + for m in (54, 0) + Ad = randn(T, m) + A = Diagonal(Ad) + + U, S, Vᴴ = @constinferred svd_compact(A) + @test U isa AbstractMatrix{T} && size(U) == size(A) + @test Vᴴ isa AbstractMatrix{T} && size(Vᴴ) == size(A) + @test S isa Diagonal{real(T)} && size(S) == size(A) + @test isunitary(U) + @test isunitary(Vᴴ) + @test all(≥(0), diagview(S)) + @test A ≈ U * S * Vᴴ + + U, S, Vᴴ = @constinferred svd_full(A) + @test U isa AbstractMatrix{T} && size(U) == size(A) + @test Vᴴ isa AbstractMatrix{T} && size(Vᴴ) == size(A) + @test S isa Diagonal{real(T)} && size(S) == size(A) + @test isunitary(U) + @test isunitary(Vᴴ) + @test all(≥(0), diagview(S)) + @test A ≈ U * S * Vᴴ + + S2 = @constinferred svd_vals(A) + @test S2 isa AbstractVector{real(T)} && length(S2) == m + @test S2 ≈ diagview(S) + + alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) + U3, S3, Vᴴ3 = @constinferred svd_trunc(A; alg) + @test diagview(S3) ≈ S2[1:min(m, 2)] + end +end From 675ce94d172fef84a7004302d54b9bcb2984e109 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 2 Sep 2025 18:07:20 -0400 Subject: [PATCH 07/11] Make JET happy --- src/implementations/lq.jl | 2 +- src/implementations/qr.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/implementations/lq.jl b/src/implementations/lq.jl index 9c2cdcb62..bc6320ea5 100644 --- a/src/implementations/lq.jl +++ b/src/implementations/lq.jl @@ -269,4 +269,4 @@ function _diagonal_lq!(A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix; return L, Q end -_diagonal_lq_null!(A::AbstractMatrix, N::AbstractMatrix) = N +_diagonal_lq_null!(A::AbstractMatrix, N; positive::Bool=false) = N diff --git a/src/implementations/qr.jl b/src/implementations/qr.jl index 2c6843c1f..582003c02 100644 --- a/src/implementations/qr.jl +++ b/src/implementations/qr.jl @@ -232,7 +232,7 @@ function _diagonal_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; return Q, R end -_diagonal_qr_null!(A::AbstractMatrix, N::AbstractMatrix) = N +_diagonal_qr_null!(A::AbstractMatrix, N; positive::Bool=false) = N ### GPU logic # placed here to avoid code duplication since much of the logic is replicable across From 4add25370e04447cba7a08e40f5cf433be8ebed6 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 3 Sep 2025 10:07:51 -0400 Subject: [PATCH 08/11] Add hermitian/symmetric checks --- src/MatrixAlgebraKit.jl | 2 +- src/implementations/eigh.jl | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 4d97b8284..a8b094fb8 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -4,7 +4,7 @@ using LinearAlgebra: LinearAlgebra using LinearAlgebra: norm # TODO: eleminate if we use VectorInterface.jl? using LinearAlgebra: mul!, rmul!, lmul!, adjoint!, rdiv!, ldiv! using LinearAlgebra: sylvester -using LinearAlgebra: isposdef, ishermitian +using LinearAlgebra: isposdef, ishermitian, issymmetric using LinearAlgebra: Diagonal, diag, diagind, isdiag using LinearAlgebra: UpperTriangular, LowerTriangular using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 860fa485a..7501d643a 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -33,6 +33,7 @@ end function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::DiagonalAlgorithm) m, n = size(A) @assert m == n && isdiag(A) + @assert (eltype(A) <: Real && issymmetric(A)) || ishermitian(A) D, V = DV @assert D isa Diagonal && V isa Diagonal @check_size(D, (m, m)) @@ -44,6 +45,7 @@ end function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, ::DiagonalAlgorithm) m, n = size(A) @assert m == n && isdiag(A) + @assert (eltype(A) <: Real && issymmetric(A)) || ishermitian(A) @assert D isa AbstractVector @check_size(D, (n,)) @check_scalar(D, A, real) From f95d1b3c67540207e3ff7cab2f4f8561cb1e7d20 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 3 Sep 2025 10:14:45 -0400 Subject: [PATCH 09/11] GPU-friendly QR/LQ --- src/implementations/lq.jl | 12 +++++------- src/implementations/qr.jl | 12 +++++------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/src/implementations/lq.jl b/src/implementations/lq.jl index bc6320ea5..391af0376 100644 --- a/src/implementations/lq.jl +++ b/src/implementations/lq.jl @@ -83,7 +83,7 @@ end for f! in (:lq_full!, :lq_compact!) @eval function initialize_output(::typeof($f!), A::AbstractMatrix, ::DiagonalAlgorithm) - return A, similar(A) + return similar(A), A end end @@ -253,17 +253,15 @@ end # -------------- function _diagonal_lq!(A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix; positive::Bool=false) + # note: Ad and Qd might share memory here so order of operations is important Ad = diagview(A) Ld = diagview(L) Qd = diagview(Q) if positive - @inbounds @simd for i in eachindex(Ad) - s = sign_safe(Ad[i]) - Qd[i] = s - Ld[i] = conj(s) * Ad[i] - end + @. Ld = abs(Ad) + @. Qd = sign_safe(Ad) else - A === L || copy!(Ld, Ad) + Ld .= Ad one!(Q) end return L, Q diff --git a/src/implementations/qr.jl b/src/implementations/qr.jl index 582003c02..8a06ffddf 100644 --- a/src/implementations/qr.jl +++ b/src/implementations/qr.jl @@ -83,7 +83,7 @@ end for f! in (:qr_full!, :qr_compact!) @eval function initialize_output(::typeof($f!), A::AbstractMatrix, ::DiagonalAlgorithm) - return similar(A), A + return A, similar(A) end end @@ -216,17 +216,15 @@ end # -------------- function _diagonal_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; positive::Bool=false) + # note: Ad and Qd might share memory here so order of operations is important Ad = diagview(A) Qd = diagview(Q) Rd = diagview(R) if positive - @inbounds @simd for i in eachindex(Ad) - s = sign_safe(Ad[i]) - Qd[i] = s - Rd[i] = conj(s) * Ad[i] - end + @. Rd = abs(Ad) + @. Qd = sign_safe(Ad) else - A === R || copy!(Rd, Ad) + Rd .= Ad one!(Q) end return Q, R From 18ae70e7cc027fdf548ac0e7ba5eb17ffcbea5e3 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 3 Sep 2025 10:38:16 -0400 Subject: [PATCH 10/11] GPU-friendly SVD + correct gaugefix --- src/implementations/svd.jl | 55 ++++++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 1874d6e2e..ab03ea94c 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -248,18 +248,25 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithm) check_input(svd_full!, A, USVᴴ, alg) Ad = diagview(A) U, S, Vᴴ = USVᴴ - Sd = diagview(S) - Sd .= abs.(Ad) - p = sortperm(Sd; rev=true) - permute!(Sd, p) - T = eltype(Vᴴ) + p = sortperm(Ad; by=abs, rev=true) zero!(U) zero!(Vᴴ) - @inbounds for (i, pi) in enumerate(p) - s = Ad[pi] - U[pi, i] = sign_safe(s) - Vᴴ[i, pi] = one(T) + n = size(A, 1) + + pV = (1:n) .+ (p .- 1) .* n + Vᴴ[pV] .= sign_safe.(view(Ad, p)) + + Sd = diagview(S) + if Ad === Sd + @. Sd = abs(Ad) + permute!(Sd, p) + else + Sd .= abs.(view(Ad, p)) end + + p .+= (0:(n - 1)) .* n + U[p] .= Ref(one(eltype(U))) + return U, S, Vᴴ end function svd_compact!(A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithm) @@ -284,12 +291,13 @@ const CUSOLVER_SVDAlgorithm = Union{CUSOLVER_QRIteration, CUSOLVER_Randomized} const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, ROCSOLVER_Jacobi} -const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm} +const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm,ROCSOLVER_SVDAlgorithm} const GPU_SVDPolar = Union{CUSOLVER_SVDPolar} const GPU_Randomized = Union{CUSOLVER_Randomized} -function check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized) +function check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, + alg::CUSOLVER_Randomized) m, n = size(A) minmn = min(m, n) U, S, Vᴴ = USVᴴ @@ -303,7 +311,8 @@ function check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOL return nothing end -function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized}) +function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, + alg::TruncatedAlgorithm{<:CUSOLVER_Randomized}) m, n = size(A) minmn = min(m, n) U = similar(A, (m, m)) @@ -312,10 +321,22 @@ function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::Truncat return (U, S, Vᴴ) end -_gpu_gesvd!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix) = throw(MethodError(_gpu_gesvd!, (A, S, U, Vᴴ))) -_gpu_Xgesvdp!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_Xgesvdp!, (A, S, U, Vᴴ))) -_gpu_Xgesvdr!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_Xgesvdr!, (A, S, U, Vᴴ))) -_gpu_gesvdj!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_gesvdj!, (A, S, U, Vᴴ))) +function _gpu_gesvd!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, + Vᴴ::AbstractMatrix) + throw(MethodError(_gpu_gesvd!, (A, S, U, Vᴴ))) +end +function _gpu_Xgesvdp!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, + Vᴴ::AbstractMatrix; kwargs...) + throw(MethodError(_gpu_Xgesvdp!, (A, S, U, Vᴴ))) +end +function _gpu_Xgesvdr!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, + Vᴴ::AbstractMatrix; kwargs...) + throw(MethodError(_gpu_Xgesvdr!, (A, S, U, Vᴴ))) +end +function _gpu_gesvdj!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, + Vᴴ::AbstractMatrix; kwargs...) + throw(MethodError(_gpu_gesvdj!, (A, S, U, Vᴴ))) +end # GPU SVD implementation function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) check_input(svd_full!, A, USVᴴ, alg) @@ -369,7 +390,7 @@ function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAl throw(ArgumentError("Unsupported SVD algorithm")) end # TODO: make this controllable using a `gaugefix` keyword argument - gaugefix!(svd_compact!, U, S, Vᴴ, size(A)...) + gaugefix!(svd_compact!, U, S, Vᴴ, size(A)...) return USVᴴ end _argmaxabs(x) = reduce(_largest, x; init=zero(eltype(x))) From e1f95f0c524d181a40be0a12076924123b94d1d1 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 3 Sep 2025 10:52:42 -0400 Subject: [PATCH 11/11] Bump v0.3.1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 505a8821c..d759b7a42 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MatrixAlgebraKit" uuid = "6c742aac-3347-4629-af66-fc926824e5e4" authors = ["Jutho and contributors"] -version = "0.3.0" +version = "0.3.1" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"