@@ -76,8 +76,8 @@ for (f, f_full, pb, adj) in (
7676 arg = $ f (copy (A), arg, Mooncake. primal (alg_dalg))
7777 function $adj (:: Mooncake.NoRData )
7878 $ pb (dA, A, arg, darg; kwargs... )
79- A .= Ac
80- arg .= argc
79+ A .= Ac
80+ arg .= argc
8181 MatrixAlgebraKit. zero! (darg)
8282 return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
8383 end
8989@is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (MatrixAlgebraKit. eig_vals!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
9090function Mooncake. rrule!! (:: CoDual{<:typeof(MatrixAlgebraKit.eig_vals!)} , A_dA:: CoDual , D_dD:: CoDual , alg_dalg:: CoDual ; kwargs... )
9191 # compute primal
92- D_ = Mooncake. primal (D_dD)
93- dD_ = Mooncake. tangent (D_dD)
94- A_ = Mooncake. primal (A_dA)
95- dA_ = Mooncake. tangent (A_dA)
92+ D_ = Mooncake. primal (D_dD)
93+ dD_ = Mooncake. tangent (D_dD)
94+ A_ = Mooncake. primal (A_dA)
95+ dA_ = Mooncake. tangent (A_dA)
9696 A, dA = arrayify (A_, dA_)
9797 D, dD = arrayify (D_, dD_)
98- Ac = copy (A)
99- Dc = copy (D)
98+ Ac = copy (A)
99+ Dc = copy (D)
100100 # update primal
101- DV = eig_full (A, Mooncake. primal (alg_dalg); kwargs... )
102- V = DV[2 ]
101+ DV = eig_full (A, Mooncake. primal (alg_dalg); kwargs... )
102+ V = DV[2 ]
103103 eig_vals! (A, D, Mooncake. primal (alg_dalg))
104104 function deig_vals_adjoint (:: Mooncake.NoRData )
105105 A .= Ac
@@ -114,19 +114,19 @@ end
114114@is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (MatrixAlgebraKit. eigh_vals!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
115115function Mooncake. rrule!! (:: CoDual{<:typeof(MatrixAlgebraKit.eigh_vals!)} , A_dA:: CoDual , D_dD:: CoDual , alg_dalg:: CoDual ; kwargs... )
116116 # compute primal
117- D_ = Mooncake. primal (D_dD)
118- dD_ = Mooncake. tangent (D_dD)
119- A_ = Mooncake. primal (A_dA)
120- dA_ = Mooncake. tangent (A_dA)
117+ D_ = Mooncake. primal (D_dD)
118+ dD_ = Mooncake. tangent (D_dD)
119+ A_ = Mooncake. primal (A_dA)
120+ dA_ = Mooncake. tangent (A_dA)
121121 A, dA = arrayify (A_, dA_)
122- Ac = copy (A)
122+ Ac = copy (A)
123123 D, dD = arrayify (D_, dD_)
124- Dc = copy (D)
125- DV = eigh_full (A, Mooncake. primal (alg_dalg); kwargs... )
126- D .= diagview (DV[1 ])
127- V = DV[2 ]
124+ Dc = copy (D)
125+ DV = eigh_full (A, Mooncake. primal (alg_dalg); kwargs... )
126+ D .= diagview (DV[1 ])
127+ V = DV[2 ]
128128 function deigh_vals_adjoint (:: Mooncake.NoRData )
129- A .= Ac
129+ A .= Ac
130130 eigh_pullback! (dA, A, (D, V), (dD, nothing ); kwargs... )
131131 D .= Dc
132132 MatrixAlgebraKit. zero! (dD)
@@ -154,7 +154,7 @@ for f in (svd_full!, svd_compact!)
154154 minmn = min (size (A)... )
155155 function dsvd_adjoint (:: Mooncake.NoRData )
156156 A .= Ac
157- if ($ f == svd_compact!)
157+ if ($ f == svd_compact!)
158158 svd_pullback! (dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
159159 else # full
160160 vU = view (U, :, 1 : minmn)
@@ -165,8 +165,8 @@ for f in (svd_full!, svd_compact!)
165165 vdVᴴ = view (dVᴴ, 1 : minmn, :)
166166 svd_pullback! (dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
167167 end
168- U .= Uc
169- S .= Sc
168+ U .= Uc
169+ S .= Sc
170170 Vᴴ .= Vᴴc
171171 MatrixAlgebraKit. zero! (dU)
172172 MatrixAlgebraKit. zero! (dS)
@@ -181,17 +181,17 @@ end
181181@is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (MatrixAlgebraKit. svd_vals!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
182182function Mooncake. rrule!! (:: CoDual{<:typeof(MatrixAlgebraKit.svd_vals!)} , A_dA:: CoDual , S_dS:: CoDual , alg_dalg:: CoDual ; kwargs... )
183183 # compute primal
184- S_ = Mooncake. primal (S_dS)
184+ S_ = Mooncake. primal (S_dS)
185185 dS_ = Mooncake. tangent (S_dS)
186- A_ = Mooncake. primal (A_dA)
186+ A_ = Mooncake. primal (A_dA)
187187 dA_ = Mooncake. tangent (A_dA)
188188 A, dA = arrayify (A_, dA_)
189189 S, dS = arrayify (S_, dS_)
190- Ac = copy (A)
190+ Ac = copy (A)
191191 U, nS, Vᴴ = svd_compact (A, Mooncake. primal (alg_dalg); kwargs... )
192192 S .= diagview (nS)
193193 function dsvd_vals_adjoint (:: Mooncake.NoRData )
194- A .= Ac
194+ A .= Ac
195195 svd_pullback! (dA, A, (U, S, Vᴴ), (nothing , dS, nothing ))
196196 MatrixAlgebraKit. zero! (dS)
197197 return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
0 commit comments