Skip to content

Commit 18ae70e

Browse files
committed
GPU-friendly SVD + correct gaugefix
1 parent f95d1b3 commit 18ae70e

1 file changed

Lines changed: 38 additions & 17 deletions

File tree

src/implementations/svd.jl

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -248,18 +248,25 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithm)
248248
check_input(svd_full!, A, USVᴴ, alg)
249249
Ad = diagview(A)
250250
U, S, Vᴴ = USVᴴ
251-
Sd = diagview(S)
252-
Sd .= abs.(Ad)
253-
p = sortperm(Sd; rev=true)
254-
permute!(Sd, p)
255-
T = eltype(Vᴴ)
251+
p = sortperm(Ad; by=abs, rev=true)
256252
zero!(U)
257253
zero!(Vᴴ)
258-
@inbounds for (i, pi) in enumerate(p)
259-
s = Ad[pi]
260-
U[pi, i] = sign_safe(s)
261-
Vᴴ[i, pi] = one(T)
254+
n = size(A, 1)
255+
256+
pV = (1:n) .+ (p .- 1) .* n
257+
Vᴴ[pV] .= sign_safe.(view(Ad, p))
258+
259+
Sd = diagview(S)
260+
if Ad === Sd
261+
@. Sd = abs(Ad)
262+
permute!(Sd, p)
263+
else
264+
Sd .= abs.(view(Ad, p))
262265
end
266+
267+
p .+= (0:(n - 1)) .* n
268+
U[p] .= Ref(one(eltype(U)))
269+
263270
return U, S, Vᴴ
264271
end
265272
function svd_compact!(A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithm)
@@ -284,12 +291,13 @@ const CUSOLVER_SVDAlgorithm = Union{CUSOLVER_QRIteration,
284291
CUSOLVER_Randomized}
285292
const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration,
286293
ROCSOLVER_Jacobi}
287-
const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm}
294+
const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm,ROCSOLVER_SVDAlgorithm}
288295

289296
const GPU_SVDPolar = Union{CUSOLVER_SVDPolar}
290297
const GPU_Randomized = Union{CUSOLVER_Randomized}
291298

292-
function check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized)
299+
function check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ,
300+
alg::CUSOLVER_Randomized)
293301
m, n = size(A)
294302
minmn = min(m, n)
295303
U, S, Vᴴ = USVᴴ
@@ -303,7 +311,8 @@ function check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOL
303311
return nothing
304312
end
305313

306-
function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized})
314+
function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix,
315+
alg::TruncatedAlgorithm{<:CUSOLVER_Randomized})
307316
m, n = size(A)
308317
minmn = min(m, n)
309318
U = similar(A, (m, m))
@@ -312,10 +321,22 @@ function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::Truncat
312321
return (U, S, Vᴴ)
313322
end
314323

315-
_gpu_gesvd!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix) = throw(MethodError(_gpu_gesvd!, (A, S, U, Vᴴ)))
316-
_gpu_Xgesvdp!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_Xgesvdp!, (A, S, U, Vᴴ)))
317-
_gpu_Xgesvdr!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_Xgesvdr!, (A, S, U, Vᴴ)))
318-
_gpu_gesvdj!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_gesvdj!, (A, S, U, Vᴴ)))
324+
function _gpu_gesvd!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix,
325+
Vᴴ::AbstractMatrix)
326+
throw(MethodError(_gpu_gesvd!, (A, S, U, Vᴴ)))
327+
end
328+
function _gpu_Xgesvdp!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix,
329+
Vᴴ::AbstractMatrix; kwargs...)
330+
throw(MethodError(_gpu_Xgesvdp!, (A, S, U, Vᴴ)))
331+
end
332+
function _gpu_Xgesvdr!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix,
333+
Vᴴ::AbstractMatrix; kwargs...)
334+
throw(MethodError(_gpu_Xgesvdr!, (A, S, U, Vᴴ)))
335+
end
336+
function _gpu_gesvdj!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix,
337+
Vᴴ::AbstractMatrix; kwargs...)
338+
throw(MethodError(_gpu_gesvdj!, (A, S, U, Vᴴ)))
339+
end
319340
# GPU SVD implementation
320341
function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
321342
check_input(svd_full!, A, USVᴴ, alg)
@@ -369,7 +390,7 @@ function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAl
369390
throw(ArgumentError("Unsupported SVD algorithm"))
370391
end
371392
# TODO: make this controllable using a `gaugefix` keyword argument
372-
gaugefix!(svd_compact!, U, S, Vᴴ, size(A)...)
393+
gaugefix!(svd_compact!, U, S, Vᴴ, size(A)...)
373394
return USVᴴ
374395
end
375396
_argmaxabs(x) = reduce(_largest, x; init=zero(eltype(x)))

0 commit comments

Comments
 (0)