1515 # Moved in 1.13
1616 from sympy .core .basic import ordering_of_classes
1717
18+ from devito .finite_differences .interpolation import interp_at , post_x0_indices
1819from devito .finite_differences .tools import coeff_priority , make_shift_x0
1920from devito .logger import warning
2021from devito .tools import (
@@ -140,10 +141,6 @@ def indices_ref(self):
140141 return DimensionTuple (* self .dimensions , getters = self .dimensions )
141142 return highest_priority (self ).indices_ref
142143
143- @cached_property
144- def is_Staggered (self ):
145- return any ([getattr (i , 'is_Staggered' , False ) for i in self ._args_diff ])
146-
147144 @cached_property
148145 def is_TimeDependent (self ):
149146 return any (i .is_Time for i in self .dimensions )
@@ -184,13 +181,11 @@ def coefficients(self):
184181 key = lambda x : coeff_priority .get (x , - 1 )
185182 return sorted (coefficients , key = key , reverse = True )[0 ]
186183
187- def _eval_at (self , func ):
188- if not func .is_Staggered :
189- # Cartesian grid, do no waste time
190- return self
184+ def _eval_at (self , func , ** kwargs ):
191185 return self .func (* [
192- getattr (a , '_eval_at' , lambda x : a )(func ) for a in self .args # noqa: B023
193- ]) # false positive
186+ getattr (a , '_eval_at' , lambda x , ** kw : a )(func , ** kwargs ) # noqa: B023
187+ for a in self .args # false positive: lambda is invoked in-place
188+ ])
194189
195190 def _subs (self , old , new , ** hints ):
196191 if old == self :
@@ -669,6 +664,63 @@ def _gather_for_diff(self):
669664 other = self .func (* other )._eval_at (highest_priority (self ))
670665 return self .func (other , * derivs )
671666
667+ def _eval_at (self , func , interp_mode = 'direct' , ** kwargs ):
668+ """
669+ Evaluate a Mul at the location of `func`.
670+
671+ Two modes:
672+
673+ - `interp_mode='direct'` (default): per-arg evaluation; each factor is
674+ independently evaluated at `func`'s location via
675+ `Differentiable._eval_at`.
676+
677+ - `interp_mode='symmetric'`: when every Differentiable factor has a
678+ staggering different from `func`'s, apply the `I * (a * I^T * b)`
679+ form:
680+
681+ 1. Pick a `block` location -- the highest-priority factor's
682+ staggering (NODE is the highest priority, so coefficient-like
683+ NODE factors win, as in the `I * C * I^T` elastic stiffness
684+ pattern). Each factor not at the block is brought there via
685+ `I^T` (an explicit 0-order FD interpolation operator).
686+ Derivatives additionally set `x0` on their own derivative
687+ dimensions to `func`'s indices.
688+ 2. The product is formed at `block`'s location.
689+ 3. The whole product is interpolated to `func` via `I` (an
690+ explicit 0-order FD operator).
691+
692+ When the trigger does not hold (e.g. some factor already matches
693+ `func`'s staggering), we fall back to `direct`.
694+ """
695+ if interp_mode != 'symmetric' :
696+ return super ()._eval_at (func , ** kwargs )
697+
698+ diff , other = split (self .args , lambda a : isinstance (a , Differentiable ))
699+
700+ # Symmetric form requires every Differentiable factor to differ from
701+ # func; otherwise direct evaluation is cleaner and equivalent.
702+ if len (diff ) < 2 or \
703+ any (a .staggered == func .staggered for a in diff ):
704+ return super ()._eval_at (func , ** kwargs )
705+
706+ block_indices = highest_priority (self ).indices_ref
707+
708+ # Bring each factor to block's location (I^T where needed)
709+ new_factors = list (other )
710+ for a in diff :
711+ if isinstance (a , sympy .Derivative ):
712+ source = post_x0_indices (a , func )
713+ a = a ._rebuild (x0 = {dim : func .indices_ref [dim ] for dim in a .dims
714+ if dim in func .indices_ref .getters })
715+ else :
716+ source = a .indices_ref
717+ new_factors .append (interp_at (a , source , block_indices ,
718+ self .interp_order ))
719+
720+ # Final I from block's location to func
721+ return interp_at (self .func (* new_factors ), block_indices ,
722+ func .indices_ref , self .interp_order )
723+
672724
673725class Pow (DifferentiableOp , sympy .Pow ):
674726 _fd_priority = 0
@@ -1020,7 +1072,7 @@ def _subs(self, old, new, **hints):
10201072
10211073class DiffDerivative (IndexDerivative , DifferentiableOp ):
10221074
1023- def _eval_at (self , func ):
1075+ def _eval_at (self , func , ** kwargs ):
10241076 # Like EvalDerivative, a DiffDerivative must have already been evaluated
10251077 # at a valid x0 and should not be re-evaluated at a different location
10261078 return self
@@ -1074,7 +1126,7 @@ def _new_rawargs(self, *args, **kwargs):
10741126 kwargs .pop ('is_commutative' , None )
10751127 return self .func (* args , ** kwargs )
10761128
1077- def _eval_at (self , func ):
1129+ def _eval_at (self , func , ** kwargs ):
10781130 # An EvalDerivative must have already been evaluated at a valid x0
10791131 # and should not be re-evaluated at a different location
10801132 return self
@@ -1092,7 +1144,7 @@ class diffify:
10921144
10931145 Notes
10941146 -----
1095- The name "diffify" stems from SymPy's "simpify ", which has an analogous task --
1147+ The name "diffify" stems from SymPy's "simplify ", which has an analogous task --
10961148 converting all arguments into SymPy core objects.
10971149 """
10981150
@@ -1240,7 +1292,7 @@ def _(expr, x0, **kwargs):
12401292@interp_for_fd .register (AbstractFunction )
12411293def _ (expr , x0 , ** kwargs ):
12421294 x0_expr = {d : v for d , v in x0 .items () if v .has (d )
1243- and expr .indices [ d ] is not v }
1295+ and expr .indices . get ( d , v ) is not v }
12441296 if x0_expr :
12451297 return expr .subs ({expr .indices [d ]: v for d , v in x0_expr .items ()})
12461298 else :
0 commit comments