@@ -94,7 +94,10 @@ for eig in (:eig, :eigh)
9494 eig_pb = Symbol (eig, " _pullback" )
9595 eig_t! = Symbol (eig, " _trunc!" )
9696 eig_t_pb = Symbol (eig, " _trunc_pullback" )
97+ eig_t_ne! = Symbol (eig, " _trunc_no_error!" )
98+ eig_t_ne_pb = Symbol (eig, " _trunc_no_error_pullback" )
9799 _make_eig_t_pb = Symbol (" _make_" , eig_t_pb)
100+ _make_eig_t_ne_pb = Symbol (" _make_" , eig_t_ne_pb)
98101 eig_v = Symbol (eig, " _vals" )
99102 eig_v! = Symbol (eig_v, " !" )
100103 eig_v_pb = Symbol (eig_v, " _pullback" )
@@ -136,6 +139,24 @@ for eig in (:eig, :eigh)
136139 end
137140 return $ eig_t_pb
138141 end
142+ function ChainRulesCore. rrule (:: typeof ($ eig_t_ne!), A, DV, alg:: TruncatedAlgorithm )
143+ Ac = copy_input ($ eig_f, A)
144+ DV = $ (eig_f!)(Ac, DV, alg. alg)
145+ DV′, ind = MatrixAlgebraKit. truncate ($ eig_t!, DV, alg. trunc)
146+ return DV′, $ (_make_eig_t_ne_pb)(A, DV, ind)
147+ end
148+ function $ (_make_eig_t_ne_pb)(A, DV, ind)
149+ function $eig_t_ne_pb (ΔDV)
150+ ΔA = zero (A)
151+ ΔD, ΔV = ΔDV
152+ MatrixAlgebraKit.$ eig_pb! (ΔA, A, DV, unthunk .((ΔD, ΔV)), ind)
153+ return NoTangent (), ΔA, ZeroTangent (), NoTangent ()
154+ end
155+ function $eig_t_ne_pb (:: Tuple{ZeroTangent, ZeroTangent, ZeroTangent} ) # is this extra definition useful?
156+ return NoTangent (), ZeroTangent (), ZeroTangent (), NoTangent ()
157+ end
158+ return $ eig_t_ne_pb
159+ end
139160 function ChainRulesCore. rrule (:: typeof ($ eig_v!), A, D, alg)
140161 DV = $ eig_f (A, alg)
141162 function $eig_v_pb (ΔD)
0 commit comments