@@ -168,6 +168,9 @@ for (f!, f, f_full, pb, adj) in (
168168 end
169169end
170170
171+ _warn_pullback_truncerror (dϵ:: Real ; tol = MatrixAlgebraKit. defaulttol (dϵ)) =
172+ abs (dϵ) ≤ tol || @warn " Pullback ignores non-zero tangents for truncation error"
173+
171174for f in (:eig , :eigh )
172175 f_trunc = Symbol (f, :_trunc )
173176 f_trunc! = Symbol (f_trunc, :! )
@@ -200,7 +203,7 @@ for f in (:eig, :eigh)
200203 copy! (A, Ac)
201204 Dtrunc, Vtrunc, ϵ = Mooncake. primal (output_codual)
202205 dDtrunc_, dVtrunc_, dϵ = Mooncake. tangent (output_codual)
203- abs (dy[3 ]) > MatrixAlgebraKit . defaulttol (dy[ 3 ]) && @warn " Pullback for $f does not yet support non-zero tangent for the truncation error "
206+ _warn_pullback_truncerror (dy[3 ])
204207 D′, dD′ = arrayify (Dtrunc, dDtrunc_)
205208 V′, dV′ = arrayify (Vtrunc, dVtrunc_)
206209 $ f_trunc_pullback! (dA, A, (D′, V′), (dD′, dV′))
@@ -235,8 +238,7 @@ for f in (:eig, :eigh)
235238 local $ f_adjoint!
236239 let ind = ind, dDVtrunc = last .(arrayify .(DVtrunc, Base. front (Mooncake. tangent (DVtrunc_dDVtrunc))))
237240 function $f_adjoint! ((_, _, dϵ):: Tuple{NoRData, NoRData, Real} )
238- abs (dϵ) ≤ MatrixAlgebraKit. defaulttol (dϵ) ||
239- @warn " Pullback for `$f! ` ignores non-zero tangents for truncation error"
241+ _warn_pullback_truncerror (dϵ)
240242
241243 # compute pullbacks
242244 $ f_pullback! (dA, Ac, DVc, dDVtrunc, ind)
@@ -265,7 +267,7 @@ for f in (:eig, :eigh)
265267 function $f_adjoint! (dy:: Tuple{NoRData, NoRData, T} ) where {T <: Real }
266268 Dtrunc, Vtrunc, ϵ = Mooncake. primal (output_codual)
267269 dDtrunc_, dVtrunc_, dϵ = Mooncake. tangent (output_codual)
268- abs (dy[3 ]) > MatrixAlgebraKit . defaulttol (dy[ 3 ]) && @warn " Pullback for $f does not yet support non-zero tangent for the truncation error "
270+ _warn_pullback_truncerror (dy[3 ])
269271 D, dD = arrayify (Dtrunc, dDtrunc_)
270272 V, dV = arrayify (Vtrunc, dVtrunc_)
271273 $ f_trunc_pullback! (dA, A, (D, V), (dD, dV))
@@ -292,8 +294,7 @@ for f in (:eig, :eigh)
292294 local $ f_adjoint!
293295 let ind = ind, dDVtrunc = last .(arrayify .(DVtrunc, Base. front (Mooncake. tangent (DVtrunc_dDVtrunc))))
294296 function $f_adjoint! ((_, _, dϵ):: Tuple{NoRData, NoRData, Real} )
295- abs (dϵ) ≤ MatrixAlgebraKit. defaulttol (dϵ) ||
296- @warn " Pullback for `$f_trunc ` ignores non-zero tangents for truncation error"
297+ _warn_pullback_truncerror (dϵ)
297298 $ f_pullback! (dA, A, DV, dDVtrunc, ind)
298299 zero! .(dDVtrunc) # since this is allocated in this function this is probably not required
299300 return ntuple (Returns (NoRData ()), 3 )
@@ -554,7 +555,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
554555 copy! (A, Ac)
555556 Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake. primal (output_codual)
556557 dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake. tangent (output_codual)
557- abs (dy[4 ]) > MatrixAlgebraKit . defaulttol (dy[ 4 ]) && @warn " Pullback for svd_trunc does not yet support non-zero tangent for the truncation error "
558+ _warn_pullback_truncerror (dy[4 ])
558559 U′, dU′ = arrayify (Utrunc, dUtrunc_)
559560 S′, dS′ = arrayify (Strunc, dStrunc_)
560561 Vᴴ′, dVᴴ′ = arrayify (Vᴴtrunc, dVᴴtrunc_)
@@ -596,8 +597,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
596597 local svd_trunc_adjoint
597598 let ind = ind, dUSVᴴtrunc = last .(arrayify .(USVᴴtrunc, Base. front (Mooncake. tangent (USVᴴtrunc_dUSVᴴtrunc))))
598599 function svd_trunc_adjoint ((_, _, _, dϵ):: Tuple{NoRData, NoRData, NoRData, Real} )
599- abs (dϵ) ≤ MatrixAlgebraKit. defaulttol (dϵ) ||
600- @warn " Pullback for `svd_trunc` ignores non-zero tangents for truncation error"
600+ _warn_pullback_truncerror (dϵ)
601601
602602 # compute pullbacks
603603 svd_pullback! (dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
@@ -629,7 +629,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
629629 function svd_trunc_adjoint (dy:: Tuple{NoRData, NoRData, NoRData, T} ) where {T <: Real }
630630 Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake. primal (output_codual)
631631 dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake. tangent (output_codual)
632- abs (dy[4 ]) > MatrixAlgebraKit . defaulttol (dy[ 4 ]) && @warn " Pullback for svd_trunc does not yet support non-zero tangent for the truncation error "
632+ _warn_pullback_truncerror (dy[4 ])
633633 U, dU = arrayify (Utrunc, dUtrunc_)
634634 S, dS = arrayify (Strunc, dStrunc_)
635635 Vᴴ, dVᴴ = arrayify (Vᴴtrunc, dVᴴtrunc_)
@@ -658,8 +658,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
658658 local svd_trunc_adjoint
659659 let ind = ind, dUSVᴴtrunc = last .(arrayify .(USVᴴtrunc, Base. front (Mooncake. tangent (USVᴴtrunc_dUSVᴴtrunc))))
660660 function svd_trunc_adjoint ((_, _, _, dϵ):: Tuple{NoRData, NoRData, NoRData, Real} )
661- abs (dϵ) ≤ MatrixAlgebraKit. defaulttol (dϵ) ||
662- @warn " Pullback for `svd_trunc` ignores non-zero tangents for truncation error"
661+ _warn_pullback_truncerror (dϵ)
663662 svd_pullback! (dA, A, USVᴴ, dUSVᴴtrunc, ind)
664663 zero! .(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
665664 return ntuple (Returns (NoRData ()), 3 )
0 commit comments