@@ -629,6 +629,51 @@ def _gather_for_diff(self):
629629
630630 return self .func (* new_args , evaluate = False )
631631
632+ def _eval_at (self , func ):
633+ # No a basic a*b*c... expression, just defer to superclass
634+ if any (isinstance (f , DifferentiableOp ) for f in self .args ):
635+ return super ()._eval_at (func )
636+
637+ # Split Derivative and Differentiable args
638+ derivs , other = split (self .args , lambda e : isinstance (e , sympy .Derivative ))
639+
640+ if derivs :
641+ derivs = Differentiable ._eval_at (self .func (* derivs ), func )
642+ else :
643+ derivs = 1
644+
645+ if not other :
646+ return derivs
647+ elif len (other ) > 1 :
648+ expr = self .func (* other )._gather_for_diff
649+ else :
650+ expr = other [0 ]
651+
652+ # Non differentiable expr (e.g., number)
653+ if not isinstance (expr , Differentiable ):
654+ return self .func (derivs , expr )
655+
656+ # Build mapper for dimensions that need to be interpolated
657+ mapper = {}
658+ for d in self .dimensions :
659+ try :
660+ if self .indices_ref [d ] is not func .indices_ref [d ]:
661+ mapper [d ] = func .indices_ref [d ]
662+ except KeyError :
663+ pass
664+
665+ # Nothing to interpolate
666+ if not mapper :
667+ return super ()._eval_at (func )
668+
669+ # Interpolate expr at the required indices
670+ interp = expr .diff (* mapper .keys (), deriv_order = [0 for _ in mapper ],
671+ fd_order = [self .interp_order for _ in mapper ],
672+ x0 = mapper )
673+
674+ # Return the full expression with Derivatives
675+ return self .func (derivs , interp )
676+
632677
633678class Pow (DifferentiableOp , sympy .Pow ):
634679 _fd_priority = 0
0 commit comments