Skip to content

Commit 1cd48c0

Browse files
committed
Rework algorithm selection logic
1 parent ec55db2 commit 1cd48c0

3 files changed

Lines changed: 34 additions & 21 deletions

File tree

src/algorithms.jl

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,26 @@ function select_null_truncation(trunc)
292292
elseif trunc isa TruncationStrategy
293293
return trunc
294294
else
295-
return throw(ArgumentError("Unknown truncation strategy: $trunc"))
295+
throw(ArgumentError("Unknown truncation strategy: $trunc"))
296+
end
297+
end
298+
299+
@doc """
300+
MatrixAlgebraKit.select_sketching(A, sketch)
301+
302+
Construct a [`SketchingStrategy`](@ref) for `A` from the given `NamedTuple` of keywords or input strategy `sketch`.
303+
""" select_sketching
304+
305+
@inline select_sketching(A, sketch) = select_sketching(typeof(A), sketch)
306+
@inline function select_sketching(::Type{A}, sketch) where {A}
307+
if isnothing(sketch)
308+
return nothing
309+
elseif sketch isa SketchingStrategy
310+
return sketch
311+
elseif sketch isa NamedTuple
312+
return select_algorithm(left_sketch!, A; sketch...)
313+
else
314+
throw(ArgumentError("Unknown sketching strategy: $sketch"))
296315
end
297316
end
298317

@@ -331,7 +350,7 @@ function truncate end
331350
Generic wrapper type for algorithms that consist of first using `alg`, followed by a
332351
truncation through `trunc`.
333352
"""
334-
struct TruncatedAlgorithm{A, T} <: AbstractAlgorithm
353+
struct TruncatedAlgorithm{A <: AbstractAlgorithm, T <: TruncationStrategy} <: AbstractAlgorithm
335354
alg::A
336355
trunc::T
337356
end
@@ -356,10 +375,9 @@ TruncatedAlgorithm(alg::SketchedAlgorithm) = TruncatedAlgorithm(alg.alg, alg.tru
356375
does_truncate(::TruncatedAlgorithm) = true
357376
does_truncate(::SketchedAlgorithm) = true
358377

359-
truncated_algorithm(alg::AbstractAlgorithm, trunc::TruncationStrategy) =
360-
TruncatedAlgorithm(alg, trunc)
361-
truncated_algorithm(alg::AbstractAlgorithm, sketch::SketchingStrategy) =
362-
SketchedAlgorithm(sketch, alg, DefaultDriver())
378+
truncated_algorithm(alg::AbstractAlgorithm, trunc::TruncationStrategy, sketch = nothing) =
379+
isnothing(sketch) ? TruncatedAlgorithm(alg, trunc) : SketchedAlgorithm(; alg, sketch, trunc)
380+
363381

364382
# Utility macros
365383
# --------------

src/implementations/orthnull.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ function right_null!(A, Nᴴ, alg::RightNullViaSVD{<:TruncatedAlgorithm})
116116
return Nᴴ
117117
end
118118

119-
# randomized algorithms don't currently work for smallest values:
120-
left_null!(A, N, alg::LeftNullViaSVD{<:TruncatedAlgorithm{<:GPU_Randomized}}) =
119+
# randomized (sketched) algorithms don't currently work for smallest values:
120+
left_null!(A, N, alg::LeftNullViaSVD{<:SketchedAlgorithm}) =
121121
throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces yet"))
122-
right_null!(A, Nᴴ, alg::RightNullViaSVD{<:TruncatedAlgorithm{<:GPU_Randomized}}) =
122+
right_null!(A, Nᴴ, alg::RightNullViaSVD{<:SketchedAlgorithm}) =
123123
throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces yet"))

src/interface/svd.jl

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -179,23 +179,18 @@ for f in (:svd_full!, :svd_compact!, :svd_vals!)
179179
end
180180

181181
for f in (:svd_trunc!, :svd_trunc_no_error!)
182-
@eval function select_algorithm(::typeof($f), A, alg; trunc = nothing, kwargs...)
183-
if alg isa TruncatedAlgorithm
182+
@eval function select_algorithm(::typeof($f), A, alg; trunc = nothing, sketch = nothing, kwargs...)
183+
if alg isa TruncatedAlgorithm || alg isa SketchedAlgorithm
184184
isnothing(trunc) ||
185-
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
186-
return alg
187-
elseif alg isa SketchedAlgorithm
188-
isnothing(trunc) ||
189-
throw(ArgumentError("`trunc` can't be specified when `alg` is a `SketchedAlgorithm`"))
185+
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm` or `SketchedAlgorithm`"))
186+
isnothing(sketch) ||
187+
throw(ArgumentError("`sketch` can't be specified when `alg` is a `TruncatedAlgorithm` or `SketchedAlgorithm`"))
190188
return alg
191189
else
192190
alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...)
193191
trunc = select_truncation(trunc)
194-
if trunc isa TruncationStrategy
195-
return truncated_algorithm(alg_svd, trunc)
196-
else
197-
throw(ArgumentError("invalid truncation $trunc"))
198-
end
192+
sketch = select_sketching(A, sketch)
193+
return truncated_algorithm(alg_svd, trunc, sketch)
199194
end
200195
end
201196
end

0 commit comments

Comments
 (0)