@@ -19,14 +19,17 @@ macro is_rev_primitive(sig)
1919 return esc (:(Mooncake. @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode $ sig))
2020end
2121
22+ # return n copies of NoRData()
23+ @inline n_NoRData (n) = ntuple (Returns (NoRData ()), n)
24+
2225@is_rev_primitive Tuple{typeof (copy_input), Any, Any}
2326function Mooncake. rrule!! (:: CoDual{typeof(copy_input)} , f_df:: CoDual , A_dA:: CoDual )
2427 Ac = copy_input (Mooncake. primal (f_df), Mooncake. primal (A_dA))
2528 Ac_dAc = Mooncake. zero_fcodual (Ac)
2629 dAc = Mooncake. tangent (Ac_dAc)
2730 function copy_input_pb (:: NoRData )
2831 Mooncake. increment!! (Mooncake. tangent (A_dA), dAc)
29- return NoRData (), NoRData (), NoRData ( )
32+ return n_NoRData ( 3 )
3033 end
3134 return Ac_dAc, copy_input_pb
3235end
@@ -63,7 +66,7 @@ for (f!, f, pb, adj) in (
6366 copy! (arg2, arg2c)
6467 zero! (darg1)
6568 zero! (darg2)
66- return NoRData (), NoRData (), NoRData (), NoRData ( )
69+ return n_NoRData ( 4 )
6770 end
6871 return args_dargs, $ adj
6972 end
@@ -84,7 +87,7 @@ for (f!, f, pb, adj) in (
8487 $ pb (dA, A, (arg1, arg2), (darg1, darg2))
8588 zero! (darg1)
8689 zero! (darg2)
87- return NoRData (), NoRData (), NoRData ( )
90+ return n_NoRData ( 3 )
8891 end
8992 return output_codual, $ adj
9093 end
@@ -108,7 +111,7 @@ for (f!, f, pb, adj) in (
108111 $ pb (dA, A, arg, darg)
109112 copy! (arg, argc)
110113 zero! (darg)
111- return NoRData (), NoRData (), NoRData (), NoRData ( )
114+ return n_NoRData ( 4 )
112115 end
113116 return arg_darg, $ adj
114117 end
@@ -121,7 +124,7 @@ for (f!, f, pb, adj) in (
121124 arg, darg = arrayify (output_codual)
122125 $ pb (dA, A, arg, darg)
123126 zero! (darg)
124- return NoRData (), NoRData (), NoRData ( )
127+ return n_NoRData ( 3 )
125128 end
126129 return output_codual, $ adj
127130 end
@@ -147,7 +150,7 @@ for (f!, f, f_full, pb, adj) in (
147150 $ pb (dA, A, DV, dD)
148151 copy! (D, Dc)
149152 zero! (dD)
150- return NoRData (), NoRData (), NoRData (), NoRData ( )
153+ return n_NoRData ( 4 )
151154 end
152155 return D_dD, $ adj
153156 end
@@ -164,7 +167,7 @@ for (f!, f, f_full, pb, adj) in (
164167 D, dD = arrayify (output_codual)
165168 $ pb (dA, A, DV, dD)
166169 zero! (dD)
167- return NoRData (), NoRData (), NoRData ( )
170+ return n_NoRData ( 3 )
168171 end
169172 return output_codual, $ adj
170173 end
@@ -214,7 +217,7 @@ for f in (:eig, :eigh)
214217 copy! (DV[2 ], DVc[2 ])
215218 zero! (dD′)
216219 zero! (dV′)
217- return NoRData (), NoRData (), NoRData (), NoRData ( )
220+ return n_NoRData ( 4 )
218221 end
219222 return output_codual, $ f_adjoint!
220223 end
@@ -250,7 +253,7 @@ for f in (:eig, :eigh)
250253 copy! (A, Ac)
251254 copy! .(DV, DVc)
252255
253- return ntuple ( Returns ( NoRData ()), 4 )
256+ return n_NoRData ( 4 )
254257 end
255258
256259 return DVtrunc_dDVtrunc, $ f_adjoint!
@@ -274,7 +277,7 @@ for f in (:eig, :eigh)
274277 $ f_trunc_pullback! (dA, A, (D, V), (dD, dV))
275278 zero! (dD)
276279 zero! (dV)
277- return NoRData (), NoRData (), NoRData ( )
280+ return n_NoRData ( 3 )
278281 end
279282 return output_codual, $ f_adjoint!
280283 end
@@ -297,7 +300,7 @@ for f in (:eig, :eigh)
297300 _warn_pullback_truncerror (dϵ)
298301 $ f_pullback! (dA, A, DV, dDVtrunc, ind)
299302 zero! .(dDVtrunc) # since this is allocated in this function this is probably not required
300- return ntuple ( Returns ( NoRData ()), 3 )
303+ return n_NoRData ( 3 )
301304 end
302305
303306 return DVtrunc_dDVtrunc, $ f_adjoint!
@@ -329,7 +332,7 @@ for f in (:eig, :eigh)
329332 copy! (DV[2 ], DVc[2 ])
330333 zero! (dD′)
331334 zero! (dV′)
332- return NoRData (), NoRData (), NoRData (), NoRData ( )
335+ return n_NoRData ( 4 )
333336 end
334337 return output_codual, $ f_adjoint!
335338 end
@@ -362,7 +365,7 @@ for f in (:eig, :eigh)
362365 copy! (A, Ac)
363366 copy! .(DV, DVc)
364367
365- return ntuple ( Returns ( NoRData ()), 4 )
368+ return n_NoRData ( 4 )
366369 end
367370
368371 return DVtrunc_dDVtrunc, $ f_adjoint!
@@ -385,7 +388,7 @@ for f in (:eig, :eigh)
385388 $ f_trunc_pullback! (dA, A, (D, V), (dD, dV))
386389 zero! (dD)
387390 zero! (dV)
388- return NoRData (), NoRData (), NoRData ( )
391+ return n_NoRData ( 3 )
389392 end
390393 return output_codual, $ f_adjoint!
391394 end
@@ -406,7 +409,7 @@ for f in (:eig, :eigh)
406409 function $f_adjoint! (:: NoRData )
407410 $ f_pullback! (dA, A, DV, dDVtrunc, ind)
408411 zero! .(dDVtrunc) # since this is allocated in this function this is probably not required
409- return ntuple ( Returns ( NoRData ()), 3 )
412+ return n_NoRData ( 3 )
410413 end
411414
412415 return DVtrunc_dDVtrunc, $ f_adjoint!
@@ -450,7 +453,7 @@ for (f!, f) in (
450453 zero! (dU)
451454 zero! (dS)
452455 zero! (dVᴴ)
453- return NoRData (), NoRData (), NoRData (), NoRData ( )
456+ return n_NoRData ( 4 )
454457 end
455458 return CoDual (output, dUSVᴴ), svd_adjoint
456459 end
@@ -484,7 +487,7 @@ for (f!, f) in (
484487 zero! (dU)
485488 zero! (dS)
486489 zero! (dVᴴ)
487- return NoRData (), NoRData (), NoRData ( )
490+ return n_NoRData ( 3 )
488491 end
489492 return USVᴴ_codual, svd_adjoint
490493 end
@@ -503,7 +506,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
503506 svd_vals_pullback! (dA, A, USVᴴ, dS)
504507 zero! (dS)
505508 copy! (S, Sc)
506- return NoRData (), NoRData (), NoRData (), NoRData ( )
509+ return n_NoRData ( 4 )
507510 end
508511 return S_dS, svd_vals_adjoint
509512end
@@ -523,7 +526,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
523526 S, dS = arrayify (S_codual)
524527 svd_vals_pullback! (dA, A, USVᴴ, dS)
525528 zero! (dS)
526- return NoRData (), NoRData (), NoRData ( )
529+ return n_NoRData ( 3 )
527530 end
528531 return S_codual, svd_vals_adjoint
529532end
@@ -564,7 +567,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
564567 zero! (dU′)
565568 zero! (dS′)
566569 zero! (dVᴴ′)
567- return NoRData (), NoRData (), NoRData ( )
570+ return n_NoRData ( 3 )
568571 end
569572 return output_codual, svd_trunc_adjoint
570573end
@@ -602,7 +605,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
602605 copy! (A, Ac)
603606 copy! .(USVᴴ, USVᴴc)
604607
605- return ntuple ( Returns ( NoRData ()), 4 )
608+ return n_NoRData ( 4 )
606609 end
607610
608611 return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
@@ -630,7 +633,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
630633 zero! (dU)
631634 zero! (dS)
632635 zero! (dVᴴ)
633- return NoRData (), NoRData (), NoRData ( )
636+ return n_NoRData ( 3 )
634637 end
635638 return output_codual, svd_trunc_adjoint
636639end
@@ -653,7 +656,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
653656 _warn_pullback_truncerror (dϵ)
654657 svd_pullback! (dA, A, USVᴴ, dUSVᴴtrunc, ind)
655658 zero! .(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
656- return ntuple ( Returns ( NoRData ()), 3 )
659+ return n_NoRData ( 3 )
657660 end
658661
659662 return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
@@ -694,7 +697,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U
694697 zero! (dU′)
695698 zero! (dS′)
696699 zero! (dVᴴ′)
697- return NoRData (), NoRData (), NoRData ( )
700+ return n_NoRData ( 3 )
698701 end
699702 return output_codual, svd_trunc_adjoint
700703end
@@ -729,7 +732,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U
729732 copy! (A, Ac)
730733 copy! .(USVᴴ, USVᴴc)
731734
732- return ntuple ( Returns ( NoRData ()), 4 )
735+ return n_NoRData ( 4 )
733736 end
734737
735738 return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
@@ -756,7 +759,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al
756759 zero! (dU)
757760 zero! (dS)
758761 zero! (dVᴴ)
759- return NoRData (), NoRData (), NoRData ( )
762+ return n_NoRData ( 3 )
760763 end
761764 return output_codual, svd_trunc_adjoint
762765end
@@ -777,7 +780,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al
777780 function svd_trunc_adjoint (:: NoRData )
778781 svd_pullback! (dA, A, USVᴴ, dUSVᴴtrunc, ind)
779782 zero! .(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
780- return ntuple ( Returns ( NoRData ()), 3 )
783+ return n_NoRData ( 3 )
781784 end
782785
783786 return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
0 commit comments