Skip to content

Commit 9a5c0ee

Browse files
Merge pull request #938 from OceanParcels/bugfix_add_vectorfield
Fixing bug when FieldSet.add_vector_field() called without adding individual Fields
2 parents b1773bd + 70e0cc5 commit 9a5c0ee

4 files changed

Lines changed: 61 additions & 4 deletions

File tree

parcels/examples/tutorial_NestedFields.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@
179179
"metadata": {},
180180
"outputs": [],
181181
"source": [
182+
"fieldset = FieldSet(U, V) # Need to redefine fieldset because FieldSets need to be constructed before ParticleSets\n",
182183
"F1 = Field('F1', np.ones((U1.grid.ydim, U1.grid.xdim), dtype=np.float32), grid=U1.grid)\n",
183184
"F2 = Field('F2', 2*np.ones((U2.grid.ydim, U2.grid.xdim), dtype=np.float32), grid=U2.grid)\n",
184185
"F = NestedField('F', [F1, F2])\n",

parcels/fieldset.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class FieldSet(object):
3434
"""
3535
def __init__(self, U, V, fields=None):
3636
self.gridset = GridSet()
37+
self.completed = False
3738
if U:
3839
self.add_field(U, 'U')
3940
self.time_origin = self.U.grid.time_origin if isinstance(self.U, Field) else self.U[0].grid.time_origin
@@ -140,6 +141,8 @@ def add_field(self, field, name=None):
140141
* `Unit converters <https://nbviewer.jupyter.org/github/OceanParcels/parcels/blob/master/parcels/examples/tutorial_unitconverters.ipynb>`_
141142
142143
"""
144+
if self.completed:
145+
raise RuntimeError("FieldSet has already been completed. Are you trying to add a Field after you've created the ParticleSet?")
143146
name = field.name if name is None else name
144147
if hasattr(self, name): # check if Field with same name already exists when adding new Field
145148
raise RuntimeError("FieldSet already has a Field with name '%s'" % name)
@@ -229,6 +232,9 @@ def add_vector_field(self, vfield):
229232
:param vfield: :class:`parcels.field.VectorField` object to be added
230233
"""
231234
setattr(self, vfield.name, vfield)
235+
for v in vfield.__dict__.values():
236+
if isinstance(v, Field) and (v not in self.get_fields()):
237+
self.add_field(v)
232238
vfield.fieldset = self
233239
if isinstance(vfield, NestedField):
234240
for f in vfield:
@@ -314,6 +320,7 @@ def check_velocityfields(U, V, W):
314320
if not f.grid.defer_load:
315321
depth_data = f.grid.depth_field.data
316322
f.grid.depth = depth_data if isinstance(depth_data, np.ndarray) else np.array(depth_data)
323+
self.completed = True
317324

318325
@classmethod
319326
def parse_wildcards(cls, paths, filenames, var):

tests/test_fieldset.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,26 @@ def test_add_duplicate_field(dupobject):
253253
assert error_thrown
254254

255255

256+
@pytest.mark.parametrize('fieldtype', ['normal', 'vector'])
257+
def test_add_field_after_pset(fieldtype):
258+
data, dimensions = generate_fieldset(100, 100)
259+
fieldset = FieldSet.from_data(data, dimensions)
260+
pset = ParticleSet(fieldset, ScipyParticle, lon=0, lat=0) # noqa ; to trigger fieldset.check_complete
261+
field1 = Field('field1', fieldset.U.data, lon=fieldset.U.lon, lat=fieldset.U.lat)
262+
field2 = Field('field2', fieldset.U.data, lon=fieldset.U.lon, lat=fieldset.U.lat)
263+
vfield = VectorField('vfield', field1, field2)
264+
error_thrown = False
265+
try:
266+
if fieldtype == 'normal':
267+
fieldset.add_field(field1)
268+
elif fieldtype == 'vector':
269+
fieldset.add_vector_field(vfield)
270+
except RuntimeError:
271+
error_thrown = True
272+
273+
assert error_thrown
274+
275+
256276
def test_fieldset_samegrids_from_file(tmpdir, filename='test_subsets'):
257277
""" Test for subsetting fieldset from file using indices dict. """
258278
data, dimensions = generate_fieldset(100, 100)
@@ -463,6 +483,36 @@ def test_vector_fields(mode, swapUV):
463483
assert abs(pset.lat[0] - .5) < 1e-9
464484

465485

486+
@pytest.mark.parametrize('mode', ['scipy', 'jit'])
487+
def test_add_second_vector_field(mode):
488+
lon = np.linspace(0., 10., 12, dtype=np.float32)
489+
lat = np.linspace(0., 10., 10, dtype=np.float32)
490+
U = np.ones((10, 12), dtype=np.float32)
491+
V = np.zeros((10, 12), dtype=np.float32)
492+
data = {'U': U, 'V': V}
493+
dimensions = {'U': {'lat': lat, 'lon': lon},
494+
'V': {'lat': lat, 'lon': lon}}
495+
fieldset = FieldSet.from_data(data, dimensions, mesh='flat')
496+
497+
data2 = {'U2': U, 'V2': V}
498+
dimensions2 = {'lon': [ln + 0.1 for ln in lon], 'lat': [lt - 0.1 for lt in lat]}
499+
fieldset2 = FieldSet.from_data(data2, dimensions2, mesh='flat')
500+
501+
UV2 = VectorField('UV2', fieldset2.U2, fieldset2.V2)
502+
fieldset.add_vector_field(UV2)
503+
504+
def SampleUV2(particle, fieldset, time):
505+
u, v = fieldset.UV2[time, particle.depth, particle.lat, particle.lon]
506+
particle.lon += u * particle.dt
507+
particle.lat += v * particle.dt
508+
509+
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=0.5, lat=0.5)
510+
pset.execute(AdvectionRK4+pset.Kernel(SampleUV2), dt=1, runtime=1)
511+
512+
assert abs(pset.lon[0] - 2.5) < 1e-9
513+
assert abs(pset.lat[0] - .5) < 1e-9
514+
515+
466516
@pytest.mark.parametrize('mode', ['scipy', 'jit'])
467517
@pytest.mark.parametrize('time_periodic', [4*86400.0, False])
468518
@pytest.mark.parametrize('field_chunksize', [False, 'auto', (1, 32, 32)])

tests/test_fieldset_sampling.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -603,17 +603,16 @@ def test_multiple_grid_addlater_error():
603603
lat=np.linspace(0., 1., ydim, dtype=np.float32))
604604
fieldset = FieldSet(U, V)
605605

606-
pset = ParticleSet(fieldset, pclass=pclass('jit'), lon=[0.8], lat=[0.9])
606+
pset = ParticleSet(fieldset, pclass=pclass('jit'), lon=[0.8], lat=[0.9]) # noqa ; to trigger fieldset.check_complete
607607

608608
P = Field('P', np.zeros((ydim*10, xdim*10), dtype=np.float32),
609609
lon=np.linspace(0., 1., xdim*10, dtype=np.float32),
610610
lat=np.linspace(0., 1., ydim*10, dtype=np.float32))
611-
fieldset.add_field(P)
612611

613612
fail = False
614613
try:
615-
pset.execute(AdvectionRK4, runtime=10, dt=1)
616-
except:
614+
fieldset.add_field(P)
615+
except RuntimeError:
617616
fail = True
618617
assert fail
619618

0 commit comments

Comments
 (0)