Skip to content

Commit c454434

Browse files
committed
slight reimplementation
1 parent 6111f9c commit c454434

3 files changed

Lines changed: 80 additions & 62 deletions

File tree

src/implementations/lq.jl

Lines changed: 65 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -94,53 +94,46 @@ end
9494
# -----------
9595
function lq_full!(A, LQ, alg::Householder)
9696
check_input(lq_full!, A, LQ, alg)
97-
return householder_lq!(alg.driver, A, LQ...; alg.kwargs...)
97+
return householder_lq!(A, LQ...; alg.kwargs...)
9898
end
9999
function lq_compact!(A, LQ, alg::Householder)
100100
check_input(lq_compact!, A, LQ, alg)
101-
return householder_lq!(alg.driver, A, LQ...; alg.kwargs...)
101+
return householder_lq!(A, LQ...; alg.kwargs...)
102102
end
103103
function lq_null!(A, Nᴴ, alg::Householder)
104104
check_input(lq_null!, A, Nᴴ, alg)
105-
return householder_lq_null!(alg.driver, A, Nᴴ; alg.kwargs...)
105+
return householder_lq_null!(A, Nᴴ; alg.kwargs...)
106106
end
107107

108-
householder_lq!(::DefaultDriver, A, L, Q; kwargs...) =
109-
householder_lq!(default_householder_driver(A), A, L, Q; kwargs...)
110-
householder_lq_null!(::DefaultDriver, A, Nᴴ; kwargs...) =
111-
householder_lq_null!(default_householder_driver(A), A, Nᴴ; kwargs...)
112-
113108
# dispatch helpers
114109
for f in (:gelqt!, :gemlqt!, :gelqf!, :unglq!, :unmlq!)
115110
@eval begin
116111
$f(::LAPACK, args...) = YALAPACK.$f(args...)
117112
end
118113
end
119114

120-
function householder_lq!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, L, Q; kwargs...)
121-
qr_alg = driver === GLA() ? GLA_HouseholderQR(; kwargs...) : Householder(driver; kwargs...)
122-
return lq_via_qr!(A, L, Q, qr_alg)
123-
end
124-
function householder_lq_null!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, Nᴴ; kwargs...)
125-
qr_alg = driver === GLA() ? GLA_HouseholderQR(; kwargs...) : Householder(driver; kwargs...)
126-
return lq_null_via_qr!(A, Nᴴ, qr_alg)
127-
end
128-
115+
@inline householder_lq!(A, L, Q; driver::Driver = DefaultDriver(), kwargs...) =
116+
householder_lq!(driver, A, L, Q; kwargs...)
117+
householder_lq!(::DefaultDriver, A, L, Q; kwargs...) =
118+
householder_lq!(default_householder_driver(A), A, L, Q; kwargs...)
119+
householder_lq!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, L, Q; kwargs...) =
120+
lq_via_qr!(A, L, Q, Householder(; driver, kwargs...))
129121
function householder_lq!(
130122
driver::LAPACK, A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
131123
positive = true, pivoted = false,
132124
blocksize = ((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A))
133125
)
126+
# error messages for disallowing driver - setting combinations
127+
pivoted && (blocksize > 1) &&
128+
throw(ArgumentError(lazy"$driver does not provide a blocked pivoted LQ decomposition"))
129+
134130
m, n = size(A)
135131
minmn = min(m, n)
136132
computeL = length(L) > 0
137133
inplaceQ = Q === A
138134

139-
pivoted && (blocksize > 1) &&
140-
throw(ArgumentError(lazy"$driver does not provide a blocked pivoted LQ decomposition"))
141135
(inplaceQ && (computeL || positive || blocksize > 1 || n < m)) &&
142136
throw(ArgumentError("inplace Q only supported if matrix is wide (`m <= n`), L is not required, and using the unblocked algorithm (`blocksize = 1`) with `positive = false`"))
143-
144137
if blocksize > 1
145138
mb = min(minmn, blocksize)
146139
if computeL # first use L as space for T
@@ -181,28 +174,17 @@ function householder_lq!(
181174
end
182175
return L, Q
183176
end
184-
function householder_lq_null!(
185-
driver::LAPACK, A::AbstractMatrix, Nᴴ::AbstractMatrix;
186-
positive = true, pivoted = false, blocksize = YALAPACK.default_qr_blocksize(A)
187-
)
188-
m, n = size(A)
189-
minmn = min(m, n)
190-
zero!(Nᴴ)
191-
one!(view(Nᴴ, 1:(n - minmn), (minmn + 1):n))
192-
if blocksize > 1
193-
mb = min(minmn, blocksize)
194-
A, T = gelqt!(driver, A, similar(A, mb, minmn))
195-
Nᴴ = gemlqt!(driver, 'R', 'N', A, T, Nᴴ)
196-
else
197-
A, τ = gelqf!(driver, A)
198-
Nᴴ = unmlq!(driver, 'R', 'N', A, τ, Nᴴ)
199-
end
200-
return Nᴴ
201-
end
202177
function householder_lq!(
203178
::Native, A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
204-
positive::Bool = true # always true regardless of setting
179+
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
205180
)
181+
# error messages for disallowing driver - setting combinations
182+
blocksize == 1 ||
183+
throw(ArgumentError(lazy"$driver does not provide a blocked LQ decomposition"))
184+
pivoted &&
185+
throw(ArgumentError(lazy"$driver does not provide a pivoted LQ decomposition"))
186+
# positive = true regardless of setting
187+
206188
m, n = size(A)
207189
minmn = min(m, n)
208190
@inbounds for i in 1:minmn
@@ -234,7 +216,46 @@ function householder_lq!(
234216
end
235217
return L, Q
236218
end
237-
function householder_lq_null!(::Native, A::AbstractMatrix, Nᴴ::AbstractMatrix; positive::Bool = true)
219+
220+
@inline householder_lq_null!(A, Nᴴ; driver::Driver = DefaultDriver(), kwargs...) =
221+
householder_lq_null!(driver, A, Nᴴ; kwargs...)
222+
householder_lq_null!(::DefaultDriver, A, Nᴴ; kwargs...) =
223+
householder_lq_null!(default_householder_driver(A), A, Nᴴ; kwargs...)
224+
householder_lq_null!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, Nᴴ; kwargs...) =
225+
lq_null_via_qr!(A, Nᴴ, Householder(; driver, kwargs...))
226+
function householder_lq_null!(
227+
driver::LAPACK, A::AbstractMatrix, Nᴴ::AbstractMatrix;
228+
positive::Bool = true, pivoted::Bool = false, blocksize::Int = pivoted ? 1 : YALAPACK.default_qr_blocksize(A)
229+
)
230+
# error messages for disallowing driver - setting combinations
231+
pivoted && (blocksize > 1) &&
232+
throw(ArgumentError(lazy"$driver does not provide a blocked pivoted LQ decomposition"))
233+
234+
m, n = size(A)
235+
minmn = min(m, n)
236+
zero!(Nᴴ)
237+
one!(view(Nᴴ, 1:(n - minmn), (minmn + 1):n))
238+
239+
if blocksize > 1
240+
mb = min(minmn, blocksize)
241+
A, T = gelqt!(driver, A, similar(A, mb, minmn))
242+
Nᴴ = gemlqt!(driver, 'R', 'N', A, T, Nᴴ)
243+
else
244+
A, τ = gelqf!(driver, A)
245+
Nᴴ = unmlq!(driver, 'R', 'N', A, τ, Nᴴ)
246+
end
247+
return Nᴴ
248+
end
249+
function householder_lq_null!(
250+
::Native, A::AbstractMatrix, Nᴴ::AbstractMatrix;
251+
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
252+
)
253+
# error messages for disallowing driver - setting combinations
254+
blocksize == 1 ||
255+
throw(ArgumentError(lazy"$driver does not provide a blocked LQ decomposition"))
256+
pivoted &&
257+
throw(ArgumentError(lazy"$driver does not provide a pivoted LQ decomposition"))
258+
238259
m, n = size(A)
239260
minmn = min(m, n)
240261
@inbounds for i in 1:minmn
@@ -280,11 +301,10 @@ end
280301
function lq_via_qr!(
281302
A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, qr_alg::AbstractAlgorithm
282303
)
283-
m, n = size(A)
284-
minmn = min(m, n)
285304
At = adjoint!(similar(A'), A)::AbstractMatrix
286305
Qt = (A === Q) ? At : similar(Q')
287306
Lt = similar(L')
307+
n = size(A, 2)
288308
if size(Q) == (n, n)
289309
Qt, Lt = qr_full!(At, (Qt, Lt), qr_alg)
290310
else
@@ -296,8 +316,6 @@ function lq_via_qr!(
296316
end
297317

298318
function lq_null_via_qr!(A::AbstractMatrix, N::AbstractMatrix, qr_alg::AbstractAlgorithm)
299-
m, n = size(A)
300-
minmn = min(m, n)
301319
At = adjoint!(similar(A'), A)::AbstractMatrix
302320
Nt = similar(N')
303321
Nt = qr_null!(At, Nt, qr_alg)
@@ -351,15 +369,15 @@ for drivertype in (:LAPACK, :Native)
351369
@eval begin
352370
Base.@deprecate(
353371
lq_full!(A, LQ, alg::$algtype),
354-
lq_full!(A, LQ, Householder($drivertype(), alg.kwargs))
372+
lq_full!(A, LQ, Householder(; driver = $drivertype(), alg.kwargs...))
355373
)
356374
Base.@deprecate(
357375
lq_compact!(A, LQ, alg::$algtype),
358-
lq_compact!(A, LQ, Householder($drivertype(), alg.kwargs))
376+
lq_compact!(A, LQ, Householder(; driver = $drivertype(), alg.kwargs...))
359377
)
360378
Base.@deprecate(
361379
lq_null!(A, Nᴴ, alg::$algtype),
362-
lq_null!(A, Nᴴ, Householder($drivertype(), alg.kwargs))
380+
lq_null!(A, Nᴴ, Householder(; driver = $drivertype(), alg.kwargs...))
363381
)
364382
end
365383
end

src/implementations/qr.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,21 +94,17 @@ end
9494
# -----------
9595
function qr_full!(A, QR, alg::Householder)
9696
check_input(qr_full!, A, QR, alg)
97-
return householder_qr!(alg.driver, A, QR...; alg.kwargs...)
97+
return householder_qr!(A, QR...; alg.kwargs...)
9898
end
9999
function qr_compact!(A, QR, alg::Householder)
100100
check_input(qr_compact!, A, QR, alg)
101-
return householder_qr!(alg.driver, A, QR...; alg.kwargs...)
101+
return householder_qr!(A, QR...; alg.kwargs...)
102102
end
103103
function qr_null!(A, N, alg::Householder)
104104
check_input(qr_null!, A, N, alg)
105-
return householder_qr_null!(alg.driver, A, N; alg.kwargs...)
105+
return householder_qr_null!(A, N; alg.kwargs...)
106106
end
107107

108-
householder_qr!(::DefaultDriver, A, Q, R; kwargs...) =
109-
householder_qr!(default_householder_driver(A), A, Q, R; kwargs...)
110-
householder_qr_null!(::DefaultDriver, A, N; kwargs...) =
111-
householder_qr_null!(default_householder_driver(A), A, N; kwargs...)
112108

113109
# dispatch helpers
114110
for f in (:geqrt!, :gemqrt!, :geqp3!, :geqrf!, :ungqr!, :unmqr!)
@@ -118,6 +114,10 @@ for f in (:geqrt!, :gemqrt!, :geqp3!, :geqrf!, :ungqr!, :unmqr!)
118114
end
119115
end
120116

117+
@inline householder_qr!(A, Q, R; driver::Driver = DefaultDriver(), kwargs...) =
118+
householder_qr!(driver, A, Q, R; kwargs...)
119+
householder_qr!(::DefaultDriver, A, Q, R; kwargs...) =
120+
householder_qr!(default_householder_driver(A), A, Q, R; kwargs...)
121121
function householder_qr!(
122122
driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
123123
positive::Bool = true, pivoted::Bool = false,
@@ -243,6 +243,10 @@ function householder_qr!(
243243
return Q, R
244244
end
245245

246+
@inline householder_qr_null!(A, N; driver::Driver = DefaultDriver(), kwargs...) =
247+
householder_qr_null!(driver, A, N; kwargs...)
248+
householder_qr_null!(::DefaultDriver, A, N; kwargs...) =
249+
householder_qr_null!(default_householder_driver(A), A, N; kwargs...)
246250
function householder_qr_null!(
247251
driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, N::AbstractMatrix;
248252
positive::Bool = true, pivoted::Bool = false,
@@ -351,15 +355,15 @@ for drivertype in (:LAPACK, :CUSOLVER, :ROCSOLVER, :Native, :GLA)
351355
@eval begin
352356
Base.@deprecate(
353357
qr_full!(A, QR, alg::$algtype),
354-
qr_full!(A, QR, Householder($drivertype(), alg.kwargs))
358+
qr_full!(A, QR, Householder(; driver = $drivertype(), alg.kwargs...))
355359
)
356360
Base.@deprecate(
357361
qr_compact!(A, QR, alg::$algtype),
358-
qr_compact!(A, QR, Householder($drivertype(), alg.kwargs))
362+
qr_compact!(A, QR, Householder(; driver = $drivertype(), alg.kwargs...))
359363
)
360364
Base.@deprecate(
361365
qr_null!(A, N, alg::$algtype),
362-
qr_null!(A, N, Householder($drivertype(), alg.kwargs))
366+
qr_null!(A, N, Householder(; driver = $drivertype(), alg.kwargs...))
363367
)
364368
end
365369
end

src/interface/decompositions.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,7 @@ The optional `driver` symbol can be used to choose between different implementat
7878
7979
Depending on the driver, various other keywords may be (un)available to customize the implementation.
8080
"""
81-
struct Householder{D <: Driver, KW} <: AbstractAlgorithm
82-
driver::D
83-
kwargs::KW
84-
end
85-
Householder(driver::Driver = DefaultDriver(); kwargs...) = Householder(driver, kwargs)
81+
@algdef Householder
8682

8783
default_householder_driver(A) = Native()
8884
default_householder_driver(::YALAPACK.MaybeBlasMat) = LAPACK()

0 commit comments

Comments
 (0)