Skip to content

Commit 133426f

Browse files
Generalising fieldset.from_modulefile to also allow custom modulenames
1 parent a5d5d02 commit 133426f

3 files changed

Lines changed: 20 additions & 11 deletions

File tree

parcels/fieldset.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,27 +1065,30 @@ def from_xarray_dataset(cls, ds, variables, dimensions, mesh='spherical', allow_
10651065
return cls(u, v, fields=fields)
10661066

10671067
@classmethod
1068-
def from_modulefile(cls, filename, **kwargs):
1068+
def from_modulefile(cls, filename, modulename="create_fieldset", **kwargs):
10691069
"""Initialises FieldSet data from a file containing a python module file with a create_fieldset() function.
10701070
10711071
Parameters
10721072
----------
1073-
filename: path to a python file containing at least a create_fieldset() function,
1074-
which returns a FieldSet object.
1073+
filename: path to a python file containing at least a function which returns a FieldSet object.
1074+
modulename: name of the function in the python file that returns a FieldSet object. Default is "create_fieldset".
10751075
"""
10761076
# check if filename exists
10771077
if not path.exists(filename):
10781078
raise IOError(f"FieldSet module file {filename} does not exist")
10791079

10801080
# Importing the source file directly (following https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly)
1081-
spec = importlib.util.spec_from_file_location("create_fieldset", filename)
1081+
spec = importlib.util.spec_from_file_location(modulename, filename)
10821082
fieldset_module = importlib.util.module_from_spec(spec)
1083-
sys.modules['create_fieldset'] = fieldset_module
1083+
sys.modules[modulename] = fieldset_module
10841084
spec.loader.exec_module(fieldset_module)
10851085

1086-
if not hasattr(fieldset_module, 'create_fieldset'):
1087-
raise IOError(f"FieldSet module {filename} does not contain a `create_fieldset` function")
1088-
return fieldset_module.create_fieldset(**kwargs)
1086+
if not hasattr(fieldset_module, modulename):
1087+
raise IOError(f"{filename} does not contain a {modulename} function")
1088+
fieldset = getattr(fieldset_module, modulename)(**kwargs)
1089+
if not isinstance(fieldset, FieldSet):
1090+
raise IOError(f"Module {filename}.{modulename} does not return a FieldSet object")
1091+
return fieldset
10891092

10901093
def get_fields(self):
10911094
"""Returns a list of all the :class:`parcels.field.Field` and :class:`parcels.field.VectorField`

tests/test_data/fieldset_nemo_error.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,7 @@ def random_function_name():
1515
variables = {'U': 'U', 'V': 'V'}
1616
dimensions = {'lon': 'glamf', 'lat': 'gphif', 'time': 'time_counter'}
1717
return parcels.FieldSet.from_nemo(filenames, variables, dimensions)
18+
19+
20+
def none_returning_function():
21+
return None

tests/test_fieldset.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,14 @@ def test_fieldset_from_modulefile():
209209
fieldset = FieldSet.from_modulefile(data_path + 'fieldset_nemo.py', indices=indices)
210210
assert fieldset.U.grid.lon.shape[1] == 4
211211

212-
213-
def test_fieldset_from_modulefile_error():
214-
data_path = path.join(path.dirname(__file__), 'test_data/')
215212
with pytest.raises(IOError):
216213
FieldSet.from_modulefile(data_path + 'fieldset_nemo_error.py')
217214

215+
FieldSet.from_modulefile(data_path + 'fieldset_nemo_error.py', modulename='random_function_name')
216+
217+
with pytest.raises(IOError):
218+
FieldSet.from_modulefile(data_path + 'fieldset_nemo_error.py', modulename='none_returning_function')
219+
218220

219221
def test_field_from_netcdf_fieldtypes():
220222
data_path = path.join(path.dirname(__file__), 'test_data/')

0 commit comments

Comments
 (0)