Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using LinearAlgebra: norm # TODO: eleminate if we use VectorInterface.jl?
using LinearAlgebra: mul!, rmul!, lmul!, adjoint!, rdiv!, ldiv!
using LinearAlgebra: sylvester
using LinearAlgebra: isposdef, ishermitian
using LinearAlgebra: Diagonal, diag, diagind
using LinearAlgebra: Diagonal, diag, diagind, isdiag
using LinearAlgebra: UpperTriangular, LowerTriangular
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt

Expand Down Expand Up @@ -35,7 +35,8 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
LQViaTransposedQR,
CUSOLVER_Simple,
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, CUSOLVER_DivideAndConquer,
ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi, ROCSOLVER_DivideAndConquer, ROCSOLVER_Bisection
ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi, ROCSOLVER_DivideAndConquer, ROCSOLVER_Bisection,
DiagonalAlgorithm
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered

VERSION >= v"1.11.0-DEV.469" &&
Expand Down
52 changes: 50 additions & 2 deletions src/implementations/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
function copy_input(::typeof(eig_full), A::AbstractMatrix)
return copy!(similar(A, float(eltype(A))), A)
end
function copy_input(::typeof(eig_vals), A::AbstractMatrix)
function copy_input(::typeof(eig_vals), A)
return copy_input(eig_full, A)
end
copy_input(::typeof(eig_trunc), A) = copy_input(eig_full, A)

copy_input(::typeof(eig_full), A::Diagonal) = copy(A)

function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm)
m, n = size(A)
m == n || throw(DimensionMismatch("square input matrix expected"))
Expand All @@ -28,6 +30,28 @@ function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::AbstractAlgori
return nothing
end

function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::DiagonalAlgorithm)
m, n = size(A)
@assert m == n && isdiag(A)
D, V = DV
@assert D isa Diagonal && V isa Diagonal
@check_size(D, (m, m))
@check_size(V, (m, m))
# Diagonal doesn't need to promote to complex scalartype since we know it is diagonalizable
@check_scalar(D, A)
@check_scalar(V, A)
return nothing
end
function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::DiagonalAlgorithm)
m, n = size(A)
@assert m == n && isdiag(A)
@assert D isa AbstractVector
@check_size(D, (n,))
# Diagonal doesn't need to promote to complex scalartype since we know it is diagonalizable
@check_scalar(D, A)
return nothing
end

# Outputs
# -------
function initialize_output(::typeof(eig_full!), A::AbstractMatrix, ::AbstractAlgorithm)
Expand All @@ -47,9 +71,15 @@ function initialize_output(::typeof(eig_trunc!), A::AbstractMatrix, alg::Truncat
return initialize_output(eig_full!, A, alg.alg)
end

function initialize_output(::typeof(eig_full!), A::Diagonal, ::DiagonalAlgorithm)
return A, similar(A)
end
function initialize_output(::typeof(eig_vals!), A::Diagonal, ::DiagonalAlgorithm)
return diagview(A)
end

# Implementation
# --------------
# actual implementation
function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm)
check_input(eig_full!, A, DV, alg)
D, V = DV
Expand Down Expand Up @@ -83,6 +113,24 @@ function eig_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm)
return truncate!(eig_trunc!, (D, V), alg.trunc)
end

# Diagonal logic
# --------------
function eig_full!(A::Diagonal, (D, V)::Tuple{Diagonal,Diagonal}, alg::DiagonalAlgorithm)
check_input(eig_full!, A, (D, V), alg)
D === A || copy!(D, A)
one!(V)
return D, V
end

function eig_vals!(A::Diagonal, D::AbstractVector, alg::DiagonalAlgorithm)
check_input(eig_vals!, A, D, alg)
Ad = diagview(A)
D === Ad || copy!(D, Ad)
return D
end

# GPU logic
# ---------
_gpu_geev!(A::AbstractMatrix, D, V) = throw(MethodError(_gpu_geev!, (A, D, V)))

function eig_full!(A::AbstractMatrix, DV, alg::GPU_EigAlgorithm)
Expand Down
49 changes: 49 additions & 0 deletions src/implementations/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ function copy_input(::typeof(eigh_vals), A::AbstractMatrix)
end
copy_input(::typeof(eigh_trunc), A) = copy_input(eigh_full, A)

copy_input(::typeof(eigh_full), A::Diagonal) = copy(A)

function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm)
m, n = size(A)
m == n || throw(DimensionMismatch("square input matrix expected"))
Expand All @@ -21,6 +23,27 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::AbstractAlgo
end
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, ::AbstractAlgorithm)
m, n = size(A)
m == n || throw(DimensionMismatch("square input matrix expected"))
@assert D isa AbstractVector
@check_size(D, (n,))
@check_scalar(D, A, real)
return nothing
end

function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::DiagonalAlgorithm)
m, n = size(A)
@assert m == n && isdiag(A)
D, V = DV
@assert D isa Diagonal && V isa Diagonal
@check_size(D, (m, m))
@check_scalar(D, A, real)
@check_size(V, (m, m))
@check_scalar(V, A)
return nothing
end
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, ::DiagonalAlgorithm)
m, n = size(A)
@assert m == n && isdiag(A)
@assert D isa AbstractVector
@check_size(D, (n,))
@check_scalar(D, A, real)
Expand All @@ -45,6 +68,13 @@ function initialize_output(::typeof(eigh_trunc!), A::AbstractMatrix,
return initialize_output(eigh_full!, A, alg.alg)
end

function initialize_output(::typeof(eigh_full!), A::Diagonal, ::DiagonalAlgorithm)
return eltype(A) <: Real ? A : similar(A, real(eltype(A))), similar(A)
end
function initialize_output(::typeof(eigh_vals!), A::Diagonal, ::DiagonalAlgorithm)
return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1))
end

# Implementation
# --------------
function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm)
Expand Down Expand Up @@ -85,6 +115,25 @@ function eigh_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm)
return truncate!(eigh_trunc!, (D, V), alg.trunc)
end

# Diagonal logic
# --------------
function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm)
check_input(eigh_full!, A, DV, alg)
D, V = DV
D === A || (diagview(D) .= real.(diagview(A)))
Comment thread
lkdvos marked this conversation as resolved.
one!(V)
return D, V
end

function eigh_vals!(A::Diagonal, D, alg::DiagonalAlgorithm)
check_input(eigh_vals!, A, D, alg)
Ad = diagview(A)
D === Ad || (D .= real.(Ad))
return D
end

# GPU logic
# ---------
_gpu_heevj!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heevj!, (A, Dd, V)))
_gpu_heevd!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heevd!, (A, Dd, V)))
_gpu_heev!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heev!, (A, Dd, V)))
Expand Down
107 changes: 87 additions & 20 deletions src/implementations/lq.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
# Inputs
# ------
function copy_input(::typeof(lq_full), A::AbstractMatrix)
return copy!(similar(A, float(eltype(A))), A)
end
function copy_input(::typeof(lq_compact), A::AbstractMatrix)
return copy!(similar(A, float(eltype(A))), A)
end
function copy_input(::typeof(lq_null), A::AbstractMatrix)
return copy!(similar(A, float(eltype(A))), A)
for f in (:lq_full, :lq_compact, :lq_null)
@eval function copy_input(::typeof($f), A::AbstractMatrix)
return copy!(similar(A, float(eltype(A))), A)
Comment thread
lkdvos marked this conversation as resolved.
end
@eval copy_input(::typeof($f), A::Diagonal) = copy(A)
end

function check_input(::typeof(lq_full!), A::AbstractMatrix, LQ, ::AbstractAlgorithm)
Expand Down Expand Up @@ -40,6 +37,28 @@ function check_input(::typeof(lq_null!), A::AbstractMatrix, Nᴴ, ::AbstractAlgo
return nothing
end

function check_input(::typeof(lq_full!), A::AbstractMatrix, (L, Q), ::DiagonalAlgorithm)
m, n = size(A)
@assert m == n && isdiag(A)
@assert Q isa Diagonal && L isa Diagonal
isempty(L) || @check_size(L, (m, n))
@check_scalar(L, A)
@check_size(Q, (n, n))
@check_scalar(Q, A)
return nothing
end
function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ, alg::DiagonalAlgorithm)
return check_input(lq_full!, A, LQ, alg)
end
function check_input(::typeof(lq_null!), A::AbstractMatrix, N, ::DiagonalAlgorithm)
m, n = size(A)
@assert m == n && isdiag(A)
@assert N isa AbstractMatrix
@check_size(N, (0, m))
@check_scalar(N, A)
return nothing
end

# Outputs
# -------
function initialize_output(::typeof(lq_full!), A::AbstractMatrix, ::AbstractAlgorithm)
Expand All @@ -62,44 +81,69 @@ function initialize_output(::typeof(lq_null!), A::AbstractMatrix, ::AbstractAlgo
return Nᴴ
end

for f! in (:lq_full!, :lq_compact!)
@eval function initialize_output(::typeof($f!), A::AbstractMatrix, ::DiagonalAlgorithm)
return A, similar(A)
end
end

# Implementation
# --------------
# actual implementation
function lq_full!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ)
check_input(lq_full!, A, LQ, alg)
L, Q = LQ
_lapack_lq!(A, L, Q; alg.kwargs...)
return L, Q
end
function lq_full!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR)
check_input(lq_full!, A, LQ, alg)
L, Q = LQ
lq_via_qr!(A, L, Q, alg.qr_alg)
return L, Q
end
function lq_compact!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ)
check_input(lq_compact!, A, LQ, alg)
L, Q = LQ
_lapack_lq!(A, L, Q; alg.kwargs...)
return L, Q
end
function lq_null!(A::AbstractMatrix, Nᴴ, alg::LAPACK_HouseholderLQ)
check_input(lq_null!, A, Nᴴ, alg)
_lapack_lq_null!(A, Nᴴ; alg.kwargs...)
return Nᴴ
end

function lq_full!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR)
check_input(lq_full!, A, LQ, alg)
L, Q = LQ
lq_via_qr!(A, L, Q, alg.qr_alg)
return L, Q
end
function lq_compact!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR)
check_input(lq_compact!, A, LQ, alg)
L, Q = LQ
lq_via_qr!(A, L, Q, alg.qr_alg)
return L, Q
end
function lq_null!(A::AbstractMatrix, Nᴴ, alg::LAPACK_HouseholderLQ)
check_input(lq_null!, A, Nᴴ, alg)
_lapack_lq_null!(A, Nᴴ; alg.kwargs...)
return Nᴴ
end
function lq_null!(A::AbstractMatrix, Nᴴ, alg::LQViaTransposedQR)
check_input(lq_null!, A, Nᴴ, alg)
lq_null_via_qr!(A, Nᴴ, alg.qr_alg)
return Nᴴ
end

function lq_full!(A::AbstractMatrix, LQ, alg::DiagonalAlgorithm)
check_input(lq_full!, A, LQ, alg)
L, Q = LQ
_diagonal_lq!(A, L, Q; alg.kwargs...)
return L, Q
end
function lq_compact!(A::AbstractMatrix, LQ, alg::DiagonalAlgorithm)
check_input(lq_compact!, A, LQ, alg)
L, Q = LQ
_diagonal_lq!(A, L, Q; alg.kwargs...)
return L, Q
end
function lq_null!(A::AbstractMatrix, N, alg::DiagonalAlgorithm)
check_input(lq_null!, A, N, alg)
return _diagonal_lq_null!(A, N; alg.kwargs...)
end

# LAPACK logic
# ------------
function _lapack_lq!(A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
positive=false,
pivoted=false,
Expand Down Expand Up @@ -177,6 +221,7 @@ function _lapack_lq_null!(A::AbstractMatrix, Nᴴ::AbstractMatrix;
end

# LQ via transposition and QR
# ---------------------------
function lq_via_qr!(A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix,
qr_alg::AbstractAlgorithm)
m, n = size(A)
Expand All @@ -203,3 +248,25 @@ function lq_null_via_qr!(A::AbstractMatrix, N::AbstractMatrix, qr_alg::AbstractA
!isempty(N) && adjoint!(N, Nt)
return N
end

# Diagonal logic
# --------------
function _diagonal_lq!(A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
positive::Bool=false)
Ad = diagview(A)
Ld = diagview(L)
Qd = diagview(Q)
if positive
@inbounds @simd for i in eachindex(Ad)
s = sign_safe(Ad[i])
Qd[i] = s
Ld[i] = conj(s) * Ad[i]
end
Comment thread
lkdvos marked this conversation as resolved.
Outdated
else
A === L || copy!(Ld, Ad)
one!(Q)
end
return L, Q
end

_diagonal_lq_null!(A::AbstractMatrix, N; positive::Bool=false) = N
Loading
Loading