@@ -111,9 +111,7 @@ for (f!, f, pb, adj) in (
111111 output = $ f (A, Mooncake. primal (alg_dalg))
112112 output_codual = Mooncake. CoDual (output, Mooncake. zero_tangent (output))
113113 function $adj (:: Mooncake.NoRData )
114- arg = Mooncake. primal (output_codual)
115- darg_ = Mooncake. tangent (output_codual)
116- arg, darg = Mooncake. arrayify (arg, darg_)
114+ arg, darg = Mooncake. arrayify (output_codual)
117115 $ pb (dA, A, arg, darg)
118116 MatrixAlgebraKit. zero! (darg)
119117 return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
@@ -131,12 +129,8 @@ for (f!, f, f_full, pb, adj) in (
131129 @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
132130 function Mooncake. rrule!! (:: CoDual{typeof($f!)} , A_dA:: CoDual , D_dD:: CoDual , alg_dalg:: CoDual )
133131 # compute primal
134- D_ = Mooncake. primal (D_dD)
135- dD_ = Mooncake. tangent (D_dD)
136- A_ = Mooncake. primal (A_dA)
137- dA_ = Mooncake. tangent (A_dA)
138- A, dA = arrayify (A_, dA_)
139- D, dD = arrayify (D_, dD_)
132+ A, dA = arrayify (A_dA)
133+ D, dD = arrayify (D_dD)
140134 # update primal
141135 DV = $ f_full (A, Mooncake. primal (alg_dalg))
142136 copy! (D, diagview (DV[1 ]))
@@ -151,18 +145,14 @@ for (f!, f, f_full, pb, adj) in (
151145 @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), Any, MatrixAlgebraKit. AbstractAlgorithm}
152146 function Mooncake. rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual )
153147 # compute primal
154- A_ = Mooncake. primal (A_dA)
155- dA_ = Mooncake. tangent (A_dA)
156- A, dA = arrayify (A_, dA_)
148+ A, dA = arrayify (A_dA)
157149 # update primal
158150 DV = $ f_full (A, Mooncake. primal (alg_dalg))
159151 V = DV[2 ]
160- output = copy ( diagview (DV[1 ]) )
152+ output = diagview (DV[1 ])
161153 output_codual = Mooncake. CoDual (output, Mooncake. zero_tangent (output))
162154 function $adj (:: Mooncake.NoRData )
163- D = Mooncake. primal (output_codual)
164- dD_ = Mooncake. tangent (output_codual)
165- D, dD = Mooncake. arrayify (D, dD_)
155+ D_dD = Mooncake. arrayify (D_dD)
166156 $ pb (dA, A, (D, V), (dD, nothing ))
167157 MatrixAlgebraKit. zero! (dD)
168158 return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
@@ -177,12 +167,10 @@ for (f, pb, adj) in (
177167 (eigh_trunc, eigh_trunc_pullback!, :deigh_trunc_adjoint ),
178168 )
179169 @eval begin
180- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), Any, MatrixAlgebraKit. TruncatedAlgorithm }
170+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), Any, MatrixAlgebraKit. AbstractAlgorithm }
181171 function Mooncake. rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual )
182172 # compute primal
183- A_ = Mooncake. primal (A_dA)
184- dA_ = Mooncake. tangent (A_dA)
185- A, dA = arrayify (A_, dA_)
173+ A, dA = arrayify (A_dA)
186174 alg = Mooncake. primal (alg_dalg)
187175 output = $ f (A, alg)
188176 # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
@@ -193,6 +181,7 @@ for (f, pb, adj) in (
193181 function $adj (dy:: Tuple{Mooncake.NoRData, Mooncake.NoRData, T} ) where {T <: Real }
194182 Dtrunc, Vtrunc, ϵ = Mooncake. primal (output_codual)
195183 dDtrunc_, dVtrunc_, dϵ = Mooncake. tangent (output_codual)
184+ abs (dϵ) > MatrixAlgebraKit. defaulttol (dϵ) && @warn " Pullback for $f does not yet support non-zero tangent for the truncation error"
196185 D, dD = Mooncake. arrayify (Dtrunc, dDtrunc_)
197186 V, dV = Mooncake. arrayify (Vtrunc, dVtrunc_)
198187 $ pb (dA, A, (D, V), (dD, dV))
281270@is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (MatrixAlgebraKit. svd_vals!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
282271function Mooncake. rrule!! (:: CoDual{typeof(MatrixAlgebraKit.svd_vals!)} , A_dA:: CoDual , S_dS:: CoDual , alg_dalg:: CoDual )
283272 # compute primal
284- S_ = Mooncake. primal (S_dS)
285- dS_ = Mooncake. tangent (S_dS)
286- A_ = Mooncake. primal (A_dA)
287- dA_ = Mooncake. tangent (A_dA)
288- A, dA = arrayify (A_, dA_)
289- S, dS = arrayify (S_, dS_)
273+ A, dA = arrayify (A_dA)
274+ S, dS = arrayify (S_dS)
290275 U, nS, Vᴴ = svd_compact (A, Mooncake. primal (alg_dalg))
291276 copy! (S, diagview (nS))
292277 function dsvd_vals_adjoint (:: Mooncake.NoRData )
@@ -300,28 +285,23 @@ end
300285@is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (MatrixAlgebraKit. svd_vals), Any, MatrixAlgebraKit. AbstractAlgorithm}
301286function Mooncake. rrule!! (:: CoDual{typeof(MatrixAlgebraKit.svd_vals)} , A_dA:: CoDual , alg_dalg:: CoDual )
302287 # compute primal
303- A = Mooncake. primal (A_dA)
304- dA_ = Mooncake. tangent (A_dA)
305- A, dA = arrayify (A, dA_)
306- S = svd_vals (A, Mooncake. primal (alg_dalg))
307- U, _, Vᴴ = svd_compact (A, Mooncake. primal (alg_dalg))
288+ A, dA = arrayify (A_dA)
289+ U, S, Vᴴ = svd_compact (A, Mooncake. primal (alg_dalg))
308290 # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
309291 # of ComplexF32) into the correct **forwards** data type (since we are now in the forward
310292 # pass). For many types this is done automatically when the forward step returns, but
311293 # not for nested structs with various fields (like Diagonal{Complex})
312- S_codual = Mooncake. CoDual (S , Mooncake. fdata (Mooncake. zero_tangent (S)))
294+ S_codual = Mooncake. CoDual (diagview (S) , Mooncake. fdata (Mooncake. zero_tangent (S)))
313295 function dsvd_vals_adjoint (:: Mooncake.NoRData )
314- S = Mooncake. primal (S_codual)
315- dS_ = Mooncake. tangent (S_codual)
316- S, dS = Mooncake. arrayify (S, dS_)
296+ S, dS = Mooncake. arrayify (S_codual)
317297 svd_pullback! (dA, A, (U, S, Vᴴ), (nothing , dS, nothing ))
318298 MatrixAlgebraKit. zero! (dS)
319299 return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
320300 end
321301 return S_codual, dsvd_vals_adjoint
322302end
323303
324- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (MatrixAlgebraKit. svd_trunc), Any, MatrixAlgebraKit. TruncatedAlgorithm }
304+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (MatrixAlgebraKit. svd_trunc), Any, MatrixAlgebraKit. AbstractAlgorithm }
325305function Mooncake. rrule!! (:: CoDual{typeof(MatrixAlgebraKit.svd_trunc)} , A_dA:: CoDual , alg_dalg:: CoDual )
326306 # compute primal
327307 A_ = Mooncake. primal (A_dA)
@@ -337,6 +317,7 @@ function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_trunc)}, A_dA::Co
337317 function dsvd_trunc_adjoint (dy:: Tuple{Mooncake.NoRData, Mooncake.NoRData, Mooncake.NoRData, T} ) where {T <: Real }
338318 Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake. primal (output_codual)
339319 dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake. tangent (output_codual)
320+ abs (dϵ) > MatrixAlgebraKit. defaulttol (dϵ) && @warn " Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"
340321 U, dU = Mooncake. arrayify (Utrunc, dUtrunc_)
341322 S, dS = Mooncake. arrayify (Strunc, dStrunc_)
342323 Vᴴ, dVᴴ = Mooncake. arrayify (Vᴴtrunc, dVᴴtrunc_)
0 commit comments