@@ -139,14 +139,14 @@ for (f!, f, f_full, pb, adj) in (
139139 D, dD = arrayify (D_, dD_)
140140 # update primal
141141 DV = $ f_full (A, Mooncake. primal (alg_dalg))
142- output = copy ( diagview (DV[1 ]))
142+ copy! (D, diagview (DV[1 ]))
143143 V = DV[2 ]
144144 function $adj (:: Mooncake.NoRData )
145145 $ pb (dA, A, (D, V), (dD, nothing ))
146146 MatrixAlgebraKit. zero! (dD)
147147 return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
148148 end
149- return Mooncake . CoDual (output, dD_) , $ adj
149+ return D_dD , $ adj
150150 end
151151 @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), Any, MatrixAlgebraKit. AbstractAlgorithm}
152152 function Mooncake. rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual )
@@ -288,13 +288,13 @@ function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::Co
288288 A, dA = arrayify (A_, dA_)
289289 S, dS = arrayify (S_, dS_)
290290 U, nS, Vᴴ = svd_compact (A, Mooncake. primal (alg_dalg))
291- output = copy ( diagview (nS))
291+ copy! (S, diagview (nS))
292292 function dsvd_vals_adjoint (:: Mooncake.NoRData )
293293 svd_pullback! (dA, A, (U, S, Vᴴ), (nothing , dS, nothing ))
294294 MatrixAlgebraKit. zero! (dS)
295295 return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
296296 end
297- return Mooncake . CoDual (output, dS) , dsvd_vals_adjoint
297+ return S_dS , dsvd_vals_adjoint
298298end
299299
300300@is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (MatrixAlgebraKit. svd_vals), Any, MatrixAlgebraKit. AbstractAlgorithm}
0 commit comments