@@ -168,118 +168,252 @@ for (f!, f, f_full, pb, adj) in (
168168 end
169169end
170170
171- for (f!, f, f_ne!, f_ne, pb, adj) in (
172- (:eig_trunc! , :eig_trunc , :eig_trunc_no_error! , :eig_trunc_no_error , :eig_trunc_pullback! , :eig_trunc_adjoint ),
173- (:eigh_trunc! , :eigh_trunc , :eigh_trunc_no_error! , :eigh_trunc_no_error , :eigh_trunc_pullback! , :eigh_trunc_adjoint ),
174- )
171+ for f in (:eig , :eigh )
172+ f_trunc = Symbol (f, :_trunc )
173+ f_trunc! = Symbol (f_trunc, :! )
174+ f_full = Symbol (f, :_full )
175+ f_full! = Symbol (f_full, :! )
176+ f_pullback! = Symbol (f, :_pullback! )
177+ f_trunc_pullback! = Symbol (f_trunc, :_pullback! )
178+ f_adjoint! = Symbol (f, :_adjoint! )
179+ f_trunc_no_error = Symbol (f_trunc, :_no_error )
180+ f_trunc_no_error! = Symbol (f_trunc_no_error, :! )
181+
175182 @eval begin
176- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f !), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
177- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f ), Any, MatrixAlgebraKit. AbstractAlgorithm}
178- function Mooncake. rrule!! (:: CoDual{typeof($f !)} , A_dA:: CoDual , DV_dDV:: CoDual , alg_dalg:: CoDual )
183+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f_trunc !), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
184+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f_trunc ), Any, MatrixAlgebraKit. AbstractAlgorithm}
185+ function Mooncake. rrule!! (:: CoDual{typeof($f_trunc !)} , A_dA:: CoDual , DV_dDV:: CoDual , alg_dalg:: CoDual )
179186 # compute primal
180187 A, dA = arrayify (A_dA)
181188 DV = Mooncake. primal (DV_dDV)
182189 dDV = Mooncake. tangent (DV_dDV)
183190 Ac = copy (A)
184191 DVc = copy .(DV)
185192 alg = Mooncake. primal (alg_dalg)
186- output = $ f ! (A, DV, alg)
193+ output = $ f_trunc ! (A, DV, alg)
187194 # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
188195 # of ComplexF32) into the correct **forwards** data type (since we are now in the forward
189196 # pass). For many types this is done automatically when the forward step returns, but
190197 # not for nested structs with various fields (like Diagonal{Complex})
191- output_codual = CoDual (output, Mooncake. fdata (Mooncake . zero_tangent ( output)) )
192- function $adj (dy:: Tuple{NoRData, NoRData, T} ) where {T <: Real }
198+ output_codual = Mooncake. zero_fcodual ( output)
199+ function $f_adjoint! (dy:: Tuple{NoRData, NoRData, <: Real} )
193200 copy! (A, Ac)
194201 Dtrunc, Vtrunc, ϵ = Mooncake. primal (output_codual)
195202 dDtrunc_, dVtrunc_, dϵ = Mooncake. tangent (output_codual)
196203 abs (dy[3 ]) > MatrixAlgebraKit. defaulttol (dy[3 ]) && @warn " Pullback for $f does not yet support non-zero tangent for the truncation error"
197204 D′, dD′ = arrayify (Dtrunc, dDtrunc_)
198205 V′, dV′ = arrayify (Vtrunc, dVtrunc_)
199- $ pb (dA, A, (D′, V′), (dD′, dV′))
206+ $ f_trunc_pullback! (dA, A, (D′, V′), (dD′, dV′))
200207 copy! (DV[1 ], DVc[1 ])
201208 copy! (DV[2 ], DVc[2 ])
202209 zero! (dD′)
203210 zero! (dV′)
204211 return NoRData (), NoRData (), NoRData (), NoRData ()
205212 end
206- return output_codual, $ adj
213+ return output_codual, $ f_adjoint!
207214 end
208- function Mooncake. rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual )
215+ function Mooncake. rrule!! (:: CoDual{typeof($f_trunc!)} , A_dA:: CoDual , DV_dDV:: CoDual , alg_dalg:: CoDual{<:TruncatedAlgorithm} )
216+ # unpack variables
217+ A, dA = arrayify (A_dA)
218+ DV_dDV_arr = arrayify .(Mooncake. primal (DV_dDV), Mooncake. tangent (DV_dDV))
219+ DV, dDV = first .(DV_dDV_arr), last .(DV_dDV_arr)
220+ alg = Mooncake. primal (alg_dalg)
221+
222+ # store state prior to primal call
223+ Ac = copy (A)
224+ DVc = copy .(DV)
225+
226+ # compute primal - capture full DV and ind
227+ DV = $ f_full! (A, DV, alg. alg)
228+ DVtrunc, ind = MatrixAlgebraKit. truncate ($ f_trunc!, DV, alg. trunc)
229+ ϵ = MatrixAlgebraKit. truncation_error (diagview (DV[1 ]), ind)
230+
231+ # pack output - note that we allocate new dDVtrunc because these aren't overwritten in the input
232+ DVtrunc_dDVtrunc = Mooncake. zero_fcodual ((DVtrunc... , ϵ))
233+
234+ # define pullback
235+ local $ f_adjoint!
236+ let ind = ind, dDVtrunc = last .(arrayify .(DVtrunc, Base. front (Mooncake. tangent (DVtrunc_dDVtrunc))))
237+ function $f_adjoint! ((_, _, dϵ):: Tuple{NoRData, NoRData, Real} )
238+ abs (dϵ) ≤ MatrixAlgebraKit. defaulttol (dϵ) ||
239+ @warn " Pullback for `$f! ` ignores non-zero tangents for truncation error"
240+
241+ # compute pullbacks
242+ $ f_pullback! (dA, Ac, DVc, dDVtrunc, ind)
243+ zero! .(dDVtrunc) # since this is allocated in this function this is probably not required
244+
245+ # restore state
246+ copy! (A, Ac)
247+ copy! .(DV, DVc)
248+
249+ return ntuple (Returns (NoRData ()), 4 )
250+ end
251+ end
252+
253+ return DVtrunc_dDVtrunc, $ f_adjoint!
254+ end
255+ function Mooncake. rrule!! (:: CoDual{typeof($f_trunc)} , A_dA:: CoDual , alg_dalg:: CoDual )
209256 # compute primal
210257 A, dA = arrayify (A_dA)
211258 alg = Mooncake. primal (alg_dalg)
212- output = $ f (A, alg)
259+ output = $ f_trunc (A, alg)
213260 # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
214261 # of ComplexF32) into the correct **forwards** data type (since we are now in the forward
215262 # pass). For many types this is done automatically when the forward step returns, but
216263 # not for nested structs with various fields (like Diagonal{Complex})
217264 output_codual = CoDual (output, Mooncake. fdata (Mooncake. zero_tangent (output)))
218- function $adj (dy:: Tuple{NoRData, NoRData, T} ) where {T <: Real }
265+ function $f_adjoint! (dy:: Tuple{NoRData, NoRData, T} ) where {T <: Real }
219266 Dtrunc, Vtrunc, ϵ = Mooncake. primal (output_codual)
220267 dDtrunc_, dVtrunc_, dϵ = Mooncake. tangent (output_codual)
221268 abs (dy[3 ]) > MatrixAlgebraKit. defaulttol (dy[3 ]) && @warn " Pullback for $f does not yet support non-zero tangent for the truncation error"
222269 D, dD = arrayify (Dtrunc, dDtrunc_)
223270 V, dV = arrayify (Vtrunc, dVtrunc_)
224- $ pb (dA, A, (D, V), (dD, dV))
271+ $ f_trunc_pullback! (dA, A, (D, V), (dD, dV))
225272 zero! (dD)
226273 zero! (dV)
227274 return NoRData (), NoRData (), NoRData ()
228275 end
229- return output_codual, $ adj
276+ return output_codual, $ f_adjoint!
230277 end
231- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f_ne!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
232- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f_ne), Any, MatrixAlgebraKit. AbstractAlgorithm}
233- function Mooncake. rrule!! (:: CoDual{typeof($f_ne!)} , A_dA:: CoDual , DV_dDV:: CoDual , alg_dalg:: CoDual )
278+ function Mooncake. rrule!! (:: CoDual{typeof($f_trunc)} , A_dA:: CoDual , alg_dalg:: CoDual{<:TruncatedAlgorithm} )
279+ # unpack variables
280+ A, dA = arrayify (A_dA)
281+ alg = Mooncake. primal (alg_dalg)
282+
283+ # compute primal - capture full DV and ind
284+ DV = $ f_full (A, alg. alg)
285+ DVtrunc, ind = MatrixAlgebraKit. truncate ($ f_trunc!, DV, alg. trunc)
286+ ϵ = MatrixAlgebraKit. truncation_error (diagview (DV[1 ]), ind)
287+
288+ # pack output
289+ DVtrunc_dDVtrunc = Mooncake. zero_fcodual ((DVtrunc... , ϵ))
290+
291+ # define pullback
292+ local $ f_adjoint!
293+ let ind = ind, dDVtrunc = last .(arrayify .(DVtrunc, Base. front (Mooncake. tangent (DVtrunc_dDVtrunc))))
294+ function $f_adjoint! ((_, _, dϵ):: Tuple{NoRData, NoRData, Real} )
295+ abs (dϵ) ≤ MatrixAlgebraKit. defaulttol (dϵ) ||
296+ @warn " Pullback for `$f_trunc ` ignores non-zero tangents for truncation error"
297+ $ f_pullback! (dA, A, DV, dDVtrunc, ind)
298+ zero! .(dDVtrunc) # since this is allocated in this function this is probably not required
299+ return ntuple (Returns (NoRData ()), 3 )
300+ end
301+ end
302+
303+ return DVtrunc_dDVtrunc, $ f_adjoint!
304+ end
305+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f_trunc_no_error!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
306+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f_trunc_no_error), Any, MatrixAlgebraKit. AbstractAlgorithm}
307+ function Mooncake. rrule!! (:: CoDual{typeof($f_trunc_no_error!)} , A_dA:: CoDual , DV_dDV:: CoDual , alg_dalg:: CoDual )
234308 # compute primal
235309 A, dA = arrayify (A_dA)
236310 alg = Mooncake. primal (alg_dalg)
237311 DV = Mooncake. primal (DV_dDV)
238312 dDV = Mooncake. tangent (DV_dDV)
239313 Ac = copy (A)
240314 DVc = copy .(DV)
241- output = $ f_ne ! (A, DV, alg)
315+ output = $ f_trunc_no_error ! (A, DV, alg)
242316 # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
243317 # of ComplexF32) into the correct **forwards** data type (since we are now in the forward
244318 # pass). For many types this is done automatically when the forward step returns, but
245319 # not for nested structs with various fields (like Diagonal{Complex})
246320 output_codual = CoDual (output, Mooncake. fdata (Mooncake. zero_tangent (output)))
247- function $adj (:: NoRData )
321+ function $f_adjoint! (:: NoRData )
248322 copy! (A, Ac)
249323 Dtrunc, Vtrunc = Mooncake. primal (output_codual)
250324 dDtrunc_, dVtrunc_ = Mooncake. tangent (output_codual)
251325 D′, dD′ = arrayify (Dtrunc, dDtrunc_)
252326 V′, dV′ = arrayify (Vtrunc, dVtrunc_)
253- $ pb (dA, A, (D′, V′), (dD′, dV′))
327+ $ f_pullback! (dA, A, (D′, V′), (dD′, dV′))
254328 copy! (DV[1 ], DVc[1 ])
255329 copy! (DV[2 ], DVc[2 ])
256330 zero! (dD′)
257331 zero! (dV′)
258332 return NoRData (), NoRData (), NoRData (), NoRData ()
259333 end
260- return output_codual, $ adj
334+ return output_codual, $ f_adjoint!
335+ end
336+ function Mooncake. rrule!! (:: CoDual{typeof($f_trunc_no_error!)} , A_dA:: CoDual , DV_dDV:: CoDual , alg_dalg:: CoDual{<:TruncatedAlgorithm} )
337+ # unpack variables
338+ A, dA = arrayify (A_dA)
339+ DV_dDV_arr = arrayify .(Mooncake. primal (DV_dDV), Mooncake. tangent (DV_dDV))
340+ DV, dDV = first .(DV_dDV_arr), last .(DV_dDV_arr)
341+ alg = Mooncake. primal (alg_dalg)
342+
343+ # store state prior to primal call
344+ Ac = copy (A)
345+ DVc = copy .(DV)
346+
347+ # compute primal - capture full DV and ind
348+ DV = $ f_full! (A, DV, alg. alg)
349+ DVtrunc, ind = MatrixAlgebraKit. truncate ($ f_trunc!, DV, alg. trunc)
350+
351+ # pack output - note that we allocate new dDVtrunc because these aren't overwritten in the input
352+ DVtrunc_dDVtrunc = Mooncake. zero_fcodual (DVtrunc)
353+
354+ # define pullback
355+ local $ f_adjoint!
356+ let ind = ind, dDVtrunc = last .(arrayify .(DVtrunc, Mooncake. tangent (DVtrunc_dDVtrunc)))
357+ function $f_adjoint! (:: NoRData )
358+ # compute pullbacks
359+ $ f_pullback! (dA, Ac, DVc, dDVtrunc, ind)
360+ zero! .(dDVtrunc) # since this is allocated in this function this is probably not required
361+
362+ # restore state
363+ copy! (A, Ac)
364+ copy! .(DV, DVc)
365+
366+ return ntuple (Returns (NoRData ()), 4 )
367+ end
368+ end
369+
370+ return DVtrunc_dDVtrunc, $ f_adjoint!
261371 end
262- function Mooncake. rrule!! (:: CoDual{typeof($f_ne )} , A_dA:: CoDual , alg_dalg:: CoDual )
372+ function Mooncake. rrule!! (:: CoDual{typeof($f_trunc_no_error )} , A_dA:: CoDual , alg_dalg:: CoDual )
263373 # compute primal
264374 A, dA = arrayify (A_dA)
265375 alg = Mooncake. primal (alg_dalg)
266- output = $ f_ne (A, alg)
376+ output = $ f_trunc_no_error (A, alg)
267377 # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
268378 # of ComplexF32) into the correct **forwards** data type (since we are now in the forward
269379 # pass). For many types this is done automatically when the forward step returns, but
270380 # not for nested structs with various fields (like Diagonal{Complex})
271381 output_codual = CoDual (output, Mooncake. fdata (Mooncake. zero_tangent (output)))
272- function $adj (:: NoRData )
382+ function $f_adjoint! (:: NoRData )
273383 Dtrunc, Vtrunc = Mooncake. primal (output_codual)
274384 dDtrunc_, dVtrunc_ = Mooncake. tangent (output_codual)
275385 D, dD = arrayify (Dtrunc, dDtrunc_)
276386 V, dV = arrayify (Vtrunc, dVtrunc_)
277- $ pb (dA, A, (D, V), (dD, dV))
387+ $ f_trunc_pullback! (dA, A, (D, V), (dD, dV))
278388 zero! (dD)
279389 zero! (dV)
280390 return NoRData (), NoRData (), NoRData ()
281391 end
282- return output_codual, $ adj
392+ return output_codual, $ f_adjoint!
393+ end
394+ function Mooncake. rrule!! (:: CoDual{typeof($f_trunc_no_error)} , A_dA:: CoDual , alg_dalg:: CoDual{<:TruncatedAlgorithm} )
395+ # unpack variables
396+ A, dA = arrayify (A_dA)
397+ alg = Mooncake. primal (alg_dalg)
398+
399+ # compute primal - capture full DV and ind
400+ DV = $ f_full (A, alg. alg)
401+ DVtrunc, ind = MatrixAlgebraKit. truncate ($ f_trunc!, DV, alg. trunc)
402+
403+ # pack output
404+ DVtrunc_dDVtrunc = Mooncake. zero_fcodual (DVtrunc)
405+
406+ # define pullback
407+ local $ f_adjoint!
408+ let ind = ind, dDVtrunc = last .(arrayify .(DVtrunc, Mooncake. tangent (DVtrunc_dDVtrunc)))
409+ function $f_adjoint! (:: NoRData )
410+ $ f_pullback! (dA, A, DV, dDVtrunc, ind)
411+ zero! .(dDVtrunc) # since this is allocated in this function this is probably not required
412+ return ntuple (Returns (NoRData ()), 3 )
413+ end
414+ end
415+
416+ return DVtrunc_dDVtrunc, $ f_adjoint!
283417 end
284418 end
285419end
0 commit comments