Skip to content

Commit d111cd9

Browse files
committed
refactor + deprecate CUSOLVER randomized SVD
1 parent 1cd48c0 commit d111cd9

5 files changed

Lines changed: 104 additions & 27 deletions

File tree

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ using MatrixAlgebraKit
44
using MatrixAlgebraKit: @algdef, Algorithm, check_input
55
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
66
using MatrixAlgebraKit: diagview, sign_safe
7-
using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
7+
using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, TruncationByOrder, AbstractAlgorithm
8+
using MatrixAlgebraKit: GaussianSketching, SketchingStrategy, SketchedAlgorithm
89
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
910
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!
1011
import MatrixAlgebraKit: heevj!, heevd!, geev!
11-
import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank
12+
import MatrixAlgebraKit: _sylvester, svd_rank
1213
using CUDA, CUDA.cuBLAS
1314
using CUDA: i32
1415
using LinearAlgebra
@@ -17,6 +18,7 @@ using LinearAlgebra: BlasFloat
1718
include("yacusolver.jl")
1819

1920
MatrixAlgebraKit.default_driver(::Type{TA}) where {TA <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER()
21+
MatrixAlgebraKit.default_driver(::Type{<:SketchedAlgorithm}, ::Type{TA}) where {TA <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER()
2022

2123
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
2224
return QRIteration(; kwargs...)
@@ -50,8 +52,31 @@ end
5052
gesvdp!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
5153
YACUSOLVER.gesvdp!(A, S, U, Vᴴ; kwargs...)
5254

53-
_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
54-
YACUSOLVER.gesvdr!(A, S, U, Vᴴ; kwargs...)
55+
# Sketched SVD via cuSOLVER's gesvdr kernel
56+
function gesvdr!(
57+
::CUSOLVER, A::StridedCuMatrix, S, U::StridedCuMatrix, Vᴴ::StridedCuMatrix;
58+
sketch::GaussianSketching, trunc::TruncationByOrder, alg::AbstractAlgorithm = DefaultAlgorithm()
59+
)
60+
isempty(A) && return U, S, Vᴴ
61+
m, n = size(A); minmn = min(m, n)
62+
k = trunc.howmany
63+
1 k minmn ||
64+
throw(ArgumentError("trunc.howmany=$k must satisfy 1 ≤ k ≤ min(size(A))=$minmn"))
65+
p = sketch.howmany - k
66+
p 0 || throw(
67+
ArgumentError(
68+
"sketch.howmany=$(sketch.howmany) must be ≥ trunc.howmany=$k"
69+
)
70+
)
71+
p = min(p, minmn - k - 1)
72+
niters = sketch.numiter - 1
73+
74+
Uk = view(U, :, 1:k)
75+
Vᴴk = view(Vᴴ, 1:k, :)
76+
Sk = view(diagview(S), 1:k)
77+
YACUSOLVER.gesvdr!(A, Sk, Uk, Vᴴk; k, p, niters)
78+
return Uk, S, Vᴴk
79+
end
5580

5681
geev!(::CUSOLVER, A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix) =
5782
YACUSOLVER.Xgeev!(A, Dd, V)

src/algorithms.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -356,17 +356,25 @@ struct TruncatedAlgorithm{A <: AbstractAlgorithm, T <: TruncationStrategy} <: Ab
356356
end
357357

358358
"""
359-
SketchedAlgorithm(alg::AbstractAlgorithm, sketch::SketchingStrategy, trunc::TruncationStrategy)
359+
SketchedAlgorithm(;
360+
alg::AbstractAlgorithm, sketch::SketchingStrategy,
361+
trunc::TruncationStrategy, driver::Driver = DefaultDriver()
362+
)
360363
361364
Generic wrapper type for self-truncating algorithms that produce an approximate low-rank
362365
factorization by first applying a sketching operation specified by `sketch`, then computing
363366
a small dense decomposition of the projected matrix using `alg`. The `driver` selects the
364-
backend (e.g. `DefaultDriver()`, `CUSOLVER()`).
365-
"""
366-
struct SketchedAlgorithm{A <: AbstractAlgorithm, S <: SketchingStrategy, T <: TruncationStrategy} <: AbstractAlgorithm
367-
alg::A
367+
backend implementing the sketched factorization (e.g. `Native()` for the generic
368+
sketch-then-decompose pipeline, `CUSOLVER()` for the fused `gesvdr` kernel).
369+
"""
370+
@kwdef struct SketchedAlgorithm{
371+
A <: AbstractAlgorithm, S <: SketchingStrategy,
372+
T <: TruncationStrategy, D <: Driver,
373+
} <: AbstractAlgorithm
374+
alg::A = DefaultAlgorithm()
368375
sketch::S
369-
trunc::T
376+
trunc::T = notrunc()
377+
driver::D = DefaultDriver()
370378
end
371379

372380
# utility conversion constructor

src/implementations/svd.jl

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -308,35 +308,47 @@ check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::SketchedAlgori
308308

309309
function svd_trunc_no_error!(A::AbstractMatrix, (U, S, Vᴴ), alg::SketchedAlgorithm)
310310
check_input(svd_trunc_no_error!, A, (U, S, Vᴴ), alg)
311+
return gesvdr!(alg.driver, A, S, U, Vᴴ; alg.sketch, alg.alg, alg.trunc)
312+
end
313+
314+
function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::SketchedAlgorithm)
315+
U, S, Vᴴ = svd_trunc_no_error!(A, USVᴴ, alg)
316+
Na = norm(A)
317+
Ns = norm(S)
318+
return U, S, Vᴴ, sqrt(max(zero(Na), (Na + Ns) * (Na - Ns)))
319+
end
320+
321+
# gesvdr! drivers
322+
# ---------------
323+
default_driver(::Type{<:SketchedAlgorithm}, ::Type{<:AbstractArray}) = Native()
324+
325+
gesvdr!(::DefaultDriver, A, S, U, Vᴴ; kwargs...) =
326+
gesvdr!(default_driver(SketchedAlgorithm, A), A, S, U, Vᴴ; kwargs...)
327+
328+
function gesvdr!(
329+
::Native, A::AbstractMatrix, S, U, Vᴴ;
330+
sketch::SketchingStrategy, alg::AbstractAlgorithm,
331+
trunc::TruncationStrategy
332+
)
311333
m, n = size(A)
312334
if m n
313-
Q, B = left_sketch!(A, (U, Vᴴ), alg.sketch)
335+
Q, B = left_sketch!(A, (U, Vᴴ), sketch)
314336
k = size(B, 1)
315337
U′ = similar(B, (k, k))
316338
Vᴴ′ = similar(B)
317-
USVᴴ_inner = svd_compact!(B, (U′, S, Vᴴ′), alg.alg)
318-
(Uout′, Sout, Vᴴout), _ = truncate(svd_trunc!, USVᴴ_inner, alg.trunc)
339+
Uout′, Sout, Vᴴout, _ = svd_trunc!(B, (U′, S, Vᴴ′), TruncatedAlgorithm(alg, trunc))
319340
Uout = Q * Uout′
320341
else
321-
B, Pᴴ = right_sketch!(A, (U, Vᴴ), alg.sketch)
342+
B, Pᴴ = right_sketch!(A, (U, Vᴴ), sketch)
322343
k = size(B, 2)
323344
U′ = similar(B)
324345
Vᴴ′ = similar(B, (k, k))
325-
USVᴴ_inner = svd_compact!(B, (U′, S, Vᴴ′), alg.alg)
326-
(Uout, Sout, Vᴴout′), _ = truncate(svd_trunc!, USVᴴ_inner, alg.trunc)
346+
Uout, Sout, Vᴴout′, _ = svd_trunc!(B, (U′, S, Vᴴ′), TruncatedAlgorithm(alg, trunc))
327347
Vᴴout = Vᴴout′ * Pᴴ
328348
end
329-
get(alg.alg.kwargs, :fixgauge, true) && gaugefix!(svd_trunc!, Uout, Vᴴout)
330349
return Uout, Sout, Vᴴout
331350
end
332351

333-
function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::SketchedAlgorithm)
334-
U, S, Vᴴ = svd_trunc_no_error!(A, USVᴴ, alg)
335-
Na = norm(A)
336-
Ns = norm(S)
337-
return U, S, Vᴴ, sqrt(max(zero(Na), (Na + Ns) * (Na - Ns)))
338-
end
339-
340352
# Deprecations
341353
# ------------
342354
for algtype in (:SafeDivideAndConquer, :DivideAndConquer, :QRIteration, :Jacobi, :Bisection)
@@ -380,6 +392,40 @@ for (algtype, newtype, drivertype) in (
380392
end
381393
end
382394

395+
# CUSOLVER_Randomized → SketchedAlgorithm with driver = CUSOLVER()
396+
function _cusolver_randomized_to_sketched(alg::CUSOLVER_Randomized)
397+
k = alg.kwargs.k
398+
p = alg.kwargs.p
399+
niters = alg.kwargs.niters
400+
return SketchedAlgorithm(
401+
QRIteration(),
402+
GaussianSketching(k + p; numiter = niters + 1),
403+
truncrank(k);
404+
driver = CUSOLVER(),
405+
)
406+
end
407+
408+
for f! in (:svd_trunc!, :svd_trunc_no_error!)
409+
@eval Base.@deprecate(
410+
$f!(A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized),
411+
$f!(A, USVᴴ, _cusolver_randomized_to_sketched(alg))
412+
)
413+
end
414+
415+
@inline function select_algorithm(::typeof(svd_trunc!), A, alg::CUSOLVER_Randomized; kwargs...)
416+
Base.depwarn(
417+
"`CUSOLVER_Randomized` is deprecated; use \
418+
`SketchedAlgorithm(QRIteration(), GaussianSketching(k+p; numiter=niters+1), truncrank(k); driver=CUSOLVER())` instead.",
419+
:select_algorithm,
420+
)
421+
isempty(kwargs) ||
422+
throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified."))
423+
return _cusolver_randomized_to_sketched(alg)
424+
end
425+
@inline function select_algorithm(::typeof(svd_trunc_no_error!), A, alg::CUSOLVER_Randomized; kwargs...)
426+
return select_algorithm(svd_trunc!, A, alg; kwargs...)
427+
end
428+
383429
# GLA_QRIteration SVD deprecations (eigh methods remain in the GLA extension)
384430
Base.@deprecate(
385431
svd_compact!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration),

src/interface/decompositions.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,6 @@ for more information.
408408
"""
409409
@algdef CUSOLVER_Randomized
410410

411-
does_truncate(::TruncatedAlgorithm{<:CUSOLVER_Randomized}) = true
412-
413411
"""
414412
CUSOLVER_Simple(; fixgauge = default_fixgauge())
415413

test/decompositions/svd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ if !is_buildkite
3333
m, n = 54, 63
3434
rtol = sqrt(TestSuite.precision(T)) # extra square root
3535
algs = [
36-
SketchedAlgorithm(DefaultAlgorithm(), GaussianSketching(m ÷ 2, numiter = 4), truncrank(m ÷ 4)),
36+
SketchedAlgorithm(; sketch = GaussianSketching(m ÷ 2, numiter = 4), trunc = truncrank(m ÷ 4)),
3737
]
3838
TestSuite.test_sketched_svd(T, (m, n), algs; rtol)
3939
TestSuite.test_sketched_svd(T, (n, m), algs)

0 commit comments

Comments
 (0)