@@ -10,10 +10,10 @@ using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
1010using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback!
1111using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
1212using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
13- using MatrixAlgebraKit: TruncatedAlgorithm
13+ using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm
1414using LinearAlgebra
1515
16- Mooncake. tangent_type (:: Type{<:MatrixAlgebraKit. AbstractAlgorithm} ) = Mooncake. NoTangent
16+ Mooncake. tangent_type (:: Type{<:AbstractAlgorithm} ) = Mooncake. NoTangent
1717
1818@is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (copy_input), Any, Any}
1919function Mooncake. rrule!! (:: CoDual{typeof(copy_input)} , f_df:: CoDual , A_dA:: CoDual )
@@ -41,8 +41,8 @@ for (f!, f, pb, adj) in (
4141 )
4242
4343 @eval begin
44- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f!), Any, Tuple{<: Any , <: Any }, MatrixAlgebraKit . AbstractAlgorithm}
45- function Mooncake. rrule!! (:: CoDual{typeof($f!)} , A_dA:: CoDual , args_dargs:: CoDual , alg_dalg:: CoDual{<:MatrixAlgebraKit. AbstractAlgorithm} )
44+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f!), Any, Tuple{<: Any , <: Any }, AbstractAlgorithm}
45+ function Mooncake. rrule!! (:: CoDual{typeof($f!)} , A_dA:: CoDual , args_dargs:: CoDual , alg_dalg:: CoDual{<:AbstractAlgorithm} )
4646 A, dA = arrayify (A_dA)
4747 args = Mooncake. primal (args_dargs)
4848 dargs = Mooncake. tangent (args_dargs)
@@ -63,8 +63,8 @@ for (f!, f, pb, adj) in (
6363 end
6464 return args_dargs, $ adj
6565 end
66- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), Any, MatrixAlgebraKit . AbstractAlgorithm}
67- function Mooncake. rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual{<:MatrixAlgebraKit. AbstractAlgorithm} )
66+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), Any, AbstractAlgorithm}
67+ function Mooncake. rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual{<:AbstractAlgorithm} )
6868 A, dA = arrayify (A_dA)
6969 output = $ f (A, Mooncake. primal (alg_dalg))
7070 # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
@@ -92,8 +92,8 @@ for (f!, f, pb, adj) in (
9292 (:lq_null! , :lq_null , :lq_null_pullback! , :lq_null_adjoint ),
9393 )
9494 @eval begin
95- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f!), Any, Any, MatrixAlgebraKit . AbstractAlgorithm}
96- function Mooncake. rrule!! (f_df:: CoDual{typeof($f!)} , A_dA:: CoDual , arg_darg:: CoDual , alg_dalg:: CoDual{<:MatrixAlgebraKit. AbstractAlgorithm} )
95+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f!), Any, Any, AbstractAlgorithm}
96+ function Mooncake. rrule!! (f_df:: CoDual{typeof($f!)} , A_dA:: CoDual , arg_darg:: CoDual , alg_dalg:: CoDual{<:AbstractAlgorithm} )
9797 A, dA = arrayify (A_dA)
9898 Ac = copy (A)
9999 arg, darg = arrayify (arg_darg)
@@ -108,8 +108,8 @@ for (f!, f, pb, adj) in (
108108 end
109109 return arg_darg, $ adj
110110 end
111- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), Any, MatrixAlgebraKit . AbstractAlgorithm}
112- function Mooncake. rrule!! (f_df:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual{<:MatrixAlgebraKit. AbstractAlgorithm} )
111+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), Any, AbstractAlgorithm}
112+ function Mooncake. rrule!! (f_df:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual{<:AbstractAlgorithm} )
113113 A, dA = arrayify (A_dA)
114114 output = $ f (A, Mooncake. primal (alg_dalg))
115115 output_codual = CoDual (output, Mooncake. zero_tangent (output))
@@ -129,7 +129,7 @@ for (f!, f, f_full, pb, adj) in (
129129 (:eigh_vals! , :eigh_vals , :eigh_full , :eigh_vals_pullback! , :eigh_vals_adjoint ),
130130 )
131131 @eval begin
132- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f!), Any, Any, MatrixAlgebraKit . AbstractAlgorithm}
132+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f!), Any, Any, AbstractAlgorithm}
133133 function Mooncake. rrule!! (:: CoDual{typeof($f!)} , A_dA:: CoDual , D_dD:: CoDual , alg_dalg:: CoDual )
134134 # compute primal
135135 A, dA = arrayify (A_dA)
@@ -147,7 +147,7 @@ for (f!, f, f_full, pb, adj) in (
147147 end
148148 return D_dD, $ adj
149149 end
150- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), Any, MatrixAlgebraKit . AbstractAlgorithm}
150+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), Any, AbstractAlgorithm}
151151 function Mooncake. rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual )
152152 # compute primal
153153 A, dA = arrayify (A_dA)
@@ -182,8 +182,8 @@ for f in (:eig, :eigh)
182182 f_trunc_no_error! = Symbol (f_trunc_no_error, :! )
183183
184184 @eval begin
185- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f_trunc!), Any, Any, MatrixAlgebraKit . AbstractAlgorithm}
186- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f_trunc), Any, MatrixAlgebraKit . AbstractAlgorithm}
185+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f_trunc!), Any, Any, AbstractAlgorithm}
186+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f_trunc), Any, AbstractAlgorithm}
187187 function Mooncake. rrule!! (:: CoDual{typeof($f_trunc!)} , A_dA:: CoDual , DV_dDV:: CoDual , alg_dalg:: CoDual )
188188 # compute primal
189189 A, dA = arrayify (A_dA)
@@ -298,8 +298,8 @@ for f in (:eig, :eigh)
298298
299299 return DVtrunc_dDVtrunc, $ f_adjoint!
300300 end
301- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f_trunc_no_error!), Any, Any, MatrixAlgebraKit . AbstractAlgorithm}
302- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f_trunc_no_error), Any, MatrixAlgebraKit . AbstractAlgorithm}
301+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f_trunc_no_error!), Any, Any, AbstractAlgorithm}
302+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f_trunc_no_error), Any, AbstractAlgorithm}
303303 function Mooncake. rrule!! (:: CoDual{typeof($f_trunc_no_error!)} , A_dA:: CoDual , DV_dDV:: CoDual , alg_dalg:: CoDual )
304304 # compute primal
305305 A, dA = arrayify (A_dA)
@@ -415,7 +415,7 @@ for (f!, f) in (
415415 (:svd_compact! , :svd_compact ),
416416 )
417417 @eval begin
418- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f!), Any, Tuple{<: Any , <: Any , <: Any }, MatrixAlgebraKit . AbstractAlgorithm}
418+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f!), Any, Tuple{<: Any , <: Any , <: Any }, AbstractAlgorithm}
419419 function Mooncake. rrule!! (:: CoDual{typeof($f!)} , A_dA:: CoDual , USVᴴ_dUSVᴴ:: CoDual , alg_dalg:: CoDual )
420420 A, dA = arrayify (A_dA)
421421 Ac = copy (A)
@@ -450,7 +450,7 @@ for (f!, f) in (
450450 end
451451 return CoDual (output, dUSVᴴ), svd_adjoint
452452 end
453- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), Any, MatrixAlgebraKit . AbstractAlgorithm}
453+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), Any, AbstractAlgorithm}
454454 function Mooncake. rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual )
455455 A, dA = arrayify (A_dA)
456456 USVᴴ = $ f (A, Mooncake. primal (alg_dalg))
@@ -487,7 +487,7 @@ for (f!, f) in (
487487 end
488488end
489489
490- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (svd_vals!), Any, Any, MatrixAlgebraKit . AbstractAlgorithm}
490+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (svd_vals!), Any, Any, AbstractAlgorithm}
491491function Mooncake. rrule!! (:: CoDual{typeof(svd_vals!)} , A_dA:: CoDual , S_dS:: CoDual , alg_dalg:: CoDual )
492492 # compute primal
493493 A, dA = arrayify (A_dA)
@@ -504,7 +504,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
504504 return S_dS, svd_vals_adjoint
505505end
506506
507- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (svd_vals), Any, MatrixAlgebraKit . AbstractAlgorithm}
507+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (svd_vals), Any, AbstractAlgorithm}
508508function Mooncake. rrule!! (:: CoDual{typeof(svd_vals)} , A_dA:: CoDual , alg_dalg:: CoDual )
509509 # compute primal
510510 A, dA = arrayify (A_dA)
@@ -524,7 +524,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
524524 return S_codual, svd_vals_adjoint
525525end
526526
527- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (svd_trunc!), Any, Any, MatrixAlgebraKit . AbstractAlgorithm}
527+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (svd_trunc!), Any, Any, AbstractAlgorithm}
528528function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc!)} , A_dA:: CoDual , USVᴴ_dUSVᴴ:: CoDual , alg_dalg:: CoDual )
529529 # compute primal
530530 A, dA = arrayify (A_dA)
@@ -604,7 +604,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
604604 return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
605605end
606606
607- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (svd_trunc), Any, MatrixAlgebraKit . AbstractAlgorithm}
607+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (svd_trunc), Any, AbstractAlgorithm}
608608function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc)} , A_dA:: CoDual , alg_dalg:: CoDual )
609609 # compute primal
610610 A, dA = arrayify (A_dA)
@@ -655,7 +655,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
655655 return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
656656end
657657
658- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (svd_trunc_no_error!), Any, Any, MatrixAlgebraKit . AbstractAlgorithm}
658+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (svd_trunc_no_error!), Any, Any, AbstractAlgorithm}
659659function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc_no_error!)} , A_dA:: CoDual , USVᴴ_dUSVᴴ:: CoDual , alg_dalg:: CoDual )
660660 # compute primal
661661 A, dA = arrayify (A_dA)
@@ -731,7 +731,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U
731731 return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
732732end
733733
734- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (svd_trunc_no_error), Any, MatrixAlgebraKit . AbstractAlgorithm}
734+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (svd_trunc_no_error), Any, AbstractAlgorithm}
735735function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc_no_error)} , A_dA:: CoDual , alg_dalg:: CoDual )
736736 # compute primal
737737 A, dA = arrayify (A_dA)
0 commit comments