Skip to content

Commit 3f4168c

Browse files
Merge pull request #1566 from OceanParcels/fieldset_from_directory
Creating a new `Fieldset.from_modulefile()` method
2 parents c2bf42c + 846520b commit 3f4168c

4 files changed

Lines changed: 85 additions & 0 deletions

File tree

parcels/fieldset.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import importlib.util
2+
import sys
13
from copy import deepcopy
24
from glob import glob
35
from os import path
@@ -1062,6 +1064,32 @@ def from_xarray_dataset(cls, ds, variables, dimensions, mesh='spherical', allow_
10621064
v = fields.pop('V', None)
10631065
return cls(u, v, fields=fields)
10641066

1067+
@classmethod
1068+
def from_modulefile(cls, filename, modulename="create_fieldset", **kwargs):
1069+
"""Initialises FieldSet data from a file containing a python module file with a create_fieldset() function.
1070+
1071+
Parameters
1072+
----------
1073+
filename: path to a python file containing at least a function which returns a FieldSet object.
1074+
modulename: name of the function in the python file that returns a FieldSet object. Default is "create_fieldset".
1075+
"""
1076+
# check if filename exists
1077+
if not path.exists(filename):
1078+
raise IOError(f"FieldSet module file {filename} does not exist")
1079+
1080+
# Importing the source file directly (following https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly)
1081+
spec = importlib.util.spec_from_file_location(modulename, filename)
1082+
fieldset_module = importlib.util.module_from_spec(spec)
1083+
sys.modules[modulename] = fieldset_module
1084+
spec.loader.exec_module(fieldset_module)
1085+
1086+
if not hasattr(fieldset_module, modulename):
1087+
raise IOError(f"{filename} does not contain a {modulename} function")
1088+
fieldset = getattr(fieldset_module, modulename)(**kwargs)
1089+
if not isinstance(fieldset, FieldSet):
1090+
raise IOError(f"Module {filename}.{modulename} does not return a FieldSet object")
1091+
return fieldset
1092+
10651093
def get_fields(self):
10661094
"""Returns a list of all the :class:`parcels.field.Field` and :class:`parcels.field.VectorField`
10671095
objects associated with this FieldSet.

tests/test_data/fieldset_nemo.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from os import path
2+
3+
import parcels
4+
5+
6+
def create_fieldset(indices=None):
7+
data_path = path.join(path.dirname(__file__))
8+
9+
filenames = {'U': {'lon': path.join(data_path, 'mask_nemo_cross_180lon.nc'),
10+
'lat': path.join(data_path, 'mask_nemo_cross_180lon.nc'),
11+
'data': path.join(data_path, 'Uu_eastward_nemo_cross_180lon.nc')},
12+
'V': {'lon': path.join(data_path, 'mask_nemo_cross_180lon.nc'),
13+
'lat': path.join(data_path, 'mask_nemo_cross_180lon.nc'),
14+
'data': path.join(data_path, 'Vv_eastward_nemo_cross_180lon.nc')}}
15+
variables = {'U': 'U', 'V': 'V'}
16+
dimensions = {'lon': 'glamf', 'lat': 'gphif', 'time': 'time_counter'}
17+
indices = indices or {}
18+
return parcels.FieldSet.from_nemo(filenames, variables, dimensions, indices=indices)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from os import path
2+
3+
import parcels
4+
5+
6+
def random_function_name():
7+
data_path = path.join(path.dirname(__file__))
8+
9+
filenames = {'U': {'lon': path.join(data_path, 'mask_nemo_cross_180lon.nc'),
10+
'lat': path.join(data_path, 'mask_nemo_cross_180lon.nc'),
11+
'data': path.join(data_path, 'Uu_eastward_nemo_cross_180lon.nc')},
12+
'V': {'lon': path.join(data_path, 'mask_nemo_cross_180lon.nc'),
13+
'lat': path.join(data_path, 'mask_nemo_cross_180lon.nc'),
14+
'data': path.join(data_path, 'Vv_eastward_nemo_cross_180lon.nc')}}
15+
variables = {'U': 'U', 'V': 'V'}
16+
dimensions = {'lon': 'glamf', 'lat': 'gphif', 'time': 'time_counter'}
17+
return parcels.FieldSet.from_nemo(filenames, variables, dimensions)
18+
19+
20+
def none_returning_function():
21+
return None

tests/test_fieldset.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,24 @@ def test_field_from_netcdf(with_timestamps):
200200
Field.from_netcdf(filenames, variable, dimensions, interp_method='cgrid_velocity')
201201

202202

203+
def test_fieldset_from_modulefile():
204+
data_path = path.join(path.dirname(__file__), 'test_data/')
205+
fieldset = FieldSet.from_modulefile(data_path + 'fieldset_nemo.py')
206+
assert fieldset.U.creation_log == 'from_nemo'
207+
208+
indices = {'lon': range(6, 10)}
209+
fieldset = FieldSet.from_modulefile(data_path + 'fieldset_nemo.py', indices=indices)
210+
assert fieldset.U.grid.lon.shape[1] == 4
211+
212+
with pytest.raises(IOError):
213+
FieldSet.from_modulefile(data_path + 'fieldset_nemo_error.py')
214+
215+
FieldSet.from_modulefile(data_path + 'fieldset_nemo_error.py', modulename='random_function_name')
216+
217+
with pytest.raises(IOError):
218+
FieldSet.from_modulefile(data_path + 'fieldset_nemo_error.py', modulename='none_returning_function')
219+
220+
203221
def test_field_from_netcdf_fieldtypes():
204222
data_path = path.join(path.dirname(__file__), 'test_data/')
205223

0 commit comments

Comments
 (0)