-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathMatrixAlgebraKitGenericLinearAlgebraExt.jl
More file actions
136 lines (111 loc) · 4.81 KB
/
MatrixAlgebraKitGenericLinearAlgebraExt.jl
File metadata and controls
136 lines (111 loc) · 4.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
module MatrixAlgebraKitGenericLinearAlgebraExt
using MatrixAlgebraKit
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, default_fixgauge
using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr!
using LinearAlgebra: I, Diagonal, lmul!
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return GLA_QRIteration()
end
for f! in (:svd_compact!, :svd_full!)
@eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing, nothing)
end
MatrixAlgebraKit.initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing
function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration)
F = svd!(A)
U, S, Vᴴ = F.U, Diagonal(F.S), F.Vt
do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ)
return U, S, Vᴴ
end
function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration)
F = svd!(A; full = true)
U, Vᴴ = F.U, F.Vt
S = MatrixAlgebraKit.zero!(similar(F.S, (size(U, 2), size(Vᴴ, 1))))
diagview(S) .= F.S
do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ)
return U, S, Vᴴ
end
function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, ::GLA_QRIteration)
return svdvals!(A)
end
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return GLA_QRIteration(; kwargs...)
end
MatrixAlgebraKit.initialize_output(::typeof(eigh_full!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing)
MatrixAlgebraKit.initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing
function MatrixAlgebraKit.eigh_full!(A::AbstractMatrix, DV, ::GLA_QRIteration)
eigval, eigvec = eigen!(Hermitian(A); sortby = real)
return Diagonal(eigval::AbstractVector{real(eltype(A))}), eigvec::AbstractMatrix{eltype(A)}
end
function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix, D, ::GLA_QRIteration)
return eigvals!(Hermitian(A); sortby = real)
end
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}}
return GLA_HouseholderQR(; kwargs...)
end
function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::GLA_HouseholderQR)
check_input(qr_full!, A, QR, alg)
Q, R = QR
return _gla_householder_qr!(A, Q, R; alg.kwargs...)
end
function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, alg::GLA_HouseholderQR)
check_input(qr_compact!, A, QR, alg)
Q, R = QR
return _gla_householder_qr!(A, Q, R; alg.kwargs...)
end
function MatrixAlgebraKit.qr_null!(A::AbstractMatrix, N, alg::GLA_HouseholderQR)
check_input(qr_null!, A, N, alg)
return _gla_householder_qr_null!(A, N; alg.kwargs...)
end
function _gla_householder_qr!(A::AbstractMatrix, Q, R; positive = false, blocksize = 1, pivoted = false)
pivoted && throw(ArgumentError("Only pivoted = false implemented for GLA_HouseholderQR."))
(blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for GLA_HouseholderQR."))
m, n = size(A)
k = min(m, n)
Q̃, R̃ = qr!(A)
lmul!(Q̃, MatrixAlgebraKit.one!(Q))
if positive
@inbounds for j in 1:k
s = sign_safe(R̃[j, j])
@simd for i in 1:m
Q[i, j] *= s
end
end
end
computeR = length(R) > 0
if computeR
if positive
@inbounds for j in n:-1:1
@simd for i in 1:min(k, j)
R[i, j] = R̃[i, j] * conj(sign_safe(R̃[i, i]))
end
@simd for i in (min(k, j) + 1):size(R, 1)
R[i, j] = zero(eltype(R))
end
end
else
R[1:k, :] .= R̃
MatrixAlgebraKit.zero!(@view(R[(k + 1):end, :]))
end
end
return Q, R
end
function _gla_householder_qr_null!(
A::AbstractMatrix, N::AbstractMatrix;
positive = false, blocksize = 1, pivoted = false
)
pivoted && throw(ArgumentError("Only pivoted = false implemented for GLA_HouseholderQR."))
(blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for GLA_HouseholderQR."))
m, n = size(A)
minmn = min(m, n)
fill!(N, zero(eltype(N)))
one!(view(N, (minmn + 1):m, 1:(m - minmn)))
Q̃, = qr!(A)
lmul!(Q̃, N)
return N
end
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}}
return MatrixAlgebraKit.LQViaTransposedQR(GLA_HouseholderQR(; kwargs...))
end
end