Skip to content

Commit dd3ba5b

Browse files
Merge pull request #1175 from OceanParcels/fieldset_velocitysampling_fix
Fixing sampling of fieldset.UV
2 parents 2906a9a + aa33131 commit dd3ba5b

16 files changed

Lines changed: 253 additions & 120 deletions

parcels/application_kernels/advectiondiffusion.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,13 @@ def AdvectionDiffusionM1(particle, fieldset, time):
3333
Kxm1 = fieldset.Kh_zonal[time, particle.depth, particle.lat, particle.lon - fieldset.dres]
3434
dKdx = (Kxp1 - Kxm1) / (2 * fieldset.dres)
3535

36-
u = fieldset.U[time, particle.depth, particle.lat, particle.lon]
36+
u, v = fieldset.UV[time, particle.depth, particle.lat, particle.lon]
3737
bx = math.sqrt(2 * fieldset.Kh_zonal[time, particle.depth, particle.lat, particle.lon])
3838

3939
Kyp1 = fieldset.Kh_meridional[time, particle.depth, particle.lat + fieldset.dres, particle.lon]
4040
Kym1 = fieldset.Kh_meridional[time, particle.depth, particle.lat - fieldset.dres, particle.lon]
4141
dKdy = (Kyp1 - Kym1) / (2 * fieldset.dres)
4242

43-
v = fieldset.V[time, particle.depth, particle.lat, particle.lon]
4443
by = math.sqrt(2 * fieldset.Kh_meridional[time, particle.depth, particle.lat, particle.lon])
4544

4645
# Particle positions are updated only after evaluating all terms.
@@ -66,16 +65,18 @@ def AdvectionDiffusionEM(particle, fieldset, time):
6665
dWx = ParcelsRandom.normalvariate(0, math.sqrt(math.fabs(particle.dt)))
6766
dWy = ParcelsRandom.normalvariate(0, math.sqrt(math.fabs(particle.dt)))
6867

68+
u, v = fieldset.UV[time, particle.depth, particle.lat, particle.lon]
69+
6970
Kxp1 = fieldset.Kh_zonal[time, particle.depth, particle.lat, particle.lon + fieldset.dres]
7071
Kxm1 = fieldset.Kh_zonal[time, particle.depth, particle.lat, particle.lon - fieldset.dres]
7172
dKdx = (Kxp1 - Kxm1) / (2 * fieldset.dres)
72-
ax = fieldset.U[time, particle.depth, particle.lat, particle.lon] + dKdx
73+
ax = u + dKdx
7374
bx = math.sqrt(2 * fieldset.Kh_zonal[time, particle.depth, particle.lat, particle.lon])
7475

7576
Kyp1 = fieldset.Kh_meridional[time, particle.depth, particle.lat + fieldset.dres, particle.lon]
7677
Kym1 = fieldset.Kh_meridional[time, particle.depth, particle.lat - fieldset.dres, particle.lon]
7778
dKdy = (Kyp1 - Kym1) / (2 * fieldset.dres)
78-
ay = fieldset.V[time, particle.depth, particle.lat, particle.lon] + dKdy
79+
ay = v + dKdy
7980
by = math.sqrt(2 * fieldset.Kh_meridional[time, particle.depth, particle.lat, particle.lon])
8081

8182
# Particle positions are updated only after evaluating all terms.

parcels/compilation/codegenerator.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,31 @@ def __init__(self, field, args, var, convert=True):
7878

7979

8080
class VectorFieldNode(IntrinsicNode):
81+
def __getattr__(self, attr):
82+
if attr == "eval":
83+
return VectorFieldEvalCallNode(self)
84+
else:
85+
raise NotImplementedError('Access to VectorField attributes are not (yet) implemented in JIT mode')
86+
8187
def __getitem__(self, attr):
8288
return VectorFieldEvalNode(self.obj, attr)
8389

8490

91+
class VectorFieldEvalCallNode(IntrinsicNode):
92+
def __init__(self, field):
93+
self.field = field
94+
self.obj = field.obj
95+
self.ccode = ""
96+
97+
8598
class VectorFieldEvalNode(IntrinsicNode):
86-
def __init__(self, field, args, var, var2, var3):
99+
def __init__(self, field, args, var, var2, var3, convert=True):
87100
self.field = field
88101
self.args = args
89102
self.var = var # the variable in which the interpolated field is written
90103
self.var2 = var2 # second variable for UV interpolation
91104
self.var3 = var3 # third variable for UVW interpolation
105+
self.convert = convert # whether to convert the result (like field.applyConversion)
92106

93107

94108
class SummedFieldNode(IntrinsicNode):
@@ -412,6 +426,26 @@ def visit_Call(self, node):
412426
self.stmt_stack += [FieldEvalNode(node.func.field, args, tmp, convert)]
413427
return ast.Name(id=tmp)
414428

429+
elif isinstance(node.func, VectorFieldEvalCallNode):
430+
# get a temporary value to assign result to
431+
tmp1 = self.get_tmp()
432+
tmp2 = self.get_tmp()
433+
tmp3 = self.get_tmp() if node.func.field.obj.vector_type == '3D' else None
434+
# whether to convert
435+
convert = True
436+
if "applyConversion" in node.keywords:
437+
k = node.keywords["applyConversion"]
438+
if isinstance(k, ast.NameConstant):
439+
convert = k.value
440+
441+
# convert args to Index(Tuple(*args))
442+
args = ast.Index(value=ast.Tuple(node.args, ast.Load()))
443+
444+
self.stmt_stack += [VectorFieldEvalNode(node.func.field, args, tmp1, tmp2, tmp3, convert)]
445+
if tmp3:
446+
return ast.Tuple([ast.Name(id=tmp1), ast.Name(id=tmp2), ast.Name(id=tmp3)], ast.Load())
447+
else:
448+
return ast.Tuple([ast.Name(id=tmp1), ast.Name(id=tmp2)], ast.Load())
415449
return node
416450

417451

@@ -907,14 +941,14 @@ def visit_VectorFieldEvalNode(self, node):
907941
args = self._check_FieldSamplingArguments(node.args.ccode)
908942
ccode_eval = node.field.obj.ccode_eval_array(node.var, node.var2, node.var3,
909943
node.field.obj.U, node.field.obj.V, node.field.obj.W, *args)
910-
if node.field.obj.U.interp_method != 'cgrid_velocity':
944+
if node.convert and node.field.obj.U.interp_method != 'cgrid_velocity':
911945
ccode_conv1 = node.field.obj.U.ccode_convert(*args)
912946
ccode_conv2 = node.field.obj.V.ccode_convert(*args)
913947
statements = [c.Statement("%s *= %s" % (node.var, ccode_conv1)),
914948
c.Statement("%s *= %s" % (node.var2, ccode_conv2))]
915949
else:
916950
statements = []
917-
if node.field.obj.vector_type == '3D':
951+
if node.convert and node.field.obj.vector_type == '3D':
918952
ccode_conv3 = node.field.obj.W.ccode_convert(*args)
919953
statements.append(c.Statement("%s *= %s" % (node.var3, ccode_conv3)))
920954
conv_stat = c.Block(statements)
@@ -1058,14 +1092,14 @@ def visit_VectorFieldEvalNode(self, node):
10581092
self.visit(node.args)
10591093
args = self._check_FieldSamplingArguments(node.args.ccode)
10601094
ccode_eval = node.field.obj.ccode_eval_object(node.var, node.var2, node.var3, node.field.obj.U, node.field.obj.V, node.field.obj.W, *args)
1061-
if node.field.obj.U.interp_method != 'cgrid_velocity':
1095+
if node.convert and node.field.obj.U.interp_method != 'cgrid_velocity':
10621096
ccode_conv1 = node.field.obj.U.ccode_convert(*args)
10631097
ccode_conv2 = node.field.obj.V.ccode_convert(*args)
10641098
statements = [c.Statement("%s *= %s" % (node.var, ccode_conv1)),
10651099
c.Statement("%s *= %s" % (node.var2, ccode_conv2))]
10661100
else:
10671101
statements = []
1068-
if node.field.obj.vector_type == '3D':
1102+
if node.convert and node.field.obj.vector_type == '3D':
10691103
ccode_conv3 = node.field.obj.W.ccode_convert(*args)
10701104
statements.append(c.Statement("%s *= %s" % (node.var3, ccode_conv3)))
10711105
conv_stat = c.Block(statements)

parcels/examples/example_globcurrent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ class MyParticle(ptype[mode]):
105105
pset = ParticleSet(fieldset, pclass=MyParticle, lon=25, lat=-35, time=fieldset.U.grid.time[0])
106106

107107
def SampleU(particle, fieldset, time):
108-
particle.sample_var += fieldset.U[time, particle.depth, particle.lat, particle.lon]
108+
u, v = fieldset.UV[time, particle.depth, particle.lat, particle.lon]
109+
particle.sample_var += u
109110

110111
pset.execute(SampleU, runtime=delta(days=rundays), dt=delta(days=1))
111112
sample_var.append(pset[0].sample_var)

parcels/examples/tutorial_diffusion.ipynb

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -400,10 +400,16 @@
400400
"def smagdiff(particle, fieldset, time):\n",
401401
" dx = 0.01\n",
402402
" # gradients are computed by using a local central difference.\n",
403-
" dudx = (fieldset.U[time, particle.depth, particle.lat, particle.lon+dx]-fieldset.U[time, particle.depth, particle.lat, particle.lon-dx]) / (2*dx)\n",
404-
" dudy = (fieldset.U[time, particle.depth, particle.lat+dx, particle.lon]-fieldset.U[time, particle.depth, particle.lat-dx, particle.lon]) / (2*dx)\n",
405-
" dvdx = (fieldset.V[time, particle.depth, particle.lat, particle.lon+dx]-fieldset.V[time, particle.depth, particle.lat, particle.lon-dx]) / (2*dx)\n",
406-
" dvdy = (fieldset.V[time, particle.depth, particle.lat+dx, particle.lon]-fieldset.V[time, particle.depth, particle.lat-dx, particle.lon]) / (2*dx)\n",
403+
" updx, vpdx = fieldset.UV[time, particle.depth, particle.lat, particle.lon+dx]\n",
404+
" umdx, vmdx = fieldset.UV[time, particle.depth, particle.lat, particle.lon-dx]\n",
405+
" updy, vpdy = fieldset.UV[time, particle.depth, particle.lat+dx, particle.lon]\n",
406+
" umdy, vmdy = fieldset.UV[time, particle.depth, particle.lat-dx, particle.lon]\n",
407+
"\n",
408+
" dudx = (updx - umdx) / (2*dx)\n",
409+
" dudy = (updy - umdy) / (2*dx)\n",
410+
" \n",
411+
" dvdx = (vpdx - vmdx) / (2*dx)\n",
412+
" dvdy = (vpdy - vmdy) / (2*dx)\n",
407413
"\n",
408414
" A = fieldset.cell_areas[time, 0, particle.lat, particle.lon]\n",
409415
" sq_deg_to_sq_m = (1852*60)**2*math.cos(particle.lat*math.pi/180)\n",
@@ -631,7 +637,7 @@
631637
],
632638
"metadata": {
633639
"kernelspec": {
634-
"display_name": "Python 3",
640+
"display_name": "Python 3 (ipykernel)",
635641
"language": "python",
636642
"name": "python3"
637643
},
@@ -645,7 +651,7 @@
645651
"name": "python",
646652
"nbconvert_exporter": "python",
647653
"pygments_lexer": "ipython3",
648-
"version": "3.8.8"
654+
"version": "3.8.13"
649655
}
650656
},
651657
"nbformat": 4,

0 commit comments

Comments
 (0)