Skip to content

Commit 7f6f16d

Browse files
committed
n_NoRData helper
1 parent 37fb719 commit 7f6f16d

1 file changed

Lines changed: 30 additions & 27 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,17 @@ macro is_rev_primitive(sig)
1919
return esc(:(Mooncake.@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode $sig))
2020
end
2121

22+
# return n copies of NoRData()
23+
@inline n_NoRData(n) = ntuple(Returns(NoRData()), n)
24+
2225
@is_rev_primitive Tuple{typeof(copy_input), Any, Any}
2326
function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual)
2427
Ac = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA))
2528
Ac_dAc = Mooncake.zero_fcodual(Ac)
2629
dAc = Mooncake.tangent(Ac_dAc)
2730
function copy_input_pb(::NoRData)
2831
Mooncake.increment!!(Mooncake.tangent(A_dA), dAc)
29-
return NoRData(), NoRData(), NoRData()
32+
return n_NoRData(3)
3033
end
3134
return Ac_dAc, copy_input_pb
3235
end
@@ -63,7 +66,7 @@ for (f!, f, pb, adj) in (
6366
copy!(arg2, arg2c)
6467
zero!(darg1)
6568
zero!(darg2)
66-
return NoRData(), NoRData(), NoRData(), NoRData()
69+
return n_NoRData(4)
6770
end
6871
return args_dargs, $adj
6972
end
@@ -84,7 +87,7 @@ for (f!, f, pb, adj) in (
8487
$pb(dA, A, (arg1, arg2), (darg1, darg2))
8588
zero!(darg1)
8689
zero!(darg2)
87-
return NoRData(), NoRData(), NoRData()
90+
return n_NoRData(3)
8891
end
8992
return output_codual, $adj
9093
end
@@ -108,7 +111,7 @@ for (f!, f, pb, adj) in (
108111
$pb(dA, A, arg, darg)
109112
copy!(arg, argc)
110113
zero!(darg)
111-
return NoRData(), NoRData(), NoRData(), NoRData()
114+
return n_NoRData(4)
112115
end
113116
return arg_darg, $adj
114117
end
@@ -121,7 +124,7 @@ for (f!, f, pb, adj) in (
121124
arg, darg = arrayify(output_codual)
122125
$pb(dA, A, arg, darg)
123126
zero!(darg)
124-
return NoRData(), NoRData(), NoRData()
127+
return n_NoRData(3)
125128
end
126129
return output_codual, $adj
127130
end
@@ -147,7 +150,7 @@ for (f!, f, f_full, pb, adj) in (
147150
$pb(dA, A, DV, dD)
148151
copy!(D, Dc)
149152
zero!(dD)
150-
return NoRData(), NoRData(), NoRData(), NoRData()
153+
return n_NoRData(4)
151154
end
152155
return D_dD, $adj
153156
end
@@ -164,7 +167,7 @@ for (f!, f, f_full, pb, adj) in (
164167
D, dD = arrayify(output_codual)
165168
$pb(dA, A, DV, dD)
166169
zero!(dD)
167-
return NoRData(), NoRData(), NoRData()
170+
return n_NoRData(3)
168171
end
169172
return output_codual, $adj
170173
end
@@ -214,7 +217,7 @@ for f in (:eig, :eigh)
214217
copy!(DV[2], DVc[2])
215218
zero!(dD′)
216219
zero!(dV′)
217-
return NoRData(), NoRData(), NoRData(), NoRData()
220+
return n_NoRData(4)
218221
end
219222
return output_codual, $f_adjoint!
220223
end
@@ -250,7 +253,7 @@ for f in (:eig, :eigh)
250253
copy!(A, Ac)
251254
copy!.(DV, DVc)
252255

253-
return ntuple(Returns(NoRData()), 4)
256+
return n_NoRData(4)
254257
end
255258

256259
return DVtrunc_dDVtrunc, $f_adjoint!
@@ -274,7 +277,7 @@ for f in (:eig, :eigh)
274277
$f_trunc_pullback!(dA, A, (D, V), (dD, dV))
275278
zero!(dD)
276279
zero!(dV)
277-
return NoRData(), NoRData(), NoRData()
280+
return n_NoRData(3)
278281
end
279282
return output_codual, $f_adjoint!
280283
end
@@ -297,7 +300,7 @@ for f in (:eig, :eigh)
297300
_warn_pullback_truncerror(dϵ)
298301
$f_pullback!(dA, A, DV, dDVtrunc, ind)
299302
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
300-
return ntuple(Returns(NoRData()), 3)
303+
return n_NoRData(3)
301304
end
302305

303306
return DVtrunc_dDVtrunc, $f_adjoint!
@@ -329,7 +332,7 @@ for f in (:eig, :eigh)
329332
copy!(DV[2], DVc[2])
330333
zero!(dD′)
331334
zero!(dV′)
332-
return NoRData(), NoRData(), NoRData(), NoRData()
335+
return n_NoRData(4)
333336
end
334337
return output_codual, $f_adjoint!
335338
end
@@ -362,7 +365,7 @@ for f in (:eig, :eigh)
362365
copy!(A, Ac)
363366
copy!.(DV, DVc)
364367

365-
return ntuple(Returns(NoRData()), 4)
368+
return n_NoRData(4)
366369
end
367370

368371
return DVtrunc_dDVtrunc, $f_adjoint!
@@ -385,7 +388,7 @@ for f in (:eig, :eigh)
385388
$f_trunc_pullback!(dA, A, (D, V), (dD, dV))
386389
zero!(dD)
387390
zero!(dV)
388-
return NoRData(), NoRData(), NoRData()
391+
return n_NoRData(3)
389392
end
390393
return output_codual, $f_adjoint!
391394
end
@@ -406,7 +409,7 @@ for f in (:eig, :eigh)
406409
function $f_adjoint!(::NoRData)
407410
$f_pullback!(dA, A, DV, dDVtrunc, ind)
408411
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
409-
return ntuple(Returns(NoRData()), 3)
412+
return n_NoRData(3)
410413
end
411414

412415
return DVtrunc_dDVtrunc, $f_adjoint!
@@ -450,7 +453,7 @@ for (f!, f) in (
450453
zero!(dU)
451454
zero!(dS)
452455
zero!(dVᴴ)
453-
return NoRData(), NoRData(), NoRData(), NoRData()
456+
return n_NoRData(4)
454457
end
455458
return CoDual(output, dUSVᴴ), svd_adjoint
456459
end
@@ -484,7 +487,7 @@ for (f!, f) in (
484487
zero!(dU)
485488
zero!(dS)
486489
zero!(dVᴴ)
487-
return NoRData(), NoRData(), NoRData()
490+
return n_NoRData(3)
488491
end
489492
return USVᴴ_codual, svd_adjoint
490493
end
@@ -503,7 +506,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
503506
svd_vals_pullback!(dA, A, USVᴴ, dS)
504507
zero!(dS)
505508
copy!(S, Sc)
506-
return NoRData(), NoRData(), NoRData(), NoRData()
509+
return n_NoRData(4)
507510
end
508511
return S_dS, svd_vals_adjoint
509512
end
@@ -523,7 +526,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
523526
S, dS = arrayify(S_codual)
524527
svd_vals_pullback!(dA, A, USVᴴ, dS)
525528
zero!(dS)
526-
return NoRData(), NoRData(), NoRData()
529+
return n_NoRData(3)
527530
end
528531
return S_codual, svd_vals_adjoint
529532
end
@@ -564,7 +567,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
564567
zero!(dU′)
565568
zero!(dS′)
566569
zero!(dVᴴ′)
567-
return NoRData(), NoRData(), NoRData()
570+
return n_NoRData(3)
568571
end
569572
return output_codual, svd_trunc_adjoint
570573
end
@@ -602,7 +605,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
602605
copy!(A, Ac)
603606
copy!.(USVᴴ, USVᴴc)
604607

605-
return ntuple(Returns(NoRData()), 4)
608+
return n_NoRData(4)
606609
end
607610

608611
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
@@ -630,7 +633,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
630633
zero!(dU)
631634
zero!(dS)
632635
zero!(dVᴴ)
633-
return NoRData(), NoRData(), NoRData()
636+
return n_NoRData(3)
634637
end
635638
return output_codual, svd_trunc_adjoint
636639
end
@@ -653,7 +656,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
653656
_warn_pullback_truncerror(dϵ)
654657
svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind)
655658
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
656-
return ntuple(Returns(NoRData()), 3)
659+
return n_NoRData(3)
657660
end
658661

659662
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
@@ -694,7 +697,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U
694697
zero!(dU′)
695698
zero!(dS′)
696699
zero!(dVᴴ′)
697-
return NoRData(), NoRData(), NoRData()
700+
return n_NoRData(3)
698701
end
699702
return output_codual, svd_trunc_adjoint
700703
end
@@ -729,7 +732,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U
729732
copy!(A, Ac)
730733
copy!.(USVᴴ, USVᴴc)
731734

732-
return ntuple(Returns(NoRData()), 4)
735+
return n_NoRData(4)
733736
end
734737

735738
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
@@ -756,7 +759,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al
756759
zero!(dU)
757760
zero!(dS)
758761
zero!(dVᴴ)
759-
return NoRData(), NoRData(), NoRData()
762+
return n_NoRData(3)
760763
end
761764
return output_codual, svd_trunc_adjoint
762765
end
@@ -777,7 +780,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al
777780
function svd_trunc_adjoint(::NoRData)
778781
svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind)
779782
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
780-
return ntuple(Returns(NoRData()), 3)
783+
return n_NoRData(3)
781784
end
782785

783786
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint

0 commit comments

Comments
 (0)