@@ -164,19 +164,21 @@ class Injection(UnevaluatedSparseOperation):
164164
165165 __rargs__ = ('field' , 'expr' , 'implicit_dims' ) + UnevaluatedSparseOperation .__rargs__
166166
167- def __new__ (cls , field , expr , implicit_dims , interpolator ):
167+ def __new__ (cls , field , expr , implicit_dims , interpolator , interp_expr = False ):
168168 obj = super ().__new__ (cls , interpolator )
169169
170170 # TODO: unused now, but will be necessary to compute the adjoint
171171 obj .field = field
172172 obj .expr = expr
173173 obj .implicit_dims = implicit_dims
174+ obj .interp_expr = interp_expr
174175
175176 return obj
176177
177178 def operation (self , ** kwargs ):
178179 return self .interpolator ._inject (expr = self .expr , field = self .field ,
179- implicit_dims = self .implicit_dims )
180+ implicit_dims = self .implicit_dims ,
181+ interp_expr = self .interp_expr )
180182
181183 def __repr__ (self ):
182184 return f"Injection({ repr (self .expr )} into { repr (self .field )} )"
@@ -366,7 +368,7 @@ def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None)
366368
367369 @check_radius
368370 @check_coords
369- def inject (self , field , expr , implicit_dims = None ):
371+ def inject (self , field , expr , implicit_dims = None , interp_expr = False ):
370372 """
371373 Generate equations injecting an arbitrary expression into a field.
372374
@@ -381,7 +383,7 @@ def inject(self, field, expr, implicit_dims=None):
381383 injection expression, but that should be honored when constructing
382384 the operator.
383385 """
384- return Injection (field , expr , implicit_dims , self )
386+ return Injection (field , expr , implicit_dims , self , interp_expr = interp_expr )
385387
386388 def _interpolate (self , expr , increment = False , self_subs = None , implicit_dims = None ):
387389 """
@@ -433,7 +435,7 @@ def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None
433435
434436 return temps + summands + last
435437
436- def _inject (self , field , expr , implicit_dims = None ):
438+ def _inject (self , field , expr , implicit_dims = None , interp_expr = False ):
437439 """
438440 Generate equations injecting an arbitrary expression into a field.
439441
@@ -479,8 +481,10 @@ def _inject(self, field, expr, implicit_dims=None):
479481 self ._rdim (subdomain = subdomain ))
480482
481483 # List of indirection indices for all adjacent grid points
482- idx_subs , temps = self ._interp_idx (fields , implicit_dims = implicit_dims ,
483- pos_only = variables , subdomain = subdomain )
484+ finterp = fields + as_tuple (variables ) if interp_expr else fields
485+ pos_only = () if interp_expr else variables
486+ idx_subs , temps = self ._interp_idx (finterp , implicit_dims = implicit_dims ,
487+ pos_only = pos_only , subdomain = subdomain )
484488
485489 # Substitute coordinate base symbols into the interpolation coefficients
486490 eqns = [Inc (_field .xreplace (idx_subs ),
0 commit comments