@@ -166,18 +166,20 @@ class Injection(UnevaluatedSparseOperation):
166166
167167 __rargs__ = ('field' , 'expr' , 'implicit_dims' ) + UnevaluatedSparseOperation .__rargs__
168168
169- def __new__ (cls , field , expr , implicit_dims , interpolator ):
169+ def __new__ (cls , field , expr , increment , implicit_dims , interpolator ):
170170 obj = super ().__new__ (cls , interpolator )
171171
172172 # TODO: unused now, but will be necessary to compute the adjoint
173173 obj .field = field
174174 obj .expr = expr
175+ obj .increment = increment
175176 obj .implicit_dims = implicit_dims
176177
177178 return obj
178179
179180 def operation (self , ** kwargs ):
180181 return self .interpolator ._inject (expr = self .expr , field = self .field ,
182+ increment = self .increment ,
181183 implicit_dims = self .implicit_dims )
182184
183185 def __repr__ (self ):
@@ -372,7 +374,7 @@ def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None)
372374
373375 @check_radius
374376 @check_coords
375- def inject (self , field , expr , implicit_dims = None ):
377+ def inject (self , field , expr , increment = True , implicit_dims = None ):
376378 """
377379 Generate equations injecting an arbitrary expression into a field.
378380
@@ -387,7 +389,7 @@ def inject(self, field, expr, implicit_dims=None):
387389 injection expression, but that should be honored when constructing
388390 the operator.
389391 """
390- return Injection (field , expr , implicit_dims , self )
392+ return Injection (field , expr , increment , implicit_dims , self )
391393
392394 def _interpolate (self , expr , increment = False , self_subs = None , implicit_dims = None ):
393395 """
@@ -439,7 +441,7 @@ def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None
439441
440442 return temps + summands + last
441443
442- def _inject (self , field , expr , implicit_dims = None ):
444+ def _inject (self , field , expr , increment = True , implicit_dims = None ):
443445 """
444446 Generate equations injecting an arbitrary expression into a field.
445447
@@ -489,9 +491,10 @@ def _inject(self, field, expr, implicit_dims=None):
489491 pos_only = variables , subdomain = subdomain )
490492
491493 # Substitute coordinate base symbols into the interpolation coefficients
492- eqns = [Inc (_field .xreplace (idx_subs ),
493- (self ._weights (subdomain = subdomain ) * _expr ).xreplace (idx_subs ),
494- implicit_dims = implicit_dims )
494+ ecls = Inc if increment else Eq
495+ eqns = [ecls (_field .xreplace (idx_subs ),
496+ (self ._weights (subdomain = subdomain ) * _expr ).xreplace (idx_subs ),
497+ implicit_dims = implicit_dims )
495498 for (_field , _expr ) in zip (fields , _exprs , strict = True )]
496499
497500 return temps + eqns
0 commit comments