Skip to content

Commit 6b352e9

Browse files
committed
pullback truncation error warning
1 parent bf151a0 commit 6b352e9

1 file changed

Lines changed: 11 additions & 12 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ for (f!, f, f_full, pb, adj) in (
168168
end
169169
end
170170

171+
_warn_pullback_truncerror(dϵ::Real; tol = MatrixAlgebraKit.defaulttol(dϵ)) =
172+
abs(dϵ) tol || @warn "Pullback ignores non-zero tangents for truncation error"
173+
171174
for f in (:eig, :eigh)
172175
f_trunc = Symbol(f, :_trunc)
173176
f_trunc! = Symbol(f_trunc, :!)
@@ -200,7 +203,7 @@ for f in (:eig, :eigh)
200203
copy!(A, Ac)
201204
Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual)
202205
dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual)
203-
abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error"
206+
_warn_pullback_truncerror(dy[3])
204207
D′, dD′ = arrayify(Dtrunc, dDtrunc_)
205208
V′, dV′ = arrayify(Vtrunc, dVtrunc_)
206209
$f_trunc_pullback!(dA, A, (D′, V′), (dD′, dV′))
@@ -235,8 +238,7 @@ for f in (:eig, :eigh)
235238
local $f_adjoint!
236239
let ind = ind, dDVtrunc = last.(arrayify.(DVtrunc, Base.front(Mooncake.tangent(DVtrunc_dDVtrunc))))
237240
function $f_adjoint!((_, _, dϵ)::Tuple{NoRData, NoRData, Real})
238-
abs(dϵ) MatrixAlgebraKit.defaulttol(dϵ) ||
239-
@warn "Pullback for `$f!` ignores non-zero tangents for truncation error"
241+
_warn_pullback_truncerror(dϵ)
240242

241243
# compute pullbacks
242244
$f_pullback!(dA, Ac, DVc, dDVtrunc, ind)
@@ -265,7 +267,7 @@ for f in (:eig, :eigh)
265267
function $f_adjoint!(dy::Tuple{NoRData, NoRData, T}) where {T <: Real}
266268
Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual)
267269
dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual)
268-
abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error"
270+
_warn_pullback_truncerror(dy[3])
269271
D, dD = arrayify(Dtrunc, dDtrunc_)
270272
V, dV = arrayify(Vtrunc, dVtrunc_)
271273
$f_trunc_pullback!(dA, A, (D, V), (dD, dV))
@@ -292,8 +294,7 @@ for f in (:eig, :eigh)
292294
local $f_adjoint!
293295
let ind = ind, dDVtrunc = last.(arrayify.(DVtrunc, Base.front(Mooncake.tangent(DVtrunc_dDVtrunc))))
294296
function $f_adjoint!((_, _, dϵ)::Tuple{NoRData, NoRData, Real})
295-
abs(dϵ) MatrixAlgebraKit.defaulttol(dϵ) ||
296-
@warn "Pullback for `$f_trunc` ignores non-zero tangents for truncation error"
297+
_warn_pullback_truncerror(dϵ)
297298
$f_pullback!(dA, A, DV, dDVtrunc, ind)
298299
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
299300
return ntuple(Returns(NoRData()), 3)
@@ -554,7 +555,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
554555
copy!(A, Ac)
555556
Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual)
556557
dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual)
557-
abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc does not yet support non-zero tangent for the truncation error"
558+
_warn_pullback_truncerror(dy[4])
558559
U′, dU′ = arrayify(Utrunc, dUtrunc_)
559560
S′, dS′ = arrayify(Strunc, dStrunc_)
560561
Vᴴ′, dVᴴ′ = arrayify(Vᴴtrunc, dVᴴtrunc_)
@@ -596,8 +597,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
596597
local svd_trunc_adjoint
597598
let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc))))
598599
function svd_trunc_adjoint((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real})
599-
abs(dϵ) MatrixAlgebraKit.defaulttol(dϵ) ||
600-
@warn "Pullback for `svd_trunc` ignores non-zero tangents for truncation error"
600+
_warn_pullback_truncerror(dϵ)
601601

602602
# compute pullbacks
603603
svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
@@ -629,7 +629,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
629629
function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real}
630630
Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual)
631631
dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual)
632-
abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc does not yet support non-zero tangent for the truncation error"
632+
_warn_pullback_truncerror(dy[4])
633633
U, dU = arrayify(Utrunc, dUtrunc_)
634634
S, dS = arrayify(Strunc, dStrunc_)
635635
Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_)
@@ -658,8 +658,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
658658
local svd_trunc_adjoint
659659
let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc))))
660660
function svd_trunc_adjoint((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real})
661-
abs(dϵ) MatrixAlgebraKit.defaulttol(dϵ) ||
662-
@warn "Pullback for `svd_trunc` ignores non-zero tangents for truncation error"
661+
_warn_pullback_truncerror(dϵ)
663662
svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind)
664663
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
665664
return ntuple(Returns(NoRData()), 3)

0 commit comments

Comments
 (0)