Skip to content

Commit 7d2befd

Browse files
committed
Fix svd_trunc
1 parent 153895a commit 7d2befd

4 files changed

Lines changed: 26 additions & 15 deletions

File tree

ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ for (fname, elty, relty) in
164164

165165
AMDGPU.unsafe_free!(dev_residual)
166166
AMDGPU.unsafe_free!(dev_n_sweeps)
167-
return U, S, Vᴴ
167+
return (S, U, Vᴴ)
168168
end
169169
end
170170
end

ext/MatrixAlgebraKitCUDAExt/yacusolver.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ for (bname, fname, elty, relty) in
242242
if jobz == 'V'
243243
adjoint!(Vᴴ, Ṽ)
244244
end
245-
return U, S, Vᴴ
245+
return S, U, Vᴴ
246246
end
247247
end
248248
end

src/common/gauge.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function gaugefix!(V::AbstractMatrix)
22
for j in axes(V, 2)
33
v = view(V, :, j)
4-
s = conj(sign(argmax(abs, v)))
4+
s = conj(sign(_argmaxabs(v)))
55
@inbounds v .*= s
66
end
77
return V
@@ -12,16 +12,16 @@ function gaugefix!(::Val{:full}, U, S, Vᴴ, m::Int, n::Int)
1212
if j <= min(m, n)
1313
u = view(U, :, j)
1414
v = view(Vᴴ, j, :)
15-
s = conj(sign(argmax(abs, u)))
15+
s = conj(sign(_argmaxabs(u)))
1616
u .*= s
1717
v .*= conj(s)
1818
elseif j <= m
1919
u = view(U, :, j)
20-
s = conj(sign(argmax(abs, u)))
20+
s = conj(sign(_argmaxabs(u)))
2121
u .*= s
2222
else
2323
v = view(Vᴴ, j, :)
24-
s = conj(sign(argmax(abs, v)))
24+
s = conj(sign(_argmaxabs(v)))
2525
v .*= s
2626
end
2727
end
@@ -32,7 +32,18 @@ function gaugefix!(::Val{:compact}, U, S, Vᴴ, m::Int, n::Int)
3232
for j in 1:size(U, 2)
3333
u = view(U, :, j)
3434
v = view(Vᴴ, j, :)
35-
s = conj(sign(argmax(abs, u)))
35+
s = conj(sign(_argmaxabs(u)))
36+
u .*= s
37+
v .*= conj(s)
38+
end
39+
return (U, S, Vᴴ)
40+
end
41+
42+
function gaugefix!(::Val{:trunc}, U, S, Vᴴ, m::Int, n::Int)
43+
for j in 1:min(m, n)
44+
u = view(U, :, j)
45+
v = view(Vᴴ, j, :)
46+
s = conj(sign(_argmaxabs(u)))
3647
u .*= s
3748
v .*= conj(s)
3849
end

src/implementations/svd.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -175,21 +175,21 @@ const GPU_SVDPolar = Union{CUSOLVER_SVDPolar}
175175
const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi}
176176
const GPU_Randomized = Union{CUSOLVER_Randomized}
177177

178-
function check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{CUSOLVER_Randomized})
178+
function check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized)
179179
m, n = size(A)
180180
minmn = min(m, n)
181181
U, S, Vᴴ = USVᴴ
182182
@assert U isa AbstractMatrix && S isa Diagonal && Vᴴ isa AbstractMatrix
183183
@check_size(U, (m, m))
184184
@check_scalar(U, A)
185-
@check_size(S, (minmn,minmn))
185+
@check_size(S, (minmn, minmn))
186186
@check_scalar(S, A, real)
187187
@check_size(Vᴴ, (n, n))
188188
@check_scalar(Vᴴ, A)
189189
return nothing
190190
end
191191

192-
function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::TruncatedAlgorithm{CUSOLVER_Randomized})
192+
function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized})
193193
m, n = size(A)
194194
minmn = min(m, n)
195195
U = similar(A, (m, m))
@@ -232,12 +232,12 @@ function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgor
232232
end
233233

234234
function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized})
235-
check_input(svd_trunc!, A, USVᴴ, alg)
235+
check_input(svd_trunc!, A, USVᴴ, alg.alg)
236236
U, S, Vᴴ = USVᴴ
237-
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.kwargs...)
237+
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...)
238238
# TODO: make this controllable using a `gaugefix` keyword argument
239-
gaugefix!(Val(:compact), U, S, Vᴴ, m, n)
240-
return truncate!(svd_trunc!, USVᴴ, alg.trunc)
239+
gaugefix!(Val(:trunc), U, S, Vᴴ, size(A)...)
240+
return truncate!(svd_trunc!, USVᴴ, alg.trunc)
241241
end
242242

243243
function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
@@ -255,7 +255,7 @@ function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAl
255255
throw(ArgumentError("Unsupported SVD algorithm"))
256256
end
257257
# TODO: make this controllable using a `gaugefix` keyword argument
258-
gaugefix!(Val(:compact), U, S, Vᴴ, m, n)
258+
gaugefix!(Val(:compact), U, S, Vᴴ, size(A)...)
259259
return USVᴴ
260260
end
261261
_argmaxabs(x) = reduce(_largest, x; init=zero(eltype(x)))

0 commit comments

Comments
 (0)