Skip to content

Commit a2dec03

Browse files
committed
Merge branch 'master' into new_xarray
2 parents e4f2605 + f97623b commit a2dec03

4 files changed

Lines changed: 24 additions & 16 deletions

File tree

parcels/field.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def __init__(self, name, data, lon=None, lat=None, depth=None, time=None, grid=N
119119
self.dimensions = kwargs.pop('dimensions', None)
120120
self.indices = kwargs.pop('indices', None)
121121
self.dataFiles = kwargs.pop('dataFiles', None)
122+
self.netcdf_engine = kwargs.pop('netcdf_engine', 'netcdf4')
122123
self.loaded_time_indices = []
123124

124125
@classmethod
@@ -151,6 +152,8 @@ def from_netcdf(cls, filenames, variable, dimensions, indices=None, grid=None,
151152
It is advised not to fully load the data, since in that case Parcels deals with
152153
a better memory management during particle set execution.
153154
full_load is however sometimes necessary for plotting the fields.
155+
:param netcdf_engine: engine to use for netcdf reading in xarray. Default is 'netcdf',
156+
but in cases where this doesn't work, setting netcdf_engine='scipy' could help
154157
"""
155158

156159
if not isinstance(filenames, Iterable) or isinstance(filenames, str):
@@ -174,15 +177,16 @@ def from_netcdf(cls, filenames, variable, dimensions, indices=None, grid=None,
174177
depth_filename = filenames[0]
175178

176179
indices = {} if indices is None else indices.copy()
177-
with NetcdfFileBuffer(lonlat_filename, dimensions, indices) as filebuffer:
180+
netcdf_engine = kwargs.pop('netcdf_engine', 'netcdf4')
181+
with NetcdfFileBuffer(lonlat_filename, dimensions, indices, netcdf_engine) as filebuffer:
178182
lon, lat = filebuffer.read_lonlat
179183
indices = filebuffer.indices
180184
# Check if parcels_mesh has been explicitly set in file
181185
if 'parcels_mesh' in filebuffer.dataset.attrs:
182186
mesh = filebuffer.dataset.attrs['parcels_mesh']
183187

184188
if 'depth' in dimensions:
185-
with NetcdfFileBuffer(depth_filename, dimensions, indices) as filebuffer:
189+
with NetcdfFileBuffer(depth_filename, dimensions, indices, netcdf_engine) as filebuffer:
186190
depth = filebuffer.read_depth
187191
else:
188192
indices['depth'] = [0]
@@ -197,7 +201,7 @@ def from_netcdf(cls, filenames, variable, dimensions, indices=None, grid=None,
197201
timeslices = []
198202
dataFiles = []
199203
for fname in data_filenames:
200-
with NetcdfFileBuffer(fname, dimensions, indices) as filebuffer:
204+
with NetcdfFileBuffer(fname, dimensions, indices, netcdf_engine) as filebuffer:
201205
ftime = filebuffer.time
202206
timeslices.append(ftime)
203207
dataFiles.append([fname] * len(ftime))
@@ -231,7 +235,7 @@ def from_netcdf(cls, filenames, variable, dimensions, indices=None, grid=None,
231235
data = np.empty((grid.tdim, grid.zdim, grid.ydim, grid.xdim), dtype=np.float32)
232236
ti = 0
233237
for tslice, fname in zip(grid.timeslices, data_filenames):
234-
with NetcdfFileBuffer(fname, dimensions, indices) as filebuffer:
238+
with NetcdfFileBuffer(fname, dimensions, indices, netcdf_engine) as filebuffer:
235239
# If Field.from_netcdf is called directly, it may not have a 'data' dimension
236240
# In that case, assume that 'name' is the data dimension
237241
filebuffer.name = filebuffer.parse_name(dimensions, variable)
@@ -257,6 +261,7 @@ def from_netcdf(cls, filenames, variable, dimensions, indices=None, grid=None,
257261
kwargs['dimensions'] = dimensions.copy()
258262
kwargs['indices'] = indices
259263
kwargs['time_periodic'] = time_periodic
264+
kwargs['netcdf_engine'] = netcdf_engine
260265

261266
return cls(variable, data, grid=grid,
262267
allow_time_extrapolation=allow_time_extrapolation, **kwargs)
@@ -858,7 +863,7 @@ def advancetime(self, field_new, advanceForward):
858863

859864
def computeTimeChunk(self, data, tindex):
860865
g = self.grid
861-
with NetcdfFileBuffer(self.dataFiles[g.ti+tindex], self.dimensions, self.indices) as filebuffer:
866+
with NetcdfFileBuffer(self.dataFiles[g.ti+tindex], self.dimensions, self.indices, self.netcdf_engine) as filebuffer:
862867
filebuffer.name = filebuffer.parse_name(self.dimensions, self.name)
863868
time_data = filebuffer.time
864869
time_data = g.time_origin.reltime(time_data)
@@ -1130,20 +1135,21 @@ def __getitem__(self, key):
11301135
class NetcdfFileBuffer(object):
11311136
""" Class that encapsulates and manages deferred access to file data. """
11321137

1133-
def __init__(self, filename, dimensions, indices):
1138+
def __init__(self, filename, dimensions, indices, netcdf_engine):
11341139
self.filename = filename
11351140
self.dimensions = dimensions # Dict with dimension keyes for file data
11361141
self.indices = indices
11371142
self.dataset = None
1143+
self.netcdf_engine = netcdf_engine
11381144

11391145
def __enter__(self):
11401146
try:
1141-
self.dataset = xr.open_dataset(str(self.filename), decode_cf=True)
1147+
self.dataset = xr.open_dataset(str(self.filename), decode_cf=True, engine=self.netcdf_engine)
11421148
self.dataset['decoded'] = True
11431149
except:
11441150
logger.warning_once("File %s could not be decoded properly by xarray (version %s).\n It will be opened with no decoding. Filling values might be wrongly parsed."
11451151
% (self.filename, xr.__version__))
1146-
self.dataset = xr.open_dataset(str(self.filename), decode_cf=False)
1152+
self.dataset = xr.open_dataset(str(self.filename), decode_cf=False, engine=self.netcdf_engine)
11471153
self.dataset['decoded'] = False
11481154
for inds in self.indices.values():
11491155
if type(inds) not in [list, range]:
@@ -1166,8 +1172,8 @@ def parse_name(self, dimensions, variable):
11661172

11671173
@property
11681174
def read_lonlat(self):
1169-
lon = getattr(self.dataset, self.dimensions['lon'])
1170-
lat = getattr(self.dataset, self.dimensions['lat'])
1175+
lon = self.dataset[self.dimensions['lon']]
1176+
lat = self.dataset[self.dimensions['lat']]
11711177
xdim = lon.size if len(lon.shape) == 1 else lon.shape[-1]
11721178
ydim = lat.size if len(lat.shape) == 1 else lat.shape[-2]
11731179
self.indices['lon'] = self.indices['lon'] if 'lon' in self.indices else range(xdim)
@@ -1195,7 +1201,7 @@ def read_lonlat(self):
11951201
@property
11961202
def read_depth(self):
11971203
if 'depth' in self.dimensions:
1198-
depth = getattr(self.dataset, self.dimensions['depth'])
1204+
depth = self.dataset[self.dimensions['depth']]
11991205
depthsize = depth.size if len(depth.shape) == 1 else depth.shape[-3]
12001206
self.indices['depth'] = self.indices['depth'] if 'depth' in self.indices else range(depthsize)
12011207
if len(depth.shape) == 1:
@@ -1211,7 +1217,7 @@ def read_depth(self):
12111217

12121218
@property
12131219
def data(self):
1214-
data = getattr(self.dataset, self.name)
1220+
data = self.dataset[self.name]
12151221
if len(data.shape) == 2:
12161222
data = data[self.indices['lat'], self.indices['lon']]
12171223
elif len(data.shape) == 3:
@@ -1229,15 +1235,15 @@ def data(self):
12291235
@property
12301236
def time(self):
12311237
try:
1232-
time_da = getattr(self.dataset, self.dimensions['time'])
1238+
time_da = self.dataset[self.dimensions['time']]
12331239
if self.dataset['decoded'] and 'Unit' not in time_da.attrs:
12341240
time = np.array([time_da]) if len(time_da.shape) == 0 else np.array(time_da)
12351241
else:
12361242
if 'units' not in time_da.attrs and 'Unit' in time_da.attrs:
12371243
time_da.attrs['units'] = time_da.attrs['Unit']
12381244
ds = xr.Dataset({self.dimensions['time']: time_da})
12391245
ds = xr.decode_cf(ds)
1240-
da = getattr(ds, self.dimensions['time'])
1246+
da = ds[self.dimensions['time']]
12411247
time = np.array([da]) if len(da.shape) == 0 else np.array(da)
12421248
if isinstance(time[0], datetime.datetime):
12431249
raise NotImplementedError('Parcels currently only parses dates ranging from 1678 AD to 2262 AD, which are stored by xarray as np.datetime64. If you need a wider date range, please open an Issue on the parcels github page.')

parcels/fieldset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ def from_netcdf(cls, filenames, variables, dimensions, indices=None,
189189
It is advised not to fully load the data, since in that case Parcels deals with
190190
a better memory management during particle set execution.
191191
full_load is however sometimes necessary for plotting the fields.
192+
:param netcdf_engine: engine to use for netcdf reading in xarray. Default is 'netcdf',
193+
but in cases where this doesn't work, setting netcdf_engine='scipy' could help
192194
"""
193195

194196
fields = {}

parcels/scripts/plottrajectoriesfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def plotTrajectoriesFile(filename, mode='2d', tracerfile=None, tracerfield='P',
100100
else:
101101
scat = ax.scatter(lon[b], lat[b], s=20, color='k')
102102
ttl = ax.set_title('Particles' + titlestr + ' at time ' + str(plottimes[0]))
103-
frames = np.arange(1, len(plottimes))
103+
frames = np.arange(0, len(plottimes))
104104

105105
def animate(t):
106106
b = time == plottimes[t]

tests/test_particle_sets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def IncrLon(particle, fieldset, time, dt):
161161
for k in range(samplevar.shape[1]):
162162
assert np.allclose([p for p in samplevar[:, k] if np.isfinite(p)], k)
163163
filesize = os.path.getsize(str(outfilepath+".nc"))
164-
assert filesize < 1024 * 60 # test that chunking leads to filesize less than 60KB
164+
assert filesize < 1024 * 65 # test that chunking leads to filesize less than 65KB
165165

166166

167167
def test_pset_repeatdt_check_dt(fieldset):

0 commit comments

Comments
 (0)