Skip to content

Commit 23c4fe4

Browse files
Merge pull request #976 from OceanParcels/dask_updates
Fixing autochunking.
2 parents dc6262e + 19de6ec commit 23c4fe4

6 files changed

Lines changed: 180 additions & 30 deletions

File tree

parcels/examples/example_dask_chunk_OCMs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def test_nemo_3D(mode, chunk_mode):
248248
assert (len(field_set.U.grid.load_chunk) != 1)
249249
assert (len(field_set.U.grid.load_chunk) == (1 * int(math.ceil(75.0/25.0)) * int(math.ceil(201.0/201.0)) * int(math.ceil(151.0/151.0))))
250250
assert (len(field_set.V.grid.load_chunk) != 1)
251-
assert (len(field_set.V.grid.load_chunk) == (1 * int(math.ceil(75.0/1.0)) * int(math.ceil(201.0/8.0)) * int(math.ceil(151.0/8.0))))
251+
assert (len(field_set.V.grid.load_chunk) == (1 * int(math.ceil(75.0/75.0)) * int(math.ceil(201.0/8.0)) * int(math.ceil(151.0/8.0))))
252252

253253

254254
@pytest.mark.parametrize('mode', ['jit'])
@@ -318,7 +318,7 @@ def test_pop(mode, chunk_mode):
318318
assert (len(field_set.U.grid.load_chunk) != 1)
319319
assert (len(field_set.V.grid.load_chunk) != 1)
320320
assert (len(field_set.W.grid.load_chunk) != 1)
321-
assert (len(field_set.U.grid.load_chunk) == (int(math.ceil(21.0/8.0)) * int(math.ceil(60.0/8.0)) * int(math.ceil(60.0/8.0))))
321+
assert (len(field_set.U.grid.load_chunk) == (int(math.ceil(21.0/3.0)) * int(math.ceil(60.0/8.0)) * int(math.ceil(60.0/8.0))))
322322

323323

324324
@pytest.mark.parametrize('mode', ['jit'])

parcels/field.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,17 +1146,14 @@ def chunk_setup(self):
11461146
def chunk_data(self):
11471147
if not self.chunk_set:
11481148
self.chunk_setup()
1149-
# self.grid.load_chunk code:
1150-
# 0: not loaded
1151-
# 1: was asked to load by kernel in JIT
1152-
# 2: is loaded and was touched last C call
1153-
# 3: is loaded
1149+
g = self.grid
11541150
if isinstance(self.data, da.core.Array):
11551151
for block_id in range(len(self.grid.load_chunk)):
1156-
if self.grid.load_chunk[block_id] == 1 or self.grid.load_chunk[block_id] > 1 and self.data_chunks[block_id] is None:
1152+
if g.load_chunk[block_id] == g.chunk_loading_requested \
1153+
or g.load_chunk[block_id] in g.chunk_loaded and self.data_chunks[block_id] is None:
11571154
block = self.get_block(block_id)
11581155
self.data_chunks[block_id] = np.array(self.data.blocks[(slice(self.grid.tdim),) + block])
1159-
elif self.grid.load_chunk[block_id] == 0:
1156+
elif g.load_chunk[block_id] == g.chunk_not_loaded:
11601157
if isinstance(self.data_chunks, list):
11611158
self.data_chunks[block_id] = None
11621159
else:
@@ -1168,7 +1165,7 @@ def chunk_data(self):
11681165
else:
11691166
self.data_chunks[0, :] = None
11701167
self.c_data_chunks[0] = None
1171-
self.grid.load_chunk[0] = 2
1168+
self.grid.load_chunk[0] = g.chunk_loaded_touched
11721169
self.data_chunks[0] = np.array(self.data)
11731170

11741171
@property
@@ -1189,9 +1186,9 @@ class CField(Structure):
11891186
allow_time_extrapolation = 1 if self.allow_time_extrapolation else 0
11901187
time_periodic = 1 if self.time_periodic else 0
11911188
for i in range(len(self.grid.load_chunk)):
1192-
if self.grid.load_chunk[i] == 1:
1189+
if self.grid.load_chunk[i] == self.grid.chunk_loading_requested:
11931190
raise ValueError('data_chunks should have been loaded by now if requested. grid.load_chunk[bid] cannot be 1')
1194-
if self.grid.load_chunk[i] > 1:
1191+
if self.grid.load_chunk[i] in self.grid.chunk_loaded:
11951192
if not self.data_chunks[i].flags.c_contiguous:
11961193
self.data_chunks[i] = self.data_chunks[i].copy()
11971194
self.c_data_chunks[i] = self.data_chunks[i].ctypes.data_as(POINTER(POINTER(c_float)))
@@ -1357,12 +1354,13 @@ def computeTimeChunk(self, data, tindex):
13571354
ti = g.ti + tindex
13581355
timestamp = self.timestamps[np.where(ti < summedlen)[0][0]]
13591356

1357+
rechunk_callback_fields = self.chunk_setup if isinstance(tindex, list) else None
13601358
filebuffer = self._field_fb_class(self.dataFiles[g.ti + tindex], self.dimensions, self.indices,
13611359
netcdf_engine=self.netcdf_engine, timestamp=timestamp,
13621360
interp_method=self.interp_method,
13631361
data_full_zdim=self.data_full_zdim,
13641362
chunksize=self.chunksize,
1365-
rechunk_callback_fields=self.chunk_setup,
1363+
rechunk_callback_fields=rechunk_callback_fields,
13661364
chunkdims_name_map=self.netcdf_chunkdims_name_map)
13671365
filebuffer.__enter__()
13681366
time_data = filebuffer.time

parcels/fieldfilebuffer.py

Lines changed: 132 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from dask import utils as da_utils
44
import numpy as np
55
import xarray as xr
6+
from netCDF4 import Dataset as ncDataset
67

78
import datetime
89
import math
@@ -188,7 +189,12 @@ def __init__(self, *args, **kwargs):
188189

189190

190191
class DaskFileBuffer(NetcdfFileBuffer):
191-
_static_name_map = ['time', 'depth', 'lat', 'lon']
192+
_static_name_maps = {'time': ['time', 'time_count', 'time_counter', 'timer_count', 't'],
193+
'depth': ['depth', 'depthu', 'depthv', 'depthw', 'depths', 'deptht', 'depthx', 'depthy',
194+
'depthz', 'z', 'z_u', 'z_v', 'z_w', 'd', 'k', 'w_dep', 'w_deps', 'Z', 'Zp1',
195+
'Zl', 'Zu', 'level'],
196+
'lat': ['lat', 'nav_lat', 'y', 'latitude', 'la', 'lt', 'j', 'YC', 'YG'],
197+
'lon': ['lon', 'nav_lon', 'x', 'longitude', 'lo', 'ln', 'i', 'XC', 'XG']}
192198
_min_dim_chunksize = 16
193199

194200
""" Class that encapsulates and manages deferred access to file data. """
@@ -267,6 +273,45 @@ def close(self):
267273
self.chunking_finalized = False
268274
self.chunk_mapping = None
269275

276+
@classmethod
277+
def add_to_dimension_name_map_global(self, name_map):
278+
"""
279+
[externally callable]
280+
This function adds entries to the name map from parcels_dim -> netcdf_dim. This is required if you want to
281+
use auto-chunking on large fields whose map parameters are not defined. This function must be called before
282+
entering the filebuffer object. Example:
283+
DaskFileBuffer.add_to_dimension_name_map_global({'lat': 'nydim',
284+
'lon': 'nxdim',
285+
'time': 'ntdim',
286+
'depth': 'nddim'})
287+
fieldset = FieldSet(..., chunksize='auto')
288+
[...]
289+
Note that not all parcels dimensions need to be present in 'name_map'.
290+
"""
291+
assert isinstance(name_map, dict)
292+
for pcls_dim_name in name_map.keys():
293+
if isinstance(name_map[pcls_dim_name], list):
294+
for nc_dim_name in name_map[pcls_dim_name]:
295+
self._static_name_maps[pcls_dim_name].append(nc_dim_name)
296+
elif isinstance(name_map[pcls_dim_name], str):
297+
self._static_name_maps[pcls_dim_name].append(name_map[pcls_dim_name])
298+
299+
def add_to_dimension_name_map(self, name_map):
300+
"""
301+
[externally callable]
302+
This function adds entries to the name map from parcels_dim -> netcdf_dim. This is required if you want to
303+
use auto-chunking on large fields whose map parameters are not defined. This function must be called after
304+
constructing an filebuffer object and before entering the filebuffer. Example:
305+
fb = DaskFileBuffer(...)
306+
fb.add_to_dimension_name_map({'lat': 'nydim', 'lon': 'nxdim', 'time': 'ntdim', 'depth': 'nddim'})
307+
with fb:
308+
[do_stuff}
309+
Note that not all parcels dimensions need to be present in 'name_map'.
310+
"""
311+
assert isinstance(name_map, dict)
312+
for pcls_dim_name in name_map.keys():
313+
self._static_name_maps[pcls_dim_name].append(name_map[pcls_dim_name])
314+
270315
def _get_available_dims_indices_by_request(self):
271316
"""
272317
[private function - not to be called from outside the class]
@@ -278,7 +323,7 @@ def _get_available_dims_indices_by_request(self):
278323
neg_offset = 0
279324
tpl_offset = 0
280325
for name in ['time', 'depth', 'lat', 'lon']:
281-
i = self._static_name_map.index(name)
326+
i = list(self._static_name_maps.keys()).index(name)
282327
if (name not in self.dimensions):
283328
result[name] = None
284329
tpl_offset += 1
@@ -300,7 +345,32 @@ def _get_available_dims_indices_by_namemap(self):
300345
"""
301346
result = {}
302347
for name in ['time', 'depth', 'lat', 'lon']:
303-
result[name] = self._static_name_map.index(name)
348+
result[name] = list(self._static_name_maps.keys()).index(name)
349+
return result
350+
351+
def _get_available_dims_indices_by_netcdf_file(self):
352+
"""
353+
[private function - not to be called from outside the class]
354+
[File needs to be open (i.e. self.dataset is not None) for this to work - otherwise generating an error]
355+
Returns a dict mapping 'parcels_dimname' -> [None, int32_index_data_array].
356+
This dictionary is based on the information provided by the requested dimensions.
357+
Example: {'time': 0, 'depth': 5, 'lat': 3, 'lon': 1}
358+
for NetCDF with dimensions:
359+
timer: 1
360+
x: [0 4000]
361+
xr: [0 3999]
362+
y: [0 2140]
363+
yr: [0 2139]
364+
z: [0 75]
365+
"""
366+
if self.dataset is None:
367+
raise IOError("Trying to parse NetCDF header information before opening the file.")
368+
result = {}
369+
for pcls_dimname in ['time', 'depth', 'lat', 'lon']:
370+
for nc_dimname in self._static_name_maps[pcls_dimname]:
371+
if nc_dimname not in self.dataset.dims.keys():
372+
continue
373+
result[pcls_dimname] = list(self.dataset.dims.keys()).index(nc_dimname)
304374
return result
305375

306376
def _is_dimension_available(self, dimension_name):
@@ -346,6 +416,14 @@ def _is_dimension_in_dataset(self, parcels_dimension_name, netcdf_dimension_name
346416
if netcdf_dimension_name is not None and netcdf_dimension_name in self.dataset.dims.keys():
347417
value = self.dataset.dims[netcdf_dimension_name]
348418
k, dname, dvalue = i, netcdf_dimension_name, value
419+
elif self.dimensions is None or self.dataset is None:
420+
return k, dname, dvalue
421+
else:
422+
for name in self._static_name_maps[dimension_name]:
423+
if name in self.dataset.dims:
424+
value = self.dataset.dims[name]
425+
k, dname, dvalue = i, name, value
426+
break
349427
return k, dname, dvalue
350428

351429
def _is_dimension_in_chunksize_request(self, parcels_dimension_name):
@@ -467,6 +545,53 @@ def _get_initial_chunk_dictionary_by_dict_(self):
467545
self.chunksize.pop('lon')
468546
return chunk_dict, chunk_index_map
469547

548+
def _failsafe_parse_(self):
549+
"""
550+
[private function - not to be called from outside the class]
551+
['name' need to be initialised]
552+
"""
553+
# ==== fail - open it as a normal array and deduce the dimensions from the variable-function names ==== #
554+
# ==== done by parsing ALL variables in the NetCDF, and comparing their call-parameters with the ==== #
555+
# ==== name map available here. ==== #
556+
init_chunk_dict = {}
557+
self.dataset = ncDataset(str(self.filename))
558+
refdims = self.dataset.dimensions.keys()
559+
max_field = ""
560+
max_dim_names = ()
561+
max_coincide_dims = 0
562+
for vname in self.dataset.variables:
563+
var = self.dataset.variables[vname]
564+
coincide_dims = []
565+
for vdname in var.dimensions:
566+
if vdname in refdims:
567+
coincide_dims.append(vdname)
568+
n_coincide_dims = len(coincide_dims)
569+
if n_coincide_dims > max_coincide_dims:
570+
max_field = vname
571+
max_dim_names = tuple(coincide_dims)
572+
max_coincide_dims = n_coincide_dims
573+
self.name = max_field
574+
for nc_dname in max_dim_names:
575+
pcls_dname = None
576+
for dname in self._static_name_maps.keys():
577+
if nc_dname in self._static_name_maps[dname]:
578+
pcls_dname = dname
579+
break
580+
nc_dimsize = None
581+
pcls_dim_chunksize = None
582+
if pcls_dname is not None and pcls_dname in self.dimensions:
583+
pcls_dim_chunksize = self._min_dim_chunksize
584+
if isinstance(self.chunksize, dict) and pcls_dname is not None:
585+
nc_dimsize = self.dataset.dimensions[nc_dname].size
586+
if pcls_dname in self.chunksize.keys():
587+
pcls_dim_chunksize = self.chunksize[pcls_dname][1]
588+
if pcls_dname is not None and nc_dname is not None and nc_dimsize is not None and pcls_dim_chunksize is not None:
589+
init_chunk_dict[nc_dname] = pcls_dim_chunksize
590+
591+
# ==== because in this case it has shown that the requested chunksize setup cannot be used, ==== #
592+
# ==== replace the requested chunksize with this auto-derived version. ==== #
593+
return init_chunk_dict
594+
470595
def _get_initial_chunk_dictionary(self):
471596
"""
472597
[private function - not to be called from outside the class]
@@ -532,8 +657,10 @@ def _get_initial_chunk_dictionary(self):
532657
except:
533658
logger.warning("Chunking with init_chunk_dict = {} failed - Executing Dask chunking 'failsafe'...".format(init_chunk_dict))
534659
self.autochunkingfailed = True
535-
self.dataset.close()
536-
raise DaskChunkingError(self.__class__.__name__, "No correct mapping found between Parcels- and NetCDF dimensions! Please correct the 'FieldSet(..., chunksize={...})' parameter and try again.")
660+
if not self.autochunkingfailed:
661+
init_chunk_dict = self._failsafe_parse_()
662+
if isinstance(self.chunksize, dict):
663+
self.chunksize = init_chunk_dict
537664
finally:
538665
self.dataset.close()
539666
self.chunk_mapping = init_chunk_map
@@ -572,8 +699,6 @@ def data_access(self):
572699
self.rechunk_callback_fields()
573700
self.chunking_finalized = True
574701
else:
575-
if not self.autochunkingfailed:
576-
data = data.rechunk(self.chunk_mapping)
577702
self.chunking_finalized = True
578703
else:
579704
da_data = da.from_array(data, chunks=self.chunksize)

parcels/fieldset.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,9 +1069,11 @@ def computeTimeChunk(self, time, dt):
10691069
f.data = f.reshape(data)
10701070
if not f.chunk_set:
10711071
f.chunk_setup()
1072-
if len(g.load_chunk) > 0:
1073-
g.load_chunk = np.where(g.load_chunk == 2, 1, g.load_chunk)
1074-
g.load_chunk = np.where(g.load_chunk == 3, 0, g.load_chunk)
1072+
if len(g.load_chunk) > g.chunk_not_loaded:
1073+
g.load_chunk = np.where(g.load_chunk == g.chunk_loaded_touched,
1074+
g.chunk_loading_requested, g.load_chunk)
1075+
g.load_chunk = np.where(g.load_chunk == g.chunk_deprecated,
1076+
g.chunk_not_loaded, g.load_chunk)
10751077

10761078
elif g.update_status == 'updated':
10771079
lib = np if isinstance(f.data, np.ndarray) else da
@@ -1119,11 +1121,14 @@ def computeTimeChunk(self, time, dt):
11191121
f.data[2, :] = None
11201122
f.data[1:, :] = f.data[:2, :]
11211123
f.data[0, :] = data
1122-
g.load_chunk = np.where(g.load_chunk == 3, 0, g.load_chunk)
1124+
g.load_chunk = np.where(g.load_chunk == g.chunk_loaded_touched,
1125+
g.chunk_loading_requested, g.load_chunk)
1126+
g.load_chunk = np.where(g.load_chunk == g.chunk_deprecated,
1127+
g.chunk_not_loaded, g.load_chunk)
11231128
if isinstance(f.data, da.core.Array) and len(g.load_chunk) > 0:
11241129
if signdt >= 0:
11251130
for block_id in range(len(g.load_chunk)):
1126-
if g.load_chunk[block_id] == 2:
1131+
if g.load_chunk[block_id] == g.chunk_loaded_touched:
11271132
if f.data_chunks[block_id] is None:
11281133
# file chunks were never loaded.
11291134
# happens when field not called by kernel, but shares a grid with another field called by kernel
@@ -1134,7 +1139,7 @@ def computeTimeChunk(self, time, dt):
11341139
f.data_chunks[block_id][2] = np.array(f.data.blocks[(slice(3),)+block][2])
11351140
else:
11361141
for block_id in range(len(g.load_chunk)):
1137-
if g.load_chunk[block_id] == 2:
1142+
if g.load_chunk[block_id] == g.chunk_loaded_touched:
11381143
if f.data_chunks[block_id] is None:
11391144
# file chunks were never loaded.
11401145
# happens when field not called by kernel, but shares a grid with another field called by kernel

parcels/grid.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,26 @@ def computeTimeChunk(self, f, time, signdt):
233233
nextTime_loc = self.time[0] + periods*(self.time_full[-1]-self.time_full[0])
234234
return nextTime_loc
235235

236+
@property
237+
def chunk_not_loaded(self):
238+
return 0
239+
240+
@property
241+
def chunk_loading_requested(self):
242+
return 1
243+
244+
@property
245+
def chunk_loaded_touched(self):
246+
return 2
247+
248+
@property
249+
def chunk_deprecated(self):
250+
return 3
251+
252+
@property
253+
def chunk_loaded(self):
254+
return [2, 3]
255+
236256

237257
class RectilinearGrid(Grid):
238258
"""Rectilinear Grid

parcels/kernel.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,9 @@ def execute_jit(self, pset, endtime, dt):
252252
f.c_data_chunks[block_id] = None
253253

254254
for g in pset.fieldset.gridset.grids:
255-
g.load_chunk = np.where(g.load_chunk == 1, 2, g.load_chunk)
256-
if len(g.load_chunk) > 0: # not the case if a field in not called in the kernel
255+
g.load_chunk = np.where(g.load_chunk == g.chunk_loading_requested,
256+
g.chunk_loaded_touched, g.load_chunk)
257+
if len(g.load_chunk) > g.chunk_not_loaded: # not the case if a field in not called in the kernel
257258
if not g.load_chunk.flags.c_contiguous:
258259
g.load_chunk = g.load_chunk.copy()
259260
if not g.depth.flags.c_contiguous:
@@ -410,8 +411,9 @@ def execute(self, pset, endtime, dt, recovery=None, output_file=None, execute_on
410411

411412
if pset.fieldset is not None:
412413
for g in pset.fieldset.gridset.grids:
413-
if len(g.load_chunk) > 0: # not the case if a field in not called in the kernel
414-
g.load_chunk = np.where(g.load_chunk == 2, 3, g.load_chunk)
414+
if len(g.load_chunk) > g.chunk_not_loaded: # not the case if a field in not called in the kernel
415+
g.load_chunk = np.where(g.load_chunk == g.chunk_loaded_touched,
416+
g.chunk_deprecated, g.load_chunk)
415417

416418
# Execute the kernel over the particle set
417419
if self.ptype.uses_jit:

0 commit comments

Comments
 (0)