Skip to content

Commit fd9544f

Browse files
committed
Comments
1 parent e433589 commit fd9544f

3 files changed

Lines changed: 18 additions & 39 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,7 @@ for (f!, f, pb, adj) in (
111111
output = $f(A, Mooncake.primal(alg_dalg))
112112
output_codual = Mooncake.CoDual(output, Mooncake.zero_tangent(output))
113113
function $adj(::Mooncake.NoRData)
114-
arg = Mooncake.primal(output_codual)
115-
darg_ = Mooncake.tangent(output_codual)
116-
arg, darg = Mooncake.arrayify(arg, darg_)
114+
arg, darg = Mooncake.arrayify(output_codual)
117115
$pb(dA, A, arg, darg)
118116
MatrixAlgebraKit.zero!(darg)
119117
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
@@ -131,12 +129,8 @@ for (f!, f, f_full, pb, adj) in (
131129
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
132130
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual)
133131
# compute primal
134-
D_ = Mooncake.primal(D_dD)
135-
dD_ = Mooncake.tangent(D_dD)
136-
A_ = Mooncake.primal(A_dA)
137-
dA_ = Mooncake.tangent(A_dA)
138-
A, dA = arrayify(A_, dA_)
139-
D, dD = arrayify(D_, dD_)
132+
A, dA = arrayify(A_dA)
133+
D, dD = arrayify(D_dD)
140134
# update primal
141135
DV = $f_full(A, Mooncake.primal(alg_dalg))
142136
copy!(D, diagview(DV[1]))
@@ -151,18 +145,14 @@ for (f!, f, f_full, pb, adj) in (
151145
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
152146
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual)
153147
# compute primal
154-
A_ = Mooncake.primal(A_dA)
155-
dA_ = Mooncake.tangent(A_dA)
156-
A, dA = arrayify(A_, dA_)
148+
A, dA = arrayify(A_dA)
157149
# update primal
158150
DV = $f_full(A, Mooncake.primal(alg_dalg))
159151
V = DV[2]
160-
output = copy(diagview(DV[1]))
152+
output = diagview(DV[1])
161153
output_codual = Mooncake.CoDual(output, Mooncake.zero_tangent(output))
162154
function $adj(::Mooncake.NoRData)
163-
D = Mooncake.primal(output_codual)
164-
dD_ = Mooncake.tangent(output_codual)
165-
D, dD = Mooncake.arrayify(D, dD_)
155+
D_dD = Mooncake.arrayify(D_dD)
166156
$pb(dA, A, (D, V), (dD, nothing))
167157
MatrixAlgebraKit.zero!(dD)
168158
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
@@ -177,12 +167,10 @@ for (f, pb, adj) in (
177167
(eigh_trunc, eigh_trunc_pullback!, :deigh_trunc_adjoint),
178168
)
179169
@eval begin
180-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.TruncatedAlgorithm}
170+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
181171
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual)
182172
# compute primal
183-
A_ = Mooncake.primal(A_dA)
184-
dA_ = Mooncake.tangent(A_dA)
185-
A, dA = arrayify(A_, dA_)
173+
A, dA = arrayify(A_dA)
186174
alg = Mooncake.primal(alg_dalg)
187175
output = $f(A, alg)
188176
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
@@ -193,6 +181,7 @@ for (f, pb, adj) in (
193181
function $adj(dy::Tuple{Mooncake.NoRData, Mooncake.NoRData, T}) where {T <: Real}
194182
Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual)
195183
dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual)
184+
abs(dϵ) > MatrixAlgebraKit.defaulttol(dϵ) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error"
196185
D, dD = Mooncake.arrayify(Dtrunc, dDtrunc_)
197186
V, dV = Mooncake.arrayify(Vtrunc, dVtrunc_)
198187
$pb(dA, A, (D, V), (dD, dV))
@@ -281,12 +270,8 @@ end
281270
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
282271
function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual)
283272
# compute primal
284-
S_ = Mooncake.primal(S_dS)
285-
dS_ = Mooncake.tangent(S_dS)
286-
A_ = Mooncake.primal(A_dA)
287-
dA_ = Mooncake.tangent(A_dA)
288-
A, dA = arrayify(A_, dA_)
289-
S, dS = arrayify(S_, dS_)
273+
A, dA = arrayify(A_dA)
274+
S, dS = arrayify(S_dS)
290275
U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
291276
copy!(S, diagview(nS))
292277
function dsvd_vals_adjoint(::Mooncake.NoRData)
@@ -300,28 +285,23 @@ end
300285
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm}
301286
function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_vals)}, A_dA::CoDual, alg_dalg::CoDual)
302287
# compute primal
303-
A = Mooncake.primal(A_dA)
304-
dA_ = Mooncake.tangent(A_dA)
305-
A, dA = arrayify(A, dA_)
306-
S = svd_vals(A, Mooncake.primal(alg_dalg))
307-
U, _, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
288+
A, dA = arrayify(A_dA)
289+
U, S, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
308290
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
309291
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
310292
# pass). For many types this is done automatically when the forward step returns, but
311293
# not for nested structs with various fields (like Diagonal{Complex})
312-
S_codual = Mooncake.CoDual(S, Mooncake.fdata(Mooncake.zero_tangent(S)))
294+
S_codual = Mooncake.CoDual(diagview(S), Mooncake.fdata(Mooncake.zero_tangent(S)))
313295
function dsvd_vals_adjoint(::Mooncake.NoRData)
314-
S = Mooncake.primal(S_codual)
315-
dS_ = Mooncake.tangent(S_codual)
316-
S, dS = Mooncake.arrayify(S, dS_)
296+
S, dS = Mooncake.arrayify(S_codual)
317297
svd_pullback!(dA, A, (U, S, Vᴴ), (nothing, dS, nothing))
318298
MatrixAlgebraKit.zero!(dS)
319299
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
320300
end
321301
return S_codual, dsvd_vals_adjoint
322302
end
323303

324-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_trunc), Any, MatrixAlgebraKit.TruncatedAlgorithm}
304+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm}
325305
function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual)
326306
# compute primal
327307
A_ = Mooncake.primal(A_dA)
@@ -337,6 +317,7 @@ function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_trunc)}, A_dA::Co
337317
function dsvd_trunc_adjoint(dy::Tuple{Mooncake.NoRData, Mooncake.NoRData, Mooncake.NoRData, T}) where {T <: Real}
338318
Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual)
339319
dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual)
320+
abs(dϵ) > MatrixAlgebraKit.defaulttol(dϵ) && @warn "Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"
340321
U, dU = Mooncake.arrayify(Utrunc, dUtrunc_)
341322
S, dS = Mooncake.arrayify(Strunc, dStrunc_)
342323
Vᴴ, dVᴴ = Mooncake.arrayify(Vᴴtrunc, dVᴴtrunc_)

test/ad_utils.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,4 @@ function remove_eighgauge_dependence!(
2828
return ΔV
2929
end
3030

31-
precision(::Type{<:Union{Float32, Complex{Float32}}}) = sqrt(eps(Float32))
32-
precision(::Type{<:Union{Float64, Complex{Float64}}}) = sqrt(eps(Float64))
31+
precision(::Type{T}) where {T <: Number} = sqrt(eps(real(T)))

test/mooncake.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; rdata
6565
if has_handwritten_rule
6666
inplace_out, inplace_pb!! = isnothing(alg) ? Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) : Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData()))
6767
else
68-
inplace_sig = isnothing(alg) ? Tuple{typeof(f!), typeof(A), typeof(args)} : Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)}
6968
rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode)
7069
inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig)
7170
inplace_out, inplace_pb!! = isnothing(alg) ? inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) : inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData()))

0 commit comments

Comments
 (0)