Skip to content

Commit a836012

Browse files
committed
GPU-friendly SVD + correct gaugefix
1 parent f95d1b3 commit a836012

1 file changed

Lines changed: 2 additions & 5 deletions

File tree

src/implementations/svd.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,11 +255,8 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithm)
255255
T = eltype(Vᴴ)
256256
zero!(U)
257257
zero!(Vᴴ)
258-
@inbounds for (i, pi) in enumerate(p)
259-
s = Ad[pi]
260-
U[pi, i] = sign_safe(s)
261-
Vᴴ[i, pi] = one(T)
262-
end
258+
U[CartesianIndex.(enumerate(p))] .= Ref(one(T))
259+
Vᴴ[CartesianIndex.(reverse.(enumerate(p)))] .= sign_safe.(view(Ad, p))
263260
return U, S, Vᴴ
264261
end
265262
function svd_compact!(A::AbstractMatrix, USVᴴ, alg::DiagonalAlgorithm)

0 commit comments

Comments
 (0)