@@ -65,7 +65,7 @@ make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(
6565
6666# no `alg` argument
6767function _get_copying_derivative (f_c, rrule, A, ΔA, args, Δargs, :: Nothing , rdata)
68- dA_copy = make_mooncake_tangent (copy (ΔA))
68+ dA_copy = make_mooncake_fdata (copy (ΔA))
6969 A_copy = copy (A)
7070 dargs_copy = Δargs isa Tuple ? make_mooncake_fdata .(deepcopy (Δargs)) : make_mooncake_fdata (deepcopy (Δargs))
7171 copy_out, copy_pb!! = rrule (Mooncake. CoDual (f_c, Mooncake. NoFData ()), Mooncake. CoDual (A_copy, dA_copy), Mooncake. CoDual (args, dargs_copy))
7575
7676# `alg` argument
7777function _get_copying_derivative (f_c, rrule, A, ΔA, args, Δargs, alg, rdata)
78- dA_copy = make_mooncake_tangent (copy (ΔA))
78+ dA_copy = make_mooncake_fdata (copy (ΔA))
7979 A_copy = copy (A)
8080 dargs_copy = Δargs isa Tuple ? make_mooncake_fdata .(deepcopy (Δargs)) : make_mooncake_fdata (deepcopy (Δargs))
8181 copy_out, copy_pb!! = rrule (Mooncake. CoDual (f_c, Mooncake. NoFData ()), Mooncake. CoDual (A_copy, dA_copy), Mooncake. CoDual (args, dargs_copy), Mooncake. CoDual (alg, Mooncake. NoFData ()))
@@ -84,7 +84,7 @@ function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata)
8484end
8585
8686function _get_inplace_derivative (f!, A, ΔA, args, Δargs, :: Nothing , rdata)
87- dA_inplace = make_mooncake_tangent (copy (ΔA))
87+ dA_inplace = make_mooncake_fdata (copy (ΔA))
8888 A_inplace = copy (A)
8989 dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata .(deepcopy (Δargs)) : make_mooncake_fdata (deepcopy (Δargs))
9090 # not every f! has a handwritten rrule!!
@@ -103,7 +103,7 @@ function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata)
103103end
104104
105105function _get_inplace_derivative (f!, A, ΔA, args, Δargs, alg, rdata)
106- dA_inplace = make_mooncake_tangent (copy (ΔA))
106+ dA_inplace = make_mooncake_fdata (copy (ΔA))
107107 A_inplace = copy (A)
108108 dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata .(deepcopy (Δargs)) : make_mooncake_fdata (deepcopy (Δargs))
109109 # not every f! has a handwritten rrule!!
@@ -143,9 +143,9 @@ function test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Moo
143143 sig = isnothing (alg) ? Tuple{typeof (f_c), typeof (A), typeof (args)} : Tuple{typeof (f_c), typeof (A), typeof (args), typeof (alg)}
144144 rvs_interp = Mooncake. get_interpreter (Mooncake. ReverseMode)
145145 rrule = Mooncake. build_rrule (rvs_interp, sig)
146- ΔA = randn! (similar (A))
146+ ΔA = A isa Diagonal ? Diagonal ( randn! ( similar (A . diag))) : randn! (similar (A))
147147
148- dA_copy = _get_copying_derivative (f_c, rrule, A, ΔA, args, Δargs, alg, rdata)
148+ dA_copy = _get_copying_derivative (f_c, rrule, A, ΔA, args, Δargs, alg, rdata)
149149 dA_inplace = _get_inplace_derivative (f!, A, ΔA, args, Δargs, alg, rdata)
150150
151151 dA_inplace_ = Mooncake. arrayify (A, dA_inplace)[2 ]
0 commit comments