Skip to content

Commit d8af318

Browse files
committed
rework algorithms for type stability
1 parent 2d40186 commit d8af318

5 files changed

Lines changed: 43 additions & 24 deletions

File tree

ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ end
5858

5959
function MatrixAlgebraKit.householder_qr!(
6060
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
61-
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
61+
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
6262
)
63-
blocksize == 1 ||
63+
blocksize <= 1 ||
6464
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
6565
pivoted &&
6666
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))
@@ -102,9 +102,9 @@ end
102102

103103
function MatrixAlgebraKit.householder_qr_null!(
104104
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, N::AbstractMatrix;
105-
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
105+
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
106106
)
107-
blocksize == 1 ||
107+
blocksize <= 1 ||
108108
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
109109
pivoted &&
110110
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))

src/algorithms.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,19 @@ See also [`@algdef`](@ref).
1919
"""
2020
struct Algorithm{name, K} <: AbstractAlgorithm
2121
kwargs::K
22+
23+
# Ensure keywords are always in canonical order
24+
function Algorithm{Name}(kwargs::NamedTuple) where {Name}
25+
kwargs_sorted = _sortkeys(kwargs)
26+
return new{Name, typeof(kwargs_sorted)}(kwargs_sorted)
27+
end
2228
end
29+
Algorithm{Name}(; kwargs...) where {Name} = Algorithm{Name}(NamedTuple(kwargs))
30+
31+
# Utility generated function to canonicalize keys in type-stable way
32+
@generated _sortkeys(nt::NamedTuple{K}) where {K} =
33+
:(NamedTuple{$(Tuple(sort!(collect(K))))}(nt))
34+
2335
name(alg::Algorithm) = name(typeof(alg))
2436
name(::Type{<:Algorithm{N}}) where {N} = N
2537

@@ -299,11 +311,6 @@ macro algdef(name)
299311
return esc(
300312
quote
301313
const $name{K} = Algorithm{$(QuoteNode(name)), K}
302-
function $name(; kwargs...)
303-
# TODO: is this necessary/useful?
304-
kw = NamedTuple(kwargs) # normalize type
305-
return $name{typeof(kw)}(kw)
306-
end
307314
function Base.show(io::IO, alg::$name)
308315
return ($_show_alg)(io, alg)
309316
end

src/implementations/lq.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,10 @@ householder_lq!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, L, Q; kwargs...) =
120120
lq_via_qr!(A, L, Q, Householder(; driver, kwargs...))
121121
function householder_lq!(
122122
driver::LAPACK, A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
123-
positive = true, pivoted = false,
124-
blocksize = ((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A))
123+
positive = true, pivoted = false, blocksize::Int = 0
125124
)
125+
blocksize = blocksize > 0 ? blocksize : ((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A))
126+
126127
# error messages for disallowing driver - setting combinations
127128
pivoted && (blocksize > 1) &&
128129
throw(ArgumentError(lazy"$driver does not provide a blocked pivoted LQ decomposition"))
@@ -176,10 +177,10 @@ function householder_lq!(
176177
end
177178
function householder_lq!(
178179
driver::Native, A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
179-
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
180+
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
180181
)
181182
# error messages for disallowing driver - setting combinations
182-
blocksize == 1 ||
183+
blocksize <= 1 ||
183184
throw(ArgumentError(lazy"$driver does not provide a blocked LQ decomposition"))
184185
pivoted &&
185186
throw(ArgumentError(lazy"$driver does not provide a pivoted LQ decomposition"))
@@ -225,8 +226,10 @@ householder_lq_null!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, Nᴴ; kwargs...
225226
lq_null_via_qr!(A, Nᴴ, Householder(; driver, kwargs...))
226227
function householder_lq_null!(
227228
driver::LAPACK, A::AbstractMatrix, Nᴴ::AbstractMatrix;
228-
positive::Bool = true, pivoted::Bool = false, blocksize::Int = pivoted ? 1 : YALAPACK.default_qr_blocksize(A)
229+
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
229230
)
231+
blocksize = blocksize > 0 ? blocksize : (pivoted ? 1 : YALAPACK.default_qr_blocksize(A))
232+
230233
# error messages for disallowing driver - setting combinations
231234
pivoted && (blocksize > 1) &&
232235
throw(ArgumentError(lazy"$driver does not provide a blocked pivoted LQ decomposition"))
@@ -248,10 +251,10 @@ function householder_lq_null!(
248251
end
249252
function householder_lq_null!(
250253
driver::Native, A::AbstractMatrix, Nᴴ::AbstractMatrix;
251-
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
254+
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
252255
)
253256
# error messages for disallowing driver - setting combinations
254-
blocksize == 1 ||
257+
blocksize <= 1 ||
255258
throw(ArgumentError(lazy"$driver does not provide a blocked LQ decomposition"))
256259
pivoted &&
257260
throw(ArgumentError(lazy"$driver does not provide a pivoted LQ decomposition"))

src/implementations/qr.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,10 @@ householder_qr!(::DefaultDriver, A, Q, R; kwargs...) =
121121
function householder_qr!(
122122
driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
123123
positive::Bool = true, pivoted::Bool = false,
124-
blocksize::Int = ((driver !== LAPACK() || pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A))
124+
blocksize::Int = 0
125125
)
126+
blocksize = blocksize > 0 ? blocksize : ((driver !== LAPACK() || pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A))
127+
126128
# error messages for disallowing driver - setting combinations
127129
(blocksize == 1 || driver === LAPACK()) ||
128130
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
@@ -202,10 +204,10 @@ function householder_qr!(
202204
end
203205
function householder_qr!(
204206
driver::Native, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
205-
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
207+
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
206208
)
207209
# error messages for disallowing driver - setting combinations
208-
blocksize == 1 ||
210+
blocksize <= 1 ||
209211
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
210212
pivoted &&
211213
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))
@@ -249,9 +251,9 @@ householder_qr_null!(::DefaultDriver, A, N; kwargs...) =
249251
householder_qr_null!(default_householder_driver(A), A, N; kwargs...)
250252
function householder_qr_null!(
251253
driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, N::AbstractMatrix;
252-
positive::Bool = true, pivoted::Bool = false,
253-
blocksize::Int = ((driver !== LAPACK() || pivoted) ? 1 : YALAPACK.default_qr_blocksize(A))
254+
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
254255
)
256+
blocksize = blocksize > 0 ? blocksize : ((driver !== LAPACK() || pivoted) ? 1 : YALAPACK.default_qr_blocksize(A))
255257
# error messages for disallowing driver - setting combinations
256258
(blocksize == 1 || driver === LAPACK()) ||
257259
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
@@ -277,10 +279,10 @@ function householder_qr_null!(
277279
end
278280
function householder_qr_null!(
279281
driver::Native, A::AbstractMatrix, N::AbstractMatrix;
280-
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
282+
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
281283
)
282284
# error messages for disallowing driver - setting combinations
283-
blocksize == 1 ||
285+
blocksize <= 1 ||
284286
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
285287
pivoted &&
286288
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))

src/interface/decompositions.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,19 @@ The optional `driver` symbol can be used to choose between different implementat
7474
7575
- `positive::Bool = true` : Fix the gauge of the resulting factors by making the diagonal elements of `L` or `R` non-negative.
7676
- `pivoted::Bool = false` : Use column- or row-pivoting for low-rank input matrices.
77-
- `blocksize::Int` : Use a blocked version of the algorithm if `blocksize > 1`.
77+
- `blocksize::Int` : Use a blocked version of the algorithm if `blocksize > 1`. Use the default if `blocksize ≤ 0`.
7878
7979
Depending on the driver, various other keywords may be (un)available to customize the implementation.
8080
"""
8181
@algdef Householder
8282

83+
function Householder(;
84+
driver::Driver = DefaultDriver(), blocksize::Int = 0,
85+
pivoted::Bool = false, positive::Bool = true
86+
)
87+
return Householder((; driver, blocksize, pivoted, positive))
88+
end
89+
8390
default_householder_driver(A) = default_householder_driver(typeof(A))
8491
default_householder_driver(::Type) = Native()
8592

0 commit comments

Comments
 (0)