Skip to content

Commit b8e08d6

Browse files
committed
import AbstractAlgorithm
1 parent a349aef commit b8e08d6

1 file changed

Lines changed: 24 additions & 24 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
1010
using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback!
1111
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
1212
using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
13-
using MatrixAlgebraKit: TruncatedAlgorithm
13+
using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm
1414
using LinearAlgebra
1515

16-
Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.NoTangent
16+
Mooncake.tangent_type(::Type{<:AbstractAlgorithm}) = Mooncake.NoTangent
1717

1818
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any}
1919
function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual)
@@ -41,8 +41,8 @@ for (f!, f, pb, adj) in (
4141
)
4242

4343
@eval begin
44-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
45-
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
44+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, AbstractAlgorithm}
45+
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
4646
A, dA = arrayify(A_dA)
4747
args = Mooncake.primal(args_dargs)
4848
dargs = Mooncake.tangent(args_dargs)
@@ -63,8 +63,8 @@ for (f!, f, pb, adj) in (
6363
end
6464
return args_dargs, $adj
6565
end
66-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
67-
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
66+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, AbstractAlgorithm}
67+
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
6868
A, dA = arrayify(A_dA)
6969
output = $f(A, Mooncake.primal(alg_dalg))
7070
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
@@ -92,8 +92,8 @@ for (f!, f, pb, adj) in (
9292
(:lq_null!, :lq_null, :lq_null_pullback!, :lq_null_adjoint),
9393
)
9494
@eval begin
95-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
96-
function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
95+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, AbstractAlgorithm}
96+
function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
9797
A, dA = arrayify(A_dA)
9898
Ac = copy(A)
9999
arg, darg = arrayify(arg_darg)
@@ -108,8 +108,8 @@ for (f!, f, pb, adj) in (
108108
end
109109
return arg_darg, $adj
110110
end
111-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
112-
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
111+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, AbstractAlgorithm}
112+
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
113113
A, dA = arrayify(A_dA)
114114
output = $f(A, Mooncake.primal(alg_dalg))
115115
output_codual = CoDual(output, Mooncake.zero_tangent(output))
@@ -129,7 +129,7 @@ for (f!, f, f_full, pb, adj) in (
129129
(:eigh_vals!, :eigh_vals, :eigh_full, :eigh_vals_pullback!, :eigh_vals_adjoint),
130130
)
131131
@eval begin
132-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
132+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, AbstractAlgorithm}
133133
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual)
134134
# compute primal
135135
A, dA = arrayify(A_dA)
@@ -147,7 +147,7 @@ for (f!, f, f_full, pb, adj) in (
147147
end
148148
return D_dD, $adj
149149
end
150-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
150+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, AbstractAlgorithm}
151151
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual)
152152
# compute primal
153153
A, dA = arrayify(A_dA)
@@ -182,8 +182,8 @@ for f in (:eig, :eigh)
182182
f_trunc_no_error! = Symbol(f_trunc_no_error, :!)
183183

184184
@eval begin
185-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
186-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm}
185+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc!), Any, Any, AbstractAlgorithm}
186+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc), Any, AbstractAlgorithm}
187187
function Mooncake.rrule!!(::CoDual{typeof($f_trunc!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual)
188188
# compute primal
189189
A, dA = arrayify(A_dA)
@@ -298,8 +298,8 @@ for f in (:eig, :eigh)
298298

299299
return DVtrunc_dDVtrunc, $f_adjoint!
300300
end
301-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc_no_error!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
302-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc_no_error), Any, MatrixAlgebraKit.AbstractAlgorithm}
301+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc_no_error!), Any, Any, AbstractAlgorithm}
302+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc_no_error), Any, AbstractAlgorithm}
303303
function Mooncake.rrule!!(::CoDual{typeof($f_trunc_no_error!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual)
304304
# compute primal
305305
A, dA = arrayify(A_dA)
@@ -415,7 +415,7 @@ for (f!, f) in (
415415
(:svd_compact!, :svd_compact),
416416
)
417417
@eval begin
418-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
418+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, AbstractAlgorithm}
419419
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
420420
A, dA = arrayify(A_dA)
421421
Ac = copy(A)
@@ -450,7 +450,7 @@ for (f!, f) in (
450450
end
451451
return CoDual(output, dUSVᴴ), svd_adjoint
452452
end
453-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
453+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, AbstractAlgorithm}
454454
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual)
455455
A, dA = arrayify(A_dA)
456456
USVᴴ = $f(A, Mooncake.primal(alg_dalg))
@@ -487,7 +487,7 @@ for (f!, f) in (
487487
end
488488
end
489489

490-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
490+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals!), Any, Any, AbstractAlgorithm}
491491
function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual)
492492
# compute primal
493493
A, dA = arrayify(A_dA)
@@ -504,7 +504,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
504504
return S_dS, svd_vals_adjoint
505505
end
506506

507-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm}
507+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals), Any, AbstractAlgorithm}
508508
function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::CoDual)
509509
# compute primal
510510
A, dA = arrayify(A_dA)
@@ -524,7 +524,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
524524
return S_codual, svd_vals_adjoint
525525
end
526526

527-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
527+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc!), Any, Any, AbstractAlgorithm}
528528
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
529529
# compute primal
530530
A, dA = arrayify(A_dA)
@@ -604,7 +604,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
604604
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
605605
end
606606

607-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm}
607+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, AbstractAlgorithm}
608608
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual)
609609
# compute primal
610610
A, dA = arrayify(A_dA)
@@ -655,7 +655,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
655655
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
656656
end
657657

658-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
658+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error!), Any, Any, AbstractAlgorithm}
659659
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
660660
# compute primal
661661
A, dA = arrayify(A_dA)
@@ -731,7 +731,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U
731731
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
732732
end
733733

734-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error), Any, MatrixAlgebraKit.AbstractAlgorithm}
734+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error), Any, AbstractAlgorithm}
735735
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual)
736736
# compute primal
737737
A, dA = arrayify(A_dA)

0 commit comments

Comments
 (0)