Skip to content

Commit 13abea2

Browse files
Implementing fieldset.UV.eval for SummedFields too
This fixes #1172 (comment)
1 parent ad1ef0c commit 13abea2

3 files changed

Lines changed: 53 additions & 7 deletions

File tree

parcels/compilation/codegenerator.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,17 +118,29 @@ def __init__(self, fields, args, var):
118118

119119

120120
class SummedVectorFieldNode(IntrinsicNode):
121+
def __getattr__(self, attr):
122+
if attr == "eval":
123+
return SummedVectorFieldEvalCallNode(self)
124+
121125
def __getitem__(self, attr):
122126
return SummedVectorFieldEvalNode(self.obj, attr)
123127

124128

129+
class SummedVectorFieldEvalCallNode(IntrinsicNode):
130+
def __init__(self, field):
131+
self.field = field
132+
self.obj = field.obj
133+
self.ccode = ""
134+
135+
125136
class SummedVectorFieldEvalNode(IntrinsicNode):
126-
def __init__(self, fields, args, var, var2, var3):
137+
def __init__(self, fields, args, var, var2, var3, convert=True):
127138
self.fields = fields
128139
self.args = args
129140
self.var = var # the variable in which the interpolated field is written
130141
self.var2 = var2 # second variable for UV interpolation
131142
self.var3 = var3 # third variable for UVW interpolation
143+
self.convert = convert # whether to convert the result (like field.applyConversion)
132144

133145

134146
class NestedFieldNode(IntrinsicNode):
@@ -446,6 +458,28 @@ def visit_Call(self, node):
446458
return ast.Tuple([ast.Name(id=tmp1), ast.Name(id=tmp2), ast.Name(id=tmp3)], ast.Load())
447459
else:
448460
return ast.Tuple([ast.Name(id=tmp1), ast.Name(id=tmp2)], ast.Load())
461+
462+
elif isinstance(node.func, SummedVectorFieldEvalCallNode):
463+
# get a temporary value to assign result to
464+
tmp = [self.get_tmp() for _ in range(len(node.func.obj))]
465+
tmp2 = [self.get_tmp() for _ in range(len(node.func.obj))]
466+
tmp3 = [self.get_tmp() if list.__getitem__(node.func.obj, 0).vector_type == '3D' else None for _ in range(len(node.func.obj))]
467+
# whether to convert
468+
convert = True
469+
if "applyConversion" in node.keywords:
470+
k = node.keywords["applyConversion"]
471+
if isinstance(k, ast.NameConstant):
472+
convert = k.value
473+
474+
# convert args to Index(Tuple(*args))
475+
args = ast.Index(value=ast.Tuple(node.args, ast.Load()))
476+
477+
self.stmt_stack += [SummedVectorFieldEvalNode(node.func.field, args, tmp, tmp2, tmp3, convert)]
478+
if all(tmp3):
479+
return ast.Tuple([ast.Name(id='+'.join(tmp)), ast.Name(id='+'.join(tmp2)), ast.Name(id='+'.join(tmp3))], ast.Load())
480+
else:
481+
return ast.Tuple([ast.Name(id='+'.join(tmp)), ast.Name(id='+'.join(tmp2))], ast.Load())
482+
449483
return node
450484

451485

@@ -975,14 +1009,14 @@ def visit_SummedVectorFieldEvalNode(self, node):
9751009
for fld, var, var2, var3 in zip(node.fields.obj, node.var, node.var2, node.var3):
9761010
ccode_eval = fld.ccode_eval_array(var, var2, var3,
9771011
fld.U, fld.V, fld.W, *args)
978-
if fld.U.interp_method != 'cgrid_velocity':
1012+
if node.convert and fld.U.interp_method != 'cgrid_velocity':
9791013
ccode_conv1 = fld.U.ccode_convert(*args)
9801014
ccode_conv2 = fld.V.ccode_convert(*args)
9811015
statements = [c.Statement("%s *= %s" % (var, ccode_conv1)),
9821016
c.Statement("%s *= %s" % (var2, ccode_conv2))]
9831017
else:
9841018
statements = []
985-
if fld.vector_type == '3D':
1019+
if node.convert and fld.vector_type == '3D':
9861020
ccode_conv3 = fld.W.ccode_convert(*args)
9871021
statements.append(c.Statement("%s *= %s" % (var3, ccode_conv3)))
9881022
cstat += [c.Assign("err", ccode_eval), c.Block(statements)]
@@ -1125,14 +1159,14 @@ def visit_SummedVectorFieldEvalNode(self, node):
11251159
args = self._check_FieldSamplingArguments(node.args.ccode)
11261160
for fld, var, var2, var3 in zip(node.fields.obj, node.var, node.var2, node.var3):
11271161
ccode_eval = fld.ccode_eval_object(var, var2, var3, fld.U, fld.V, fld.W, *args)
1128-
if fld.U.interp_method != 'cgrid_velocity':
1162+
if node.convert and fld.U.interp_method != 'cgrid_velocity':
11291163
ccode_conv1 = fld.U.ccode_convert(*args)
11301164
ccode_conv2 = fld.V.ccode_convert(*args)
11311165
statements = [c.Statement("%s *= %s" % (var, ccode_conv1)),
11321166
c.Statement("%s *= %s" % (var2, ccode_conv2))]
11331167
else:
11341168
statements = []
1135-
if fld.vector_type == '3D':
1169+
if node.convert and fld.vector_type == '3D':
11361170
ccode_conv3 = fld.W.ccode_convert(*args)
11371171
statements.append(c.Statement("%s *= %s" % (var3, ccode_conv3)))
11381172
cstat += [c.Assign("err", ccode_eval), c.Block(statements)]

parcels/field.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1887,6 +1887,14 @@ def __init__(self, name, F, V=None, W=None):
18871887
self.append(VectorField(name+'_%d' % i, Fi, Vi, Wi))
18881888
self.name = name
18891889

1890+
def eval(self, time, z, y, x, particle=None, applyConversion=True):
1891+
vals = []
1892+
val = None
1893+
for iField in range(len(self)):
1894+
val = list.__getitem__(self, iField).eval(time, z, y, x, applyConversion=applyConversion)
1895+
vals.append(val)
1896+
return tuple(np.sum(vals, 0)) if isinstance(val, tuple) else np.sum(vals)
1897+
18901898
def __getitem__(self, key):
18911899
if isinstance(key, int):
18921900
return list.__getitem__(self, key)

tests/test_fieldset_sampling.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -814,16 +814,20 @@ def test_summedfields(pset_mode, mode, with_W, k_sample_p, mesh):
814814
fieldsetS.add_field((P1+P4)+(P2+P3), name='P')
815815
assert np.allclose(fieldsetS.P[0, 0, 0, 0], 60)
816816

817+
def sample_UV_noconvert(particle, fieldset, time):
818+
(particle.u, particle.v) = fieldset.UV.eval(time, particle.depth, particle.lat, particle.lon, applyConversion=False) # noqa
819+
817820
if with_W:
818821
W1 = Field('W', 2*np.ones((zdim * gf, ydim * gf, xdim * gf), dtype=np.float32), grid=U1.grid)
819822
W2 = Field('W', np.ones((zdim, ydim, xdim), dtype=np.float32), grid=U2.grid)
820823
fieldsetS.add_field(W1+W2, name='W')
821824
pset = pset_type[pset_mode]['pset'](fieldsetS, pclass=pclass(mode), lon=[0], lat=[0.9])
822-
pset.execute(AdvectionRK4_3D+pset.Kernel(k_sample_p), runtime=2, dt=1)
825+
pset.execute(AdvectionRK4_3D+pset.Kernel(k_sample_p)+sample_UV_noconvert, runtime=2, dt=1)
823826
assert np.isclose(pset.depth[0], 6)
824827
else:
825828
pset = pset_type[pset_mode]['pset'](fieldsetS, pclass=pclass(mode), lon=[0], lat=[0.9])
826-
pset.execute(AdvectionRK4+pset.Kernel(k_sample_p), runtime=2, dt=1)
829+
pset.execute(AdvectionRK4+pset.Kernel(k_sample_p)+sample_UV_noconvert, runtime=2, dt=1)
830+
assert np.isclose(pset.u[0], 0.3)
827831
assert np.isclose(pset.p[0], 60)
828832
assert np.isclose(pset.lon[0]*conv, 0.6, atol=1e-3)
829833
assert np.isclose(pset.lat[0], 0.9)

0 commit comments

Comments
 (0)