Skip to content

Commit 562afe2

Browse files
author
Katharine Hyatt
committed
Move SVD gauge fix to its own function
1 parent 9ee9b16 commit 562afe2

3 files changed

Lines changed: 38 additions & 51 deletions

File tree

src/common/gauge.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,35 @@ function gaugefix!(V::AbstractMatrix)
66
end
77
return V
88
end
9+
10+
function gaugefix!(::Val{:full}, U, S, Vᴴ, m::Int, n::Int)
11+
for j in 1:max(m, n)
12+
if j <= min(m, n)
13+
u = view(U, :, j)
14+
v = view(Vᴴ, j, :)
15+
s = conj(sign(argmax(abs, u)))
16+
u .*= s
17+
v .*= conj(s)
18+
elseif j <= m
19+
u = view(U, :, j)
20+
s = conj(sign(argmax(abs, u)))
21+
u .*= s
22+
else
23+
v = view(Vᴴ, j, :)
24+
s = conj(sign(argmax(abs, v)))
25+
v .*= s
26+
end
27+
end
28+
return (U, S, Vᴴ)
29+
end
30+
31+
function gaugefix!(::Val{:compact}, U, S, Vᴴ, m::Int, n::Int)
32+
for j in 1:size(U, 2)
33+
u = view(U, :, j)
34+
v = view(Vᴴ, j, :)
35+
s = conj(sign(argmax(abs, u)))
36+
u .*= s
37+
v .*= conj(s)
38+
end
39+
return (U, S, Vᴴ)
40+
end

src/implementations/svd.jl

Lines changed: 4 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -101,23 +101,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
101101
S[i, 1] = zero(eltype(S))
102102
end
103103
# TODO: make this controllable using a `gaugefix` keyword argument
104-
for j in 1:max(m, n)
105-
if j <= minmn
106-
u = view(U, :, j)
107-
v = view(Vᴴ, j, :)
108-
s = conj(sign(argmax(abs, u)))
109-
u .*= s
110-
v .*= conj(s)
111-
elseif j <= m
112-
u = view(U, :, j)
113-
s = conj(sign(argmax(abs, u)))
114-
u .*= s
115-
else
116-
v = view(Vᴴ, j, :)
117-
s = conj(sign(argmax(abs, v)))
118-
v .*= s
119-
end
120-
end
104+
gaugefix!(Val(:full), U, S, Vᴴ, m, n)
121105
return USVᴴ
122106
end
123107

@@ -142,13 +126,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
142126
throw(ArgumentError("Unsupported SVD algorithm"))
143127
end
144128
# TODO: make this controllable using a `gaugefix` keyword argument
145-
for j in 1:size(U, 2)
146-
u = view(U, :, j)
147-
v = view(Vᴴ, j, :)
148-
s = conj(sign(argmax(abs, u)))
149-
u .*= s
150-
v .*= conj(s)
151-
end
129+
gaugefix!(Val(:compact), U, S, Vᴴ, m, n)
152130
return USVᴴ
153131
end
154132

@@ -249,23 +227,7 @@ function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgor
249227
diagview(S) .= view(S, 1:minmn, 1)
250228
view(S, 2:minmn, 1) .= zero(eltype(S))
251229
# TODO: make this controllable using a `gaugefix` keyword argument
252-
for j in 1:max(m, n)
253-
if j <= minmn
254-
u = view(U, :, j)
255-
v = view(Vᴴ, j, :)
256-
s = conj(sign(_argmaxabs(u)))
257-
u .*= s
258-
v .*= conj(s)
259-
elseif j <= m
260-
u = view(U, :, j)
261-
s = conj(sign(_argmaxabs(u)))
262-
u .*= s
263-
else
264-
v = view(Vᴴ, j, :)
265-
s = conj(sign(_argmaxabs(v)))
266-
v .*= s
267-
end
268-
end
230+
gaugefix!(Val(:full), U, S, Vᴴ, m, n)
269231
return USVᴴ
270232
end
271233

@@ -286,14 +248,7 @@ function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAl
286248
throw(ArgumentError("Unsupported SVD algorithm"))
287249
end
288250
# TODO: make this controllable using a `gaugefix` keyword argument
289-
minmn = min(size(A)...)
290-
for j in 1:minmn # make this more general to account for the larger U in CUSOVLER_Randomized
291-
u = view(U, :, j)
292-
v = view(Vᴴ, j, :)
293-
s = conj(sign(_argmaxabs(u)))
294-
u .*= s
295-
v .*= conj(s)
296-
end
251+
gaugefix!(Val(:compact), U, S, Vᴴ, m, n)
297252
return USVᴴ
298253
end
299254
_argmaxabs(x) = reduce(_largest, x; init=zero(eltype(x)))

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ if !is_buildkite
3535
@safetestset "Image and Null Space" begin
3636
include("orthnull.jl")
3737
end
38-
#=@safetestset "ChainRules" begin
38+
@safetestset "ChainRules" begin
3939
include("chainrules.jl")
40-
end=#
40+
end
4141
@safetestset "MatrixAlgebraKit.jl" begin
4242
@safetestset "Code quality (Aqua.jl)" begin
4343
using MatrixAlgebraKit

0 commit comments

Comments
 (0)