11from abc import ABC , abstractmethod
22from functools import cached_property , wraps
3+ from itertools import groupby
34
45import numpy as np
56import sympy
@@ -257,7 +258,8 @@ def _field_shifts(self, field):
257258 if not staggered or all (s == 0 for s in staggered ):
258259 return None
259260 return tuple ((d .spacing / 2 ) if s else 0
260- for d , s in zip (self ._gdims , staggered , strict = True ))
261+ for d , s in zip (field .dimensions , staggered , strict = True )
262+ if d .is_Space )
261263
262264 @memoized_meth
263265 def _rdim (self , subdomain = None , shifts = None ):
@@ -485,35 +487,43 @@ def _inject(self, field, expr, implicit_dims=None):
485487 # E.g., a generic SymPy expression or a number
486488 _exprs = exprs
487489
488- variables = list (v for e in _exprs for v in retrieve_function_carriers (e ))
489-
490- # Implicit dimensions
491- implicit_dims = self ._augment_implicit_dims (implicit_dims , variables )
492-
493- # All fields in a single injection share the same staggering by
494- # construction (they are written together at the same indices), so
495- # derive shifts from the first field.
496- shifts = self ._field_shifts (fields [0 ])
497-
498- # Move all temporaries inside inner loop to improve parallelism
499- # Can only be done for inject as interpolation needs a summing temp
500- # that wouldn't allow collapsing
501- implicit_dims = implicit_dims + tuple (r .parent for r in
502- self ._rdim (subdomain = subdomain ,
503- shifts = shifts ))
504-
505- # List of indirection indices for all adjacent grid points
506- idx_subs , temps = self ._interp_idx (list (fields ) + variables ,
507- implicit_dims = implicit_dims ,
508- subdomain = subdomain , shifts = shifts )
509-
510- eqns = [Inc (_field .xreplace (idx_subs ),
511- (self ._weights (subdomain = subdomain , shifts = shifts )
512- * _expr ).xreplace (idx_subs ),
513- implicit_dims = implicit_dims )
514- for _field , _expr in zip (fields , _exprs , strict = True )]
515-
516- return temps + eqns
490+ eqns = []
491+ temps = []
492+ pairs = zip (fields , _exprs , strict = True )
493+ # We need to create one set of equations (temps and and coeffs) per staggering
494+ # field in which we inject as the reference index depends on the field's origin
495+ for _ , g in groupby (pairs , lambda f : f [0 ].staggered ):
496+ g = list (g )
497+ g_fields = [f for f , _ in g ]
498+ g_exprs = [e for _ , e in g ]
499+ variables = list (v for e in g_exprs for v in retrieve_function_carriers (e ))
500+
501+ # Implicit dimensions
502+ implicit_dims = self ._augment_implicit_dims (implicit_dims , variables )
503+
504+ # All fields in a single injection share the same staggering by
505+ # construction (they are written together at the same indices), so
506+ # derive shifts from the first field.
507+ shifts = self ._field_shifts (g_fields [0 ])
508+
509+ # Move all temporaries inside inner loop to improve parallelism
510+ # Can only be done for inject as interpolation needs a summing temp
511+ # that wouldn't allow collapsing
512+ implicit_dims = implicit_dims + tuple (r .parent for r in
513+ self ._rdim (subdomain = subdomain ,
514+ shifts = shifts ))
515+
516+ # List of indirection indices for all adjacent grid points
517+ idx_subs , _temps = self ._interp_idx (g_fields + variables ,
518+ implicit_dims = implicit_dims ,
519+ subdomain = subdomain , shifts = shifts )
520+ w = self ._weights (subdomain = subdomain , shifts = shifts )
521+ temps .extend (_temps )
522+ eqns .extend ([Inc (_field .xreplace (idx_subs ), (w * _expr ).xreplace (idx_subs ),
523+ implicit_dims = implicit_dims )
524+ for _field , _expr in zip (g_fields , g_exprs , strict = True )])
525+
526+ return filter_ordered (temps ) + eqns
517527
518528
519529class LinearInterpolator (WeightedInterpolator ):
@@ -540,10 +550,13 @@ def _weights(self, subdomain=None, shifts=None):
540550 def _point_symbols (self , shifts = None ):
541551 """Symbol for coordinate value in each Dimension of the point."""
542552 dtype = self .sfunction .coordinates .dtype
543- suffix = self .sfunction ._shifts_suffix (shifts )
544- return DimensionTuple (* (Symbol (name = f'p{ d } { suffix } ' , dtype = dtype )
545- for d in self .grid .dimensions ),
546- getters = self .grid .dimensions )
553+ symbols = []
554+ for d in self .grid .dimensions :
555+ if shifts and shifts [self .grid .dimensions .index (d )] != 0 :
556+ symbols .append (Symbol (name = f'p{ d } _s1' , dtype = dtype ))
557+ else :
558+ symbols .append (Symbol (name = f'p{ d } ' , dtype = dtype ))
559+ return DimensionTuple (* symbols , getters = self .grid .dimensions )
547560
548561 def _coeff_temps (self , implicit_dims , shifts = None ):
549562 # Positions
0 commit comments