Skip to content

Commit c553fdd

Browse files
committed
add global gaugefix toggle
1 parent e22b2cf commit c553fdd

7 files changed

Lines changed: 32 additions & 14 deletions

File tree

docs/src/user_interface/decompositions.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,9 +424,10 @@ D, V = eigh_full(A; gaugefix = false)
424424
```
425425

426426
The same keyword is available for `eig_full`, `eig_trunc`, `svd_full`, `svd_compact`, and `svd_trunc` functions.
427+
Additionally, the default value can also be controlled with a global toggle using [`MatrixAlgebraKit.default_gaugefix`](@ref).
427428

428429
```@docs; canonical=false
429430
MatrixAlgebraKit.gaugefix!
431+
MatrixAlgebraKit.default_gaugefix
430432
```
431433

432-

ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module MatrixAlgebraKitGenericLinearAlgebraExt
22

33
using MatrixAlgebraKit
4-
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!
4+
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, default_gaugefix
55
using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr!
66
using LinearAlgebra: I, Diagonal, lmul!
77

@@ -17,7 +17,7 @@ function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIte
1717
F = svd!(A)
1818
U, S, Vᴴ = F.U, Diagonal(F.S), F.Vt
1919

20-
do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool
20+
do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool
2121
do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ)
2222

2323
return U, S, Vᴴ
@@ -29,7 +29,7 @@ function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIterat
2929
S = MatrixAlgebraKit.zero!(similar(F.S, (size(U, 2), size(Vᴴ, 1))))
3030
diagview(S) .= F.S
3131

32-
do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool
32+
do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool
3333
do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ)
3434

3535
return U, S, Vᴴ

src/common/defaults.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,20 @@ default_pullback_rank_atol(A) = eps(norm(A, Inf))^(3 / 4)
4141
Default tolerance for deciding to warn if the provided `A` is not hermitian.
4242
"""
4343
default_hermitian_tol(A) = eps(norm(A, Inf))^(3 / 4)
44+
45+
46+
const DEFAULT_GAUGEFIX = Ref(true)
47+
48+
@doc """
49+
default_gaugefix() -> current_value
50+
default_gaugefix(new_value::Bool) -> previous_value
51+
52+
Global toggle for enabling or disabling the default behavior of gauge fixing the output of the eigen- and singular value decompositions.
53+
""" default_gaugefix
54+
55+
default_gaugefix() = DEFAULT_GAUGEFIX[]
56+
function default_gaugefix(new_value::Bool)
57+
previous_value = DEFAULT_GAUGEFIX[]
58+
DEFAULT_GAUGEFIX[] = new_value
59+
return previous_value
60+
end

src/implementations/eig.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm)
8282
check_input(eig_full!, A, DV, alg)
8383
D, V = DV
8484

85-
do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool
85+
do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool
8686
lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)})
8787

8888
if alg isa LAPACK_Simple
@@ -145,7 +145,7 @@ function eig_full!(A::AbstractMatrix, DV, alg::GPU_EigAlgorithm)
145145
check_input(eig_full!, A, DV, alg)
146146
D, V = DV
147147

148-
do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool
148+
do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool
149149
lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)})
150150

151151
if alg isa GPU_Simple

src/implementations/eigh.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm)
9292
D, V = DV
9393
Dd = D.diag
9494

95-
do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool
95+
do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool
9696
lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)})
9797

9898
if alg isa LAPACK_MultipleRelativelyRobustRepresentations
@@ -168,7 +168,7 @@ function eigh_full!(A::AbstractMatrix, DV, alg::GPU_EighAlgorithm)
168168
D, V = DV
169169
Dd = D.diag
170170

171-
do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool
171+
do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool
172172
lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)})
173173

174174
if alg isa GPU_Jacobi

src/implementations/gen_eig.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function gen_eig_full!(A::AbstractMatrix, B::AbstractMatrix, WV, alg::LAPACK_Eig
5858
check_input(gen_eig_full!, A, B, WV, alg)
5959
W, V = WV
6060

61-
do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool
61+
do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool
6262
lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)})
6363

6464
if alg isa LAPACK_Simple

src/implementations/svd.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
120120
return USVᴴ
121121
end
122122

123-
do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool
123+
do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool
124124
lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)})
125125

126126
if alg isa LAPACK_QRIteration
@@ -153,7 +153,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
153153
check_input(svd_compact!, A, USVᴴ, alg)
154154
U, S, Vᴴ = USVᴴ
155155

156-
do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool
156+
do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool
157157
lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)})
158158

159159
if alg isa LAPACK_QRIteration
@@ -336,7 +336,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
336336
return USVᴴ
337337
end
338338

339-
do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool
339+
do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool
340340
lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)})
341341

342342
if alg isa GPU_QRIteration
@@ -362,7 +362,7 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran
362362
U, S, Vᴴ = USVᴴ
363363
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...)
364364

365-
do_gauge_fix = get(alg.alg.kwargs, :gaugefix, true)::Bool
365+
do_gauge_fix = get(alg.alg.kwargs, :gaugefix, default_gaugefix())::Bool
366366
do_gauge_fix && gaugefix!(svd_trunc!, U, Vᴴ)
367367

368368
# TODO: make sure that truncation is based on maxrank, otherwise this might be wrong
@@ -377,7 +377,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
377377
check_input(svd_compact!, A, USVᴴ, alg)
378378
U, S, Vᴴ = USVᴴ
379379

380-
do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool
380+
do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool
381381
lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)})
382382

383383
if alg isa GPU_QRIteration

0 commit comments

Comments
 (0)