Skip to content

Commit c332ebc

Browse files
author
Katharine Hyatt
committed
Fix arg copying order
1 parent ac5577a commit c332ebc

2 files changed

Lines changed: 51 additions & 50 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 51 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ for (f!, f, pb, adj) in (
5454
$f!(A, args, Mooncake.primal(alg_dalg))
5555
function $adj(::NoRData)
5656
copy!(A, Ac)
57+
$pb(dA, A, (arg1, arg2), (darg1, darg2))
5758
copy!(arg1, arg1c)
5859
copy!(arg2, arg2c)
59-
$pb(dA, A, (arg1, arg2), (darg1, darg2))
6060
zero!(darg1)
6161
zero!(darg2)
6262
return NoRData(), NoRData(), NoRData(), NoRData()
@@ -101,9 +101,9 @@ for (f!, f, pb, adj) in (
101101
$f!(A, arg, Mooncake.primal(alg_dalg))
102102
function $adj(::NoRData)
103103
copy!(A, Ac)
104-
copy!(arg, argc)
105104
$pb(dA, A, arg, darg)
106-
MatrixAlgebraKit.zero!(darg)
105+
copy!(arg, argc)
106+
zero!(darg)
107107
return NoRData(), NoRData(), NoRData(), NoRData()
108108
end
109109
return arg_darg, $adj
@@ -116,7 +116,7 @@ for (f!, f, pb, adj) in (
116116
function $adj(::NoRData)
117117
arg, darg = arrayify(output_codual)
118118
$pb(dA, A, arg, darg)
119-
MatrixAlgebraKit.zero!(darg)
119+
zero!(darg)
120120
return NoRData(), NoRData(), NoRData()
121121
end
122122
return output_codual, $adj
@@ -134,14 +134,15 @@ for (f!, f, f_full, pb, adj) in (
134134
# compute primal
135135
A, dA = arrayify(A_dA)
136136
D, dD = arrayify(D_dD)
137+
Dc = copy(D)
137138
# update primal
138139
DV = $f_full(A, Mooncake.primal(alg_dalg))
139140
copy!(D, diagview(DV[1]))
140141
V = DV[2]
141142
function $adj(::NoRData)
142-
copy!(D, diagview(DV[1]))
143143
$pb(dA, A, DV, dD)
144-
MatrixAlgebraKit.zero!(dD)
144+
copy!(D, Dc)
145+
zero!(dD)
145146
return NoRData(), NoRData(), NoRData(), NoRData()
146147
end
147148
return D_dD, $adj
@@ -158,7 +159,7 @@ for (f!, f, f_full, pb, adj) in (
158159
function $adj(::NoRData)
159160
D, dD = arrayify(output_codual)
160161
$pb(dA, A, DV, dD)
161-
MatrixAlgebraKit.zero!(dD)
162+
zero!(dD)
162163
return NoRData(), NoRData(), NoRData()
163164
end
164165
return output_codual, $adj
@@ -189,16 +190,16 @@ for (f!, f, f_ne!, f_ne, pb, adj) in (
189190
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
190191
function $adj(dy::Tuple{NoRData, NoRData, T}) where {T <: Real}
191192
copy!(A, Ac)
192-
copy!(DV[1], DVc[1])
193-
copy!(DV[2], DVc[2])
194193
Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual)
195194
dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual)
196195
abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error"
197196
D′, dD′ = arrayify(Dtrunc, dDtrunc_)
198197
V′, dV′ = arrayify(Vtrunc, dVtrunc_)
199198
$pb(dA, A, (D′, V′), (dD′, dV′))
200-
MatrixAlgebraKit.zero!(dD)
201-
MatrixAlgebraKit.zero!(dV)
199+
copy!(DV[1], DVc[1])
200+
copy!(DV[2], DVc[2])
201+
zero!(dD)
202+
zero!(dV)
202203
return NoRData(), NoRData(), NoRData()
203204
end
204205
return output_codual, $adj
@@ -220,8 +221,8 @@ for (f!, f, f_ne!, f_ne, pb, adj) in (
220221
D, dD = arrayify(Dtrunc, dDtrunc_)
221222
V, dV = arrayify(Vtrunc, dVtrunc_)
222223
$pb(dA, A, (D, V), (dD, dV))
223-
MatrixAlgebraKit.zero!(dD)
224-
MatrixAlgebraKit.zero!(dV)
224+
zero!(dD)
225+
zero!(dV)
225226
return NoRData(), NoRData(), NoRData()
226227
end
227228
return output_codual, $adj
@@ -244,15 +245,15 @@ for (f!, f, f_ne!, f_ne, pb, adj) in (
244245
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
245246
function $adj(::NoRData)
246247
copy!(A, Ac)
247-
copy!(DV[1], DVc[1])
248-
copy!(DV[2], DVc[2])
249248
Dtrunc, Vtrunc = Mooncake.primal(output_codual)
250249
dDtrunc_, dVtrunc_ = Mooncake.tangent(output_codual)
251250
D′, dD′ = arrayify(Dtrunc, dDtrunc_)
252251
V′, dV′ = arrayify(Vtrunc, dVtrunc_)
253252
$pb(dA, A, (D′, V′), (dD′, dV′))
254-
MatrixAlgebraKit.zero!(dD)
255-
MatrixAlgebraKit.zero!(dV)
253+
copy!(DV[1], DVc[1])
254+
copy!(DV[2], DVc[2])
255+
zero!(dD)
256+
zero!(dV)
256257
return NoRData(), NoRData(), NoRData()
257258
end
258259
return output_codual, $adj
@@ -273,8 +274,8 @@ for (f!, f, f_ne!, f_ne, pb, adj) in (
273274
D, dD = arrayify(Dtrunc, dDtrunc_)
274275
V, dV = arrayify(Vtrunc, dVtrunc_)
275276
$pb(dA, A, (D, V), (dD, dV))
276-
MatrixAlgebraKit.zero!(dD)
277-
MatrixAlgebraKit.zero!(dV)
277+
zero!(dD)
278+
zero!(dV)
278279
return NoRData(), NoRData(), NoRData()
279280
end
280281
return output_codual, $adj
@@ -300,9 +301,6 @@ for (f!, f) in (
300301
output = $f!(A, Mooncake.primal(alg_dalg))
301302
function svd_adjoint(::NoRData)
302303
copy!(A, Ac)
303-
copy!(U, USVᴴc[1])
304-
copy!(S, USVᴴc[2])
305-
copy!(Vᴴ, USVᴴc[3])
306304
if $(f! == svd_compact!)
307305
svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
308306
else # full
@@ -315,9 +313,12 @@ for (f!, f) in (
315313
vdVᴴ = view(dVᴴ, 1:minmn, :)
316314
svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
317315
end
318-
MatrixAlgebraKit.zero!(dU)
319-
MatrixAlgebraKit.zero!(dS)
320-
MatrixAlgebraKit.zero!(dVᴴ)
316+
copy!(U, USVᴴc[1])
317+
copy!(S, USVᴴc[2])
318+
copy!(Vᴴ, USVᴴc[3])
319+
zero!(dU)
320+
zero!(dS)
321+
zero!(dVᴴ)
321322
return NoRData(), NoRData(), NoRData(), NoRData()
322323
end
323324
return CoDual(output, dUSVᴴ), svd_adjoint
@@ -349,9 +350,9 @@ for (f!, f) in (
349350
vdVᴴ = view(dVᴴ, 1:minmn, :)
350351
svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
351352
end
352-
MatrixAlgebraKit.zero!(dU)
353-
MatrixAlgebraKit.zero!(dS)
354-
MatrixAlgebraKit.zero!(dVᴴ)
353+
zero!(dU)
354+
zero!(dS)
355+
zero!(dVᴴ)
355356
return NoRData(), NoRData(), NoRData()
356357
end
357358
return USVᴴ_codual, svd_adjoint
@@ -364,12 +365,13 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
364365
# compute primal
365366
A, dA = arrayify(A_dA)
366367
S, dS = arrayify(S_dS)
368+
Sc = copy(S)
367369
USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
368370
copy!(S, diagview(USVᴴ[2]))
369371
function svd_vals_adjoint(::NoRData)
370372
svd_vals_pullback!(dA, A, USVᴴ, dS)
371-
MatrixAlgebraKit.zero!(dS)
372-
copy!(S, diagview(USVᴴ[2]))
373+
zero!(dS)
374+
copy!(S, Sc)
373375
return NoRData(), NoRData(), NoRData(), NoRData()
374376
end
375377
return S_dS, svd_vals_adjoint
@@ -389,7 +391,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
389391
function svd_vals_adjoint(::NoRData)
390392
S, dS = arrayify(S_codual)
391393
svd_vals_pullback!(dA, A, USVᴴ, dS)
392-
MatrixAlgebraKit.zero!(dS)
394+
zero!(dS)
393395
return NoRData(), NoRData(), NoRData()
394396
end
395397
return S_codual, svd_vals_adjoint
@@ -415,19 +417,19 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
415417
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
416418
function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real}
417419
copy!(A, Ac)
418-
copy!(U, USVᴴc[1])
419-
copy!(S, USVᴴc[2])
420-
copy!(Vᴴ, USVᴴc[3])
421420
Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual)
422421
dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual)
423422
abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc does not yet support non-zero tangent for the truncation error"
424423
U′, dU′ = arrayify(Utrunc, dUtrunc_)
425424
S′, dS′ = arrayify(Strunc, dStrunc_)
426425
Vᴴ′, dVᴴ′ = arrayify(Vᴴtrunc, dVᴴtrunc_)
427426
svd_trunc_pullback!(dA, A, (U′, S′, Vᴴ′), (dU′, dS′, dVᴴ′))
428-
MatrixAlgebraKit.zero!(dU)
429-
MatrixAlgebraKit.zero!(dS)
430-
MatrixAlgebraKit.zero!(dVᴴ)
427+
copy!(U, USVᴴc[1])
428+
copy!(S, USVᴴc[2])
429+
copy!(Vᴴ, USVᴴc[3])
430+
zero!(dU)
431+
zero!(dS)
432+
zero!(dVᴴ)
431433
return NoRData(), NoRData(), NoRData()
432434
end
433435
return output_codual, svd_trunc_adjoint
@@ -454,9 +456,9 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
454456
S, dS = arrayify(Strunc, dStrunc_)
455457
Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_)
456458
svd_trunc_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
457-
MatrixAlgebraKit.zero!(dU)
458-
MatrixAlgebraKit.zero!(dS)
459-
MatrixAlgebraKit.zero!(dVᴴ)
459+
zero!(dU)
460+
zero!(dS)
461+
zero!(dVᴴ)
460462
return NoRData(), NoRData(), NoRData()
461463
end
462464
return output_codual, svd_trunc_adjoint
@@ -482,18 +484,18 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, US
482484
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
483485
function svd_trunc_adjoint(::NoRData)
484486
copy!(A, Ac)
485-
copy!(U, USVᴴc[1])
486-
copy!(S, USVᴴc[2])
487-
copy!(Vᴴ, USVᴴc[3])
488487
Utrunc, Strunc, Vᴴtrunc = Mooncake.primal(output_codual)
489488
dUtrunc_, dStrunc_, dVᴴtrunc_ = Mooncake.tangent(output_codual)
490489
U′, dU′ = arrayify(Utrunc, dUtrunc_)
491490
S′, dS′ = arrayify(Strunc, dStrunc_)
492491
Vᴴ′, dVᴴ′ = arrayify(Vᴴtrunc, dVᴴtrunc_)
493492
svd_trunc_pullback!(dA, A, (U′, S′, Vᴴ′), (dU′, dS′, dVᴴ′))
494-
MatrixAlgebraKit.zero!(dU)
495-
MatrixAlgebraKit.zero!(dS)
496-
MatrixAlgebraKit.zero!(dVᴴ)
493+
copy!(U, USVᴴc[1])
494+
copy!(S, USVᴴc[2])
495+
copy!(Vᴴ, USVᴴc[3])
496+
zero!(dU)
497+
zero!(dS)
498+
zero!(dVᴴ)
497499
return NoRData(), NoRData(), NoRData()
498500
end
499501
return output_codual, svd_trunc_adjoint
@@ -519,9 +521,9 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al
519521
S, dS = arrayify(Strunc, dStrunc_)
520522
Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_)
521523
svd_trunc_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
522-
MatrixAlgebraKit.zero!(dU)
523-
MatrixAlgebraKit.zero!(dS)
524-
MatrixAlgebraKit.zero!(dVᴴ)
524+
zero!(dU)
525+
zero!(dS)
526+
zero!(dVᴴ)
525527
return NoRData(), NoRData(), NoRData()
526528
end
527529
return output_codual, svd_trunc_adjoint

test/mooncake.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,3 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
2727
#n == m && TestSuite.test_mooncake(Diagonal{T, Vector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T))
2828
end
2929
end
30-

0 commit comments

Comments
 (0)