11from abc import ABC , abstractmethod
2+ from contextlib import suppress
23from functools import cached_property , wraps
4+ from itertools import groupby
35
46import numpy as np
57import sympy
@@ -67,12 +69,9 @@ def _extract_subdomain(variables):
6769 """
6870 sdms = set ()
6971 for v in variables :
70- try :
72+ with suppress ( AttributeError ) :
7173 if v .grid .is_SubDomain :
7274 sdms .add (v .grid )
73- except AttributeError :
74- # Variable not on a grid (Indexed for example)
75- pass
7675
7776 if len (sdms ) > 1 :
7877 raise NotImplementedError ("Sparse operation on multiple Functions defined on"
@@ -245,19 +244,17 @@ def _cdim(self):
245244
246245 def _field_shifts (self , field ):
247246 """
248- Per-grid-Dimension half-cell shift induced by `` field` `'s staggering
249- (e.g. `` h_x/2`` for a field staggered in ``x` `). Returns None for
247+ Per-grid-Dimension half-cell shift induced by `field`'s staggering
248+ (e.g. `h_x/2` for a field staggered in `x `). Returns None for
250249 unstaggered fields. SubDomain-induced origin offsets are deliberately
251250 ignored — they are not staggering.
252251 """
253- try :
254- staggered = field .staggered
255- except AttributeError :
256- return None
257- if not staggered or all (s == 0 for s in staggered ):
258- return None
252+ staggered = field .staggered
253+ if not staggered or staggered .on_node :
254+ return ()
259255 return tuple ((d .spacing / 2 ) if s else 0
260- for d , s in zip (self ._gdims , staggered , strict = True ))
256+ for d , s in zip (field .dimensions , staggered , strict = True )
257+ if d .is_Space )
261258
262259 @memoized_meth
263260 def _rdim (self , subdomain = None , shifts = None ):
@@ -295,12 +292,10 @@ def _augment_implicit_dims(self, implicit_dims, extras=None):
295292 # dimensions of that SubDomain from any extra dimensions found
296293 edims = []
297294 for v in extras :
298- try :
295+ with suppress ( AttributeError ) :
299296 if v .grid .is_SubDomain :
300297 edims .extend ([d for d in v .grid .dimensions
301298 if d .is_Sub and d .root in self ._gdims ])
302- except AttributeError :
303- pass
304299
305300 gdims = filter_ordered (edims + list (self ._gdims ))
306301 extra = filter_ordered ([i for v in extras for i in v .dimensions
@@ -326,11 +321,11 @@ def _positions(self, implicit_dims, shifts=None):
326321 def _interp_idx (self , variables , implicit_dims = None , subdomain = None ,
327322 shifts = None ):
328323 """
329- Generate interpolation indices for the DiscreteFunctions in `` variables` `.
324+ Generate interpolation indices for the DiscreteFunctions in `variables`.
330325
331- `` shifts` ` is a per-Dimension physical offset for the target field's
326+ `shifts` is a per-Dimension physical offset for the target field's
332327 origin: it only affects the integer position symbol via the position
333- map (`` pos = floor((c - o - shift)/h)` `). The index substitution itself
328+ map (`pos = floor((c - o - shift)/h)`). The index substitution itself
334329 is unchanged — any staggered offset in a field's own symbolic access is
335330 absorbed by Devito's normal indexing.
336331 """
@@ -360,7 +355,7 @@ def _interp_idx(self, variables, implicit_dims=None, subdomain=None,
360355 @check_coords
361356 def interpolate (self , expr , increment = False , self_subs = None , implicit_dims = None ):
362357 """
363- Generate equations interpolating an arbitrary expression into `` self` `.
358+ Generate equations interpolating an arbitrary expression into `self`.
364359
365360 Parameters
366361 ----------
@@ -398,7 +393,7 @@ def inject(self, field, expr, implicit_dims=None):
398393
399394 def _interpolate (self , expr , increment = False , self_subs = None , implicit_dims = None ):
400395 """
401- Generate equations interpolating an arbitrary expression into `` self` `.
396+ Generate equations interpolating an arbitrary expression into `self`.
402397
403398 Parameters
404399 ----------
@@ -412,16 +407,13 @@ def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None
412407 the operator.
413408 """
414409 # Derivatives must be evaluated before the introduction of indirect accesses
415- try :
416- _expr = expr ._eval_at (self .sfunction ).evaluate
417- except AttributeError :
418- # E.g., a generic SymPy expression or a number
419- _expr = expr
410+ with suppress (AttributeError ):
411+ expr = expr ._eval_at (self .sfunction ).evaluate
420412
421413 if self_subs is None :
422414 self_subs = {}
423415
424- variables = list (retrieve_function_carriers (_expr ))
416+ variables = list (retrieve_function_carriers (expr ))
425417 subdomain = _extract_subdomain (variables )
426418
427419 # Implicit dimensions
@@ -436,7 +428,7 @@ def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None
436428 summands = [Eq (rhs , 0. , implicit_dims = implicit_dims )]
437429 # Substitute coordinate base symbols into the interpolation coefficients
438430 weights = self ._weights (subdomain = subdomain )
439- summands .extend ([Inc (rhs , (weights * _expr ).xreplace (idx_subs ),
431+ summands .extend ([Inc (rhs , (weights * expr ).xreplace (idx_subs ),
440432 implicit_dims = implicit_dims )])
441433
442434 # Write/Incr `self`
@@ -478,42 +470,44 @@ def _inject(self, field, expr, implicit_dims=None):
478470 # accesses. Variables are sampled at their own grid location; the
479471 # position map for the target field carries the staggering so the
480472 # field's stencil neighbors land on the right indices.
481- try :
482- _exprs = tuple (e ._eval_at (f ).evaluate
483- for e , f in zip (exprs , fields , strict = True ))
484- except AttributeError :
485- # E.g., a generic SymPy expression or a number
486- _exprs = exprs
487-
488- variables = list (v for e in _exprs for v in retrieve_function_carriers (e ))
489-
490- # Implicit dimensions
491- implicit_dims = self ._augment_implicit_dims (implicit_dims , variables )
492-
493- # All fields in a single injection share the same staggering by
494- # construction (they are written together at the same indices), so
495- # derive shifts from the first field.
496- shifts = self ._field_shifts (fields [0 ])
497-
498- # Move all temporaries inside inner loop to improve parallelism
499- # Can only be done for inject as interpolation needs a summing temp
500- # that wouldn't allow collapsing
501- implicit_dims = implicit_dims + tuple (r .parent for r in
502- self ._rdim (subdomain = subdomain ,
503- shifts = shifts ))
504-
505- # List of indirection indices for all adjacent grid points
506- idx_subs , temps = self ._interp_idx (list (fields ) + variables ,
507- implicit_dims = implicit_dims ,
508- subdomain = subdomain , shifts = shifts )
509-
510- eqns = [Inc (_field .xreplace (idx_subs ),
511- (self ._weights (subdomain = subdomain , shifts = shifts )
512- * _expr ).xreplace (idx_subs ),
513- implicit_dims = implicit_dims )
514- for _field , _expr in zip (fields , _exprs , strict = True )]
515-
516- return temps + eqns
473+ with suppress (AttributeError ):
474+ exprs = tuple (e ._eval_at (f ).evaluate
475+ for e , f in zip (exprs , fields , strict = True ))
476+
477+ eqns = []
478+ temps = []
479+ # We need to create one set of equations (temps and and coeffs) per staggering
480+ # field in which we inject as the reference index depends on the field's origin
481+ for _ , g in groupby (zip (fields , exprs , strict = True ), lambda f : f [0 ].staggered ):
482+ g_fields , g_exprs = zip (* g , strict = True )
483+ variables = list (v for e in g_exprs for v in retrieve_function_carriers (e ))
484+
485+ implicit_dims = self ._augment_implicit_dims (implicit_dims , variables )
486+
487+ # All fields in a single injection share the same staggering by
488+ # construction (they are written together at the same indices), so
489+ # derive shifts from the first field.
490+ shifts = self ._field_shifts (g_fields [0 ])
491+
492+ # Move all temporaries inside inner loop to improve parallelism
493+ # Can only be done for inject as interpolation needs a summing temp
494+ # that wouldn't allow collapsing
495+ implicit_dims = implicit_dims + tuple (r .parent for r in
496+ self ._rdim (subdomain = subdomain ,
497+ shifts = shifts ))
498+
499+ # List of indirection indices for all adjacent grid points
500+ idx_subs , _temps = self ._interp_idx (list (g_fields ) + variables ,
501+ implicit_dims = implicit_dims ,
502+ subdomain = subdomain , shifts = shifts )
503+
504+ w = self ._weights (subdomain = subdomain , shifts = shifts )
505+ temps .extend (_temps )
506+ eqns .extend ([Inc (f .xreplace (idx_subs ), (w * e ).xreplace (idx_subs ),
507+ implicit_dims = implicit_dims )
508+ for f , e in zip (g_fields , g_exprs , strict = True )])
509+
510+ return filter_ordered (temps ) + eqns
517511
518512
519513class LinearInterpolator (WeightedInterpolator ):
@@ -540,10 +534,13 @@ def _weights(self, subdomain=None, shifts=None):
540534 def _point_symbols (self , shifts = None ):
541535 """Symbol for coordinate value in each Dimension of the point."""
542536 dtype = self .sfunction .coordinates .dtype
543- suffix = self .sfunction ._shifts_suffix (shifts )
544- return DimensionTuple (* (Symbol (name = f'p{ d } { suffix } ' , dtype = dtype )
545- for d in self .grid .dimensions ),
546- getters = self .grid .dimensions )
537+ symbols = []
538+ for d in self .grid .dimensions :
539+ if shifts and shifts [self .grid .dimensions .index (d )] != 0 :
540+ symbols .append (Symbol (name = f'p{ d } _s1' , dtype = dtype ))
541+ else :
542+ symbols .append (Symbol (name = f'p{ d } ' , dtype = dtype ))
543+ return DimensionTuple (* symbols , getters = self .grid .dimensions )
547544
548545 def _coeff_temps (self , implicit_dims , shifts = None ):
549546 # Positions
0 commit comments