Skip to content

Commit bc9ae9b

Browse files
lkdvoskshyatt
authored andcommitted
update GLA QR implementation
1 parent fbeab4c commit bc9ae9b

2 files changed

Lines changed: 29 additions & 46 deletions

File tree

ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl

Lines changed: 28 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
module MatrixAlgebraKitGenericLinearAlgebraExt
22

33
using MatrixAlgebraKit
4-
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, default_fixgauge
5-
using MatrixAlgebraKit: left_orth_alg
4+
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge
65
using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr!
76
using LinearAlgebra: I, Diagonal, lmul!
87

@@ -57,81 +56,65 @@ function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix, D, ::GLA_QRIteration)
5756
return eigvals!(Hermitian(A); sortby = real)
5857
end
5958

60-
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}}
61-
return GLA_HouseholderQR(; kwargs...)
62-
end
63-
64-
function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::GLA_HouseholderQR)
65-
check_input(qr_full!, A, QR, alg)
66-
Q, R = QR
67-
return _gla_householder_qr!(A, Q, R; alg.kwargs...)
68-
end
69-
70-
function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, alg::GLA_HouseholderQR)
71-
check_input(qr_compact!, A, QR, alg)
72-
Q, R = QR
73-
return _gla_householder_qr!(A, Q, R; alg.kwargs...)
74-
end
75-
76-
function MatrixAlgebraKit.qr_null!(A::AbstractMatrix, N, alg::GLA_HouseholderQR)
77-
check_input(qr_null!, A, N, alg)
78-
return _gla_householder_qr_null!(A, N; alg.kwargs...)
79-
end
80-
81-
function _gla_householder_qr!(A::AbstractMatrix, Q, R; positive = true, blocksize = 1, pivoted = false)
82-
pivoted && throw(ArgumentError("Only pivoted = false implemented for GLA_HouseholderQR."))
83-
(blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for GLA_HouseholderQR."))
59+
function MatrixAlgebraKit.householder_qr!(
60+
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
61+
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
62+
)
63+
blocksize == 1 ||
64+
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
65+
pivoted &&
66+
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))
8467

8568
m, n = size(A)
86-
k = min(m, n)
69+
minmn = min(m, n)
70+
computeR = length(R) > 0
71+
72+
# compute QR
8773
Q̃, R̃ = qr!(A)
8874
lmul!(Q̃, MatrixAlgebraKit.one!(Q))
8975

9076
if positive
91-
@inbounds for j in 1:k
77+
@inbounds for j in 1:minmn
9278
s = sign_safe(R̃[j, j])
9379
@simd for i in 1:m
9480
Q[i, j] *= s
9581
end
9682
end
9783
end
9884

99-
computeR = length(R) > 0
10085
if computeR
10186
if positive
10287
@inbounds for j in n:-1:1
103-
@simd for i in 1:min(k, j)
88+
@simd for i in 1:min(minmn, j)
10489
R[i, j] = R̃[i, j] * conj(sign_safe(R̃[i, i]))
10590
end
106-
@simd for i in (min(k, j) + 1):size(R, 1)
91+
@simd for i in (min(minmn, j) + 1):size(R, 1)
10792
R[i, j] = zero(eltype(R))
10893
end
10994
end
11095
else
111-
R[1:k, :] .=
112-
MatrixAlgebraKit.zero!(@view(R[(k + 1):end, :]))
96+
R[1:minmn, :] .=
97+
MatrixAlgebraKit.zero!(@view(R[(minmn + 1):end, :]))
11398
end
11499
end
115100
return Q, R
116101
end
117102

118-
function _gla_householder_qr_null!(
119-
A::AbstractMatrix, N::AbstractMatrix;
120-
positive = true, blocksize = 1, pivoted = false
103+
function MatrixAlgebraKit.householder_qr_null!(
104+
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, N::AbstractMatrix;
105+
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
121106
)
122-
pivoted && throw(ArgumentError("Only pivoted = false implemented for GLA_HouseholderQR."))
123-
(blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for GLA_HouseholderQR."))
107+
blocksize == 1 ||
108+
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
109+
pivoted &&
110+
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))
111+
124112
m, n = size(A)
125113
minmn = min(m, n)
126-
fill!(N, zero(eltype(N)))
114+
zero!(N)
127115
one!(view(N, (minmn + 1):m, 1:(m - minmn)))
128116
Q̃, = qr!(A)
129-
lmul!(Q̃, N)
130-
return N
131-
end
132-
133-
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}}
134-
return MatrixAlgebraKit.LQViaTransposedQR(GLA_HouseholderQR(; kwargs...))
117+
return lmul!(Q̃, N)
135118
end
136119

137120
MatrixAlgebraKit.left_orth_alg(alg::GLA_HouseholderQR) = MatrixAlgebraKit.LeftOrthViaQR(alg)

src/implementations/qr.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ _diagonal_qr_null!(A::AbstractMatrix, N; positive::Bool = true) = N
331331

332332
# Deprecations
333333
# ------------
334-
for drivertype in (:LAPACK, :CUSOLVER, :ROCSOLVER, :Native)
334+
for drivertype in (:LAPACK, :CUSOLVER, :ROCSOLVER, :Native, :GLA)
335335
algtype = Symbol(drivertype, :_HouseholderQR)
336336
@eval begin
337337
Base.@deprecate(

0 commit comments

Comments
 (0)