-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathad_utils.jl
More file actions
31 lines (30 loc) · 1.1 KB
/
ad_utils.jl
File metadata and controls
31 lines (30 loc) · 1.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
function remove_svdgauge_dependence!(
ΔU, ΔVᴴ, U, S, Vᴴ;
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
)
gaugepart = mul!(U' * ΔU, Vᴴ, ΔVᴴ', true, true)
gaugepart = project_antihermitian!(gaugepart)
gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0
mul!(ΔU, U, gaugepart, -1, 1)
return ΔU, ΔVᴴ
end
function remove_eiggauge_dependence!(
ΔV, D, V;
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
)
gaugepart = V' * ΔV
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
mul!(ΔV, V / (V' * V), gaugepart, -1, 1)
return ΔV
end
function remove_eighgauge_dependence!(
ΔV, D, V;
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
)
gaugepart = V' * ΔV
gaugepart = project_antihermitian!(gaugepart)
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
mul!(ΔV, V, gaugepart, -1, 1)
return ΔV
end
precision(::Type{T}) where {T <: Number} = sqrt(eps(real(T)))