@@ -2,7 +2,7 @@ module MatrixAlgebraKitChainRulesCoreExt
22
33using MatrixAlgebraKit
44using MatrixAlgebraKit: copy_input, initialize_output, zero!, diagview,
5- TruncatedAlgorithm, findtruncated, findtruncated_svd, compute_truncerr!
5+ TruncatedAlgorithm, findtruncated, findtruncated_svd, truncation_error
66using ChainRulesCore
77using LinearAlgebra
88
@@ -113,7 +113,7 @@ for eig in (:eig, :eigh)
113113 Ac = copy_input ($ eig_f, A)
114114 DV = $ (eig_f!)(Ac, DV, alg. alg)
115115 DV′, ind = MatrixAlgebraKit. truncate ($ eig_t!, DV, alg. trunc)
116- ϵ = compute_truncerr! ( copy ( diagview (DV[1 ]) ), ind)
116+ ϵ = truncation_error ( diagview (DV[1 ]), ind)
117117 return (DV′... , ϵ), $ (_make_eig_t_pb)(A, DV, ind)
118118 end
119119 function $ (_make_eig_t_pb)(A, DV, ind)
@@ -157,7 +157,7 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlg
157157 Ac = copy_input (svd_compact, A)
158158 USVᴴ = svd_compact! (Ac, USVᴴ, alg. alg)
159159 USVᴴ′, ind = MatrixAlgebraKit. truncate (svd_trunc!, USVᴴ, alg. trunc)
160- ϵ = compute_truncerr! ( copy ( diagview (USVᴴ[2 ]) ), ind)
160+ ϵ = truncation_error ( diagview (USVᴴ[2 ]), ind)
161161 return (USVᴴ′... , ϵ), _make_svd_trunc_pullback (A, USVᴴ, ind)
162162end
163163function _make_svd_trunc_pullback (A, USVᴴ, ind)
0 commit comments