Skip to content

Commit f93de2c

Browse files
authored
Improve svd gaugefixing performance (#238)
* Improve svd gaugefixing performance * Fixes * Reduce allocs * Restore the alloc * Use abs2 for complex comparisons
1 parent fd5a0e4 commit f93de2c

1 file changed

Lines changed: 7 additions & 8 deletions

File tree

src/common/gauge.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ is real and positive.
1010

1111
# Helper functions
1212
_argmaxabs(x) = reduce(_largest, x; init = zero(eltype(x)))
13-
_largest(x, y) = abs(x) < abs(y) ? y : x
13+
_largest(x::Real, y::Real) = abs(x) < abs(y) ? y : x
14+
_largest(x::Complex, y::Complex) = abs2(x) < abs2(y) ? y : x
1415

1516
function gaugefix!(::typeof(qr_householder!), Q, R, Rd)
1617
ax = Base.OneTo(length(Rd))
@@ -67,12 +68,10 @@ end
6768

6869
function gaugefix!(::Union{typeof(svd_compact!), typeof(svd_trunc!)}, U, Vᴴ)
6970
@assert axes(U, 2) == axes(Vᴴ, 1)
70-
for j in axes(U, 2)
71-
u = view(U, :, j)
72-
v = view(Vᴴ, j, :)
73-
s = sign(_argmaxabs(u))
74-
u .*= conj(s)
75-
v .*= s
76-
end
71+
signs = reduce(_largest, U; dims = 1, init = zero(eltype(U)))
72+
@. signs = sign(signs)
73+
signs_t = transpose(signs)
74+
@. U = U * conj(signs)
75+
@. Vᴴ = signs_t * Vᴴ
7776
return (U, Vᴴ)
7877
end

0 commit comments

Comments
 (0)