Skip to content

Commit 37fb719

Browse files
committed
add utility @is_rev_primitive
1 parent b98491f commit 37fb719

1 file changed

Lines changed: 24 additions & 20 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module MatrixAlgebraKitMooncakeExt
22

33
using Mooncake
4-
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
4+
using Mooncake: CoDual, Dual, NoRData, rrule!!, frule!!, arrayify
55
using MatrixAlgebraKit
66
using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output, zero!
77
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
@@ -15,7 +15,11 @@ using LinearAlgebra
1515

1616
Mooncake.tangent_type(::Type{<:AbstractAlgorithm}) = Mooncake.NoTangent
1717

18-
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any}
18+
macro is_rev_primitive(sig)
19+
return esc(:(Mooncake.@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode $sig))
20+
end
21+
22+
@is_rev_primitive Tuple{typeof(copy_input), Any, Any}
1923
function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual)
2024
Ac = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA))
2125
Ac_dAc = Mooncake.zero_fcodual(Ac)
@@ -41,7 +45,7 @@ for (f!, f, pb, adj) in (
4145
)
4246

4347
@eval begin
44-
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, AbstractAlgorithm}
48+
@is_rev_primitive Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, AbstractAlgorithm}
4549
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
4650
A, dA = arrayify(A_dA)
4751
args = Mooncake.primal(args_dargs)
@@ -63,7 +67,7 @@ for (f!, f, pb, adj) in (
6367
end
6468
return args_dargs, $adj
6569
end
66-
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, AbstractAlgorithm}
70+
@is_rev_primitive Tuple{typeof($f), Any, AbstractAlgorithm}
6771
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
6872
A, dA = arrayify(A_dA)
6973
output = $f(A, Mooncake.primal(alg_dalg))
@@ -92,7 +96,7 @@ for (f!, f, pb, adj) in (
9296
(:lq_null!, :lq_null, :lq_null_pullback!, :lq_null_adjoint),
9397
)
9498
@eval begin
95-
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, AbstractAlgorithm}
99+
@is_rev_primitive Tuple{typeof($f!), Any, Any, AbstractAlgorithm}
96100
function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
97101
A, dA = arrayify(A_dA)
98102
Ac = copy(A)
@@ -108,7 +112,7 @@ for (f!, f, pb, adj) in (
108112
end
109113
return arg_darg, $adj
110114
end
111-
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, AbstractAlgorithm}
115+
@is_rev_primitive Tuple{typeof($f), Any, AbstractAlgorithm}
112116
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
113117
A, dA = arrayify(A_dA)
114118
output = $f(A, Mooncake.primal(alg_dalg))
@@ -129,7 +133,7 @@ for (f!, f, f_full, pb, adj) in (
129133
(:eigh_vals!, :eigh_vals, :eigh_full, :eigh_vals_pullback!, :eigh_vals_adjoint),
130134
)
131135
@eval begin
132-
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, AbstractAlgorithm}
136+
@is_rev_primitive Tuple{typeof($f!), Any, Any, AbstractAlgorithm}
133137
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual)
134138
# compute primal
135139
A, dA = arrayify(A_dA)
@@ -147,7 +151,7 @@ for (f!, f, f_full, pb, adj) in (
147151
end
148152
return D_dD, $adj
149153
end
150-
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, AbstractAlgorithm}
154+
@is_rev_primitive Tuple{typeof($f), Any, AbstractAlgorithm}
151155
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual)
152156
# compute primal
153157
A, dA = arrayify(A_dA)
@@ -182,8 +186,8 @@ for f in (:eig, :eigh)
182186
f_trunc_no_error! = Symbol(f_trunc_no_error, :!)
183187

184188
@eval begin
185-
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc!), Any, Any, AbstractAlgorithm}
186-
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc), Any, AbstractAlgorithm}
189+
@is_rev_primitive Tuple{typeof($f_trunc!), Any, Any, AbstractAlgorithm}
190+
@is_rev_primitive Tuple{typeof($f_trunc), Any, AbstractAlgorithm}
187191
function Mooncake.rrule!!(::CoDual{typeof($f_trunc!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual)
188192
# compute primal
189193
A, dA = arrayify(A_dA)
@@ -298,8 +302,8 @@ for f in (:eig, :eigh)
298302

299303
return DVtrunc_dDVtrunc, $f_adjoint!
300304
end
301-
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc_no_error!), Any, Any, AbstractAlgorithm}
302-
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc_no_error), Any, AbstractAlgorithm}
305+
@is_rev_primitive Tuple{typeof($f_trunc_no_error!), Any, Any, AbstractAlgorithm}
306+
@is_rev_primitive Tuple{typeof($f_trunc_no_error), Any, AbstractAlgorithm}
303307
function Mooncake.rrule!!(::CoDual{typeof($f_trunc_no_error!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual)
304308
# compute primal
305309
A, dA = arrayify(A_dA)
@@ -415,7 +419,7 @@ for (f!, f) in (
415419
(:svd_compact!, :svd_compact),
416420
)
417421
@eval begin
418-
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, AbstractAlgorithm}
422+
@is_rev_primitive Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, AbstractAlgorithm}
419423
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
420424
A, dA = arrayify(A_dA)
421425
Ac = copy(A)
@@ -450,7 +454,7 @@ for (f!, f) in (
450454
end
451455
return CoDual(output, dUSVᴴ), svd_adjoint
452456
end
453-
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, AbstractAlgorithm}
457+
@is_rev_primitive Tuple{typeof($f), Any, AbstractAlgorithm}
454458
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual)
455459
A, dA = arrayify(A_dA)
456460
USVᴴ = $f(A, Mooncake.primal(alg_dalg))
@@ -487,7 +491,7 @@ for (f!, f) in (
487491
end
488492
end
489493

490-
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals!), Any, Any, AbstractAlgorithm}
494+
@is_rev_primitive Tuple{typeof(svd_vals!), Any, Any, AbstractAlgorithm}
491495
function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual)
492496
# compute primal
493497
A, dA = arrayify(A_dA)
@@ -504,7 +508,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
504508
return S_dS, svd_vals_adjoint
505509
end
506510

507-
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals), Any, AbstractAlgorithm}
511+
@is_rev_primitive Tuple{typeof(svd_vals), Any, AbstractAlgorithm}
508512
function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::CoDual)
509513
# compute primal
510514
A, dA = arrayify(A_dA)
@@ -524,7 +528,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
524528
return S_codual, svd_vals_adjoint
525529
end
526530

527-
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc!), Any, Any, AbstractAlgorithm}
531+
@is_rev_primitive Tuple{typeof(svd_trunc!), Any, Any, AbstractAlgorithm}
528532
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
529533
# compute primal
530534
A, dA = arrayify(A_dA)
@@ -604,7 +608,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
604608
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
605609
end
606610

607-
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, AbstractAlgorithm}
611+
@is_rev_primitive Tuple{typeof(svd_trunc), Any, AbstractAlgorithm}
608612
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual)
609613
# compute primal
610614
A, dA = arrayify(A_dA)
@@ -655,7 +659,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
655659
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
656660
end
657661

658-
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error!), Any, Any, AbstractAlgorithm}
662+
@is_rev_primitive Tuple{typeof(svd_trunc_no_error!), Any, Any, AbstractAlgorithm}
659663
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
660664
# compute primal
661665
A, dA = arrayify(A_dA)
@@ -731,7 +735,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U
731735
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
732736
end
733737

734-
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error), Any, AbstractAlgorithm}
738+
@is_rev_primitive Tuple{typeof(svd_trunc_no_error), Any, AbstractAlgorithm}
735739
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual)
736740
# compute primal
737741
A, dA = arrayify(A_dA)

0 commit comments

Comments
 (0)