@@ -54,9 +54,9 @@ for (f!, f, pb, adj) in (
5454 $ f! (A, args, Mooncake. primal (alg_dalg))
5555 function $adj (:: NoRData )
5656 copy! (A, Ac)
57+ $ pb (dA, A, (arg1, arg2), (darg1, darg2))
5758 copy! (arg1, arg1c)
5859 copy! (arg2, arg2c)
59- $ pb (dA, A, (arg1, arg2), (darg1, darg2))
6060 zero! (darg1)
6161 zero! (darg2)
6262 return NoRData (), NoRData (), NoRData (), NoRData ()
@@ -101,9 +101,9 @@ for (f!, f, pb, adj) in (
101101 $ f! (A, arg, Mooncake. primal (alg_dalg))
102102 function $adj (:: NoRData )
103103 copy! (A, Ac)
104- copy! (arg, argc)
105104 $ pb (dA, A, arg, darg)
106- MatrixAlgebraKit. zero! (darg)
105+ copy! (arg, argc)
106+ zero! (darg)
107107 return NoRData (), NoRData (), NoRData (), NoRData ()
108108 end
109109 return arg_darg, $ adj
@@ -116,7 +116,7 @@ for (f!, f, pb, adj) in (
116116 function $adj (:: NoRData )
117117 arg, darg = arrayify (output_codual)
118118 $ pb (dA, A, arg, darg)
119- MatrixAlgebraKit . zero! (darg)
119+ zero! (darg)
120120 return NoRData (), NoRData (), NoRData ()
121121 end
122122 return output_codual, $ adj
@@ -134,14 +134,15 @@ for (f!, f, f_full, pb, adj) in (
134134 # compute primal
135135 A, dA = arrayify (A_dA)
136136 D, dD = arrayify (D_dD)
137+ Dc = copy (D)
137138 # update primal
138139 DV = $ f_full (A, Mooncake. primal (alg_dalg))
139140 copy! (D, diagview (DV[1 ]))
140141 V = DV[2 ]
141142 function $adj (:: NoRData )
142- copy! (D, diagview (DV[1 ]))
143143 $ pb (dA, A, DV, dD)
144- MatrixAlgebraKit. zero! (dD)
144+ copy! (D, Dc)
145+ zero! (dD)
145146 return NoRData (), NoRData (), NoRData (), NoRData ()
146147 end
147148 return D_dD, $ adj
@@ -158,7 +159,7 @@ for (f!, f, f_full, pb, adj) in (
158159 function $adj (:: NoRData )
159160 D, dD = arrayify (output_codual)
160161 $ pb (dA, A, DV, dD)
161- MatrixAlgebraKit . zero! (dD)
162+ zero! (dD)
162163 return NoRData (), NoRData (), NoRData ()
163164 end
164165 return output_codual, $ adj
@@ -189,16 +190,16 @@ for (f!, f, f_ne!, f_ne, pb, adj) in (
189190 output_codual = CoDual (output, Mooncake. fdata (Mooncake. zero_tangent (output)))
190191 function $adj (dy:: Tuple{NoRData, NoRData, T} ) where {T <: Real }
191192 copy! (A, Ac)
192- copy! (DV[1 ], DVc[1 ])
193- copy! (DV[2 ], DVc[2 ])
194193 Dtrunc, Vtrunc, ϵ = Mooncake. primal (output_codual)
195194 dDtrunc_, dVtrunc_, dϵ = Mooncake. tangent (output_codual)
196195 abs (dy[3 ]) > MatrixAlgebraKit. defaulttol (dy[3 ]) && @warn " Pullback for $f does not yet support non-zero tangent for the truncation error"
197196 D′, dD′ = arrayify (Dtrunc, dDtrunc_)
198197 V′, dV′ = arrayify (Vtrunc, dVtrunc_)
199198 $ pb (dA, A, (D′, V′), (dD′, dV′))
200- MatrixAlgebraKit. zero! (dD)
201- MatrixAlgebraKit. zero! (dV)
199+ copy! (DV[1 ], DVc[1 ])
200+ copy! (DV[2 ], DVc[2 ])
201+ zero! (dD)
202+ zero! (dV)
202203 return NoRData (), NoRData (), NoRData ()
203204 end
204205 return output_codual, $ adj
@@ -220,8 +221,8 @@ for (f!, f, f_ne!, f_ne, pb, adj) in (
220221 D, dD = arrayify (Dtrunc, dDtrunc_)
221222 V, dV = arrayify (Vtrunc, dVtrunc_)
222223 $ pb (dA, A, (D, V), (dD, dV))
223- MatrixAlgebraKit . zero! (dD)
224- MatrixAlgebraKit . zero! (dV)
224+ zero! (dD)
225+ zero! (dV)
225226 return NoRData (), NoRData (), NoRData ()
226227 end
227228 return output_codual, $ adj
@@ -244,15 +245,15 @@ for (f!, f, f_ne!, f_ne, pb, adj) in (
244245 output_codual = CoDual (output, Mooncake. fdata (Mooncake. zero_tangent (output)))
245246 function $adj (:: NoRData )
246247 copy! (A, Ac)
247- copy! (DV[1 ], DVc[1 ])
248- copy! (DV[2 ], DVc[2 ])
249248 Dtrunc, Vtrunc = Mooncake. primal (output_codual)
250249 dDtrunc_, dVtrunc_ = Mooncake. tangent (output_codual)
251250 D′, dD′ = arrayify (Dtrunc, dDtrunc_)
252251 V′, dV′ = arrayify (Vtrunc, dVtrunc_)
253252 $ pb (dA, A, (D′, V′), (dD′, dV′))
254- MatrixAlgebraKit. zero! (dD)
255- MatrixAlgebraKit. zero! (dV)
253+ copy! (DV[1 ], DVc[1 ])
254+ copy! (DV[2 ], DVc[2 ])
255+ zero! (dD)
256+ zero! (dV)
256257 return NoRData (), NoRData (), NoRData ()
257258 end
258259 return output_codual, $ adj
@@ -273,8 +274,8 @@ for (f!, f, f_ne!, f_ne, pb, adj) in (
273274 D, dD = arrayify (Dtrunc, dDtrunc_)
274275 V, dV = arrayify (Vtrunc, dVtrunc_)
275276 $ pb (dA, A, (D, V), (dD, dV))
276- MatrixAlgebraKit . zero! (dD)
277- MatrixAlgebraKit . zero! (dV)
277+ zero! (dD)
278+ zero! (dV)
278279 return NoRData (), NoRData (), NoRData ()
279280 end
280281 return output_codual, $ adj
@@ -300,9 +301,6 @@ for (f!, f) in (
300301 output = $ f! (A, Mooncake. primal (alg_dalg))
301302 function svd_adjoint (:: NoRData )
302303 copy! (A, Ac)
303- copy! (U, USVᴴc[1 ])
304- copy! (S, USVᴴc[2 ])
305- copy! (Vᴴ, USVᴴc[3 ])
306304 if $ (f! == svd_compact!)
307305 svd_pullback! (dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
308306 else # full
@@ -315,9 +313,12 @@ for (f!, f) in (
315313 vdVᴴ = view (dVᴴ, 1 : minmn, :)
316314 svd_pullback! (dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
317315 end
318- MatrixAlgebraKit. zero! (dU)
319- MatrixAlgebraKit. zero! (dS)
320- MatrixAlgebraKit. zero! (dVᴴ)
316+ copy! (U, USVᴴc[1 ])
317+ copy! (S, USVᴴc[2 ])
318+ copy! (Vᴴ, USVᴴc[3 ])
319+ zero! (dU)
320+ zero! (dS)
321+ zero! (dVᴴ)
321322 return NoRData (), NoRData (), NoRData (), NoRData ()
322323 end
323324 return CoDual (output, dUSVᴴ), svd_adjoint
@@ -349,9 +350,9 @@ for (f!, f) in (
349350 vdVᴴ = view (dVᴴ, 1 : minmn, :)
350351 svd_pullback! (dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
351352 end
352- MatrixAlgebraKit . zero! (dU)
353- MatrixAlgebraKit . zero! (dS)
354- MatrixAlgebraKit . zero! (dVᴴ)
353+ zero! (dU)
354+ zero! (dS)
355+ zero! (dVᴴ)
355356 return NoRData (), NoRData (), NoRData ()
356357 end
357358 return USVᴴ_codual, svd_adjoint
@@ -364,12 +365,13 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
364365 # compute primal
365366 A, dA = arrayify (A_dA)
366367 S, dS = arrayify (S_dS)
368+ Sc = copy (S)
367369 USVᴴ = svd_compact (A, Mooncake. primal (alg_dalg))
368370 copy! (S, diagview (USVᴴ[2 ]))
369371 function svd_vals_adjoint (:: NoRData )
370372 svd_vals_pullback! (dA, A, USVᴴ, dS)
371- MatrixAlgebraKit . zero! (dS)
372- copy! (S, diagview (USVᴴ[ 2 ]) )
373+ zero! (dS)
374+ copy! (S, Sc )
373375 return NoRData (), NoRData (), NoRData (), NoRData ()
374376 end
375377 return S_dS, svd_vals_adjoint
@@ -389,7 +391,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
389391 function svd_vals_adjoint (:: NoRData )
390392 S, dS = arrayify (S_codual)
391393 svd_vals_pullback! (dA, A, USVᴴ, dS)
392- MatrixAlgebraKit . zero! (dS)
394+ zero! (dS)
393395 return NoRData (), NoRData (), NoRData ()
394396 end
395397 return S_codual, svd_vals_adjoint
@@ -415,19 +417,19 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
415417 output_codual = CoDual (output, Mooncake. fdata (Mooncake. zero_tangent (output)))
416418 function svd_trunc_adjoint (dy:: Tuple{NoRData, NoRData, NoRData, T} ) where {T <: Real }
417419 copy! (A, Ac)
418- copy! (U, USVᴴc[1 ])
419- copy! (S, USVᴴc[2 ])
420- copy! (Vᴴ, USVᴴc[3 ])
421420 Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake. primal (output_codual)
422421 dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake. tangent (output_codual)
423422 abs (dy[4 ]) > MatrixAlgebraKit. defaulttol (dy[4 ]) && @warn " Pullback for svd_trunc does not yet support non-zero tangent for the truncation error"
424423 U′, dU′ = arrayify (Utrunc, dUtrunc_)
425424 S′, dS′ = arrayify (Strunc, dStrunc_)
426425 Vᴴ′, dVᴴ′ = arrayify (Vᴴtrunc, dVᴴtrunc_)
427426 svd_trunc_pullback! (dA, A, (U′, S′, Vᴴ′), (dU′, dS′, dVᴴ′))
428- MatrixAlgebraKit. zero! (dU)
429- MatrixAlgebraKit. zero! (dS)
430- MatrixAlgebraKit. zero! (dVᴴ)
427+ copy! (U, USVᴴc[1 ])
428+ copy! (S, USVᴴc[2 ])
429+ copy! (Vᴴ, USVᴴc[3 ])
430+ zero! (dU)
431+ zero! (dS)
432+ zero! (dVᴴ)
431433 return NoRData (), NoRData (), NoRData ()
432434 end
433435 return output_codual, svd_trunc_adjoint
@@ -454,9 +456,9 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
454456 S, dS = arrayify (Strunc, dStrunc_)
455457 Vᴴ, dVᴴ = arrayify (Vᴴtrunc, dVᴴtrunc_)
456458 svd_trunc_pullback! (dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
457- MatrixAlgebraKit . zero! (dU)
458- MatrixAlgebraKit . zero! (dS)
459- MatrixAlgebraKit . zero! (dVᴴ)
459+ zero! (dU)
460+ zero! (dS)
461+ zero! (dVᴴ)
460462 return NoRData (), NoRData (), NoRData ()
461463 end
462464 return output_codual, svd_trunc_adjoint
@@ -482,18 +484,18 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, US
482484 output_codual = CoDual (output, Mooncake. fdata (Mooncake. zero_tangent (output)))
483485 function svd_trunc_adjoint (:: NoRData )
484486 copy! (A, Ac)
485- copy! (U, USVᴴc[1 ])
486- copy! (S, USVᴴc[2 ])
487- copy! (Vᴴ, USVᴴc[3 ])
488487 Utrunc, Strunc, Vᴴtrunc = Mooncake. primal (output_codual)
489488 dUtrunc_, dStrunc_, dVᴴtrunc_ = Mooncake. tangent (output_codual)
490489 U′, dU′ = arrayify (Utrunc, dUtrunc_)
491490 S′, dS′ = arrayify (Strunc, dStrunc_)
492491 Vᴴ′, dVᴴ′ = arrayify (Vᴴtrunc, dVᴴtrunc_)
493492 svd_trunc_pullback! (dA, A, (U′, S′, Vᴴ′), (dU′, dS′, dVᴴ′))
494- MatrixAlgebraKit. zero! (dU)
495- MatrixAlgebraKit. zero! (dS)
496- MatrixAlgebraKit. zero! (dVᴴ)
493+ copy! (U, USVᴴc[1 ])
494+ copy! (S, USVᴴc[2 ])
495+ copy! (Vᴴ, USVᴴc[3 ])
496+ zero! (dU)
497+ zero! (dS)
498+ zero! (dVᴴ)
497499 return NoRData (), NoRData (), NoRData ()
498500 end
499501 return output_codual, svd_trunc_adjoint
@@ -519,9 +521,9 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al
519521 S, dS = arrayify (Strunc, dStrunc_)
520522 Vᴴ, dVᴴ = arrayify (Vᴴtrunc, dVᴴtrunc_)
521523 svd_trunc_pullback! (dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
522- MatrixAlgebraKit . zero! (dU)
523- MatrixAlgebraKit . zero! (dS)
524- MatrixAlgebraKit . zero! (dVᴴ)
524+ zero! (dU)
525+ zero! (dS)
526+ zero! (dVᴴ)
525527 return NoRData (), NoRData (), NoRData ()
526528 end
527529 return output_codual, svd_trunc_adjoint
0 commit comments