Skip to content

Commit c255dad

Browse files
JuthoKatharine Hyatt
andauthored
WIP: native algorithms (#90)
* native_qr * add lq and tests * add defaults and more tests * Test both native and GLA * fix lq --------- Co-authored-by: Katharine Hyatt <katharine.s.hyatt@gmail.com>
1 parent 91a8a69 commit c255dad

9 files changed

Lines changed: 339 additions & 8 deletions

File tree

src/MatrixAlgebraKit.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ export left_polar!, right_polar!
3131
export left_orth, right_orth, left_null, right_null
3232
export left_orth!, right_orth!, left_null!, right_null!
3333

34+
export Native_HouseholderQR, Native_HouseholderLQ
3435
export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
3536
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
3637
LAPACK_DivideAndConquer, LAPACK_Jacobi
@@ -72,6 +73,7 @@ export notrunc, truncrank, trunctol, truncerror, truncfilter
7273
end
7374

7475
include("common/defaults.jl")
76+
include("common/householder.jl")
7577
include("common/initialization.jl")
7678
include("common/pullbacks.jl")
7779
include("common/safemethods.jl")

src/common/householder.jl

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
const IndexRange{T <: Integer} = Base.AbstractRange{T}
2+
3+
# Elementary Householder reflection
4+
struct Householder{T, V <: AbstractVector, R <: IndexRange}
5+
β::T
6+
v::V
7+
r::R
8+
end
9+
Base.adjoint(H::Householder) = Householder(conj(H.β), H.v, H.r)
10+
11+
function householder(x::AbstractVector, r::IndexRange = axes(x, 1), k = first(r))
12+
i = findfirst(==(k), r)
13+
i == nothing && error("k = $k should be in the range r = $r")
14+
β, v, ν = _householder!(x[r], i)
15+
return Householder(β, v, r), ν
16+
end
17+
# Householder reflector h that zeros the elements A[r,col] (except for A[k,col]) upon lmul!(h,A)
18+
function householder(A::AbstractMatrix, r::IndexRange, col::Int, k = first(r))
19+
i = findfirst(==(k), r)
20+
i == nothing && error("k = $k should be in the range r = $r")
21+
β, v, ν = _householder!(A[r, col], i)
22+
return Householder(β, v, r), ν
23+
end
24+
# Householder reflector that zeros the elements A[row,r] (except for A[row,k]) upon rmul!(A,h')
25+
function householder(A::AbstractMatrix, row::Int, r::IndexRange, k = first(r))
26+
i = findfirst(==(k), r)
27+
i == nothing && error("k = $k should be in the range r = $r")
28+
β, v, ν = _householder!(conj!(A[row, r]), i)
29+
return Householder(β, v, r), ν
30+
end
31+
32+
# generate Householder vector based on vector v, such that applying the reflection
33+
# to v yields a vector with single non-zero element on position i, whose value is
34+
# positive and thus equal to norm(v)
35+
function _householder!(v::AbstractVector{T}, i::Int = 1) where {T}
36+
β::T = zero(T)
37+
@inbounds begin
38+
σ = abs2(zero(T))
39+
@simd for k in 1:(i - 1)
40+
σ += abs2(v[k])
41+
end
42+
@simd for k in (i + 1):length(v)
43+
σ += abs2(v[k])
44+
end
45+
vi = v[i]
46+
ν = sqrt(abs2(vi) + σ)
47+
48+
if σ == 0 && vi == ν
49+
β = zero(vi)
50+
else
51+
if real(vi) < 0
52+
vi = vi - ν
53+
else
54+
vi = ((vi - conj(vi)) * ν - σ) / (conj(vi) + ν)
55+
end
56+
@simd for k in 1:(i - 1)
57+
v[k] /= vi
58+
end
59+
v[i] = 1
60+
@simd for k in (i + 1):length(v)
61+
v[k] /= vi
62+
end
63+
β = -conj(vi) / (ν)
64+
end
65+
end
66+
return β, v, ν
67+
end
68+
69+
function LinearAlgebra.lmul!(H::Householder, x::AbstractVector)
70+
v = H.v
71+
r = H.r
72+
β = H.β
73+
β == 0 && return x
74+
@inbounds begin
75+
μ = conj(zero(v[1])) * zero(x[r[1]])
76+
i = 1
77+
@simd for j in r
78+
μ += conj(v[i]) * x[j]
79+
i += 1
80+
end
81+
μ *= β
82+
i = 1
83+
@simd for j in H.r
84+
x[j] -= μ * v[i]
85+
i += 1
86+
end
87+
end
88+
return x
89+
end
90+
function LinearAlgebra.lmul!(H::Householder, A::AbstractMatrix; cols = axes(A, 2))
91+
v = H.v
92+
r = H.r
93+
β = H.β
94+
β == 0 && return A
95+
@inbounds begin
96+
for k in cols
97+
μ = conj(zero(v[1])) * zero(A[r[1], k])
98+
i = 1
99+
@simd for j in r
100+
μ += conj(v[i]) * A[j, k]
101+
i += 1
102+
end
103+
μ *= β
104+
i = 1
105+
@simd for j in H.r
106+
A[j, k] -= μ * v[i]
107+
i += 1
108+
end
109+
end
110+
end
111+
return A
112+
end
113+
function LinearAlgebra.rmul!(A::AbstractMatrix, H::Householder; rows = axes(A, 1))
114+
v = H.v
115+
r = H.r
116+
β = H.β
117+
β == 0 && return A
118+
w = similar(A, length(rows))
119+
fill!(w, 0)
120+
all(in(axes(A, 2)), r) || error("Householder range r = $r not compatible with matrix A of size $(size(A))")
121+
@inbounds begin
122+
l = 1
123+
for k in r
124+
j = 1
125+
@simd for i in rows
126+
w[j] += A[i, k] * v[l]
127+
j += 1
128+
end
129+
l += 1
130+
end
131+
l = 1
132+
for k in r
133+
j = 1
134+
@simd for i in rows
135+
A[i, k] -= β * w[j] * conj(v[l])
136+
j += 1
137+
end
138+
l += 1
139+
end
140+
end
141+
return A
142+
end

src/implementations/lq.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,3 +270,85 @@ function _diagonal_lq!(
270270
end
271271

272272
_diagonal_lq_null!(A::AbstractMatrix, N; positive::Bool = false) = N
273+
274+
# Native logic
275+
# -------------
276+
function lq_full!(A::AbstractMatrix, LQ, alg::Native_HouseholderLQ)
277+
check_input(lq_full!, A, LQ, alg)
278+
L, Q = LQ
279+
A === Q &&
280+
throw(ArgumentError("inplace Q not supported with native LQ implementation"))
281+
_native_lq!(A, L, Q; alg.kwargs...)
282+
return L, Q
283+
end
284+
function lq_compact!(A::AbstractMatrix, LQ, alg::Native_HouseholderLQ)
285+
check_input(lq_compact!, A, LQ, alg)
286+
L, Q = LQ
287+
A === Q &&
288+
throw(ArgumentError("inplace Q not supported with native LQ implementation"))
289+
_native_lq!(A, L, Q; alg.kwargs...)
290+
return L, Q
291+
end
292+
function lq_null!(A::AbstractMatrix, N, alg::Native_HouseholderLQ)
293+
check_input(lq_null!, A, N, alg)
294+
_native_lq_null!(A, N; alg.kwargs...)
295+
return N
296+
end
297+
298+
function _native_lq!(
299+
A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
300+
positive::Bool = true # always true regardless of setting
301+
)
302+
m, n = size(A)
303+
minmn = min(m, n)
304+
@inbounds for i in 1:minmn
305+
for j in 1:(i - 1)
306+
L[i, j] = A[i, j]
307+
end
308+
β, v, L[i, i] = _householder!(conj!(view(A, i, i:n)), 1)
309+
for j in (i + 1):size(L, 2)
310+
L[i, j] = 0
311+
end
312+
H = Householder(conj(β), v, i:n)
313+
rmul!(A, H; rows = (i + 1):m)
314+
# A[i, i] == 1; store β instead
315+
A[i, i] = β
316+
end
317+
# copy remaining rows for m > n
318+
@inbounds for j in 1:size(L, 2)
319+
for i in (minmn + 1):m
320+
L[i, j] = A[i, j]
321+
end
322+
end
323+
# build Q
324+
one!(Q)
325+
@inbounds for i in minmn:-1:1
326+
β = A[i, i]
327+
A[i, i] = 1
328+
Hᴴ = Householder(β, view(A, i, i:n), i:n)
329+
rmul!(Q, Hᴴ)
330+
end
331+
return L, Q
332+
end
333+
334+
function _native_lq_null!(A::AbstractMatrix, Nᴴ::AbstractMatrix; positive::Bool = true)
335+
m, n = size(A)
336+
minmn = min(m, n)
337+
@inbounds for i in 1:minmn
338+
β, v, ν = _householder!(conj!(view(A, i, i:n)), 1)
339+
H = Householder(conj(β), v, i:n)
340+
rmul!(A, H; rows = (i + 1):m)
341+
# A[i, i] == 1; store β instead
342+
A[i, i] = β
343+
end
344+
# build Nᴴ
345+
fill!(Nᴴ, zero(eltype(Nᴴ)))
346+
one!(view(Nᴴ, 1:(n - minmn), (minmn + 1):n))
347+
@inbounds for i in minmn:-1:1
348+
β = A[i, i]
349+
A[i, i] = 1
350+
Hᴴ = Householder(β, view(A, i, i:n), i:n)
351+
rmul!(Nᴴ, Hᴴ)
352+
end
353+
return Nᴴ
354+
end

src/implementations/qr.jl

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,9 @@ end
233233

234234
_diagonal_qr_null!(A::AbstractMatrix, N; positive::Bool = false) = N
235235

236-
### GPU logic
237-
# placed here to avoid code duplication since much of the logic is replicable across
238-
# CUDA and AMDGPU
239-
###
236+
# GPU logic
237+
# --------------
238+
# placed here to avoid code duplication since much of the logic is replicable across CUDA and AMDGPU
240239
function MatrixAlgebraKit.qr_full!(
241240
A::AbstractMatrix, QR, alg::Union{CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR}
242241
)
@@ -325,3 +324,85 @@ function _gpu_qr_null!(
325324
N = _gpu_unmqr!('L', 'N', A, τ, N)
326325
return N
327326
end
327+
328+
# Native logic
329+
# --------------
330+
function qr_full!(A::AbstractMatrix, QR, alg::Native_HouseholderQR)
331+
check_input(qr_full!, A, QR, alg)
332+
Q, R = QR
333+
A === Q &&
334+
throw(ArgumentError("inplace Q not supported with native QR implementation"))
335+
_native_qr!(A, Q, R; alg.kwargs...)
336+
return Q, R
337+
end
338+
function qr_compact!(A::AbstractMatrix, QR, alg::Native_HouseholderQR)
339+
check_input(qr_compact!, A, QR, alg)
340+
Q, R = QR
341+
A === Q &&
342+
throw(ArgumentError("inplace Q not supported with native QR implementation"))
343+
_native_qr!(A, Q, R; alg.kwargs...)
344+
return Q, R
345+
end
346+
function qr_null!(A::AbstractMatrix, N, alg::Native_HouseholderQR)
347+
check_input(qr_null!, A, N, alg)
348+
_native_qr_null!(A, N; alg.kwargs...)
349+
return N
350+
end
351+
352+
function _native_qr!(
353+
A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
354+
positive::Bool = true # always true regardless of setting
355+
)
356+
m, n = size(A)
357+
minmn = min(m, n)
358+
@inbounds for j in 1:minmn
359+
for i in 1:(j - 1)
360+
R[i, j] = A[i, j]
361+
end
362+
β, v, R[j, j] = _householder!(view(A, j:m, j), 1)
363+
for i in (j + 1):size(R, 1)
364+
R[i, j] = 0
365+
end
366+
H = Householder(β, v, j:m)
367+
lmul!(H, A; cols = (j + 1):n)
368+
# A[j,j] == 1; store β instead
369+
A[j, j] = β
370+
end
371+
# copy remaining columns if m < n
372+
@inbounds for j in (minmn + 1):n
373+
for i in 1:size(R, 1)
374+
R[i, j] = A[i, j]
375+
end
376+
end
377+
# build Q
378+
one!(Q)
379+
@inbounds for j in minmn:-1:1
380+
β = A[j, j]
381+
A[j, j] = 1
382+
Hᴴ = Householder(conj(β), view(A, j:m, j), j:m)
383+
lmul!(Hᴴ, Q)
384+
end
385+
return Q, R
386+
end
387+
388+
function _native_qr_null!(A::AbstractMatrix, N::AbstractMatrix; positive::Bool = true)
389+
m, n = size(A)
390+
minmn = min(m, n)
391+
@inbounds for j in 1:minmn
392+
β, v, ν = _householder!(view(A, j:m, j), 1)
393+
H = Householder(β, v, j:m)
394+
lmul!(H, A; cols = (j + 1):n)
395+
# A[j,j] == 1; store β instead
396+
A[j, j] = β
397+
end
398+
# build N
399+
fill!(N, zero(eltype(N)))
400+
one!(view(N, (minmn + 1):m, 1:(m - minmn)))
401+
@inbounds for j in minmn:-1:1
402+
β = A[j, j]
403+
A[j, j] = 1
404+
Hᴴ = Householder(conj(β), view(A, j:m, j), j:m)
405+
lmul!(Hᴴ, N)
406+
end
407+
return N
408+
end

src/interface/decompositions.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,24 @@
99

1010
# QR, LQ, QL, RQ Decomposition
1111
# ----------------------------
12+
"""
13+
Native_HouseholderQR()
14+
15+
Algorithm type to denote a native implementation for computing the QR decomposition of
16+
a matrix using Householder reflectors. The diagonal elements of `R` will be non-negative
17+
by construction.
18+
"""
19+
@algdef Native_HouseholderQR
20+
21+
"""
22+
Native_HouseholderLQ()
23+
24+
Algorithm type to denote a native implementation for computing the LQ decomposition of
25+
a matrix using Householder reflectors. The diagonal elements of `L` will be non-negative
26+
by construction.
27+
"""
28+
@algdef Native_HouseholderLQ
29+
1230
"""
1331
LAPACK_HouseholderQR(; blocksize, positive = false, pivoted = false)
1432

src/interface/lq.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ default_lq_algorithm(A; kwargs...) = default_lq_algorithm(typeof(A); kwargs...)
7272
function default_lq_algorithm(T::Type; kwargs...)
7373
throw(MethodError(default_lq_algorithm, (T,)))
7474
end
75+
function default_lq_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix}
76+
return Native_HouseholderLQ(; kwargs...)
77+
end
7578
function default_lq_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat}
7679
return LAPACK_HouseholderLQ(; kwargs...)
7780
end

src/interface/qr.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ default_qr_algorithm(A; kwargs...) = default_qr_algorithm(typeof(A); kwargs...)
7272
function default_qr_algorithm(T::Type; kwargs...)
7373
throw(MethodError(default_qr_algorithm, (T,)))
7474
end
75+
function default_qr_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix}
76+
return Native_HouseholderQR(; kwargs...)
77+
end
7578
function default_qr_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat}
7679
return LAPACK_HouseholderQR(; kwargs...)
7780
end

0 commit comments

Comments
 (0)