11module MatrixAlgebraKitMooncakeExt
22
33using Mooncake
4- using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
4+ using Mooncake: CoDual, Dual, NoRData, rrule!!, frule!!, arrayify
55using MatrixAlgebraKit
66using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output, zero!
77using MatrixAlgebraKit: qr_pullback!, lq_pullback!
@@ -15,7 +15,11 @@ using LinearAlgebra
1515
1616Mooncake. tangent_type (:: Type{<:AbstractAlgorithm} ) = Mooncake. NoTangent
1717
18- @is_primitive DefaultCtx Mooncake. ReverseMode Tuple{typeof (copy_input), Any, Any}
18+ macro is_rev_primitive (sig)
19+ return esc (:(Mooncake. @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode $ sig))
20+ end
21+
22+ @is_rev_primitive Tuple{typeof (copy_input), Any, Any}
1923function Mooncake. rrule!! (:: CoDual{typeof(copy_input)} , f_df:: CoDual , A_dA:: CoDual )
2024 Ac = copy_input (Mooncake. primal (f_df), Mooncake. primal (A_dA))
2125 Ac_dAc = Mooncake. zero_fcodual (Ac)
@@ -41,7 +45,7 @@ for (f!, f, pb, adj) in (
4145 )
4246
4347 @eval begin
44- @is_primitive DefaultCtx Mooncake . ReverseMode Tuple{typeof ($ f!), Any, Tuple{<: Any , <: Any }, AbstractAlgorithm}
48+ @is_rev_primitive Tuple{typeof ($ f!), Any, Tuple{<: Any , <: Any }, AbstractAlgorithm}
4549 function Mooncake. rrule!! (:: CoDual{typeof($f!)} , A_dA:: CoDual , args_dargs:: CoDual , alg_dalg:: CoDual{<:AbstractAlgorithm} )
4650 A, dA = arrayify (A_dA)
4751 args = Mooncake. primal (args_dargs)
@@ -63,7 +67,7 @@ for (f!, f, pb, adj) in (
6367 end
6468 return args_dargs, $ adj
6569 end
66- @is_primitive DefaultCtx Mooncake . ReverseMode Tuple{typeof ($ f), Any, AbstractAlgorithm}
70+ @is_rev_primitive Tuple{typeof ($ f), Any, AbstractAlgorithm}
6771 function Mooncake. rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual{<:AbstractAlgorithm} )
6872 A, dA = arrayify (A_dA)
6973 output = $ f (A, Mooncake. primal (alg_dalg))
@@ -92,7 +96,7 @@ for (f!, f, pb, adj) in (
9296 (:lq_null! , :lq_null , :lq_null_pullback! , :lq_null_adjoint ),
9397 )
9498 @eval begin
95- @is_primitive DefaultCtx Mooncake . ReverseMode Tuple{typeof ($ f!), Any, Any, AbstractAlgorithm}
99+ @is_rev_primitive Tuple{typeof ($ f!), Any, Any, AbstractAlgorithm}
96100 function Mooncake. rrule!! (f_df:: CoDual{typeof($f!)} , A_dA:: CoDual , arg_darg:: CoDual , alg_dalg:: CoDual{<:AbstractAlgorithm} )
97101 A, dA = arrayify (A_dA)
98102 Ac = copy (A)
@@ -108,7 +112,7 @@ for (f!, f, pb, adj) in (
108112 end
109113 return arg_darg, $ adj
110114 end
111- @is_primitive DefaultCtx Mooncake . ReverseMode Tuple{typeof ($ f), Any, AbstractAlgorithm}
115+ @is_rev_primitive Tuple{typeof ($ f), Any, AbstractAlgorithm}
112116 function Mooncake. rrule!! (f_df:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual{<:AbstractAlgorithm} )
113117 A, dA = arrayify (A_dA)
114118 output = $ f (A, Mooncake. primal (alg_dalg))
@@ -129,7 +133,7 @@ for (f!, f, f_full, pb, adj) in (
129133 (:eigh_vals! , :eigh_vals , :eigh_full , :eigh_vals_pullback! , :eigh_vals_adjoint ),
130134 )
131135 @eval begin
132- @is_primitive DefaultCtx Mooncake . ReverseMode Tuple{typeof ($ f!), Any, Any, AbstractAlgorithm}
136+ @is_rev_primitive Tuple{typeof ($ f!), Any, Any, AbstractAlgorithm}
133137 function Mooncake. rrule!! (:: CoDual{typeof($f!)} , A_dA:: CoDual , D_dD:: CoDual , alg_dalg:: CoDual )
134138 # compute primal
135139 A, dA = arrayify (A_dA)
@@ -147,7 +151,7 @@ for (f!, f, f_full, pb, adj) in (
147151 end
148152 return D_dD, $ adj
149153 end
150- @is_primitive DefaultCtx Mooncake . ReverseMode Tuple{typeof ($ f), Any, AbstractAlgorithm}
154+ @is_rev_primitive Tuple{typeof ($ f), Any, AbstractAlgorithm}
151155 function Mooncake. rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual )
152156 # compute primal
153157 A, dA = arrayify (A_dA)
@@ -182,8 +186,8 @@ for f in (:eig, :eigh)
182186 f_trunc_no_error! = Symbol (f_trunc_no_error, :! )
183187
184188 @eval begin
185- @is_primitive DefaultCtx Mooncake . ReverseMode Tuple{typeof ($ f_trunc!), Any, Any, AbstractAlgorithm}
186- @is_primitive DefaultCtx Mooncake . ReverseMode Tuple{typeof ($ f_trunc), Any, AbstractAlgorithm}
189+ @is_rev_primitive Tuple{typeof ($ f_trunc!), Any, Any, AbstractAlgorithm}
190+ @is_rev_primitive Tuple{typeof ($ f_trunc), Any, AbstractAlgorithm}
187191 function Mooncake. rrule!! (:: CoDual{typeof($f_trunc!)} , A_dA:: CoDual , DV_dDV:: CoDual , alg_dalg:: CoDual )
188192 # compute primal
189193 A, dA = arrayify (A_dA)
@@ -298,8 +302,8 @@ for f in (:eig, :eigh)
298302
299303 return DVtrunc_dDVtrunc, $ f_adjoint!
300304 end
301- @is_primitive DefaultCtx Mooncake . ReverseMode Tuple{typeof ($ f_trunc_no_error!), Any, Any, AbstractAlgorithm}
302- @is_primitive DefaultCtx Mooncake . ReverseMode Tuple{typeof ($ f_trunc_no_error), Any, AbstractAlgorithm}
305+ @is_rev_primitive Tuple{typeof ($ f_trunc_no_error!), Any, Any, AbstractAlgorithm}
306+ @is_rev_primitive Tuple{typeof ($ f_trunc_no_error), Any, AbstractAlgorithm}
303307 function Mooncake. rrule!! (:: CoDual{typeof($f_trunc_no_error!)} , A_dA:: CoDual , DV_dDV:: CoDual , alg_dalg:: CoDual )
304308 # compute primal
305309 A, dA = arrayify (A_dA)
@@ -415,7 +419,7 @@ for (f!, f) in (
415419 (:svd_compact! , :svd_compact ),
416420 )
417421 @eval begin
418- @is_primitive DefaultCtx Mooncake . ReverseMode Tuple{typeof ($ f!), Any, Tuple{<: Any , <: Any , <: Any }, AbstractAlgorithm}
422+ @is_rev_primitive Tuple{typeof ($ f!), Any, Tuple{<: Any , <: Any , <: Any }, AbstractAlgorithm}
419423 function Mooncake. rrule!! (:: CoDual{typeof($f!)} , A_dA:: CoDual , USVᴴ_dUSVᴴ:: CoDual , alg_dalg:: CoDual )
420424 A, dA = arrayify (A_dA)
421425 Ac = copy (A)
@@ -450,7 +454,7 @@ for (f!, f) in (
450454 end
451455 return CoDual (output, dUSVᴴ), svd_adjoint
452456 end
453- @is_primitive DefaultCtx Mooncake . ReverseMode Tuple{typeof ($ f), Any, AbstractAlgorithm}
457+ @is_rev_primitive Tuple{typeof ($ f), Any, AbstractAlgorithm}
454458 function Mooncake. rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual )
455459 A, dA = arrayify (A_dA)
456460 USVᴴ = $ f (A, Mooncake. primal (alg_dalg))
@@ -487,7 +491,7 @@ for (f!, f) in (
487491 end
488492end
489493
490- @is_primitive DefaultCtx Mooncake . ReverseMode Tuple{typeof (svd_vals!), Any, Any, AbstractAlgorithm}
494+ @is_rev_primitive Tuple{typeof (svd_vals!), Any, Any, AbstractAlgorithm}
491495function Mooncake. rrule!! (:: CoDual{typeof(svd_vals!)} , A_dA:: CoDual , S_dS:: CoDual , alg_dalg:: CoDual )
492496 # compute primal
493497 A, dA = arrayify (A_dA)
@@ -504,7 +508,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
504508 return S_dS, svd_vals_adjoint
505509end
506510
507- @is_primitive DefaultCtx Mooncake . ReverseMode Tuple{typeof (svd_vals), Any, AbstractAlgorithm}
511+ @is_rev_primitive Tuple{typeof (svd_vals), Any, AbstractAlgorithm}
508512function Mooncake. rrule!! (:: CoDual{typeof(svd_vals)} , A_dA:: CoDual , alg_dalg:: CoDual )
509513 # compute primal
510514 A, dA = arrayify (A_dA)
@@ -524,7 +528,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
524528 return S_codual, svd_vals_adjoint
525529end
526530
527- @is_primitive DefaultCtx Mooncake . ReverseMode Tuple{typeof (svd_trunc!), Any, Any, AbstractAlgorithm}
531+ @is_rev_primitive Tuple{typeof (svd_trunc!), Any, Any, AbstractAlgorithm}
528532function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc!)} , A_dA:: CoDual , USVᴴ_dUSVᴴ:: CoDual , alg_dalg:: CoDual )
529533 # compute primal
530534 A, dA = arrayify (A_dA)
@@ -604,7 +608,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
604608 return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
605609end
606610
607- @is_primitive DefaultCtx Mooncake . ReverseMode Tuple{typeof (svd_trunc), Any, AbstractAlgorithm}
611+ @is_rev_primitive Tuple{typeof (svd_trunc), Any, AbstractAlgorithm}
608612function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc)} , A_dA:: CoDual , alg_dalg:: CoDual )
609613 # compute primal
610614 A, dA = arrayify (A_dA)
@@ -655,7 +659,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
655659 return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
656660end
657661
658- @is_primitive DefaultCtx Mooncake . ReverseMode Tuple{typeof (svd_trunc_no_error!), Any, Any, AbstractAlgorithm}
662+ @is_rev_primitive Tuple{typeof (svd_trunc_no_error!), Any, Any, AbstractAlgorithm}
659663function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc_no_error!)} , A_dA:: CoDual , USVᴴ_dUSVᴴ:: CoDual , alg_dalg:: CoDual )
660664 # compute primal
661665 A, dA = arrayify (A_dA)
@@ -731,7 +735,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U
731735 return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
732736end
733737
734- @is_primitive DefaultCtx Mooncake . ReverseMode Tuple{typeof (svd_trunc_no_error), Any, AbstractAlgorithm}
738+ @is_rev_primitive Tuple{typeof (svd_trunc_no_error), Any, AbstractAlgorithm}
735739function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc_no_error)} , A_dA:: CoDual , alg_dalg:: CoDual )
736740 # compute primal
737741 A, dA = arrayify (A_dA)
0 commit comments