@@ -13,7 +13,6 @@ using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
1313using MatrixAlgebraKit: TruncatedAlgorithm
1414using LinearAlgebra
1515
16-
1716Mooncake. tangent_type (:: Type{<:MatrixAlgebraKit.AbstractAlgorithm} ) = Mooncake. NoTangent
1817
1918@is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (copy_input), Any, Any}
@@ -235,21 +234,19 @@ for f in (:eig, :eigh)
235234 DVtrunc_dDVtrunc = Mooncake. zero_fcodual ((DVtrunc... , ϵ))
236235
237236 # define pullback
238- local $ f_adjoint!
239- let ind = ind, dDVtrunc = last .(arrayify .(DVtrunc, Base. front (Mooncake. tangent (DVtrunc_dDVtrunc))))
240- function $f_adjoint! ((_, _, dϵ):: Tuple{NoRData, NoRData, Real} )
241- _warn_pullback_truncerror (dϵ)
237+ dDVtrunc = last .(arrayify .(DVtrunc, Base. front (Mooncake. tangent (DVtrunc_dDVtrunc))))
238+ function $f_adjoint! ((_, _, dϵ):: Tuple{NoRData, NoRData, Real} )
239+ _warn_pullback_truncerror (dϵ)
242240
243- # compute pullbacks
244- $ f_pullback! (dA, Ac, DVc, dDVtrunc, ind)
245- zero! .(dDVtrunc) # since this is allocated in this function this is probably not required
241+ # compute pullbacks
242+ $ f_pullback! (dA, Ac, DVc, dDVtrunc, ind)
243+ zero! .(dDVtrunc) # since this is allocated in this function this is probably not required
246244
247- # restore state
248- copy! (A, Ac)
249- copy! .(DV, DVc)
245+ # restore state
246+ copy! (A, Ac)
247+ copy! .(DV, DVc)
250248
251- return ntuple (Returns (NoRData ()), 4 )
252- end
249+ return ntuple (Returns (NoRData ()), 4 )
253250 end
254251
255252 return DVtrunc_dDVtrunc, $ f_adjoint!
@@ -291,14 +288,12 @@ for f in (:eig, :eigh)
291288 DVtrunc_dDVtrunc = Mooncake. zero_fcodual ((DVtrunc... , ϵ))
292289
293290 # define pullback
294- local $ f_adjoint!
295- let ind = ind, dDVtrunc = last .(arrayify .(DVtrunc, Base. front (Mooncake. tangent (DVtrunc_dDVtrunc))))
296- function $f_adjoint! ((_, _, dϵ):: Tuple{NoRData, NoRData, Real} )
297- _warn_pullback_truncerror (dϵ)
298- $ f_pullback! (dA, A, DV, dDVtrunc, ind)
299- zero! .(dDVtrunc) # since this is allocated in this function this is probably not required
300- return ntuple (Returns (NoRData ()), 3 )
301- end
291+ dDVtrunc = last .(arrayify .(DVtrunc, Base. front (Mooncake. tangent (DVtrunc_dDVtrunc))))
292+ function $f_adjoint! ((_, _, dϵ):: Tuple{NoRData, NoRData, Real} )
293+ _warn_pullback_truncerror (dϵ)
294+ $ f_pullback! (dA, A, DV, dDVtrunc, ind)
295+ zero! .(dDVtrunc) # since this is allocated in this function this is probably not required
296+ return ntuple (Returns (NoRData ()), 3 )
302297 end
303298
304299 return DVtrunc_dDVtrunc, $ f_adjoint!
@@ -353,19 +348,17 @@ for f in (:eig, :eigh)
353348 DVtrunc_dDVtrunc = Mooncake. zero_fcodual (DVtrunc)
354349
355350 # define pullback
356- local $ f_adjoint!
357- let ind = ind, dDVtrunc = last .(arrayify .(DVtrunc, Mooncake. tangent (DVtrunc_dDVtrunc)))
358- function $f_adjoint! (:: NoRData )
359- # compute pullbacks
360- $ f_pullback! (dA, Ac, DVc, dDVtrunc, ind)
361- zero! .(dDVtrunc) # since this is allocated in this function this is probably not required
362-
363- # restore state
364- copy! (A, Ac)
365- copy! .(DV, DVc)
366-
367- return ntuple (Returns (NoRData ()), 4 )
368- end
351+ dDVtrunc = last .(arrayify .(DVtrunc, Mooncake. tangent (DVtrunc_dDVtrunc)))
352+ function $f_adjoint! (:: NoRData )
353+ # compute pullbacks
354+ $ f_pullback! (dA, Ac, DVc, dDVtrunc, ind)
355+ zero! .(dDVtrunc) # since this is allocated in this function this is probably not required
356+
357+ # restore state
358+ copy! (A, Ac)
359+ copy! .(DV, DVc)
360+
361+ return ntuple (Returns (NoRData ()), 4 )
369362 end
370363
371364 return DVtrunc_dDVtrunc, $ f_adjoint!
@@ -405,13 +398,11 @@ for f in (:eig, :eigh)
405398 DVtrunc_dDVtrunc = Mooncake. zero_fcodual (DVtrunc)
406399
407400 # define pullback
408- local $ f_adjoint!
409- let ind = ind, dDVtrunc = last .(arrayify .(DVtrunc, Mooncake. tangent (DVtrunc_dDVtrunc)))
410- function $f_adjoint! (:: NoRData )
411- $ f_pullback! (dA, A, DV, dDVtrunc, ind)
412- zero! .(dDVtrunc) # since this is allocated in this function this is probably not required
413- return ntuple (Returns (NoRData ()), 3 )
414- end
401+ dDVtrunc = last .(arrayify .(DVtrunc, Mooncake. tangent (DVtrunc_dDVtrunc)))
402+ function $f_adjoint! (:: NoRData )
403+ $ f_pullback! (dA, A, DV, dDVtrunc, ind)
404+ zero! .(dDVtrunc) # since this is allocated in this function this is probably not required
405+ return ntuple (Returns (NoRData ()), 3 )
415406 end
416407
417408 return DVtrunc_dDVtrunc, $ f_adjoint!
@@ -594,22 +585,20 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
594585 USVᴴtrunc_dUSVᴴtrunc = Mooncake. zero_fcodual ((USVᴴtrunc... , ϵ))
595586
596587 # define pullback
597- local svd_trunc_adjoint
598- let ind = ind, dUSVᴴtrunc = last .(arrayify .(USVᴴtrunc, Base. front (Mooncake. tangent (USVᴴtrunc_dUSVᴴtrunc))))
599- function svd_trunc_adjoint ((_, _, _, dϵ):: Tuple{NoRData, NoRData, NoRData, Real} )
600- _warn_pullback_truncerror (dϵ)
588+ dUSVᴴtrunc = last .(arrayify .(USVᴴtrunc, Base. front (Mooncake. tangent (USVᴴtrunc_dUSVᴴtrunc))))
589+ function svd_trunc_adjoint ((_, _, _, dϵ):: Tuple{NoRData, NoRData, NoRData, Real} )
590+ _warn_pullback_truncerror (dϵ)
601591
602- # compute pullbacks
603- svd_pullback! (dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
604- zero! .(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
605- zero! .(dUSVᴴ)
592+ # compute pullbacks
593+ svd_pullback! (dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
594+ zero! .(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
595+ zero! .(dUSVᴴ)
606596
607- # restore state
608- copy! (A, Ac)
609- copy! .(USVᴴ, USVᴴc)
597+ # restore state
598+ copy! (A, Ac)
599+ copy! .(USVᴴ, USVᴴc)
610600
611- return ntuple (Returns (NoRData ()), 4 )
612- end
601+ return ntuple (Returns (NoRData ()), 4 )
613602 end
614603
615604 return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
@@ -655,14 +644,12 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
655644 USVᴴtrunc_dUSVᴴtrunc = Mooncake. zero_fcodual ((USVᴴtrunc... , ϵ))
656645
657646 # define pullback
658- local svd_trunc_adjoint
659- let ind = ind, dUSVᴴtrunc = last .(arrayify .(USVᴴtrunc, Base. front (Mooncake. tangent (USVᴴtrunc_dUSVᴴtrunc))))
660- function svd_trunc_adjoint ((_, _, _, dϵ):: Tuple{NoRData, NoRData, NoRData, Real} )
661- _warn_pullback_truncerror (dϵ)
662- svd_pullback! (dA, A, USVᴴ, dUSVᴴtrunc, ind)
663- zero! .(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
664- return ntuple (Returns (NoRData ()), 3 )
665- end
647+ dUSVᴴtrunc = last .(arrayify .(USVᴴtrunc, Base. front (Mooncake. tangent (USVᴴtrunc_dUSVᴴtrunc))))
648+ function svd_trunc_adjoint ((_, _, _, dϵ):: Tuple{NoRData, NoRData, NoRData, Real} )
649+ _warn_pullback_truncerror (dϵ)
650+ svd_pullback! (dA, A, USVᴴ, dUSVᴴtrunc, ind)
651+ zero! .(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
652+ return ntuple (Returns (NoRData ()), 3 )
666653 end
667654
668655 return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
@@ -727,20 +714,18 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U
727714 USVᴴtrunc_dUSVᴴtrunc = Mooncake. zero_fcodual (USVᴴtrunc)
728715
729716 # define pullback
730- local svd_trunc_adjoint
731- let ind = ind, dUSVᴴtrunc = last .(arrayify .(USVᴴtrunc, Mooncake. tangent (USVᴴtrunc_dUSVᴴtrunc)))
732- function svd_trunc_adjoint (:: NoRData )
733- # compute pullbacks
734- svd_pullback! (dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
735- zero! .(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
736- zero! .(dUSVᴴ)
737-
738- # restore state
739- copy! (A, Ac)
740- copy! .(USVᴴ, USVᴴc)
741-
742- return ntuple (Returns (NoRData ()), 4 )
743- end
717+ dUSVᴴtrunc = last .(arrayify .(USVᴴtrunc, Mooncake. tangent (USVᴴtrunc_dUSVᴴtrunc)))
718+ function svd_trunc_adjoint (:: NoRData )
719+ # compute pullbacks
720+ svd_pullback! (dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
721+ zero! .(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
722+ zero! .(dUSVᴴ)
723+
724+ # restore state
725+ copy! (A, Ac)
726+ copy! .(USVᴴ, USVᴴc)
727+
728+ return ntuple (Returns (NoRData ()), 4 )
744729 end
745730
746731 return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
@@ -784,13 +769,11 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al
784769 USVᴴtrunc_dUSVᴴtrunc = Mooncake. zero_fcodual (USVᴴtrunc)
785770
786771 # define pullback
787- local svd_trunc_adjoint
788- let ind = ind, dUSVᴴtrunc = last .(arrayify .(USVᴴtrunc, Mooncake. tangent (USVᴴtrunc_dUSVᴴtrunc)))
789- function svd_trunc_adjoint (:: NoRData )
790- svd_pullback! (dA, A, USVᴴ, dUSVᴴtrunc, ind)
791- zero! .(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
792- return ntuple (Returns (NoRData ()), 3 )
793- end
772+ dUSVᴴtrunc = last .(arrayify .(USVᴴtrunc, Mooncake. tangent (USVᴴtrunc_dUSVᴴtrunc)))
773+ function svd_trunc_adjoint (:: NoRData )
774+ svd_pullback! (dA, A, USVᴴ, dUSVᴴtrunc, ind)
775+ zero! .(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
776+ return ntuple (Returns (NoRData ()), 3 )
794777 end
795778
796779 return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
0 commit comments