Skip to content

Commit 7f66104

Browse files
Merge pull request #1200 from OceanParcels/summedvectorfield_eval_implementation
Implementing fieldset.UV.eval for SummedVectorFields too
2 parents 85a45b0 + ed36fdd commit 7f66104

3 files changed

Lines changed: 53 additions & 7 deletions

File tree

parcels/compilation/codegenerator.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,17 +118,29 @@ def __init__(self, fields, args, var):
118118

119119

120120
class SummedVectorFieldNode(IntrinsicNode):
121+
def __getattr__(self, attr):
122+
if attr == "eval":
123+
return SummedVectorFieldEvalCallNode(self)
124+
121125
def __getitem__(self, attr):
122126
return SummedVectorFieldEvalNode(self.obj, attr)
123127

124128

129+
class SummedVectorFieldEvalCallNode(IntrinsicNode):
130+
def __init__(self, field):
131+
self.field = field
132+
self.obj = field.obj
133+
self.ccode = ""
134+
135+
125136
class SummedVectorFieldEvalNode(IntrinsicNode):
126-
def __init__(self, fields, args, var, var2, var3):
137+
def __init__(self, fields, args, var, var2, var3, convert=True):
127138
self.fields = fields
128139
self.args = args
129140
self.var = var # the variable in which the interpolated field is written
130141
self.var2 = var2 # second variable for UV interpolation
131142
self.var3 = var3 # third variable for UVW interpolation
143+
self.convert = convert # whether to convert the result (like field.applyConversion)
132144

133145

134146
class NestedFieldNode(IntrinsicNode):
@@ -450,6 +462,28 @@ def visit_Call(self, node):
450462
return ast.Tuple([ast.Name(id=tmp1), ast.Name(id=tmp2), ast.Name(id=tmp3)], ast.Load())
451463
else:
452464
return ast.Tuple([ast.Name(id=tmp1), ast.Name(id=tmp2)], ast.Load())
465+
466+
elif isinstance(node.func, SummedVectorFieldEvalCallNode):
467+
# get a temporary value to assign result to
468+
tmp = [self.get_tmp() for _ in range(len(node.func.obj))]
469+
tmp2 = [self.get_tmp() for _ in range(len(node.func.obj))]
470+
tmp3 = [self.get_tmp() if list.__getitem__(node.func.obj, 0).vector_type == '3D' else None for _ in range(len(node.func.obj))]
471+
# whether to convert
472+
convert = True
473+
if "applyConversion" in node.keywords:
474+
k = node.keywords["applyConversion"]
475+
if isinstance(k, ast.NameConstant):
476+
convert = k.value
477+
478+
# convert args to Index(Tuple(*args))
479+
args = ast.Index(value=ast.Tuple(node.args, ast.Load()))
480+
481+
self.stmt_stack += [SummedVectorFieldEvalNode(node.func.field, args, tmp, tmp2, tmp3, convert)]
482+
if all(tmp3):
483+
return ast.Tuple([ast.Name(id='+'.join(tmp)), ast.Name(id='+'.join(tmp2)), ast.Name(id='+'.join(tmp3))], ast.Load())
484+
else:
485+
return ast.Tuple([ast.Name(id='+'.join(tmp)), ast.Name(id='+'.join(tmp2))], ast.Load())
486+
453487
return node
454488

455489

@@ -979,14 +1013,14 @@ def visit_SummedVectorFieldEvalNode(self, node):
9791013
for fld, var, var2, var3 in zip(node.fields.obj, node.var, node.var2, node.var3):
9801014
ccode_eval = fld.ccode_eval_array(var, var2, var3,
9811015
fld.U, fld.V, fld.W, *args)
982-
if fld.U.interp_method != 'cgrid_velocity':
1016+
if node.convert and fld.U.interp_method != 'cgrid_velocity':
9831017
ccode_conv1 = fld.U.ccode_convert(*args)
9841018
ccode_conv2 = fld.V.ccode_convert(*args)
9851019
statements = [c.Statement("%s *= %s" % (var, ccode_conv1)),
9861020
c.Statement("%s *= %s" % (var2, ccode_conv2))]
9871021
else:
9881022
statements = []
989-
if fld.vector_type == '3D':
1023+
if node.convert and fld.vector_type == '3D':
9901024
ccode_conv3 = fld.W.ccode_convert(*args)
9911025
statements.append(c.Statement("%s *= %s" % (var3, ccode_conv3)))
9921026
cstat += [c.Assign("err", ccode_eval), c.Block(statements)]
@@ -1129,14 +1163,14 @@ def visit_SummedVectorFieldEvalNode(self, node):
11291163
args = self._check_FieldSamplingArguments(node.args.ccode)
11301164
for fld, var, var2, var3 in zip(node.fields.obj, node.var, node.var2, node.var3):
11311165
ccode_eval = fld.ccode_eval_object(var, var2, var3, fld.U, fld.V, fld.W, *args)
1132-
if fld.U.interp_method != 'cgrid_velocity':
1166+
if node.convert and fld.U.interp_method != 'cgrid_velocity':
11331167
ccode_conv1 = fld.U.ccode_convert(*args)
11341168
ccode_conv2 = fld.V.ccode_convert(*args)
11351169
statements = [c.Statement("%s *= %s" % (var, ccode_conv1)),
11361170
c.Statement("%s *= %s" % (var2, ccode_conv2))]
11371171
else:
11381172
statements = []
1139-
if fld.vector_type == '3D':
1173+
if node.convert and fld.vector_type == '3D':
11401174
ccode_conv3 = fld.W.ccode_convert(*args)
11411175
statements.append(c.Statement("%s *= %s" % (var3, ccode_conv3)))
11421176
cstat += [c.Assign("err", ccode_eval), c.Block(statements)]

parcels/field.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1897,6 +1897,14 @@ def __init__(self, name, F, V=None, W=None):
18971897
self.append(VectorField(name+'_%d' % i, Fi, Vi, Wi))
18981898
self.name = name
18991899

1900+
def eval(self, time, z, y, x, particle=None, applyConversion=True):
1901+
vals = []
1902+
val = None
1903+
for iField in range(len(self)):
1904+
val = list.__getitem__(self, iField).eval(time, z, y, x, applyConversion=applyConversion)
1905+
vals.append(val)
1906+
return tuple(np.sum(vals, 0)) if isinstance(val, tuple) else np.sum(vals)
1907+
19001908
def __getitem__(self, key):
19011909
if isinstance(key, int):
19021910
return list.__getitem__(self, key)

tests/test_fieldset_sampling.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -814,16 +814,20 @@ def test_summedfields(pset_mode, mode, with_W, k_sample_p, mesh):
814814
fieldsetS.add_field((P1+P4)+(P2+P3), name='P')
815815
assert np.allclose(fieldsetS.P[0, 0, 0, 0], 60)
816816

817+
def sample_UV_noconvert(particle, fieldset, time):
818+
(particle.u, particle.v) = fieldset.UV.eval(time, particle.depth, particle.lat, particle.lon, applyConversion=False) # noqa
819+
817820
if with_W:
818821
W1 = Field('W', 2*np.ones((zdim * gf, ydim * gf, xdim * gf), dtype=np.float32), grid=U1.grid)
819822
W2 = Field('W', np.ones((zdim, ydim, xdim), dtype=np.float32), grid=U2.grid)
820823
fieldsetS.add_field(W1+W2, name='W')
821824
pset = pset_type[pset_mode]['pset'](fieldsetS, pclass=pclass(mode), lon=[0], lat=[0.9])
822-
pset.execute(AdvectionRK4_3D+pset.Kernel(k_sample_p), runtime=2, dt=1)
825+
pset.execute(AdvectionRK4_3D+pset.Kernel(k_sample_p)+sample_UV_noconvert, runtime=2, dt=1)
823826
assert np.isclose(pset.depth[0], 6)
824827
else:
825828
pset = pset_type[pset_mode]['pset'](fieldsetS, pclass=pclass(mode), lon=[0], lat=[0.9])
826-
pset.execute(AdvectionRK4+pset.Kernel(k_sample_p), runtime=2, dt=1)
829+
pset.execute(AdvectionRK4+pset.Kernel(k_sample_p)+sample_UV_noconvert, runtime=2, dt=1)
830+
assert np.isclose(pset.u[0], 0.3)
827831
assert np.isclose(pset.p[0], 60)
828832
assert np.isclose(pset.lon[0]*conv, 0.6, atol=1e-3)
829833
assert np.isclose(pset.lat[0], 0.9)

0 commit comments

Comments
 (0)