Skip to content

Commit e2e068b

Browse files
committed
add SVD algorithms
change default algorithms
1 parent 2e4e25d commit e2e068b

3 files changed

Lines changed: 166 additions & 120 deletions

File tree

src/implementations/svd.jl

Lines changed: 109 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -105,133 +105,106 @@ function initialize_output(::typeof(svd_vals!), A::Diagonal, ::DiagonalAlgorithm
105105
return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1))
106106
end
107107

108-
# Implementation
109-
# --------------
110-
function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
111-
check_input(svd_full!, A, USVᴴ, alg)
112-
U, S, Vᴴ = USVᴴ
113-
fill!(S, zero(eltype(S)))
114-
m, n = size(A)
115-
minmn = min(m, n)
116-
if minmn == 0
117-
one!(U)
118-
zero!(S)
119-
one!(Vᴴ)
120-
return USVᴴ
121-
end
122-
123-
do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
124-
alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)})
125-
126-
if alg isa LAPACK_QRIteration
127-
isempty(alg_kwargs) ||
128-
throw(ArgumentError("invalid keyword arguments for LAPACK_QRIteration"))
129-
YALAPACK.gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ)
130-
elseif alg isa LAPACK_DivideAndConquer
131-
isempty(alg_kwargs) ||
132-
throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer"))
133-
YALAPACK.gesdd!(A, view(S, 1:minmn, 1), U, Vᴴ)
134-
elseif alg isa LAPACK_SafeDivideAndConquer
135-
isempty(alg_kwargs) ||
136-
throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer"))
137-
YALAPACK.gesdvd!(A, view(S, 1:minmn, 1), U, Vᴴ)
138-
elseif alg isa LAPACK_Bisection
139-
throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
140-
elseif alg isa LAPACK_Jacobi
141-
throw(ArgumentError("LAPACK_Jacobi is not supported for full SVD"))
142-
else
143-
throw(ArgumentError("Unsupported SVD algorithm"))
144-
end
145-
146-
for i in 2:minmn
147-
S[i, i] = S[i, 1]
148-
S[i, 1] = zero(eltype(S))
149-
end
150-
151-
do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ)
108+
# ==========================
109+
# IMPLEMENTATIONS
110+
# ==========================
152111

153-
return USVᴴ
112+
for f! in (:gesdd!, :gesvd!, :gesvdj!, :gesvdp!, :gesvdx!, :gesvdr!, :gesdvd!)
113+
@eval $f!(driver::Driver, args...) = throw(ArgumentError("$driver does not provide $f!"))
154114
end
155115

156-
function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
157-
check_input(svd_compact!, A, USVᴴ, alg)
158-
U, S, Vᴴ = USVᴴ
159-
m, n = size(A)
160-
minmn = min(m, n)
161-
if minmn == 0
162-
one!(U)
163-
zero!(S)
164-
one!(Vᴴ)
165-
return USVᴴ
166-
end
167-
168-
do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
169-
alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)})
170-
171-
if alg isa LAPACK_QRIteration
172-
isempty(alg_kwargs) ||
173-
throw(ArgumentError("invalid keyword arguments for LAPACK_QRIteration"))
174-
YALAPACK.gesvd!(A, diagview(S), U, Vᴴ)
175-
elseif alg isa LAPACK_DivideAndConquer
176-
isempty(alg_kwargs) ||
177-
throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer"))
178-
YALAPACK.gesdd!(A, diagview(S), U, Vᴴ)
179-
elseif alg isa LAPACK_SafeDivideAndConquer
180-
isempty(alg_kwargs) ||
181-
throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer"))
182-
YALAPACK.gesdvd!(A, diagview(S), U, Vᴴ)
183-
elseif alg isa LAPACK_Bisection
184-
YALAPACK.gesvdx!(A, diagview(S), U, Vᴴ; alg_kwargs...)
185-
elseif alg isa LAPACK_Jacobi
186-
isempty(alg_kwargs) ||
187-
throw(ArgumentError("invalid keyword arguments for LAPACK_Jacobi"))
188-
YALAPACK.gesvj!(A, diagview(S), U, Vᴴ)
189-
else
190-
throw(ArgumentError("Unsupported SVD algorithm"))
191-
end
192-
193-
do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ)
194-
195-
return USVᴴ
116+
# LAPACK
117+
for f! in (:gesdd!, :gesvd!, :gesvdj!, :gesvdx!, :gesdvd!)
118+
@eval $f!(::LAPACK, args...; kwargs...) = YALAPACK.$f!(args...; kwargs...)
196119
end
197120

198-
function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
199-
check_input(svd_vals!, A, S, alg)
200-
m, n = size(A)
201-
minmn = min(m, n)
202-
if minmn == 0
203-
zero!(S)
204-
return S
121+
for (f, f_lapack!, Alg) in (
122+
(:safe_divide_and_conquer, :gesdvd!, :SafeDivideAndConquer),
123+
(:divide_and_conquer, :gesdd!, :DivideAndConquer),
124+
(:qr_iteration, :gesvd!, :QRIteration),
125+
(:bisection, :gesvdx!, :Bisection),
126+
(:jacobi, :gesvdj!, :Jacobi),
127+
)
128+
f_svd! = Symbol(f, :_svd!)
129+
f_svd_full! = Symbol(f, :_svd_full!)
130+
f_svd_vals! = Symbol(f, :_svd_vals!)
131+
132+
# MatrixAlgebraKit wrappers
133+
@eval begin
134+
function svd_compact!(A, USVᴴ, alg::$Alg)
135+
check_input(svd_compact!, A, USVᴴ, alg)
136+
return $f_svd!(A, USVᴴ...; alg.kwargs...)
137+
end
138+
function svd_full!(A, USVᴴ, alg::$Alg)
139+
check_input(svd_full!, A, USVᴴ, alg)
140+
return $f_svd_full!(A, USVᴴ...; alg.kwargs...)
141+
end
142+
function svd_vals!(A, S, alg::$Alg)
143+
check_input(svd_vals!, A, S, alg)
144+
return $f_svd_vals!(A, S; alg.kwargs...)
145+
end
205146
end
206-
U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0))
207-
208-
alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)})
209147

210-
if alg isa LAPACK_QRIteration
211-
isempty(alg_kwargs) ||
212-
throw(ArgumentError("invalid keyword arguments for LAPACK_QRIteration"))
213-
YALAPACK.gesvd!(A, S, U, Vᴴ)
214-
elseif alg isa LAPACK_DivideAndConquer
215-
isempty(alg_kwargs) ||
216-
throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer"))
217-
YALAPACK.gesdd!(A, S, U, Vᴴ)
218-
elseif alg isa LAPACK_SafeDivideAndConquer
219-
isempty(alg_kwargs) ||
220-
throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer"))
221-
YALAPACK.gesdvd!(A, S, U, Vᴴ)
222-
elseif alg isa LAPACK_Bisection
223-
YALAPACK.gesvdx!(A, S, U, Vᴴ; alg_kwargs...)
224-
elseif alg isa LAPACK_Jacobi
225-
isempty(alg_kwargs) ||
226-
throw(ArgumentError("invalid keyword arguments for LAPACK_Jacobi"))
227-
YALAPACK.gesvj!(A, S, U, Vᴴ)
228-
else
229-
throw(ArgumentError("Unsupported SVD algorithm"))
148+
# driver
149+
@eval begin
150+
@inline $f_svd!(A, U, S, Vᴴ; driver::Driver = DefaultDriver(), kwargs...) =
151+
$f_svd!(driver, A, U, S, Vᴴ; kwargs...)
152+
@inline $f_svd_full!(A, U, S, Vᴴ; driver::Driver = DefaultDriver(), kwargs...) =
153+
$f_svd_full!(driver, A, U, S, Vᴴ; kwargs...)
154+
@inline $f_svd_vals!(A, S; driver::Driver = DefaultDriver(), kwargs...) =
155+
$f_svd_vals!(driver, A, S; kwargs...)
156+
@inline $f_svd!(::DefaultDriver, A, U, S, Vᴴ; kwargs...) =
157+
$f_svd!($(Symbol(:default_, f, :_driver))(A), A, U, S, Vᴴ; kwargs...)
158+
@inline $f_svd_full!(::DefaultDriver, A, S; kwargs...) =
159+
$f_svd_full!($(Symbol(:default_, f, :_driver)), A, S; kwargs...)
160+
@inline $f_svd_vals!(::DefaultDriver, A, S; kwargs...) =
161+
$f_svd_vals!($(Symbol(:default_, f, :_driver)), A, S; kwargs...)
230162
end
231163

232-
return S
164+
# Implementation
165+
@eval begin
166+
function $f_svd!(
167+
driver::Driver, A::AbstractMatrix, U::AbstractMatrix, S::AbstractMatrix, Vᴴ::AbstractMatrix;
168+
fixgauge::Bool = true, kwargs...
169+
)
170+
supports_svd(driver, $(QuoteNode(f))) || throw(ArgumentError(lazy"$driver does not provide $f"))
171+
isempty(A) && return one!(U), zero!(S), one!(Vᴴ)
172+
$f_lapack!(driver, A, view(S, 1:minmn, 1), U, Vᴴ; kwargs...)
173+
fixgauge && gaugefix!(svd_compact!, U, Vᴴ)
174+
return U, S, Vᴴ
175+
end
176+
function $f_svd_full!(
177+
driver::Driver, A::AbstractMatrix, U::AbstractMatrix, S::AbstractMatrix, Vᴴ::AbstractMatrix;
178+
fixgauge::Bool = true, kwargs...
179+
)
180+
supports_svd_full(driver, $(QuoteNode(f))) || throw(ArgumentError(lazy"$driver does not provide $f"))
181+
isempty(A) && return one!(U), zero!(S), one!(Vᴴ)
182+
zero!(S)
183+
minmn = min(size(A)...)
184+
$f_lapack!(driver, A, view(S, 1:minmn, 1), U, Vᴴ; kwargs...)
185+
diagview(S) .= view(S, 1:minmn, 1)
186+
view(S, 2:minmn, 1) .= zero(eltype(S))
187+
fixgauge && gaugefix!(svd_full!, U, Vᴴ)
188+
return U, S, Vᴴ
189+
end
190+
function $f_svd_vals!(
191+
driver::Driver, A::AbstractMatrix, S::AbstractVector;
192+
fixgauge::Bool = true, kwargs...
193+
)
194+
supports_svd(driver, $(QuoteNode(f))) || throw(ArgumentError(lazy"$driver does not provide $f"))
195+
isempty(A) && return zero!(S)
196+
U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0))
197+
$f_lapack!(driver, A, view(S, 1:minmn, 1), U, Vᴴ; kwargs...)
198+
return S
199+
end
200+
end
233201
end
234202

203+
supports_svd(::Driver, ::Symbol) = false
204+
supports_svd(::LAPACK, f::Symbol) = f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration, :bisection, :jacobi)
205+
supports_svd_full(::Driver, ::Symbol) = false
206+
supports_svd_full(::LAPACK, f::Symbol) = f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration)
207+
235208
function svd_trunc_no_error!(A, USVᴴ, alg::TruncatedAlgorithm)
236209
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
237210
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
@@ -485,3 +458,23 @@ function svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm)
485458

486459
return S
487460
end
461+
462+
# Deprecations
463+
# ------------
464+
for algtype in (:DivideAndConquer, :QRIteration, :Jacobi, :Bisection)
465+
algtype = Symbol(:LAPACK_, algtype)
466+
@eval begin
467+
Base.@deprecate(
468+
svd_compact!(A, USVᴴ, alg::$algtype),
469+
svd_compact!(A, USVᴴ, $algtype(; driver = LAPACK(), alg.kwargs...))
470+
)
471+
Base.@deprecate(
472+
svd_full!(A, USVᴴ, alg::$algtype),
473+
svd_full!(A, USVᴴ, $algtype(; driver = LAPACK(), alg.kwargs...))
474+
)
475+
Base.@deprecate(
476+
svd_vals!(A, S, alg::$algtype),
477+
svd_vals!(A, S, $algtype(; driver = LAPACK(), alg.kwargs...))
478+
)
479+
end
480+
end

src/interface/decompositions.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,61 @@ default_householder_driver(::Type{<:SubArray{T, N, A}}) where {T, N, A} =
9898
default_householder_driver(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} =
9999
default_householder_driver(A)
100100

101+
"""
102+
DivideAndConquer(; [driver], kwargs...)
103+
104+
Algorithm type to denote the algorithm for computing the eigenvalue decomposition of a Hermitian matrix,
105+
or the singular value decomposition of a general matrix using the divide-and-conquer algorithm.
106+
107+
The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or singular vectors, see also [`gaugefix!`](@ref).
108+
"""
109+
@algdef DivideAndConquer
110+
111+
"""
112+
SafeDivideAndConquer(; [driver], kwargs...)
113+
114+
Algorithm type to denote the algorithm for computing the eigenvalue decomposition of a Hermitian matrix,
115+
or the singular value decomposition of a general matrix using the divide-and-conquer algorithm,
116+
with an additional fallback to the standard QR iteration algorithm in case the former fails to converge.
117+
118+
The optional `driver` symbol can be used to choose between different implementations of this algorithm.
119+
The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or singular vectors, see also [`gaugefix!`](@ref).
120+
121+
!!! warning
122+
This approach requires a copy of the input matrix, and is thus the most memory intensive SVD strategy.
123+
However, as it combines the speed of the Divide and Conquer algorithm with the robustness of the
124+
QR Iteration algorithm, it is the default SVD strategy for LAPACK-based implementations in MatrixAlgebraKit.
125+
126+
See also [`DivideAndConquer`](@ref) and [`QRIteration`](@ref).
127+
"""
128+
@algdef SafeDivideAndConquer
129+
130+
"""
131+
QRIteration(; [driver], kwargs...)
132+
133+
Algorithm type to denote the algorithm for computing the eigenvalue decomposition of a Hermitian matrix,
134+
or the singular value decomposition of a general matrix via QR iteration.
135+
136+
The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or singular vectors, see also [`gaugefix!`](@ref).
137+
"""
138+
@algdef QRIteration
139+
@algdef Bisection
140+
@algdef Jacobi
141+
142+
for f in (:divide_and_conquer, :qr_iteration, :bisection, :jacobi)
143+
default_f_driver = Symbol(:default_, f, :_driver)
144+
@eval begin
145+
$default_f_driver(A) = $default_f_driver(typeof(A))
146+
$default_f_driver(::Type) = Native()
147+
148+
$default_f_driver(::Type{A}) where {A <: YALAPACK.MaybeBlasMat} = LAPACK()
149+
150+
# note: StridedVector fallback is needed for handling reshaped parent types
151+
$default_f_driver(::Type{A}) where {A <: StridedVector{<:BlasFloat}} = LAPACK()
152+
$default_f_driver(::Type{<:SubArray{T, N, A}}) where {T, N, A} = $default_f_driver(A)
153+
$default_f_driver(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} = $default_f_driver(A)
154+
end
155+
end
101156

102157
# General Eigenvalue Decomposition
103158
# -------------------------------

src/interface/svd.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,9 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_compact(!)`](@ref svd_compact) an
158158
# Algorithm selection
159159
# -------------------
160160
default_svd_algorithm(A; kwargs...) = default_svd_algorithm(typeof(A); kwargs...)
161-
function default_svd_algorithm(T::Type; kwargs...)
162-
throw(MethodError(default_svd_algorithm, (T,)))
163-
end
161+
default_svd_algorithm(T::Type; kwargs...) = throw(MethodError(default_svd_algorithm, (T,)))
164162
function default_svd_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat}
165-
return LAPACK_SafeDivideAndConquer(; kwargs...)
163+
return SafeDivideAndConquer(; kwargs...)
166164
end
167165
function default_svd_algorithm(::Type{T}; kwargs...) where {T <: Diagonal}
168166
return DiagonalAlgorithm(; kwargs...)

0 commit comments

Comments
 (0)