@@ -1065,27 +1065,30 @@ def from_xarray_dataset(cls, ds, variables, dimensions, mesh='spherical', allow_
10651065 return cls (u , v , fields = fields )
10661066
10671067 @classmethod
1068- def from_modulefile (cls , filename , ** kwargs ):
1068+ def from_modulefile (cls , filename , modulename = "create_fieldset" , ** kwargs ):
10691069 """Initialises FieldSet data from a file containing a python module file with a create_fieldset() function.
10701070
10711071 Parameters
10721072 ----------
1073- filename: path to a python file containing at least a create_fieldset() function,
1074- which returns a FieldSet object.
1073+ filename: path to a python file containing at least a function which returns a FieldSet object.
1074+ modulename: name of the function in the python file that returns a FieldSet object. Default is "create_fieldset" .
10751075 """
10761076 # check if filename exists
10771077 if not path .exists (filename ):
10781078 raise IOError (f"FieldSet module file { filename } does not exist" )
10791079
10801080 # Importing the source file directly (following https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly)
1081- spec = importlib .util .spec_from_file_location ("create_fieldset" , filename )
1081+ spec = importlib .util .spec_from_file_location (modulename , filename )
10821082 fieldset_module = importlib .util .module_from_spec (spec )
1083- sys .modules ['create_fieldset' ] = fieldset_module
1083+ sys .modules [modulename ] = fieldset_module
10841084 spec .loader .exec_module (fieldset_module )
10851085
1086- if not hasattr (fieldset_module , 'create_fieldset' ):
1087- raise IOError (f"FieldSet module { filename } does not contain a `create_fieldset` function" )
1088- return fieldset_module .create_fieldset (** kwargs )
1086+ if not hasattr (fieldset_module , modulename ):
1087+ raise IOError (f"{ filename } does not contain a { modulename } function" )
1088+ fieldset = getattr (fieldset_module , modulename )(** kwargs )
1089+ if not isinstance (fieldset , FieldSet ):
1090+ raise IOError (f"Module { filename } .{ modulename } does not return a FieldSet object" )
1091+ return fieldset
10891092
10901093 def get_fields (self ):
10911094 """Returns a list of all the :class:`parcels.field.Field` and :class:`parcels.field.VectorField`
0 commit comments