Skip to content

Commit 410221d

Browse files
committed
Format
1 parent 66dc4c1 commit 410221d

2 files changed

Lines changed: 28 additions & 28 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ for (f, f_full, pb, adj) in (
7676
arg = $f(copy(A), arg, Mooncake.primal(alg_dalg))
7777
function $adj(::Mooncake.NoRData)
7878
$pb(dA, A, arg, darg; kwargs...)
79-
A .= Ac
80-
arg .= argc
79+
A .= Ac
80+
arg .= argc
8181
MatrixAlgebraKit.zero!(darg)
8282
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
8383
end
@@ -89,17 +89,17 @@ end
8989
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eig_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
9090
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eig_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...)
9191
# compute primal
92-
D_ = Mooncake.primal(D_dD)
93-
dD_ = Mooncake.tangent(D_dD)
94-
A_ = Mooncake.primal(A_dA)
95-
dA_ = Mooncake.tangent(A_dA)
92+
D_ = Mooncake.primal(D_dD)
93+
dD_ = Mooncake.tangent(D_dD)
94+
A_ = Mooncake.primal(A_dA)
95+
dA_ = Mooncake.tangent(A_dA)
9696
A, dA = arrayify(A_, dA_)
9797
D, dD = arrayify(D_, dD_)
98-
Ac = copy(A)
99-
Dc = copy(D)
98+
Ac = copy(A)
99+
Dc = copy(D)
100100
# update primal
101-
DV = eig_full(A, Mooncake.primal(alg_dalg); kwargs...)
102-
V = DV[2]
101+
DV = eig_full(A, Mooncake.primal(alg_dalg); kwargs...)
102+
V = DV[2]
103103
eig_vals!(A, D, Mooncake.primal(alg_dalg))
104104
function deig_vals_adjoint(::Mooncake.NoRData)
105105
A .= Ac
@@ -114,19 +114,19 @@ end
114114
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eigh_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
115115
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eigh_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...)
116116
# compute primal
117-
D_ = Mooncake.primal(D_dD)
118-
dD_ = Mooncake.tangent(D_dD)
119-
A_ = Mooncake.primal(A_dA)
120-
dA_ = Mooncake.tangent(A_dA)
117+
D_ = Mooncake.primal(D_dD)
118+
dD_ = Mooncake.tangent(D_dD)
119+
A_ = Mooncake.primal(A_dA)
120+
dA_ = Mooncake.tangent(A_dA)
121121
A, dA = arrayify(A_, dA_)
122-
Ac = copy(A)
122+
Ac = copy(A)
123123
D, dD = arrayify(D_, dD_)
124-
Dc = copy(D)
125-
DV = eigh_full(A, Mooncake.primal(alg_dalg); kwargs...)
126-
D .= diagview(DV[1])
127-
V = DV[2]
124+
Dc = copy(D)
125+
DV = eigh_full(A, Mooncake.primal(alg_dalg); kwargs...)
126+
D .= diagview(DV[1])
127+
V = DV[2]
128128
function deigh_vals_adjoint(::Mooncake.NoRData)
129-
A .= Ac
129+
A .= Ac
130130
eigh_pullback!(dA, A, (D, V), (dD, nothing); kwargs...)
131131
D .= Dc
132132
MatrixAlgebraKit.zero!(dD)
@@ -154,7 +154,7 @@ for f in (svd_full!, svd_compact!)
154154
minmn = min(size(A)...)
155155
function dsvd_adjoint(::Mooncake.NoRData)
156156
A .= Ac
157-
if ($f == svd_compact!)
157+
if ($f == svd_compact!)
158158
svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
159159
else # full
160160
vU = view(U, :, 1:minmn)
@@ -165,8 +165,8 @@ for f in (svd_full!, svd_compact!)
165165
vdVᴴ = view(dVᴴ, 1:minmn, :)
166166
svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
167167
end
168-
U .= Uc
169-
S .= Sc
168+
U .= Uc
169+
S .= Sc
170170
Vᴴ .= Vᴴc
171171
MatrixAlgebraKit.zero!(dU)
172172
MatrixAlgebraKit.zero!(dS)
@@ -181,17 +181,17 @@ end
181181
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
182182
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual; kwargs...)
183183
# compute primal
184-
S_ = Mooncake.primal(S_dS)
184+
S_ = Mooncake.primal(S_dS)
185185
dS_ = Mooncake.tangent(S_dS)
186-
A_ = Mooncake.primal(A_dA)
186+
A_ = Mooncake.primal(A_dA)
187187
dA_ = Mooncake.tangent(A_dA)
188188
A, dA = arrayify(A_, dA_)
189189
S, dS = arrayify(S_, dS_)
190-
Ac = copy(A)
190+
Ac = copy(A)
191191
U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg); kwargs...)
192192
S .= diagview(nS)
193193
function dsvd_vals_adjoint(::Mooncake.NoRData)
194-
A .= Ac
194+
A .= Ac
195195
svd_pullback!(dA, A, (U, S, Vᴴ), (nothing, dS, nothing))
196196
MatrixAlgebraKit.zero!(dS)
197197
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()

src/pullbacks/eig.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ function eig_pullback!(
4747
Δgauge = norm(view(VᴴΔV, mask), Inf)
4848
Δgauge < gauge_atol ||
4949
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
50-
50+
5151
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, tol))
5252

5353
if !iszerotangent(ΔDmat)

0 commit comments

Comments
 (0)