-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathMatrixAlgebraKitGenericLinearAlgebraExt.jl
More file actions
119 lines (103 loc) · 3.91 KB
/
MatrixAlgebraKitGenericLinearAlgebraExt.jl
File metadata and controls
119 lines (103 loc) · 3.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
module MatrixAlgebraKitGenericLinearAlgebraExt
using MatrixAlgebraKit
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge
using MatrixAlgebraKit: GLA, Driver
import MatrixAlgebraKit: gesvd!, heev!
using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr!
using LinearAlgebra: I, Diagonal, lmul!
const GlaFloat = Union{BigFloat, Complex{BigFloat}}
const GlaStridedVecOrMatrix{T <: GlaFloat} = Union{StridedVector{T}, StridedMatrix{T}}
MatrixAlgebraKit.default_driver(::Type{<:QRIteration}, ::Type{TA}) where {TA <: GlaStridedVecOrMatrix} = GLA()
MatrixAlgebraKit.supports_svd_full(::GLA, f::Symbol) = f === :qr_iteration
function MatrixAlgebraKit.default_svd_algorithm(
::Type{T};
driver::Driver = GLA(), kwargs...
) where {T <: GlaStridedVecOrMatrix}
return QRIteration(; driver, kwargs...)
end
function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...)
m, n = size(A)
if length(U) == 0 && length(Vᴴ) == 0
Sv = svdvals!(A)
copyto!(S, Sv)
else
minmn = min(m, n)
# full SVD if U has m columns or Vᴴ has n rows (beyond the compact min(m,n))
full = (length(U) > 0 && size(U, 2) > minmn) || (length(Vᴴ) > 0 && size(Vᴴ, 1) > minmn)
F = svd!(A; full = full)
length(S) > 0 && copyto!(S, F.S)
length(U) > 0 && copyto!(U, F.U)
length(Vᴴ) > 0 && copyto!(Vᴴ, F.Vt)
end
return S, U, Vᴴ
end
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; driver::Driver = GLA(), kwargs...) where {T <: GlaStridedVecOrMatrix}
return QRIteration(; driver, kwargs...)
end
function heev!(::GLA, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...)
if length(V) > 0
eigval, eigvec = eigen!(Hermitian(A); sortby = real)
copyto!(Dd, eigval)
copyto!(V, eigvec)
else
eigval = eigvals!(Hermitian(A); sortby = real)
copyto!(Dd, eigval)
end
return Dd, V
end
function MatrixAlgebraKit.householder_qr!(
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
blocksize <= 1 ||
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
pivoted &&
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))
m, n = size(A)
minmn = min(m, n)
computeR = length(R) > 0
# compute QR
Q̃, R̃ = qr!(A)
lmul!(Q̃, MatrixAlgebraKit.one!(Q))
if positive
@inbounds for j in 1:minmn
s = sign_safe(R̃[j, j])
@simd for i in 1:m
Q[i, j] *= s
end
end
end
if computeR
if positive
@inbounds for j in n:-1:1
@simd for i in 1:min(minmn, j)
R[i, j] = R̃[i, j] * conj(sign_safe(R̃[i, i]))
end
@simd for i in (min(minmn, j) + 1):size(R, 1)
R[i, j] = zero(eltype(R))
end
end
else
R[1:minmn, :] .= R̃
MatrixAlgebraKit.zero!(@view(R[(minmn + 1):end, :]))
end
end
return Q, R
end
function MatrixAlgebraKit.householder_qr_null!(
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, N::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
blocksize <= 1 ||
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
pivoted &&
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))
m, n = size(A)
minmn = min(m, n)
zero!(N)
one!(view(N, (minmn + 1):m, 1:(m - minmn)))
Q̃, = qr!(A)
return lmul!(Q̃, N)
end
MatrixAlgebraKit.left_orth_alg(alg::GLA_HouseholderQR) = MatrixAlgebraKit.LeftOrthViaQR(alg)
end