@@ -15,19 +15,6 @@ using LinearAlgebra
1515
1616Mooncake. tangent_type (:: Type{<:MatrixAlgebraKit.AbstractAlgorithm} ) = Mooncake. NoTangent
1717
18- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (copy_input), Any, Any}
19- function Mooncake. rrule!! (:: CoDual{typeof(copy_input)} , f_df:: CoDual , A_dA:: CoDual )
20- Ac = copy_input (Mooncake. primal (f_df), Mooncake. primal (A_dA))
21- Ac_dAc = Mooncake. zero_fcodual (Ac)
22- dAc = Mooncake. tangent (Ac_dAc)
23- function copy_input_pb (:: NoRData )
24- Mooncake. increment!! (Mooncake. tangent (A_dA), dAc)
25- return NoRData (), NoRData (), NoRData ()
26- end
27- return Ac_dAc, copy_input_pb
28- end
29-
30- Mooncake. @zero_derivative Mooncake. DefaultCtx Tuple{typeof (initialize_output), Any, Any, Any}
3118# two-argument in-place factorizations like LQ, QR, EIG
3219for (f!, f, pb, adj) in (
3320 (:qr_full! , :qr_full , :qr_pullback! , :qr_adjoint ),
@@ -53,12 +40,26 @@ for (f!, f, pb, adj) in (
5340 arg2c = copy (arg2)
5441 $ f! (A, args, Mooncake. primal (alg_dalg))
5542 function $adj (:: NoRData )
43+ if ! (A === arg1 || A === arg2)
44+ $ pb (dA, A, (arg1, arg2), (darg1, darg2))
45+ else
46+ ΔA = zero (A)
47+ $ pb (ΔA, A, (arg1, arg2), (darg1, darg2))
48+ dA .= ΔA
49+ end
50+ if A === arg1
51+ zero! (darg2)
52+ copy! (arg2, arg2c)
53+ elseif A === arg2
54+ zero! (darg1)
55+ copy! (arg1, arg1c)
56+ else
57+ zero! (darg1)
58+ zero! (darg2)
59+ copy! (arg2, arg2c)
60+ copy! (arg1, arg1c)
61+ end
5662 copy! (A, Ac)
57- $ pb (dA, A, (arg1, arg2), (darg1, darg2))
58- copy! (arg1, arg1c)
59- copy! (arg2, arg2c)
60- zero! (darg1)
61- zero! (darg2)
6263 return NoRData (), NoRData (), NoRData (), NoRData ()
6364 end
6465 return args_dargs, $ adj
@@ -140,9 +141,19 @@ for (f!, f, f_full, pb, adj) in (
140141 copy! (D, diagview (DV[1 ]))
141142 V = DV[2 ]
142143 function $adj (:: NoRData )
143- $ pb (dA, A, DV, dD)
144- copy! (D, Dc)
145- zero! (dD)
144+ if A != = D
145+ $ pb (dA, A, DV, dD)
146+ else
147+ ΔA = zero (A)
148+ $ pb (ΔA, A, DV, dD)
149+ dA .= A
150+ end
151+ if A != = D
152+ zero! (dD)
153+ copy! (D, Dc)
154+ else
155+ copy! (A, Ac)
156+ end
146157 return NoRData (), NoRData (), NoRData (), NoRData ()
147158 end
148159 return D_dD, $ adj
@@ -199,15 +210,27 @@ for f in (:eig, :eigh)
199210 # not for nested structs with various fields (like Diagonal{Complex})
200211 output_codual = Mooncake. zero_fcodual (output)
201212 function $f_adjoint! (dy:: Tuple{NoRData, NoRData, <:Real} )
202- copy! (A, Ac)
203213 Dtrunc, Vtrunc, ϵ = Mooncake. primal (output_codual)
204214 dDtrunc_, dVtrunc_, dϵ = Mooncake. tangent (output_codual)
205215 _warn_pullback_truncerror (dy[3 ])
206216 D′, dD′ = arrayify (Dtrunc, dDtrunc_)
207217 V′, dV′ = arrayify (Vtrunc, dVtrunc_)
208- $ f_trunc_pullback! (dA, A, (D′, V′), (dD′, dV′))
209- copy! (DV[1 ], DVc[1 ])
210- copy! (DV[2 ], DVc[2 ])
218+ D, dD = arrayify (DV[1 ], dDV[1 ])
219+ V, dV = arrayify (DV[2 ], dDV[2 ])
220+ copy! (A, Ac)
221+ if ! (A === D || A === V)
222+ $ f_trunc_pullback! (dA, A, (D′, V′), (dD′, dV′))
223+ else
224+ ΔA = zero (A)
225+ $ f_trunc_pullback! (ΔA, A, (D′, V′), (dD′, dV′))
226+ dA .= ΔA
227+ end
228+ if A === D
229+ copy! (DV[2 ], DVc[2 ])
230+ else
231+ copy! (DV[1 ], DVc[1 ])
232+ copy! (DV[2 ], DVc[2 ])
233+ end
211234 zero! (dD′)
212235 zero! (dV′)
213236 return NoRData (), NoRData (), NoRData (), NoRData ()
@@ -239,12 +262,22 @@ for f in (:eig, :eigh)
239262 _warn_pullback_truncerror (dϵ)
240263
241264 # compute pullbacks
242- $ f_pullback! (dA, Ac, DV, dDVtrunc, ind)
243- zero! .(dDVtrunc) # since this is allocated in this function this is probably not required
244-
265+ if ! (A === DV[1 ] || A === DV[2 ])
266+ $ f_pullback! (dA, Ac, DV, dDVtrunc, ind)
267+ else
268+ ΔA = zero (A)
269+ $ f_pullback! (ΔA, Ac, DV, dDVtrunc, ind)
270+ dA .= ΔA
271+ end
245272 # restore state
246273 copy! (A, Ac)
247- copy! .(DV, DVc)
274+ if A === DV[1 ]
275+ copy! (DV[2 ], DVc[2 ])
276+ zero! (dDV[2 ])
277+ else
278+ copy! .(DV, DVc)
279+ zero! .(dDV)
280+ end
248281
249282 return ntuple (Returns (NoRData ()), 4 )
250283 end
@@ -351,12 +384,23 @@ for f in (:eig, :eigh)
351384 dDVtrunc = last .(arrayify .(DVtrunc, Mooncake. tangent (DVtrunc_dDVtrunc)))
352385 function $f_adjoint! (:: NoRData )
353386 # compute pullbacks
354- $ f_pullback! (dA, Ac, DV, dDVtrunc, ind)
355- zero! .(dDV)
387+ if ! (A === DV[1 ] || A === DV[2 ])
388+ $ f_pullback! (dA, Ac, DV, dDVtrunc, ind)
389+ else
390+ ΔA = zero (A)
391+ $ f_pullback! (ΔA, Ac, DV, dDVtrunc, ind)
392+ dA .= ΔA
393+ end
356394
357395 # restore state
358396 copy! (A, Ac)
359- copy! .(DV, DVc)
397+ if A === DV[1 ]
398+ copy! (DV[2 ], DVc[2 ])
399+ zero! (dDV[2 ])
400+ else
401+ copy! .(DV, DVc)
402+ zero! .(dDV)
403+ end
360404
361405 return ntuple (Returns (NoRData ()), 4 )
362406 end
0 commit comments