-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathMatrixAlgebraKitGenericLinearAlgebraExt.jl
More file actions
122 lines (99 loc) · 4.13 KB
/
MatrixAlgebraKitGenericLinearAlgebraExt.jl
File metadata and controls
122 lines (99 loc) · 4.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
module MatrixAlgebraKitGenericLinearAlgebraExt
using MatrixAlgebraKit
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge
using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr!
using LinearAlgebra: I, Diagonal, lmul!
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return GLA_QRIteration()
end
for f! in (:svd_compact!, :svd_full!)
@eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing, nothing)
end
MatrixAlgebraKit.initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing
function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration)
F = svd!(A)
U, S, Vᴴ = F.U, Diagonal(F.S), F.Vt
do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ)
return U, S, Vᴴ
end
function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration)
F = svd!(A; full = true)
U, Vᴴ = F.U, F.Vt
S = MatrixAlgebraKit.zero!(similar(F.S, (size(U, 2), size(Vᴴ, 1))))
diagview(S) .= F.S
do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ)
return U, S, Vᴴ
end
function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, ::GLA_QRIteration)
return svdvals!(A)
end
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return GLA_QRIteration(; kwargs...)
end
MatrixAlgebraKit.initialize_output(::typeof(eigh_full!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing)
MatrixAlgebraKit.initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing
function MatrixAlgebraKit.eigh_full!(A::AbstractMatrix, DV, ::GLA_QRIteration)
eigval, eigvec = eigen!(Hermitian(A); sortby = real)
return Diagonal(eigval::AbstractVector{real(eltype(A))}), eigvec::AbstractMatrix{eltype(A)}
end
function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix, D, ::GLA_QRIteration)
return eigvals!(Hermitian(A); sortby = real)
end
function MatrixAlgebraKit.householder_qr!(
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
blocksize <= 1 ||
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
pivoted &&
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))
m, n = size(A)
minmn = min(m, n)
computeR = length(R) > 0
# compute QR
Q̃, R̃ = qr!(A)
lmul!(Q̃, MatrixAlgebraKit.one!(Q))
if positive
@inbounds for j in 1:minmn
s = sign_safe(R̃[j, j])
@simd for i in 1:m
Q[i, j] *= s
end
end
end
if computeR
if positive
@inbounds for j in n:-1:1
@simd for i in 1:min(minmn, j)
R[i, j] = R̃[i, j] * conj(sign_safe(R̃[i, i]))
end
@simd for i in (min(minmn, j) + 1):size(R, 1)
R[i, j] = zero(eltype(R))
end
end
else
R[1:minmn, :] .= R̃
MatrixAlgebraKit.zero!(@view(R[(minmn + 1):end, :]))
end
end
return Q, R
end
function MatrixAlgebraKit.householder_qr_null!(
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, N::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
blocksize <= 1 ||
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
pivoted &&
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))
m, n = size(A)
minmn = min(m, n)
zero!(N)
one!(view(N, (minmn + 1):m, 1:(m - minmn)))
Q̃, = qr!(A)
return lmul!(Q̃, N)
end
MatrixAlgebraKit.left_orth_alg(alg::GLA_HouseholderQR) = MatrixAlgebraKit.LeftOrthViaQR(alg)
end