@@ -248,18 +248,25 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithm)
248248 check_input (svd_full!, A, USVᴴ, alg)
249249 Ad = diagview (A)
250250 U, S, Vᴴ = USVᴴ
251- Sd = diagview (S)
252- Sd .= abs .(Ad)
253- p = sortperm (Sd; rev= true )
254- permute! (Sd, p)
255- T = eltype (Vᴴ)
251+ p = sortperm (Ad; by= abs, rev= true )
256252 zero! (U)
257253 zero! (Vᴴ)
258- @inbounds for (i, pi ) in enumerate (p)
259- s = Ad[pi ]
260- U[pi , i] = sign_safe (s)
261- Vᴴ[i, pi ] = one (T)
254+ n = size (A, 1 )
255+
256+ pV = (1 : n) .+ (p .- 1 ) .* n
257+ Vᴴ[pV] .= sign_safe .(view (Ad, p))
258+
259+ Sd = diagview (S)
260+ if Ad === Sd
261+ @. Sd = abs (Ad)
262+ permute! (Sd, p)
263+ else
264+ Sd .= abs .(view (Ad, p))
262265 end
266+
267+ p .+ = (0 : (n - 1 )) .* n
268+ U[p] .= Ref (one (eltype (U)))
269+
263270 return U, S, Vᴴ
264271end
265272function svd_compact! (A:: AbstractMatrix , USVᴴ, alg:: DiagonalAlgorithm )
@@ -284,12 +291,13 @@ const CUSOLVER_SVDAlgorithm = Union{CUSOLVER_QRIteration,
284291 CUSOLVER_Randomized}
285292const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration,
286293 ROCSOLVER_Jacobi}
287- const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm}
294+ const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm,ROCSOLVER_SVDAlgorithm}
288295
289296const GPU_SVDPolar = Union{CUSOLVER_SVDPolar}
290297const GPU_Randomized = Union{CUSOLVER_Randomized}
291298
292- function check_input (:: typeof (svd_trunc!), A:: AbstractMatrix , USVᴴ, alg:: CUSOLVER_Randomized )
299+ function check_input (:: typeof (svd_trunc!), A:: AbstractMatrix , USVᴴ,
300+ alg:: CUSOLVER_Randomized )
293301 m, n = size (A)
294302 minmn = min (m, n)
295303 U, S, Vᴴ = USVᴴ
@@ -303,7 +311,8 @@ function check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOL
303311 return nothing
304312end
305313
306- function initialize_output (:: typeof (svd_trunc!), A:: AbstractMatrix , alg:: TruncatedAlgorithm{<:CUSOLVER_Randomized} )
314+ function initialize_output (:: typeof (svd_trunc!), A:: AbstractMatrix ,
315+ alg:: TruncatedAlgorithm{<:CUSOLVER_Randomized} )
307316 m, n = size (A)
308317 minmn = min (m, n)
309318 U = similar (A, (m, m))
@@ -312,10 +321,22 @@ function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::Truncat
312321 return (U, S, Vᴴ)
313322end
314323
315- _gpu_gesvd! (A:: AbstractMatrix , S:: AbstractVector , U:: AbstractMatrix , Vᴴ:: AbstractMatrix ) = throw (MethodError (_gpu_gesvd!, (A, S, U, Vᴴ)))
316- _gpu_Xgesvdp! (A:: AbstractMatrix , S:: AbstractVector , U:: AbstractMatrix , Vᴴ:: AbstractMatrix ; kwargs... ) = throw (MethodError (_gpu_Xgesvdp!, (A, S, U, Vᴴ)))
317- _gpu_Xgesvdr! (A:: AbstractMatrix , S:: AbstractVector , U:: AbstractMatrix , Vᴴ:: AbstractMatrix ; kwargs... ) = throw (MethodError (_gpu_Xgesvdr!, (A, S, U, Vᴴ)))
318- _gpu_gesvdj! (A:: AbstractMatrix , S:: AbstractVector , U:: AbstractMatrix , Vᴴ:: AbstractMatrix ; kwargs... ) = throw (MethodError (_gpu_gesvdj!, (A, S, U, Vᴴ)))
324+ function _gpu_gesvd! (A:: AbstractMatrix , S:: AbstractVector , U:: AbstractMatrix ,
325+ Vᴴ:: AbstractMatrix )
326+ throw (MethodError (_gpu_gesvd!, (A, S, U, Vᴴ)))
327+ end
328+ function _gpu_Xgesvdp! (A:: AbstractMatrix , S:: AbstractVector , U:: AbstractMatrix ,
329+ Vᴴ:: AbstractMatrix ; kwargs... )
330+ throw (MethodError (_gpu_Xgesvdp!, (A, S, U, Vᴴ)))
331+ end
332+ function _gpu_Xgesvdr! (A:: AbstractMatrix , S:: AbstractVector , U:: AbstractMatrix ,
333+ Vᴴ:: AbstractMatrix ; kwargs... )
334+ throw (MethodError (_gpu_Xgesvdr!, (A, S, U, Vᴴ)))
335+ end
336+ function _gpu_gesvdj! (A:: AbstractMatrix , S:: AbstractVector , U:: AbstractMatrix ,
337+ Vᴴ:: AbstractMatrix ; kwargs... )
338+ throw (MethodError (_gpu_gesvdj!, (A, S, U, Vᴴ)))
339+ end
319340# GPU SVD implementation
320341function MatrixAlgebraKit. svd_full! (A:: AbstractMatrix , USVᴴ, alg:: GPU_SVDAlgorithm )
321342 check_input (svd_full!, A, USVᴴ, alg)
@@ -369,7 +390,7 @@ function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAl
369390 throw (ArgumentError (" Unsupported SVD algorithm" ))
370391 end
371392 # TODO : make this controllable using a `gaugefix` keyword argument
372- gaugefix! (svd_compact!, U, S, Vᴴ, size (A)... )
393+ gaugefix! (svd_compact!, U, S, Vᴴ, size (A)... )
373394 return USVᴴ
374395end
375396_argmaxabs (x) = reduce (_largest, x; init= zero (eltype (x)))
0 commit comments