@@ -2,7 +2,7 @@ module MatrixAlgebraKitChainRulesCoreExt
22
33using MatrixAlgebraKit
44using MatrixAlgebraKit: copy_input, initialize_output, zero!, diagview,
5- TruncatedAlgorithm, findtruncated, findtruncated_svd
5+ TruncatedAlgorithm, findtruncated, findtruncated_svd, compute_truncerr!
66using ChainRulesCore
77using LinearAlgebra
88
@@ -113,15 +113,20 @@ for eig in (:eig, :eigh)
113113 Ac = copy_input ($ eig_f, A)
114114 DV = $ (eig_f!)(Ac, DV, alg. alg)
115115 DV′, ind = MatrixAlgebraKit. truncate ($ eig_t!, DV, alg. trunc)
116- return DV′, $ (_make_eig_t_pb)(A, DV, ind)
116+ ϵ = compute_truncerr! (diagview (copy (DV[1 ])), ind)
117+ return (DV′... , ϵ), $ (_make_eig_t_pb)(A, DV, ind)
117118 end
118119 function $ (_make_eig_t_pb)(A, DV, ind)
119- function $eig_t_pb (ΔDV )
120+ function $eig_t_pb (ΔDVϵ )
120121 ΔA = zero (A)
121- MatrixAlgebraKit.$ eig_pb! (ΔA, A, DV, unthunk .(ΔDV), ind)
122+ ΔD, ΔV, Δϵ = ΔDVϵ
123+ if ! MatrixAlgebraKit. iszerotangent (Δϵ) && ! iszero (unthunk (Δϵ))
124+ throw (ArgumentError (" Pullback for eig_trunc! does not yet support non-zero tangent for the truncation error" ))
125+ end
126+ MatrixAlgebraKit.$ eig_pb! (ΔA, A, DV, unthunk .((ΔD, ΔV)), ind)
122127 return NoTangent (), ΔA, ZeroTangent (), NoTangent ()
123128 end
124- function $eig_t_pb (:: Tuple{ZeroTangent, ZeroTangent} ) # is this extra definition useful?
129+ function $eig_t_pb (:: Tuple{ZeroTangent, ZeroTangent, ZeroTangent } ) # is this extra definition useful?
125130 return NoTangent (), ZeroTangent (), ZeroTangent (), NoTangent ()
126131 end
127132 return $ eig_t_pb
@@ -152,15 +157,20 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlg
152157 Ac = copy_input (svd_compact, A)
153158 USVᴴ = svd_compact! (Ac, USVᴴ, alg. alg)
154159 USVᴴ′, ind = MatrixAlgebraKit. truncate (svd_trunc!, USVᴴ, alg. trunc)
155- return USVᴴ′, _make_svd_trunc_pullback (A, USVᴴ, ind)
160+ ϵ = compute_truncerr! (diagview (copy (USVᴴ[2 ])), ind)
161+ return (USVᴴ′... , ϵ), _make_svd_trunc_pullback (A, USVᴴ, ind)
156162end
157163function _make_svd_trunc_pullback (A, USVᴴ, ind)
158- function svd_trunc_pullback (ΔUSVᴴ )
164+ function svd_trunc_pullback (ΔUSVᴴϵ )
159165 ΔA = zero (A)
160- MatrixAlgebraKit. svd_pullback! (ΔA, A, USVᴴ, unthunk .(ΔUSVᴴ), ind)
166+ ΔU, ΔS, ΔVᴴ, Δϵ = ΔUSVᴴϵ
167+ if ! MatrixAlgebraKit. iszerotangent (Δϵ) && ! iszero (unthunk (Δϵ))
168+ throw (ArgumentError (" Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error" ))
169+ end
170+ MatrixAlgebraKit. svd_pullback! (ΔA, A, USVᴴ, unthunk .((ΔU, ΔS, ΔVᴴ)), ind)
161171 return NoTangent (), ΔA, ZeroTangent (), NoTangent ()
162172 end
163- function svd_trunc_pullback (:: Tuple{ZeroTangent, ZeroTangent, ZeroTangent} ) # is this extra definition useful?
173+ function svd_trunc_pullback (:: Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent } ) # is this extra definition useful?
164174 return NoTangent (), ZeroTangent (), ZeroTangent (), NoTangent ()
165175 end
166176 return svd_trunc_pullback
0 commit comments