Skip to content

Commit facdda6

Browse files
author
Katharine Hyatt
committed
Testfile cleanup
1 parent 7e584f1 commit facdda6

1 file changed

Lines changed: 56 additions & 62 deletions

File tree

test/testsuite/mooncake.jl

Lines changed: 56 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -62,113 +62,107 @@ make_mooncake_fdata(x) = make_mooncake_tangent(x)
6262
make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(x.diag),))
6363
make_mooncake_fdata(x::Tuple) = map(make_mooncake_fdata, x)
6464

65+
# copies a preset tangent into a Mooncake CoDual
66+
# for use in the pullback.
67+
function copy_tangent(var::Mooncake.CoDual, Δargs)
68+
dargs = make_mooncake_fdata(deepcopy(Δargs))
69+
copyto!(Mooncake.tangent(var), dargs)
70+
return
71+
end
72+
73+
function copy_tangent(var::Mooncake.CoDual, Δargs::Tuple)
74+
dargs = make_mooncake_fdata.(deepcopy(Δargs))
75+
for (var_tangent, darg) in zip(Mooncake.tangent(var), dargs)
76+
if var_tangent isa Mooncake.FData
77+
for (var_f, darg_f) in zip(Mooncake._fields(var_tangent), Mooncake._fields(darg))
78+
copyto!(var_f, darg_f)
79+
end
80+
else
81+
copyto!(var_tangent, darg)
82+
end
83+
end
84+
return
85+
end
86+
6587
# no `alg` argument
6688
function _get_copying_derivative(f, rrule, A, ΔA, args, Δargs, ::Nothing, rdata)
6789
dA_copy = make_mooncake_fdata(copy(ΔA))
6890
A_copy = copy(A)
69-
dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs))
70-
copy_out, copy_pb!! = rrule(Mooncake.CoDual(f, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy))
71-
if args isa Tuple
72-
for (copy_out_, dargs_copy_) in zip(Mooncake.tangent(copy_out), dargs_copy)
73-
if copy_out_ isa Matrix
74-
copyto!(copy_out_, dargs_copy_)
75-
elseif copy_out_ isa Mooncake.FData
76-
for (c_f, a_f) in zip(Mooncake._fields(copy_out_), Mooncake._fields(dargs_copy_))
77-
copyto!(c_f, a_f)
78-
end
79-
end
80-
end
81-
else
82-
copyto!(Mooncake.tangent(copy_out), dargs_copy)
83-
end
91+
A_dA = Mooncake.CoDual(A_copy, dA_copy)
92+
copy_out, copy_pb!! = rrule(Mooncake.CoDual(f, Mooncake.NoFData()), A_dA)
93+
# copy Δargs into tangent of the output variable for the pullback check
94+
copy_tangent(copy_out, Δargs)
8495
copy_pb!!(rdata)
96+
@test Mooncake.primal(A_dA) == A
8597
return dA_copy, Mooncake.tangent(copy_out)
8698
end
8799

88100
# `alg` argument
89101
function _get_copying_derivative(f, rrule, A, ΔA, args, Δargs, alg, rdata)
90102
dA_copy = make_mooncake_fdata(copy(ΔA))
91103
A_copy = copy(A)
92-
dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs))
93-
copy_out, copy_pb!! = rrule(Mooncake.CoDual(f, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(alg, Mooncake.NoFData()))
94-
if args isa Tuple
95-
for (copy_out_, dargs_copy_) in zip(Mooncake.tangent(copy_out), dargs_copy)
96-
if copy_out_ isa Matrix
97-
copyto!(copy_out_, dargs_copy_)
98-
elseif copy_out_ isa Mooncake.FData
99-
for (c_f, a_f) in zip(Mooncake._fields(copy_out_), Mooncake._fields(dargs_copy_))
100-
copyto!(c_f, a_f)
101-
end
102-
end
103-
end
104-
else
105-
copyto!(Mooncake.tangent(copy_out), dargs_copy)
106-
end
104+
A_dA = Mooncake.CoDual(A_copy, dA_copy)
105+
copy_out, copy_pb!! = rrule(Mooncake.CoDual(f, Mooncake.NoFData()), A_dA, Mooncake.CoDual(alg, Mooncake.NoFData()))
106+
# copy Δargs into tangent of the output variable for the pullback check
107+
copy_tangent(copy_out, Δargs)
107108
copy_pb!!(rdata)
109+
@test Mooncake.primal(A_dA) == A
108110
return dA_copy, Mooncake.tangent(copy_out)
109111
end
110112

111113
function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata; ȳ = Δargs)
112114
dA_inplace = make_mooncake_fdata(copy(ΔA))
113115
A_inplace = copy(A)
114-
dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs))
116+
args_copy = deepcopy(args)
117+
dargs_inplace = make_mooncake_fdata(deepcopy(Δargs))
115118
# not every f! has a handwritten rrule!!
116119
inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)}
117120
has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig)
121+
A_dA = Mooncake.CoDual(A_inplace, dA_inplace)
122+
args_dargs = Mooncake.CoDual(args_copy, dargs_inplace)
118123
if has_handwritten_rule
119-
inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace))
124+
inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), A_dA, args_dargs)
120125
else
121126
inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)}
122127
rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode)
123128
inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig)
124-
inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace))
125-
end
126-
if args isa Tuple
127-
for (inplace_out_, ȳ_) in zip(Mooncake.tangent(inplace_out), ȳ)
128-
if inplace_out_ isa Matrix
129-
copyto!(inplace_out_, ȳ_)
130-
elseif inplace_out_ isa Mooncake.FData
131-
for (i_f, a_f) in zip(Mooncake._fields(inplace_out_), Mooncake._fields(make_mooncake_fdata(ȳ_)))
132-
copyto!(i_f, a_f)
133-
end
134-
end
135-
end
136-
else
137-
copyto!(Mooncake.tangent(inplace_out), ȳ)
129+
inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), A_dA, args_dargs)
138130
end
131+
# copy reference derivative of output ȳ into inplace_out
132+
# needed for inplace methods like svd_trunc! that generate
133+
# new output variables
134+
copy_tangent(inplace_out, ȳ)
139135
inplace_pb!!(rdata)
136+
@test Mooncake.primal(A_dA) == A
137+
@test Mooncake.primal(args_dargs) == args_copy
140138
return dA_inplace, Mooncake.tangent(inplace_out)
141139
end
142140

143141
function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata; ȳ = Δargs)
144142
dA_inplace = make_mooncake_fdata(copy(ΔA))
145143
A_inplace = copy(A)
146-
dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(Δargs) : make_mooncake_fdata(Δargs)
144+
args_copy = deepcopy(args)
145+
dargs_inplace = make_mooncake_fdata(deepcopy(Δargs))
147146
# not every f! has a handwritten rrule!!
148147
inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)}
149148
has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig)
149+
A_dA = Mooncake.CoDual(A_inplace, dA_inplace)
150+
args_dargs = Mooncake.CoDual(args_copy, dargs_inplace)
150151
if has_handwritten_rule
151-
inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData()))
152+
inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), A_dA, args_dargs, Mooncake.CoDual(alg, Mooncake.NoFData()))
152153
else
153154
inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)}
154155
rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode)
155156
inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig)
156-
inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData()))
157-
end
158-
if args isa Tuple
159-
for (inplace_out_, ȳ_) in zip(Mooncake.tangent(inplace_out), ȳ)
160-
if inplace_out_ isa Matrix
161-
copyto!(inplace_out_, ȳ_)
162-
elseif inplace_out_ isa Mooncake.FData
163-
for (i_f, a_f) in zip(Mooncake._fields(inplace_out_), Mooncake._fields(make_mooncake_fdata(ȳ_)))
164-
copyto!(i_f, a_f)
165-
end
166-
end
167-
end
168-
else
169-
copyto!(Mooncake.tangent(inplace_out), ȳ)
157+
inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), A_dA, args_dargs, Mooncake.CoDual(alg, Mooncake.NoFData()))
170158
end
159+
# copy reference derivative of output ȳ into inplace_out
160+
# needed for inplace methods like svd_trunc! that generate
161+
# new output variables
162+
copy_tangent(inplace_out, ȳ)
171163
inplace_pb!!(rdata)
164+
@test Mooncake.primal(A_dA) == A
165+
@test Mooncake.primal(args_dargs) == args_copy
172166
return dA_inplace, Mooncake.tangent(inplace_out)
173167
end
174168

0 commit comments

Comments
 (0)