Skip to content

Commit bf151a0

Browse files
committed
add specializations eig_trunc(!) for TruncatedAlgorithm
1 parent a23735c commit bf151a0

1 file changed

Lines changed: 163 additions & 29 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 163 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -168,118 +168,252 @@ for (f!, f, f_full, pb, adj) in (
168168
end
169169
end
170170

171-
for (f!, f, f_ne!, f_ne, pb, adj) in (
172-
(:eig_trunc!, :eig_trunc, :eig_trunc_no_error!, :eig_trunc_no_error, :eig_trunc_pullback!, :eig_trunc_adjoint),
173-
(:eigh_trunc!, :eigh_trunc, :eigh_trunc_no_error!, :eigh_trunc_no_error, :eigh_trunc_pullback!, :eigh_trunc_adjoint),
174-
)
171+
for f in (:eig, :eigh)
172+
f_trunc = Symbol(f, :_trunc)
173+
f_trunc! = Symbol(f_trunc, :!)
174+
f_full = Symbol(f, :_full)
175+
f_full! = Symbol(f_full, :!)
176+
f_pullback! = Symbol(f, :_pullback!)
177+
f_trunc_pullback! = Symbol(f_trunc, :_pullback!)
178+
f_adjoint! = Symbol(f, :_adjoint!)
179+
f_trunc_no_error = Symbol(f_trunc, :_no_error)
180+
f_trunc_no_error! = Symbol(f_trunc_no_error, :!)
181+
175182
@eval begin
176-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
177-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
178-
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual)
183+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
184+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm}
185+
function Mooncake.rrule!!(::CoDual{typeof($f_trunc!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual)
179186
# compute primal
180187
A, dA = arrayify(A_dA)
181188
DV = Mooncake.primal(DV_dDV)
182189
dDV = Mooncake.tangent(DV_dDV)
183190
Ac = copy(A)
184191
DVc = copy.(DV)
185192
alg = Mooncake.primal(alg_dalg)
186-
output = $f!(A, DV, alg)
193+
output = $f_trunc!(A, DV, alg)
187194
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
188195
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
189196
# pass). For many types this is done automatically when the forward step returns, but
190197
# not for nested structs with various fields (like Diagonal{Complex})
191-
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
192-
function $adj(dy::Tuple{NoRData, NoRData, T}) where {T <: Real}
198+
output_codual = Mooncake.zero_fcodual(output)
199+
function $f_adjoint!(dy::Tuple{NoRData, NoRData, <:Real})
193200
copy!(A, Ac)
194201
Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual)
195202
dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual)
196203
abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error"
197204
D′, dD′ = arrayify(Dtrunc, dDtrunc_)
198205
V′, dV′ = arrayify(Vtrunc, dVtrunc_)
199-
$pb(dA, A, (D′, V′), (dD′, dV′))
206+
$f_trunc_pullback!(dA, A, (D′, V′), (dD′, dV′))
200207
copy!(DV[1], DVc[1])
201208
copy!(DV[2], DVc[2])
202209
zero!(dD′)
203210
zero!(dV′)
204211
return NoRData(), NoRData(), NoRData(), NoRData()
205212
end
206-
return output_codual, $adj
213+
return output_codual, $f_adjoint!
207214
end
208-
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual)
215+
function Mooncake.rrule!!(::CoDual{typeof($f_trunc!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm})
216+
# unpack variables
217+
A, dA = arrayify(A_dA)
218+
DV_dDV_arr = arrayify.(Mooncake.primal(DV_dDV), Mooncake.tangent(DV_dDV))
219+
DV, dDV = first.(DV_dDV_arr), last.(DV_dDV_arr)
220+
alg = Mooncake.primal(alg_dalg)
221+
222+
# store state prior to primal call
223+
Ac = copy(A)
224+
DVc = copy.(DV)
225+
226+
# compute primal - capture full DV and ind
227+
DV = $f_full!(A, DV, alg.alg)
228+
DVtrunc, ind = MatrixAlgebraKit.truncate($f_trunc!, DV, alg.trunc)
229+
ϵ = MatrixAlgebraKit.truncation_error(diagview(DV[1]), ind)
230+
231+
# pack output - note that we allocate new dDVtrunc because these aren't overwritten in the input
232+
DVtrunc_dDVtrunc = Mooncake.zero_fcodual((DVtrunc..., ϵ))
233+
234+
# define pullback
235+
local $f_adjoint!
236+
let ind = ind, dDVtrunc = last.(arrayify.(DVtrunc, Base.front(Mooncake.tangent(DVtrunc_dDVtrunc))))
237+
function $f_adjoint!((_, _, dϵ)::Tuple{NoRData, NoRData, Real})
238+
abs(dϵ) MatrixAlgebraKit.defaulttol(dϵ) ||
239+
@warn "Pullback for `$f!` ignores non-zero tangents for truncation error"
240+
241+
# compute pullbacks
242+
$f_pullback!(dA, Ac, DVc, dDVtrunc, ind)
243+
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
244+
245+
# restore state
246+
copy!(A, Ac)
247+
copy!.(DV, DVc)
248+
249+
return ntuple(Returns(NoRData()), 4)
250+
end
251+
end
252+
253+
return DVtrunc_dDVtrunc, $f_adjoint!
254+
end
255+
function Mooncake.rrule!!(::CoDual{typeof($f_trunc)}, A_dA::CoDual, alg_dalg::CoDual)
209256
# compute primal
210257
A, dA = arrayify(A_dA)
211258
alg = Mooncake.primal(alg_dalg)
212-
output = $f(A, alg)
259+
output = $f_trunc(A, alg)
213260
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
214261
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
215262
# pass). For many types this is done automatically when the forward step returns, but
216263
# not for nested structs with various fields (like Diagonal{Complex})
217264
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
218-
function $adj(dy::Tuple{NoRData, NoRData, T}) where {T <: Real}
265+
function $f_adjoint!(dy::Tuple{NoRData, NoRData, T}) where {T <: Real}
219266
Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual)
220267
dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual)
221268
abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error"
222269
D, dD = arrayify(Dtrunc, dDtrunc_)
223270
V, dV = arrayify(Vtrunc, dVtrunc_)
224-
$pb(dA, A, (D, V), (dD, dV))
271+
$f_trunc_pullback!(dA, A, (D, V), (dD, dV))
225272
zero!(dD)
226273
zero!(dV)
227274
return NoRData(), NoRData(), NoRData()
228275
end
229-
return output_codual, $adj
276+
return output_codual, $f_adjoint!
230277
end
231-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_ne!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
232-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_ne), Any, MatrixAlgebraKit.AbstractAlgorithm}
233-
function Mooncake.rrule!!(::CoDual{typeof($f_ne!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual)
278+
function Mooncake.rrule!!(::CoDual{typeof($f_trunc)}, A_dA::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm})
279+
# unpack variables
280+
A, dA = arrayify(A_dA)
281+
alg = Mooncake.primal(alg_dalg)
282+
283+
# compute primal - capture full DV and ind
284+
DV = $f_full(A, alg.alg)
285+
DVtrunc, ind = MatrixAlgebraKit.truncate($f_trunc!, DV, alg.trunc)
286+
ϵ = MatrixAlgebraKit.truncation_error(diagview(DV[1]), ind)
287+
288+
# pack output
289+
DVtrunc_dDVtrunc = Mooncake.zero_fcodual((DVtrunc..., ϵ))
290+
291+
# define pullback
292+
local $f_adjoint!
293+
let ind = ind, dDVtrunc = last.(arrayify.(DVtrunc, Base.front(Mooncake.tangent(DVtrunc_dDVtrunc))))
294+
function $f_adjoint!((_, _, dϵ)::Tuple{NoRData, NoRData, Real})
295+
abs(dϵ) MatrixAlgebraKit.defaulttol(dϵ) ||
296+
@warn "Pullback for `$f_trunc` ignores non-zero tangents for truncation error"
297+
$f_pullback!(dA, A, DV, dDVtrunc, ind)
298+
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
299+
return ntuple(Returns(NoRData()), 3)
300+
end
301+
end
302+
303+
return DVtrunc_dDVtrunc, $f_adjoint!
304+
end
305+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc_no_error!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
306+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc_no_error), Any, MatrixAlgebraKit.AbstractAlgorithm}
307+
function Mooncake.rrule!!(::CoDual{typeof($f_trunc_no_error!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual)
234308
# compute primal
235309
A, dA = arrayify(A_dA)
236310
alg = Mooncake.primal(alg_dalg)
237311
DV = Mooncake.primal(DV_dDV)
238312
dDV = Mooncake.tangent(DV_dDV)
239313
Ac = copy(A)
240314
DVc = copy.(DV)
241-
output = $f_ne!(A, DV, alg)
315+
output = $f_trunc_no_error!(A, DV, alg)
242316
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
243317
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
244318
# pass). For many types this is done automatically when the forward step returns, but
245319
# not for nested structs with various fields (like Diagonal{Complex})
246320
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
247-
function $adj(::NoRData)
321+
function $f_adjoint!(::NoRData)
248322
copy!(A, Ac)
249323
Dtrunc, Vtrunc = Mooncake.primal(output_codual)
250324
dDtrunc_, dVtrunc_ = Mooncake.tangent(output_codual)
251325
D′, dD′ = arrayify(Dtrunc, dDtrunc_)
252326
V′, dV′ = arrayify(Vtrunc, dVtrunc_)
253-
$pb(dA, A, (D′, V′), (dD′, dV′))
327+
$f_pullback!(dA, A, (D′, V′), (dD′, dV′))
254328
copy!(DV[1], DVc[1])
255329
copy!(DV[2], DVc[2])
256330
zero!(dD′)
257331
zero!(dV′)
258332
return NoRData(), NoRData(), NoRData(), NoRData()
259333
end
260-
return output_codual, $adj
334+
return output_codual, $f_adjoint!
335+
end
336+
function Mooncake.rrule!!(::CoDual{typeof($f_trunc_no_error!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm})
337+
# unpack variables
338+
A, dA = arrayify(A_dA)
339+
DV_dDV_arr = arrayify.(Mooncake.primal(DV_dDV), Mooncake.tangent(DV_dDV))
340+
DV, dDV = first.(DV_dDV_arr), last.(DV_dDV_arr)
341+
alg = Mooncake.primal(alg_dalg)
342+
343+
# store state prior to primal call
344+
Ac = copy(A)
345+
DVc = copy.(DV)
346+
347+
# compute primal - capture full DV and ind
348+
DV = $f_full!(A, DV, alg.alg)
349+
DVtrunc, ind = MatrixAlgebraKit.truncate($f_trunc!, DV, alg.trunc)
350+
351+
# pack output - note that we allocate new dDVtrunc because these aren't overwritten in the input
352+
DVtrunc_dDVtrunc = Mooncake.zero_fcodual(DVtrunc)
353+
354+
# define pullback
355+
local $f_adjoint!
356+
let ind = ind, dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc)))
357+
function $f_adjoint!(::NoRData)
358+
# compute pullbacks
359+
$f_pullback!(dA, Ac, DVc, dDVtrunc, ind)
360+
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
361+
362+
# restore state
363+
copy!(A, Ac)
364+
copy!.(DV, DVc)
365+
366+
return ntuple(Returns(NoRData()), 4)
367+
end
368+
end
369+
370+
return DVtrunc_dDVtrunc, $f_adjoint!
261371
end
262-
function Mooncake.rrule!!(::CoDual{typeof($f_ne)}, A_dA::CoDual, alg_dalg::CoDual)
372+
function Mooncake.rrule!!(::CoDual{typeof($f_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual)
263373
# compute primal
264374
A, dA = arrayify(A_dA)
265375
alg = Mooncake.primal(alg_dalg)
266-
output = $f_ne(A, alg)
376+
output = $f_trunc_no_error(A, alg)
267377
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
268378
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
269379
# pass). For many types this is done automatically when the forward step returns, but
270380
# not for nested structs with various fields (like Diagonal{Complex})
271381
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
272-
function $adj(::NoRData)
382+
function $f_adjoint!(::NoRData)
273383
Dtrunc, Vtrunc = Mooncake.primal(output_codual)
274384
dDtrunc_, dVtrunc_ = Mooncake.tangent(output_codual)
275385
D, dD = arrayify(Dtrunc, dDtrunc_)
276386
V, dV = arrayify(Vtrunc, dVtrunc_)
277-
$pb(dA, A, (D, V), (dD, dV))
387+
$f_trunc_pullback!(dA, A, (D, V), (dD, dV))
278388
zero!(dD)
279389
zero!(dV)
280390
return NoRData(), NoRData(), NoRData()
281391
end
282-
return output_codual, $adj
392+
return output_codual, $f_adjoint!
393+
end
394+
function Mooncake.rrule!!(::CoDual{typeof($f_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm})
395+
# unpack variables
396+
A, dA = arrayify(A_dA)
397+
alg = Mooncake.primal(alg_dalg)
398+
399+
# compute primal - capture full DV and ind
400+
DV = $f_full(A, alg.alg)
401+
DVtrunc, ind = MatrixAlgebraKit.truncate($f_trunc!, DV, alg.trunc)
402+
403+
# pack output
404+
DVtrunc_dDVtrunc = Mooncake.zero_fcodual(DVtrunc)
405+
406+
# define pullback
407+
local $f_adjoint!
408+
let ind = ind, dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc)))
409+
function $f_adjoint!(::NoRData)
410+
$f_pullback!(dA, A, DV, dDVtrunc, ind)
411+
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
412+
return ntuple(Returns(NoRData()), 3)
413+
end
414+
end
415+
416+
return DVtrunc_dDVtrunc, $f_adjoint!
283417
end
284418
end
285419
end

0 commit comments

Comments
 (0)