@@ -62,113 +62,107 @@ make_mooncake_fdata(x) = make_mooncake_tangent(x)
6262make_mooncake_fdata (x:: Diagonal ) = Mooncake. FData ((diag = make_mooncake_tangent (x. diag),))
6363make_mooncake_fdata (x:: Tuple ) = map (make_mooncake_fdata, x)
6464
65+ # copies a preset tangent into a Mooncake CoDual
66+ # for use in the pullback.
67+ function copy_tangent (var:: Mooncake.CoDual , Δargs)
68+ dargs = make_mooncake_fdata (deepcopy (Δargs))
69+ copyto! (Mooncake. tangent (var), dargs)
70+ return
71+ end
72+
73+ function copy_tangent (var:: Mooncake.CoDual , Δargs:: Tuple )
74+ dargs = make_mooncake_fdata .(deepcopy (Δargs))
75+ for (var_tangent, darg) in zip (Mooncake. tangent (var), dargs)
76+ if var_tangent isa Mooncake. FData
77+ for (var_f, darg_f) in zip (Mooncake. _fields (var_tangent), Mooncake. _fields (darg))
78+ copyto! (var_f, darg_f)
79+ end
80+ else
81+ copyto! (var_tangent, darg)
82+ end
83+ end
84+ return
85+ end
86+
6587# no `alg` argument
6688function _get_copying_derivative (f, rrule, A, ΔA, args, Δargs, :: Nothing , rdata)
6789 dA_copy = make_mooncake_fdata (copy (ΔA))
6890 A_copy = copy (A)
69- dargs_copy = Δargs isa Tuple ? make_mooncake_fdata .(deepcopy (Δargs)) : make_mooncake_fdata (deepcopy (Δargs))
70- copy_out, copy_pb!! = rrule (Mooncake. CoDual (f, Mooncake. NoFData ()), Mooncake. CoDual (A_copy, dA_copy))
71- if args isa Tuple
72- for (copy_out_, dargs_copy_) in zip (Mooncake. tangent (copy_out), dargs_copy)
73- if copy_out_ isa Matrix
74- copyto! (copy_out_, dargs_copy_)
75- elseif copy_out_ isa Mooncake. FData
76- for (c_f, a_f) in zip (Mooncake. _fields (copy_out_), Mooncake. _fields (dargs_copy_))
77- copyto! (c_f, a_f)
78- end
79- end
80- end
81- else
82- copyto! (Mooncake. tangent (copy_out), dargs_copy)
83- end
91+ A_dA = Mooncake. CoDual (A_copy, dA_copy)
92+ copy_out, copy_pb!! = rrule (Mooncake. CoDual (f, Mooncake. NoFData ()), A_dA)
93+ # copy Δargs into tangent of the output variable for the pullback check
94+ copy_tangent (copy_out, Δargs)
8495 copy_pb!! (rdata)
96+ @test Mooncake. primal (A_dA) == A
8597 return dA_copy, Mooncake. tangent (copy_out)
8698end
8799
88100# `alg` argument
89101function _get_copying_derivative (f, rrule, A, ΔA, args, Δargs, alg, rdata)
90102 dA_copy = make_mooncake_fdata (copy (ΔA))
91103 A_copy = copy (A)
92- dargs_copy = Δargs isa Tuple ? make_mooncake_fdata .(deepcopy (Δargs)) : make_mooncake_fdata (deepcopy (Δargs))
93- copy_out, copy_pb!! = rrule (Mooncake. CoDual (f, Mooncake. NoFData ()), Mooncake. CoDual (A_copy, dA_copy), Mooncake. CoDual (alg, Mooncake. NoFData ()))
94- if args isa Tuple
95- for (copy_out_, dargs_copy_) in zip (Mooncake. tangent (copy_out), dargs_copy)
96- if copy_out_ isa Matrix
97- copyto! (copy_out_, dargs_copy_)
98- elseif copy_out_ isa Mooncake. FData
99- for (c_f, a_f) in zip (Mooncake. _fields (copy_out_), Mooncake. _fields (dargs_copy_))
100- copyto! (c_f, a_f)
101- end
102- end
103- end
104- else
105- copyto! (Mooncake. tangent (copy_out), dargs_copy)
106- end
104+ A_dA = Mooncake. CoDual (A_copy, dA_copy)
105+ copy_out, copy_pb!! = rrule (Mooncake. CoDual (f, Mooncake. NoFData ()), A_dA, Mooncake. CoDual (alg, Mooncake. NoFData ()))
106+ # copy Δargs into tangent of the output variable for the pullback check
107+ copy_tangent (copy_out, Δargs)
107108 copy_pb!! (rdata)
109+ @test Mooncake. primal (A_dA) == A
108110 return dA_copy, Mooncake. tangent (copy_out)
109111end
110112
111113function _get_inplace_derivative (f!, A, ΔA, args, Δargs, :: Nothing , rdata; ȳ = Δargs)
112114 dA_inplace = make_mooncake_fdata (copy (ΔA))
113115 A_inplace = copy (A)
114- dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata .(deepcopy (Δargs)) : make_mooncake_fdata (deepcopy (Δargs))
116+ args_copy = deepcopy (args)
117+ dargs_inplace = make_mooncake_fdata (deepcopy (Δargs))
115118 # not every f! has a handwritten rrule!!
116119 inplace_sig = Tuple{typeof (f!), typeof (A), typeof (args)}
117120 has_handwritten_rule = hasmethod (Mooncake. rrule!!, inplace_sig)
121+ A_dA = Mooncake. CoDual (A_inplace, dA_inplace)
122+ args_dargs = Mooncake. CoDual (args_copy, dargs_inplace)
118123 if has_handwritten_rule
119- inplace_out, inplace_pb!! = Mooncake. rrule!! (Mooncake. CoDual (f!, Mooncake. NoFData ()), Mooncake . CoDual (A_inplace, dA_inplace), Mooncake . CoDual (args, dargs_inplace) )
124+ inplace_out, inplace_pb!! = Mooncake. rrule!! (Mooncake. CoDual (f!, Mooncake. NoFData ()), A_dA, args_dargs )
120125 else
121126 inplace_sig = Tuple{typeof (f!), typeof (A), typeof (args)}
122127 rvs_interp = Mooncake. get_interpreter (Mooncake. ReverseMode)
123128 inplace_rrule = Mooncake. build_rrule (rvs_interp, inplace_sig)
124- inplace_out, inplace_pb!! = inplace_rrule (Mooncake. CoDual (f!, Mooncake. NoFData ()), Mooncake. CoDual (A_inplace, dA_inplace), Mooncake. CoDual (args, dargs_inplace))
125- end
126- if args isa Tuple
127- for (inplace_out_, ȳ_) in zip (Mooncake. tangent (inplace_out), ȳ)
128- if inplace_out_ isa Matrix
129- copyto! (inplace_out_, ȳ_)
130- elseif inplace_out_ isa Mooncake. FData
131- for (i_f, a_f) in zip (Mooncake. _fields (inplace_out_), Mooncake. _fields (make_mooncake_fdata (ȳ_)))
132- copyto! (i_f, a_f)
133- end
134- end
135- end
136- else
137- copyto! (Mooncake. tangent (inplace_out), ȳ)
129+ inplace_out, inplace_pb!! = inplace_rrule (Mooncake. CoDual (f!, Mooncake. NoFData ()), A_dA, args_dargs)
138130 end
131+ # copy reference derivative of output ȳ into inplace_out
132+ # needed for inplace methods like svd_trunc! that generate
133+ # new output variables
134+ copy_tangent (inplace_out, ȳ)
139135 inplace_pb!! (rdata)
136+ @test Mooncake. primal (A_dA) == A
137+ @test Mooncake. primal (args_dargs) == args_copy
140138 return dA_inplace, Mooncake. tangent (inplace_out)
141139end
142140
143141function _get_inplace_derivative (f!, A, ΔA, args, Δargs, alg, rdata; ȳ = Δargs)
144142 dA_inplace = make_mooncake_fdata (copy (ΔA))
145143 A_inplace = copy (A)
146- dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata .(Δargs) : make_mooncake_fdata (Δargs)
144+ args_copy = deepcopy (args)
145+ dargs_inplace = make_mooncake_fdata (deepcopy (Δargs))
147146 # not every f! has a handwritten rrule!!
148147 inplace_sig = Tuple{typeof (f!), typeof (A), typeof (args), typeof (alg)}
149148 has_handwritten_rule = hasmethod (Mooncake. rrule!!, inplace_sig)
149+ A_dA = Mooncake. CoDual (A_inplace, dA_inplace)
150+ args_dargs = Mooncake. CoDual (args_copy, dargs_inplace)
150151 if has_handwritten_rule
151- inplace_out, inplace_pb!! = Mooncake. rrule!! (Mooncake. CoDual (f!, Mooncake. NoFData ()), Mooncake . CoDual (A_inplace, dA_inplace), Mooncake . CoDual (args, dargs_inplace) , Mooncake. CoDual (alg, Mooncake. NoFData ()))
152+ inplace_out, inplace_pb!! = Mooncake. rrule!! (Mooncake. CoDual (f!, Mooncake. NoFData ()), A_dA, args_dargs , Mooncake. CoDual (alg, Mooncake. NoFData ()))
152153 else
153154 inplace_sig = Tuple{typeof (f!), typeof (A), typeof (args), typeof (alg)}
154155 rvs_interp = Mooncake. get_interpreter (Mooncake. ReverseMode)
155156 inplace_rrule = Mooncake. build_rrule (rvs_interp, inplace_sig)
156- inplace_out, inplace_pb!! = inplace_rrule (Mooncake. CoDual (f!, Mooncake. NoFData ()), Mooncake. CoDual (A_inplace, dA_inplace), Mooncake. CoDual (args, dargs_inplace), Mooncake. CoDual (alg, Mooncake. NoFData ()))
157- end
158- if args isa Tuple
159- for (inplace_out_, ȳ_) in zip (Mooncake. tangent (inplace_out), ȳ)
160- if inplace_out_ isa Matrix
161- copyto! (inplace_out_, ȳ_)
162- elseif inplace_out_ isa Mooncake. FData
163- for (i_f, a_f) in zip (Mooncake. _fields (inplace_out_), Mooncake. _fields (make_mooncake_fdata (ȳ_)))
164- copyto! (i_f, a_f)
165- end
166- end
167- end
168- else
169- copyto! (Mooncake. tangent (inplace_out), ȳ)
157+ inplace_out, inplace_pb!! = inplace_rrule (Mooncake. CoDual (f!, Mooncake. NoFData ()), A_dA, args_dargs, Mooncake. CoDual (alg, Mooncake. NoFData ()))
170158 end
159+ # copy reference derivative of output ȳ into inplace_out
160+ # needed for inplace methods like svd_trunc! that generate
161+ # new output variables
162+ copy_tangent (inplace_out, ȳ)
171163 inplace_pb!! (rdata)
164+ @test Mooncake. primal (A_dA) == A
165+ @test Mooncake. primal (args_dargs) == args_copy
172166 return dA_inplace, Mooncake. tangent (inplace_out)
173167end
174168
0 commit comments