Skip to content

Commit 6f99876

Browse files
committed
Custom rule for initialize_output
1 parent 045b79d commit 6f99876

1 file changed

Lines changed: 12 additions & 1 deletion

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module MatrixAlgebraKitMooncakeExt
33
using Mooncake
44
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
55
using MatrixAlgebraKit
6-
using MatrixAlgebraKit: inv_safe, diagview, copy_input
6+
using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output
77
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
88
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
99
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
@@ -26,6 +26,17 @@ function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDu
2626
return CoDual(Ac, dAc), copy_input_pb
2727
end
2828

29+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(initialize_output), Any, Any, Any}
30+
function Mooncake.rrule!!(::CoDual{typeof(initialize_output)}, f_df::CoDual, A_dA::CoDual, alg_dalg::CoDual)
31+
output = initialize_output(Mooncake.primal(f_df), Mooncake.primal(A_dA), Mooncake.primal(alg_dalg))
32+
doutput = Mooncake.zero_tangent(output)
33+
function initialize_output_pb(::NoRData)
34+
return NoRData(), NoRData(), NoRData(), NoRData()
35+
end
36+
return CoDual(output, doutput), initialize_output_pb
37+
end
38+
39+
2940
# two-argument in-place factorizations like LQ, QR, EIG
3041
for (f!, f, pb, adj) in (
3142
(:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint),

0 commit comments

Comments
 (0)