Skip to content

Commit b98491f

Browse files
committed
use imported DefaultCtx
1 parent b8e08d6 commit b98491f

1 file changed

Lines changed: 20 additions & 20 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using LinearAlgebra
1515

1616
Mooncake.tangent_type(::Type{<:AbstractAlgorithm}) = Mooncake.NoTangent
1717

18-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any}
18+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any}
1919
function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual)
2020
Ac = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA))
2121
Ac_dAc = Mooncake.zero_fcodual(Ac)
@@ -27,7 +27,7 @@ function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDu
2727
return Ac_dAc, copy_input_pb
2828
end
2929

30-
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(initialize_output), Any, Any, Any}
30+
Mooncake.@zero_derivative DefaultCtx Tuple{typeof(initialize_output), Any, Any, Any}
3131
# two-argument in-place factorizations like LQ, QR, EIG
3232
for (f!, f, pb, adj) in (
3333
(:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint),
@@ -41,7 +41,7 @@ for (f!, f, pb, adj) in (
4141
)
4242

4343
@eval begin
44-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, AbstractAlgorithm}
44+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, AbstractAlgorithm}
4545
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
4646
A, dA = arrayify(A_dA)
4747
args = Mooncake.primal(args_dargs)
@@ -63,7 +63,7 @@ for (f!, f, pb, adj) in (
6363
end
6464
return args_dargs, $adj
6565
end
66-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, AbstractAlgorithm}
66+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, AbstractAlgorithm}
6767
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
6868
A, dA = arrayify(A_dA)
6969
output = $f(A, Mooncake.primal(alg_dalg))
@@ -92,7 +92,7 @@ for (f!, f, pb, adj) in (
9292
(:lq_null!, :lq_null, :lq_null_pullback!, :lq_null_adjoint),
9393
)
9494
@eval begin
95-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, AbstractAlgorithm}
95+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, AbstractAlgorithm}
9696
function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
9797
A, dA = arrayify(A_dA)
9898
Ac = copy(A)
@@ -108,7 +108,7 @@ for (f!, f, pb, adj) in (
108108
end
109109
return arg_darg, $adj
110110
end
111-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, AbstractAlgorithm}
111+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, AbstractAlgorithm}
112112
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
113113
A, dA = arrayify(A_dA)
114114
output = $f(A, Mooncake.primal(alg_dalg))
@@ -129,7 +129,7 @@ for (f!, f, f_full, pb, adj) in (
129129
(:eigh_vals!, :eigh_vals, :eigh_full, :eigh_vals_pullback!, :eigh_vals_adjoint),
130130
)
131131
@eval begin
132-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, AbstractAlgorithm}
132+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, AbstractAlgorithm}
133133
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual)
134134
# compute primal
135135
A, dA = arrayify(A_dA)
@@ -147,7 +147,7 @@ for (f!, f, f_full, pb, adj) in (
147147
end
148148
return D_dD, $adj
149149
end
150-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, AbstractAlgorithm}
150+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, AbstractAlgorithm}
151151
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual)
152152
# compute primal
153153
A, dA = arrayify(A_dA)
@@ -182,8 +182,8 @@ for f in (:eig, :eigh)
182182
f_trunc_no_error! = Symbol(f_trunc_no_error, :!)
183183

184184
@eval begin
185-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc!), Any, Any, AbstractAlgorithm}
186-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc), Any, AbstractAlgorithm}
185+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc!), Any, Any, AbstractAlgorithm}
186+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc), Any, AbstractAlgorithm}
187187
function Mooncake.rrule!!(::CoDual{typeof($f_trunc!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual)
188188
# compute primal
189189
A, dA = arrayify(A_dA)
@@ -298,8 +298,8 @@ for f in (:eig, :eigh)
298298

299299
return DVtrunc_dDVtrunc, $f_adjoint!
300300
end
301-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc_no_error!), Any, Any, AbstractAlgorithm}
302-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc_no_error), Any, AbstractAlgorithm}
301+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc_no_error!), Any, Any, AbstractAlgorithm}
302+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc_no_error), Any, AbstractAlgorithm}
303303
function Mooncake.rrule!!(::CoDual{typeof($f_trunc_no_error!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual)
304304
# compute primal
305305
A, dA = arrayify(A_dA)
@@ -415,7 +415,7 @@ for (f!, f) in (
415415
(:svd_compact!, :svd_compact),
416416
)
417417
@eval begin
418-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, AbstractAlgorithm}
418+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, AbstractAlgorithm}
419419
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
420420
A, dA = arrayify(A_dA)
421421
Ac = copy(A)
@@ -450,7 +450,7 @@ for (f!, f) in (
450450
end
451451
return CoDual(output, dUSVᴴ), svd_adjoint
452452
end
453-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, AbstractAlgorithm}
453+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, AbstractAlgorithm}
454454
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual)
455455
A, dA = arrayify(A_dA)
456456
USVᴴ = $f(A, Mooncake.primal(alg_dalg))
@@ -487,7 +487,7 @@ for (f!, f) in (
487487
end
488488
end
489489

490-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals!), Any, Any, AbstractAlgorithm}
490+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals!), Any, Any, AbstractAlgorithm}
491491
function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual)
492492
# compute primal
493493
A, dA = arrayify(A_dA)
@@ -504,7 +504,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
504504
return S_dS, svd_vals_adjoint
505505
end
506506

507-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals), Any, AbstractAlgorithm}
507+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals), Any, AbstractAlgorithm}
508508
function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::CoDual)
509509
# compute primal
510510
A, dA = arrayify(A_dA)
@@ -524,7 +524,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
524524
return S_codual, svd_vals_adjoint
525525
end
526526

527-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc!), Any, Any, AbstractAlgorithm}
527+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc!), Any, Any, AbstractAlgorithm}
528528
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
529529
# compute primal
530530
A, dA = arrayify(A_dA)
@@ -604,7 +604,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
604604
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
605605
end
606606

607-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, AbstractAlgorithm}
607+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, AbstractAlgorithm}
608608
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual)
609609
# compute primal
610610
A, dA = arrayify(A_dA)
@@ -655,7 +655,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
655655
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
656656
end
657657

658-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error!), Any, Any, AbstractAlgorithm}
658+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error!), Any, Any, AbstractAlgorithm}
659659
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
660660
# compute primal
661661
A, dA = arrayify(A_dA)
@@ -731,7 +731,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U
731731
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
732732
end
733733

734-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error), Any, AbstractAlgorithm}
734+
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error), Any, AbstractAlgorithm}
735735
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual)
736736
# compute primal
737737
A, dA = arrayify(A_dA)

0 commit comments

Comments
 (0)