11from abc import ABC , abstractmethod
2+ from contextlib import suppress
23from functools import cached_property , wraps
4+ from itertools import groupby
35
46import numpy as np
57import sympy
@@ -67,12 +69,9 @@ def _extract_subdomain(variables):
6769 """
6870 sdms = set ()
6971 for v in variables :
70- try :
72+ with suppress ( AttributeError ) :
7173 if v .grid .is_SubDomain :
7274 sdms .add (v .grid )
73- except AttributeError :
74- # Variable not on a grid (Indexed for example)
75- pass
7675
7776 if len (sdms ) > 1 :
7877 raise NotImplementedError ("Sparse operation on multiple Functions defined on"
@@ -230,7 +229,7 @@ def r(self):
230229 return self .sfunction .r
231230
232231 @memoized_meth
233- def _weights (self , subdomain = None ):
232+ def _weights (self , subdomain = None , shifts = None ):
234233 raise NotImplementedError
235234
236235 @property
@@ -243,8 +242,22 @@ def _cdim(self):
243242 dims = [self .sfunction ._crdim (d ) for d in self ._gdims ]
244243 return dims
245244
245+ def _field_shifts (self , field ):
246+ """
247+ Per-grid-Dimension half-cell shift induced by `field`'s staggering
248+ (e.g. `h_x/2` for a field staggered in `x`). Returns None for
249+ unstaggered fields. SubDomain-induced origin offsets are deliberately
250+ ignored — they are not staggering.
251+ """
252+ staggered = field .staggered
253+ if not staggered or staggered .on_node :
254+ return ()
255+ return tuple ((d .spacing / 2 ) if s else 0
256+ for d , s in zip (field .dimensions , staggered , strict = True )
257+ if d .is_Space )
258+
246259 @memoized_meth
247- def _rdim (self , subdomain = None ):
260+ def _rdim (self , subdomain = None , shifts = None ):
248261 # If the interpolation operation is limited to a SubDomain,
249262 # use the SubDimensions of that SubDomain
250263 if subdomain :
@@ -254,7 +267,7 @@ def _rdim(self, subdomain=None):
254267
255268 # Make radius dimension conditional to avoid OOB
256269 rdims = []
257- pos = self .sfunction ._position_map .values ()
270+ pos = self .sfunction ._position_map ( shifts = shifts ) .values ()
258271
259272 for (d , rd , p ) in zip (gdims , self ._cdim , pos , strict = True ):
260273 # Add conditional to avoid OOB
@@ -279,12 +292,10 @@ def _augment_implicit_dims(self, implicit_dims, extras=None):
279292 # dimensions of that SubDomain from any extra dimensions found
280293 edims = []
281294 for v in extras :
282- try :
295+ with suppress ( AttributeError ) :
283296 if v .grid .is_SubDomain :
284297 edims .extend ([d for d in v .grid .dimensions
285298 if d .is_Sub and d .root in self ._gdims ])
286- except AttributeError :
287- pass
288299
289300 gdims = filter_ordered (edims + list (self ._gdims ))
290301 extra = filter_ordered ([i for v in extras for i in v .dimensions
@@ -300,27 +311,34 @@ def _augment_implicit_dims(self, implicit_dims, extras=None):
300311 idims = extra + as_tuple (implicit_dims ) + self .sfunction .dimensions
301312 return tuple (idims )
302313
303- def _coeff_temps (self , implicit_dims ):
314+ def _coeff_temps (self , implicit_dims , shifts = None ):
304315 return []
305316
306- def _positions (self , implicit_dims ):
317+ def _positions (self , implicit_dims , shifts = None ):
307318 return [Eq (v , INT (floor (k )), implicit_dims = implicit_dims )
308- for k , v in self .sfunction ._position_map .items ()]
319+ for k , v in self .sfunction ._position_map ( shifts = shifts ) .items ()]
309320
310- def _interp_idx (self , variables , implicit_dims = None , subdomain = None ):
321+ def _interp_idx (self , variables , implicit_dims = None , subdomain = None ,
322+ shifts = None ):
311323 """
312- Generate interpolation indices for the DiscreteFunctions in ``variables``.
324+ Generate interpolation indices for the DiscreteFunctions in `variables`.
325+
326+ `shifts` is a per-Dimension physical offset for the target field's
327+ origin: it only affects the integer position symbol via the position
328+ map (`pos = floor((c - o - shift)/h)`). The index substitution itself
329+ is unchanged — any staggered offset in a field's own symbolic access is
330+ absorbed by Devito's normal indexing.
313331 """
314- pos = self .sfunction ._position_map .values ()
332+ pos = self .sfunction ._position_map ( shifts = shifts ) .values ()
315333
316334 # Temporaries for the position
317- temps = self ._positions (implicit_dims )
335+ temps = self ._positions (implicit_dims , shifts = shifts )
318336
319337 # Coefficient symbol expression
320- temps .extend (self ._coeff_temps (implicit_dims ))
338+ temps .extend (self ._coeff_temps (implicit_dims , shifts = shifts ))
321339
322340 # Substitution mapper for variables
323- mapper = self ._rdim (subdomain = subdomain ).getters
341+ mapper = self ._rdim (subdomain = subdomain , shifts = shifts ).getters
324342
325343 # Index substitution to make in variables
326344 subs = {
@@ -337,7 +355,7 @@ def _interp_idx(self, variables, implicit_dims=None, subdomain=None):
337355 @check_coords
338356 def interpolate (self , expr , increment = False , self_subs = None , implicit_dims = None ):
339357 """
340- Generate equations interpolating an arbitrary expression into `` self` `.
358+ Generate equations interpolating an arbitrary expression into `self`.
341359
342360 Parameters
343361 ----------
@@ -375,7 +393,7 @@ def inject(self, field, expr, implicit_dims=None):
375393
376394 def _interpolate (self , expr , increment = False , self_subs = None , implicit_dims = None ):
377395 """
378- Generate equations interpolating an arbitrary expression into `` self` `.
396+ Generate equations interpolating an arbitrary expression into `self`.
379397
380398 Parameters
381399 ----------
@@ -389,16 +407,13 @@ def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None
389407 the operator.
390408 """
391409 # Derivatives must be evaluated before the introduction of indirect accesses
392- try :
393- _expr = expr .evaluate
394- except AttributeError :
395- # E.g., a generic SymPy expression or a number
396- _expr = expr
410+ with suppress (AttributeError ):
411+ expr = expr ._eval_at (self .sfunction ).evaluate
397412
398413 if self_subs is None :
399414 self_subs = {}
400415
401- variables = list (retrieve_function_carriers (_expr ))
416+ variables = list (retrieve_function_carriers (expr ))
402417 subdomain = _extract_subdomain (variables )
403418
404419 # Implicit dimensions
@@ -413,7 +428,7 @@ def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None
413428 summands = [Eq (rhs , 0. , implicit_dims = implicit_dims )]
414429 # Substitute coordinate base symbols into the interpolation coefficients
415430 weights = self ._weights (subdomain = subdomain )
416- summands .extend ([Inc (rhs , (weights * _expr ).xreplace (idx_subs ),
431+ summands .extend ([Inc (rhs , (weights * expr ).xreplace (idx_subs ),
417432 implicit_dims = implicit_dims )])
418433
419434 # Write/Incr `self`
@@ -451,35 +466,48 @@ def _inject(self, field, expr, implicit_dims=None):
451466
452467 subdomain = _extract_subdomain (fields )
453468
454- # Derivatives must be evaluated before the introduction of indirect accesses
455- try :
456- _exprs = tuple (e .evaluate for e in exprs )
457- except AttributeError :
458- # E.g., a generic SymPy expression or a number
459- _exprs = exprs
460-
461- variables = list (v for e in _exprs for v in retrieve_function_carriers (e ))
462-
463- # Implicit dimensions
464- implicit_dims = self ._augment_implicit_dims (implicit_dims , variables )
465- # Move all temporaries inside inner loop to improve parallelism
466- # Can only be done for inject as interpolation need a temporary
467- # summing temp that wouldn't allow collapsing
468- implicit_dims = implicit_dims + tuple (r .parent for r in
469- self ._rdim (subdomain = subdomain ))
470-
471- # List of indirection indices for all adjacent grid points
472- finterp = fields + as_tuple (variables )
473- idx_subs , temps = self ._interp_idx (finterp , implicit_dims = implicit_dims ,
474- subdomain = subdomain )
475-
476- # Substitute coordinate base symbols into the interpolation coefficients
477- eqns = [Inc (_field .xreplace (idx_subs ),
478- (self ._weights (subdomain = subdomain ) * _expr ).xreplace (idx_subs ),
479- implicit_dims = implicit_dims )
480- for (_field , _expr ) in zip (fields , _exprs , strict = True )]
481-
482- return temps + eqns
469+ # Derivatives must be evaluated before the introduction of indirect
470+ # accesses. Variables are sampled at their own grid location; the
471+ # position map for the target field carries the staggering so the
472+ # field's stencil neighbors land on the right indices.
473+ with suppress (AttributeError ):
474+ exprs = tuple (e ._eval_at (f ).evaluate
475+ for e , f in zip (exprs , fields , strict = True ))
476+
477+ eqns = []
478+ temps = []
479+ # We need to create one set of equations (temps and and coeffs) per staggering
480+ # field in which we inject as the reference index depends on the field's origin
481+ for _ , g in groupby (zip (fields , exprs , strict = True ), lambda f : f [0 ].staggered ):
482+ g_fields , g_exprs = zip (* g , strict = True )
483+ variables = list (v for e in g_exprs for v in retrieve_function_carriers (e ))
484+
485+ implicit_dims = self ._augment_implicit_dims (implicit_dims , variables )
486+
487+ # All fields in a single injection share the same staggering by
488+ # construction (they are written together at the same indices), so
489+ # derive shifts from the first field.
490+ shifts = self ._field_shifts (g_fields [0 ])
491+
492+ # Move all temporaries inside inner loop to improve parallelism
493+ # Can only be done for inject as interpolation needs a summing temp
494+ # that wouldn't allow collapsing
495+ implicit_dims = implicit_dims + tuple (r .parent for r in
496+ self ._rdim (subdomain = subdomain ,
497+ shifts = shifts ))
498+
499+ # List of indirection indices for all adjacent grid points
500+ idx_subs , _temps = self ._interp_idx (list (g_fields ) + variables ,
501+ implicit_dims = implicit_dims ,
502+ subdomain = subdomain , shifts = shifts )
503+
504+ w = self ._weights (subdomain = subdomain , shifts = shifts )
505+ temps .extend (_temps )
506+ eqns .extend ([Inc (f .xreplace (idx_subs ), (w * e ).xreplace (idx_subs ),
507+ implicit_dims = implicit_dims )
508+ for f , e in zip (g_fields , g_exprs , strict = True )])
509+
510+ return filter_ordered (temps ) + eqns
483511
484512
485513class LinearInterpolator (WeightedInterpolator ):
@@ -495,24 +523,30 @@ class LinearInterpolator(WeightedInterpolator):
495523 _name = 'linear'
496524
497525 @memoized_meth
498- def _weights (self , subdomain = None ):
499- rdim = self ._rdim (subdomain = subdomain )
526+ def _weights (self , subdomain = None , shifts = None ):
527+ rdim = self ._rdim (subdomain = subdomain , shifts = shifts )
500528 c = [(1 - p ) * (1 - r ) + p * r
501- for (p , d , r ) in zip (self ._point_symbols , self ._gdims , rdim , strict = True )]
529+ for (p , d , r ) in zip (self ._point_symbols (shifts ), self ._gdims , rdim ,
530+ strict = True )]
502531 return Mul (* c )
503532
504- @cached_property
505- def _point_symbols (self ):
533+ @memoized_meth
534+ def _point_symbols (self , shifts = None ):
506535 """Symbol for coordinate value in each Dimension of the point."""
507536 dtype = self .sfunction .coordinates .dtype
508- return DimensionTuple (* (Symbol (name = f'p{ d } ' , dtype = dtype )
509- for d in self .grid .dimensions ),
510- getters = self .grid .dimensions )
537+ symbols = []
538+ for d in self .grid .dimensions :
539+ if shifts and shifts [self .grid .dimensions .index (d )] != 0 :
540+ symbols .append (Symbol (name = f'p{ d } _s1' , dtype = dtype ))
541+ else :
542+ symbols .append (Symbol (name = f'p{ d } ' , dtype = dtype ))
543+ return DimensionTuple (* symbols , getters = self .grid .dimensions )
511544
512- def _coeff_temps (self , implicit_dims ):
545+ def _coeff_temps (self , implicit_dims , shifts = None ):
513546 # Positions
514- pmap = self .sfunction ._position_map
515- poseq = [Eq (self ._point_symbols [d ], pos - floor (pos ),
547+ pmap = self .sfunction ._position_map (shifts = shifts )
548+ psyms = self ._point_symbols (shifts )
549+ poseq = [Eq (psyms [d ], pos - floor (pos ),
516550 implicit_dims = implicit_dims )
517551 for (d , pos ) in zip (self ._gdims , pmap .keys (), strict = True )]
518552 return poseq
@@ -531,23 +565,24 @@ class PrecomputedInterpolator(WeightedInterpolator):
531565
532566 _name = 'precomp'
533567
534- def _positions (self , implicit_dims ):
568+ def _positions (self , implicit_dims , shifts = None ):
535569 if self .sfunction .gridpoints_data is None :
536- return super ()._positions (implicit_dims )
570+ return super ()._positions (implicit_dims , shifts = shifts )
537571 else :
538572 # No position temp as we have directly the gridpoints
539573 return [Eq (p , k , implicit_dims = implicit_dims )
540- for (k , p ) in self .sfunction ._position_map .items ()]
574+ for (k , p ) in self .sfunction ._position_map ( shifts = shifts ) .items ()]
541575
542576 @property
543577 def interpolation_coeffs (self ):
544578 return self .sfunction .interpolation_coeffs
545579
546580 @memoized_meth
547- def _weights (self , subdomain = None ):
581+ def _weights (self , subdomain = None , shifts = None ):
548582 ddim , cdim = self .interpolation_coeffs .dimensions [1 :]
549583 mappers = [{ddim : ri , cdim : rd - rd .parent .symbolic_min }
550- for (ri , rd ) in enumerate (self ._rdim (subdomain = subdomain ))]
584+ for (ri , rd ) in enumerate (self ._rdim (subdomain = subdomain ,
585+ shifts = shifts ))]
551586 return Mul (* [self .interpolation_coeffs .subs (mapper )
552587 for mapper in mappers ])
553588
@@ -592,8 +627,8 @@ def interpolation_coeffs(self):
592627 return tuple (coeffs )
593628
594629 @memoized_meth
595- def _weights (self , subdomain = None ):
596- rdims = self ._rdim (subdomain = subdomain )
630+ def _weights (self , subdomain = None , shifts = None ):
631+ rdims = self ._rdim (subdomain = subdomain , shifts = shifts )
597632 return Mul (* [
598633 w ._subs (rd , rd - rd .parent .symbolic_min )
599634 for (rd , w ) in zip (rdims , self .interpolation_coeffs , strict = True )
0 commit comments