Skip to content

Commit ac5577a

Browse files
author
Katharine Hyatt
committed
Add tests/fixes for inplace rules
1 parent b900e8a commit ac5577a

3 files changed

Lines changed: 189 additions & 25 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 152 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module MatrixAlgebraKitMooncakeExt
33
using Mooncake
44
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
55
using MatrixAlgebraKit
6-
using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output
6+
using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output, zero
77
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
88
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
99
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
@@ -54,11 +54,11 @@ for (f!, f, pb, adj) in (
5454
$f!(A, args, Mooncake.primal(alg_dalg))
5555
function $adj(::NoRData)
5656
copy!(A, Ac)
57-
$pb(dA, A, (arg1, arg2), (darg1, darg2))
5857
copy!(arg1, arg1c)
5958
copy!(arg2, arg2c)
60-
MatrixAlgebraKit.zero!(darg1)
61-
MatrixAlgebraKit.zero!(darg2)
59+
$pb(dA, A, (arg1, arg2), (darg1, darg2))
60+
zero!(darg1)
61+
zero!(darg2)
6262
return NoRData(), NoRData(), NoRData(), NoRData()
6363
end
6464
return args_dargs, $adj
@@ -78,8 +78,8 @@ for (f!, f, pb, adj) in (
7878
arg1, darg1 = arrayify(arg1, darg1_)
7979
arg2, darg2 = arrayify(arg2, darg2_)
8080
$pb(dA, A, (arg1, arg2), (darg1, darg2))
81-
MatrixAlgebraKit.zero!(darg1)
82-
MatrixAlgebraKit.zero!(darg2)
81+
zero!(darg1)
82+
zero!(darg2)
8383
return NoRData(), NoRData(), NoRData()
8484
end
8585
return output_codual, $adj
@@ -101,8 +101,8 @@ for (f!, f, pb, adj) in (
101101
$f!(A, arg, Mooncake.primal(alg_dalg))
102102
function $adj(::NoRData)
103103
copy!(A, Ac)
104-
$pb(dA, A, arg, darg)
105104
copy!(arg, argc)
105+
$pb(dA, A, arg, darg)
106106
MatrixAlgebraKit.zero!(darg)
107107
return NoRData(), NoRData(), NoRData(), NoRData()
108108
end
@@ -139,6 +139,7 @@ for (f!, f, f_full, pb, adj) in (
139139
copy!(D, diagview(DV[1]))
140140
V = DV[2]
141141
function $adj(::NoRData)
142+
copy!(D, diagview(DV[1]))
142143
$pb(dA, A, DV, dD)
143144
MatrixAlgebraKit.zero!(dD)
144145
return NoRData(), NoRData(), NoRData(), NoRData()
@@ -165,12 +166,43 @@ for (f!, f, f_full, pb, adj) in (
165166
end
166167
end
167168

168-
for (f, f_ne, pb, adj) in (
169-
(:eig_trunc, :eig_trunc_no_error, :eig_trunc_pullback!, :eig_trunc_adjoint),
170-
(:eigh_trunc, :eigh_trunc_no_error, :eigh_trunc_pullback!, :eigh_trunc_adjoint),
169+
for (f!, f, f_ne!, f_ne, pb, adj) in (
170+
(:eig_trunc!, :eig_trunc, :eig_trunc_no_error!, :eig_trunc_no_error, :eig_trunc_pullback!, :eig_trunc_adjoint),
171+
(:eigh_trunc!, :eigh_trunc, :eigh_trunc_no_error!, :eigh_trunc_no_error, :eigh_trunc_pullback!, :eigh_trunc_adjoint),
171172
)
172173
@eval begin
174+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
173175
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
176+
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual)
177+
# compute primal
178+
A, dA = arrayify(A_dA)
179+
DV = Mooncake.primal(DV_dDV)
180+
dDV = Mooncake.tangent(DV_dDV)
181+
Ac = copy(A)
182+
DVc = copy.(DV)
183+
alg = Mooncake.primal(alg_dalg)
184+
output = $f!(A, DV, alg)
185+
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
186+
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
187+
# pass). For many types this is done automatically when the forward step returns, but
188+
# not for nested structs with various fields (like Diagonal{Complex})
189+
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
190+
function $adj(dy::Tuple{NoRData, NoRData, T}) where {T <: Real}
191+
copy!(A, Ac)
192+
copy!(DV[1], DVc[1])
193+
copy!(DV[2], DVc[2])
194+
Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual)
195+
dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual)
196+
abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error"
197+
D′, dD′ = arrayify(Dtrunc, dDtrunc_)
198+
V′, dV′ = arrayify(Vtrunc, dVtrunc_)
199+
$pb(dA, A, (D′, V′), (dD′, dV′))
200+
MatrixAlgebraKit.zero!(dD)
201+
MatrixAlgebraKit.zero!(dV)
202+
return NoRData(), NoRData(), NoRData()
203+
end
204+
return output_codual, $adj
205+
end
174206
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual)
175207
# compute primal
176208
A, dA = arrayify(A_dA)
@@ -194,7 +226,37 @@ for (f, f_ne, pb, adj) in (
194226
end
195227
return output_codual, $adj
196228
end
229+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_ne!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
197230
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_ne), Any, MatrixAlgebraKit.AbstractAlgorithm}
231+
function Mooncake.rrule!!(::CoDual{typeof($f_ne!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual)
232+
# compute primal
233+
A, dA = arrayify(A_dA)
234+
alg = Mooncake.primal(alg_dalg)
235+
DV = Mooncake.primal(DV_dDV)
236+
dDV = Mooncake.tangent(DV_dDV)
237+
Ac = copy(A)
238+
DVc = copy.(DV)
239+
output = $f_ne(A, DV, alg)
240+
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
241+
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
242+
# pass). For many types this is done automatically when the forward step returns, but
243+
# not for nested structs with various fields (like Diagonal{Complex})
244+
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
245+
function $adj(::NoRData)
246+
copy!(A, Ac)
247+
copy!(DV[1], DVc[1])
248+
copy!(DV[2], DVc[2])
249+
Dtrunc, Vtrunc = Mooncake.primal(output_codual)
250+
dDtrunc_, dVtrunc_ = Mooncake.tangent(output_codual)
251+
D′, dD′ = arrayify(Dtrunc, dDtrunc_)
252+
V′, dV′ = arrayify(Vtrunc, dVtrunc_)
253+
$pb(dA, A, (D′, V′), (dD′, dV′))
254+
MatrixAlgebraKit.zero!(dD)
255+
MatrixAlgebraKit.zero!(dV)
256+
return NoRData(), NoRData(), NoRData()
257+
end
258+
return output_codual, $adj
259+
end
198260
function Mooncake.rrule!!(::CoDual{typeof($f_ne)}, A_dA::CoDual, alg_dalg::CoDual)
199261
# compute primal
200262
A, dA = arrayify(A_dA)
@@ -234,9 +296,13 @@ for (f!, f) in (
234296
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
235297
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
236298
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
299+
USVᴴc = copy.(USVᴴ)
237300
output = $f!(A, Mooncake.primal(alg_dalg))
238301
function svd_adjoint(::NoRData)
239302
copy!(A, Ac)
303+
copy!(U, USVᴴc[1])
304+
copy!(S, USVᴴc[2])
305+
copy!(Vᴴ, USVᴴc[3])
240306
if $(f! == svd_compact!)
241307
svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
242308
else # full
@@ -303,6 +369,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
303369
function svd_vals_adjoint(::NoRData)
304370
svd_vals_pullback!(dA, A, USVᴴ, dS)
305371
MatrixAlgebraKit.zero!(dS)
372+
copy!(S, diagview(USVᴴ[2]))
306373
return NoRData(), NoRData(), NoRData(), NoRData()
307374
end
308375
return S_dS, svd_vals_adjoint
@@ -328,6 +395,44 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
328395
return S_codual, svd_vals_adjoint
329396
end
330397

398+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
399+
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
400+
# compute primal
401+
A, dA = arrayify(A_dA)
402+
alg = Mooncake.primal(alg_dalg)
403+
Ac = copy(A)
404+
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
405+
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
406+
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
407+
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
408+
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
409+
USVᴴc = copy.(USVᴴ)
410+
output = svd_trunc!(A, USVᴴ, alg)
411+
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
412+
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
413+
# pass). For many types this is done automatically when the forward step returns, but
414+
# not for nested structs with various fields (like Diagonal{Complex})
415+
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
416+
function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real}
417+
copy!(A, Ac)
418+
copy!(U, USVᴴc[1])
419+
copy!(S, USVᴴc[2])
420+
copy!(Vᴴ, USVᴴc[3])
421+
Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual)
422+
dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual)
423+
abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc does not yet support non-zero tangent for the truncation error"
424+
U′, dU′ = arrayify(Utrunc, dUtrunc_)
425+
S′, dS′ = arrayify(Strunc, dStrunc_)
426+
Vᴴ′, dVᴴ′ = arrayify(Vᴴtrunc, dVᴴtrunc_)
427+
svd_trunc_pullback!(dA, A, (U′, S′, Vᴴ′), (dU′, dS′, dVᴴ′))
428+
MatrixAlgebraKit.zero!(dU)
429+
MatrixAlgebraKit.zero!(dS)
430+
MatrixAlgebraKit.zero!(dVᴴ)
431+
return NoRData(), NoRData(), NoRData()
432+
end
433+
return output_codual, svd_trunc_adjoint
434+
end
435+
331436
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm}
332437
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual)
333438
# compute primal
@@ -357,6 +462,43 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
357462
return output_codual, svd_trunc_adjoint
358463
end
359464

465+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
466+
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
467+
# compute primal
468+
A, dA = arrayify(A_dA)
469+
alg = Mooncake.primal(alg_dalg)
470+
Ac = copy(A)
471+
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
472+
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
473+
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
474+
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
475+
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
476+
USVᴴc = copy.(USVᴴ)
477+
output = svd_trunc_no_error!(A, USVᴴ, alg)
478+
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
479+
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
480+
# pass). For many types this is done automatically when the forward step returns, but
481+
# not for nested structs with various fields (like Diagonal{Complex})
482+
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
483+
function svd_trunc_adjoint(::NoRData)
484+
copy!(A, Ac)
485+
copy!(U, USVᴴc[1])
486+
copy!(S, USVᴴc[2])
487+
copy!(Vᴴ, USVᴴc[3])
488+
Utrunc, Strunc, Vᴴtrunc = Mooncake.primal(output_codual)
489+
dUtrunc_, dStrunc_, dVᴴtrunc_ = Mooncake.tangent(output_codual)
490+
U′, dU′ = arrayify(Utrunc, dUtrunc_)
491+
S′, dS′ = arrayify(Strunc, dStrunc_)
492+
Vᴴ′, dVᴴ′ = arrayify(Vᴴtrunc, dVᴴtrunc_)
493+
svd_trunc_pullback!(dA, A, (U′, S′, Vᴴ′), (dU′, dS′, dVᴴ′))
494+
MatrixAlgebraKit.zero!(dU)
495+
MatrixAlgebraKit.zero!(dS)
496+
MatrixAlgebraKit.zero!(dVᴴ)
497+
return NoRData(), NoRData(), NoRData()
498+
end
499+
return output_codual, svd_trunc_adjoint
500+
end
501+
360502
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error), Any, MatrixAlgebraKit.AbstractAlgorithm}
361503
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual)
362504
# compute primal

test/mooncake.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
2727
#n == m && TestSuite.test_mooncake(Diagonal{T, Vector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T))
2828
end
2929
end
30+

test/testsuite/mooncake.jl

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,26 +70,44 @@ make_mooncake_fdata(x::Tuple) = map(make_mooncake_fdata, x)
7070
function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, ::Nothing, rdata)
7171
dA_copy = make_mooncake_fdata(copy(ΔA))
7272
A_copy = copy(A)
73-
dargs_copy = make_mooncake_fdata(deepcopy(Δargs))
74-
copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy))
73+
dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs))
74+
copy_out, copy_pb!! = rrule(Mooncake.CoDual(f, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy))
75+
if args isa Tuple
76+
copyto!.(Mooncake.tangent(copy_out), dargs_copy)
77+
else
78+
copyto!(Mooncake.tangent(copy_out), dargs_copy)
79+
end
80+
@test Mooncake.primal(copy_out) args
7581
copy_pb!!(rdata)
76-
return dA_copy
82+
return dA_copy, Mooncake.tangent(copy_out)
7783
end
7884

7985
# `alg` argument
8086
function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata)
8187
dA_copy = make_mooncake_fdata(copy(ΔA))
8288
A_copy = copy(A)
83-
dargs_copy = make_mooncake_fdata(deepcopy(Δargs))
84-
copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy), Mooncake.CoDual(alg, Mooncake.NoFData()))
89+
dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs))
90+
copy_out, copy_pb!! = rrule(Mooncake.CoDual(f, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(alg, Mooncake.NoFData()))
91+
if args isa Tuple
92+
copyto!.(Mooncake.tangent(copy_out), dargs_copy)
93+
else
94+
copyto!(Mooncake.tangent(copy_out), dargs_copy)
95+
end
96+
if args isa Tuple
97+
for (arg, out) in zip(args, Mooncake.primal(copy_out))
98+
@test out arg
99+
end
100+
else
101+
@test Mooncake.primal(copy_out) args
102+
end
85103
copy_pb!!(rdata)
86-
return dA_copy
104+
return dA_copy, Mooncake.tangent(copy_out)
87105
end
88106

89107
function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata)
90108
dA_inplace = make_mooncake_fdata(copy(ΔA))
91109
A_inplace = copy(A)
92-
dargs_inplace = make_mooncake_fdata(deepcopy(Δargs))
110+
dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs))
93111
# not every f! has a handwritten rrule!!
94112
inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)}
95113
has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig)
@@ -102,13 +120,13 @@ function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata)
102120
inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace))
103121
end
104122
inplace_pb!!(rdata)
105-
return dA_inplace
123+
return dA_inplace, dargs_inplace
106124
end
107125

108126
function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata)
109127
dA_inplace = make_mooncake_fdata(copy(ΔA))
110128
A_inplace = copy(A)
111-
dargs_inplace = make_mooncake_fdata(deepcopy(Δargs))
129+
dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(Δargs) : make_mooncake_fdata(Δargs)
112130
# not every f! has a handwritten rrule!!
113131
inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)}
114132
has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig)
@@ -121,7 +139,7 @@ function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata)
121139
inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData()))
122140
end
123141
inplace_pb!!(rdata)
124-
return dA_inplace
142+
return dA_inplace, dargs_inplace
125143
end
126144

127145
"""
@@ -142,18 +160,21 @@ The arguments to this function are:
142160
- `rdata` Mooncake reverse data to supply to the pullback, in case `f` and `f!` return scalar results (as truncating functions do)
143161
"""
144162
function test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData())
145-
f_c = isnothing(alg) ? (A, args) -> f!(MatrixAlgebraKit.copy_input(f, A), args) : (A, args, alg) -> f!(MatrixAlgebraKit.copy_input(f, A), args, alg)
146-
sig = isnothing(alg) ? Tuple{typeof(f_c), typeof(A), typeof(args)} : Tuple{typeof(f_c), typeof(A), typeof(args), typeof(alg)}
163+
sig = isnothing(alg) ? Tuple{typeof(f), typeof(A)} : Tuple{typeof(f), typeof(A), typeof(alg)}
147164
rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode)
148165
rrule = Mooncake.build_rrule(rvs_interp, sig)
149-
ΔA = isa(A, Diagonal) ? Diagonal(randn!(similar(A.diag))) : randn!(similar(A))
166+
ΔA = randn(rng, eltype(A), size(A))
150167

151-
dA_copy = _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata)
152-
dA_inplace = _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata)
168+
copy_args = isa(args, Tuple) ? copy.(args) : copy(args)
169+
inplace_args = isa(args, Tuple) ? copy.(args) : copy(args)
170+
dA_copy, dargs_copy = _get_copying_derivative(f, rrule, A, ΔA, copy_args, Δargs, alg, rdata)
171+
dA_inplace, dargs_inplace = _get_inplace_derivative(f!, A, ΔA, inplace_args, Δargs, alg, rdata)
153172

154173
dA_inplace_ = Mooncake.arrayify(A, dA_inplace)[2]
155174
dA_copy_ = Mooncake.arrayify(A, dA_copy)[2]
156175
@test dA_inplace_ dA_copy_
176+
@test copy_args == inplace_args
177+
@test dargs_copy == dargs_inplace
157178
return
158179
end
159180

0 commit comments

Comments
 (0)