@@ -3,7 +3,7 @@ module MatrixAlgebraKitMooncakeExt
33using Mooncake
44using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
55using MatrixAlgebraKit
6- using MatrixAlgebraKit: inv_safe, diagview, copy_input
6+ using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output
77using MatrixAlgebraKit: qr_pullback!, lq_pullback!
88using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
99using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
@@ -26,6 +26,17 @@ function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDu
2626 return CoDual (Ac, dAc), copy_input_pb
2727end
2828
29+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (initialize_output), Any, Any, Any}
30+ function Mooncake. rrule!! (:: CoDual{typeof(initialize_output)} , f_df:: CoDual , A_dA:: CoDual , alg_dalg:: CoDual )
31+ output = initialize_output (Mooncake. primal (f_df), Mooncake. primal (A_dA), Mooncake. primal (alg_dalg))
32+ doutput = Mooncake. zero_tangent (output)
33+ function initialize_output_pb (:: NoRData )
34+ return NoRData (), NoRData (), NoRData (), NoRData ()
35+ end
36+ return CoDual (output, doutput), initialize_output_pb
37+ end
38+
39+
2940# two-argument in-place factorizations like LQ, QR, EIG
3041for (f!, f, pb, adj) in (
3142 (:qr_full! , :qr_full , :qr_pullback! , :qr_adjoint ),
0 commit comments