|
1 | 1 | module MatrixAlgebraKitGenericLinearAlgebraExt |
2 | 2 |
|
3 | 3 | using MatrixAlgebraKit |
4 | | -using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, default_fixgauge |
5 | | -using MatrixAlgebraKit: left_orth_alg |
| 4 | +using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge |
6 | 5 | using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr! |
7 | 6 | using LinearAlgebra: I, Diagonal, lmul! |
8 | 7 |
|
@@ -57,81 +56,65 @@ function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix, D, ::GLA_QRIteration) |
57 | 56 | return eigvals!(Hermitian(A); sortby = real) |
58 | 57 | end |
59 | 58 |
|
60 | | -function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}} |
61 | | - return GLA_HouseholderQR(; kwargs...) |
62 | | -end |
63 | | - |
64 | | -function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::GLA_HouseholderQR) |
65 | | - check_input(qr_full!, A, QR, alg) |
66 | | - Q, R = QR |
67 | | - return _gla_householder_qr!(A, Q, R; alg.kwargs...) |
68 | | -end |
69 | | - |
70 | | -function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, alg::GLA_HouseholderQR) |
71 | | - check_input(qr_compact!, A, QR, alg) |
72 | | - Q, R = QR |
73 | | - return _gla_householder_qr!(A, Q, R; alg.kwargs...) |
74 | | -end |
75 | | - |
76 | | -function MatrixAlgebraKit.qr_null!(A::AbstractMatrix, N, alg::GLA_HouseholderQR) |
77 | | - check_input(qr_null!, A, N, alg) |
78 | | - return _gla_householder_qr_null!(A, N; alg.kwargs...) |
79 | | -end |
80 | | - |
81 | | -function _gla_householder_qr!(A::AbstractMatrix, Q, R; positive = true, blocksize = 1, pivoted = false) |
82 | | - pivoted && throw(ArgumentError("Only pivoted = false implemented for GLA_HouseholderQR.")) |
83 | | - (blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for GLA_HouseholderQR.")) |
| 59 | +function MatrixAlgebraKit.householder_qr!( |
| 60 | + driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; |
| 61 | + positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1 |
| 62 | + ) |
| 63 | + blocksize == 1 || |
| 64 | + throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition")) |
| 65 | + pivoted && |
| 66 | + throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition")) |
84 | 67 |
|
85 | 68 | m, n = size(A) |
86 | | - k = min(m, n) |
| 69 | + minmn = min(m, n) |
| 70 | + computeR = length(R) > 0 |
| 71 | + |
| 72 | + # compute QR |
87 | 73 | Q̃, R̃ = qr!(A) |
88 | 74 | lmul!(Q̃, MatrixAlgebraKit.one!(Q)) |
89 | 75 |
|
90 | 76 | if positive |
91 | | - @inbounds for j in 1:k |
| 77 | + @inbounds for j in 1:minmn |
92 | 78 | s = sign_safe(R̃[j, j]) |
93 | 79 | @simd for i in 1:m |
94 | 80 | Q[i, j] *= s |
95 | 81 | end |
96 | 82 | end |
97 | 83 | end |
98 | 84 |
|
99 | | - computeR = length(R) > 0 |
100 | 85 | if computeR |
101 | 86 | if positive |
102 | 87 | @inbounds for j in n:-1:1 |
103 | | - @simd for i in 1:min(k, j) |
| 88 | + @simd for i in 1:min(minmn, j) |
104 | 89 | R[i, j] = R̃[i, j] * conj(sign_safe(R̃[i, i])) |
105 | 90 | end |
106 | | - @simd for i in (min(k, j) + 1):size(R, 1) |
| 91 | + @simd for i in (min(minmn, j) + 1):size(R, 1) |
107 | 92 | R[i, j] = zero(eltype(R)) |
108 | 93 | end |
109 | 94 | end |
110 | 95 | else |
111 | | - R[1:k, :] .= R̃ |
112 | | - MatrixAlgebraKit.zero!(@view(R[(k + 1):end, :])) |
| 96 | + R[1:minmn, :] .= R̃ |
| 97 | + MatrixAlgebraKit.zero!(@view(R[(minmn + 1):end, :])) |
113 | 98 | end |
114 | 99 | end |
115 | 100 | return Q, R |
116 | 101 | end |
117 | 102 |
|
118 | | -function _gla_householder_qr_null!( |
119 | | - A::AbstractMatrix, N::AbstractMatrix; |
120 | | - positive = true, blocksize = 1, pivoted = false |
| 103 | +function MatrixAlgebraKit.householder_qr_null!( |
| 104 | + driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, N::AbstractMatrix; |
| 105 | + positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1 |
121 | 106 | ) |
122 | | - pivoted && throw(ArgumentError("Only pivoted = false implemented for GLA_HouseholderQR.")) |
123 | | - (blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for GLA_HouseholderQR.")) |
| 107 | + blocksize == 1 || |
| 108 | + throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition")) |
| 109 | + pivoted && |
| 110 | + throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition")) |
| 111 | + |
124 | 112 | m, n = size(A) |
125 | 113 | minmn = min(m, n) |
126 | | - fill!(N, zero(eltype(N))) |
| 114 | + zero!(N) |
127 | 115 | one!(view(N, (minmn + 1):m, 1:(m - minmn))) |
128 | 116 | Q̃, = qr!(A) |
129 | | - lmul!(Q̃, N) |
130 | | - return N |
131 | | -end |
132 | | - |
133 | | -function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}} |
134 | | - return MatrixAlgebraKit.LQViaTransposedQR(GLA_HouseholderQR(; kwargs...)) |
| 117 | + return lmul!(Q̃, N) |
135 | 118 | end |
136 | 119 |
|
137 | 120 | MatrixAlgebraKit.left_orth_alg(alg::GLA_HouseholderQR) = MatrixAlgebraKit.LeftOrthViaQR(alg) |
|
0 commit comments