Skip to content

Commit 237d85b

Browse files
Merge pull request #978 from OceanParcels/deferred_load_with_two_snapshots
Deferred load with two snapshots
2 parents 23c4fe4 + 323adc6 commit 237d85b

5 files changed

Lines changed: 76 additions & 50 deletions

File tree

parcels/examples/example_dask_chunk_OCMs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def compute_swash_particle_advection(field_set, mode, lonp, latp, depthp):
216216

217217
def compute_ofam_particle_advection(field_set, mode, lonp, latp, depthp):
218218
pset = ParticleSet(field_set, pclass=ptype[mode], lon=lonp, lat=latp, depth=depthp)
219-
pfile = ParticleFile("ofam_particles_chunk", pset, outputdt=delta(minutes=10))
219+
pfile = ParticleFile("ofam_particles_chunk", pset, outputdt=delta(days=1))
220220
pset.execute(AdvectionRK4, runtime=delta(days=10), dt=delta(minutes=5), output_file=pfile)
221221
return pset
222222

@@ -255,7 +255,7 @@ def test_nemo_3D(mode, chunk_mode):
255255
@pytest.mark.parametrize('chunk_mode', [False, 'auto', 'specific', 'failsafe'])
256256
def test_globcurrent_2D(mode, chunk_mode):
257257
if chunk_mode in ['auto', ]:
258-
dask.config.set({'array.chunk-size': '32KiB'})
258+
dask.config.set({'array.chunk-size': '16KiB'})
259259
else:
260260
dask.config.set({'array.chunk-size': '128MiB'})
261261
field_set = fieldset_from_globcurrent(chunk_mode)
@@ -325,7 +325,7 @@ def test_pop(mode, chunk_mode):
325325
@pytest.mark.parametrize('chunk_mode', [False, 'auto', 'specific', 'failsafe'])
326326
def test_swash(mode, chunk_mode):
327327
if chunk_mode in ['auto', ]:
328-
dask.config.set({'array.chunk-size': '64KiB'})
328+
dask.config.set({'array.chunk-size': '32KiB'})
329329
else:
330330
dask.config.set({'array.chunk-size': '128MiB'})
331331
field_set = fieldset_from_swash(chunk_mode)

parcels/field.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def __init__(self, name, data, lon=None, lat=None, depth=None, time=None, grid=N
200200
self.c_data_chunks = []
201201
self.nchunks = []
202202
self.chunk_set = False
203-
self.filebuffers = [None] * 3
203+
self.filebuffers = [None] * 2
204204
if len(kwargs) > 0:
205205
raise SyntaxError('Field received an unexpected keyword argument "%s"' % list(kwargs.keys())[0])
206206

@@ -392,7 +392,7 @@ def from_netcdf(cls, filenames, variable, dimensions, indices=None, grid=None,
392392
if 'full_load' in kwargs: # for backward compatibility with Parcels < v2.0.0
393393
deferred_load = not kwargs['full_load']
394394

395-
if grid.time.size <= 3 or deferred_load is False:
395+
if grid.time.size <= 2 or deferred_load is False:
396396
deferred_load = False
397397

398398
if chunksize not in [False, None]:
@@ -1324,11 +1324,9 @@ def data_concatenate(self, data, data_to_concat, tindex):
13241324
if tindex == 0:
13251325
data = lib.concatenate([data_to_concat, data[tindex+1:, :]], axis=0)
13261326
elif tindex == 1:
1327-
data = lib.concatenate([data[:tindex, :], data_to_concat, data[tindex+1:, :]], axis=0)
1328-
elif tindex == 2:
13291327
data = lib.concatenate([data[:tindex, :], data_to_concat], axis=0)
13301328
else:
1331-
raise ValueError("data_concatenate is used for computeTimeChunk, with tindex in [0, 1, 2]")
1329+
raise ValueError("data_concatenate is used for computeTimeChunk, with tindex in [0, 1]")
13321330
return data
13331331

13341332
def advancetime(self, field_new, advanceForward):

parcels/fieldset.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,7 @@ def advancetime(self, fieldset_new):
993993
:param fieldset_new: FieldSet snapshot with which the oldest time has to be replaced"""
994994

995995
logger.warning_once("Fieldset.advancetime() is deprecated.\n \
996-
Parcels deals automatically with loading only 3 time steps simustaneously\
996+
Parcels deals automatically with loading only 2 time steps simultaneously\
997997
such that the total allocated memory remains limited.")
998998

999999
advance = 0
@@ -1055,7 +1055,7 @@ def computeTimeChunk(self, time, dt):
10551055
else:
10561056
zd = g.zdim
10571057
data = lib.empty((g.tdim, zd, g.ydim-2*g.meridional_halo, g.xdim-2*g.zonal_halo), dtype=np.float32)
1058-
f.loaded_time_indices = range(3)
1058+
f.loaded_time_indices = range(2)
10591059
for tind in f.loaded_time_indices:
10601060
for fb in f.filebuffers:
10611061
if fb is not None:
@@ -1083,43 +1083,43 @@ def computeTimeChunk(self, time, dt):
10831083
zd = g.zdim
10841084
data = lib.empty((g.tdim, zd, g.ydim-2*g.meridional_halo, g.xdim-2*g.zonal_halo), dtype=np.float32)
10851085
if signdt >= 0:
1086-
f.loaded_time_indices = [2]
1086+
f.loaded_time_indices = [1]
10871087
if f.filebuffers[0] is not None:
10881088
f.filebuffers[0].close()
10891089
f.filebuffers[0] = None
1090-
f.filebuffers[:2] = f.filebuffers[1:]
1091-
data = f.computeTimeChunk(data, 2)
1090+
f.filebuffers[0] = f.filebuffers[1]
1091+
data = f.computeTimeChunk(data, 1)
10921092
else:
10931093
f.loaded_time_indices = [0]
1094-
if f.filebuffers[2] is not None:
1095-
f.filebuffers[2].close()
1096-
f.filebuffers[2] = None
1097-
f.filebuffers[1:] = f.filebuffers[:2]
1094+
if f.filebuffers[1] is not None:
1095+
f.filebuffers[1].close()
1096+
f.filebuffers[1] = None
1097+
f.filebuffers[1] = f.filebuffers[0]
10981098
data = f.computeTimeChunk(data, 0)
10991099
data = f.rescale_and_set_minmax(data)
11001100
if signdt >= 0:
1101-
data = f.reshape(data)[2:, :]
1101+
data = f.reshape(data)[1, :]
11021102
if lib is da:
1103-
f.data = lib.concatenate([f.data[1:, :], data], axis=0)
1103+
f.data = lib.stack([f.data[1, :], data], axis=0)
11041104
else:
11051105
if not isinstance(f.data, DeferredArray):
11061106
if isinstance(f.data, list):
11071107
del f.data[0, :]
11081108
else:
11091109
f.data[0, :] = None
1110-
f.data[:2, :] = f.data[1:, :]
1111-
f.data[2, :] = data
1110+
f.data[0, :] = f.data[1, :]
1111+
f.data[1, :] = data
11121112
else:
1113-
data = f.reshape(data)[0:1, :]
1113+
data = f.reshape(data)[0, :]
11141114
if lib is da:
1115-
f.data = lib.concatenate([data, f.data[:2, :]], axis=0)
1115+
f.data = lib.stack([data, f.data[0, :]], axis=0)
11161116
else:
11171117
if not isinstance(f.data, DeferredArray):
11181118
if isinstance(f.data, list):
1119-
del f.data[2, :]
1119+
del f.data[1, :]
11201120
else:
1121-
f.data[2, :] = None
1122-
f.data[1:, :] = f.data[:2, :]
1121+
f.data[1, :] = None
1122+
f.data[1, :] = f.data[0, :]
11231123
f.data[0, :] = data
11241124
g.load_chunk = np.where(g.load_chunk == g.chunk_loaded_touched,
11251125
g.chunk_loading_requested, g.load_chunk)
@@ -1135,8 +1135,7 @@ def computeTimeChunk(self, time, dt):
11351135
break
11361136
block = f.get_block(block_id)
11371137
f.data_chunks[block_id][0] = None
1138-
f.data_chunks[block_id][:2] = f.data_chunks[block_id][1:]
1139-
f.data_chunks[block_id][2] = np.array(f.data.blocks[(slice(3),)+block][2])
1138+
f.data_chunks[block_id][1] = np.array(f.data.blocks[(slice(2),)+block][1])
11401139
else:
11411140
for block_id in range(len(g.load_chunk)):
11421141
if g.load_chunk[block_id] == g.chunk_loaded_touched:
@@ -1145,9 +1144,8 @@ def computeTimeChunk(self, time, dt):
11451144
# happens when field not called by kernel, but shares a grid with another field called by kernel
11461145
break
11471146
block = f.get_block(block_id)
1148-
f.data_chunks[block_id][2] = None
1149-
f.data_chunks[block_id][1:] = f.data_chunks[block_id][:2]
1150-
f.data_chunks[block_id][0] = np.array(f.data.blocks[(slice(3),)+block][0])
1147+
f.data_chunks[block_id][1] = None
1148+
f.data_chunks[block_id][0] = np.array(f.data.blocks[(slice(2),)+block][0])
11511149
# do user-defined computations on fieldset data
11521150
if self.compute_on_defer:
11531151
self.compute_on_defer(self)

parcels/grid.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -191,45 +191,45 @@ def computeTimeChunk(self, f, time, signdt):
191191
prev_time_indices = self.time
192192
if self.update_status == 'not_updated':
193193
if self.ti >= 0:
194-
if (time - periods*(self.time_full[-1]-self.time_full[0]) < self.time[0] or time - periods*(self.time_full[-1]-self.time_full[0]) > self.time[2]):
194+
if time - periods*(self.time_full[-1]-self.time_full[0]) < self.time[0] or time - periods*(self.time_full[-1]-self.time_full[0]) > self.time[1]:
195195
self.ti = -1 # reset
196196
elif signdt >= 0 and (time - periods*(self.time_full[-1]-self.time_full[0]) < self.time_full[0] or time - periods*(self.time_full[-1]-self.time_full[0]) >= self.time_full[-1]):
197197
self.ti = -1 # reset
198198
elif signdt < 0 and (time - periods*(self.time_full[-1]-self.time_full[0]) <= self.time_full[0] or time - periods*(self.time_full[-1]-self.time_full[0]) > self.time_full[-1]):
199199
self.ti = -1 # reset
200-
elif signdt >= 0 and time - periods*(self.time_full[-1]-self.time_full[0]) >= self.time[1] and self.ti < len(self.time_full)-3:
200+
elif signdt >= 0 and time - periods*(self.time_full[-1]-self.time_full[0]) >= self.time[1] and self.ti < len(self.time_full)-2:
201201
self.ti += 1
202-
self.time = self.time_full[self.ti:self.ti+3]
202+
self.time = self.time_full[self.ti:self.ti+2]
203203
self.update_status = 'updated'
204-
elif signdt == -1 and time - periods*(self.time_full[-1]-self.time_full[0]) < self.time[1] and self.ti > 0:
204+
elif signdt < 0 and time - periods*(self.time_full[-1]-self.time_full[0]) <= self.time[0] and self.ti > 0:
205205
self.ti -= 1
206-
self.time = self.time_full[self.ti:self.ti+3]
206+
self.time = self.time_full[self.ti:self.ti+2]
207207
self.update_status = 'updated'
208208
if self.ti == -1:
209209
self.time = self.time_full
210210
self.ti, _ = f.time_index(time)
211211
periods = self.periods.value if isinstance(self.periods, c_int) else self.periods
212212
if signdt == -1 and self.ti == 0 and (time - periods*(self.time_full[-1]-self.time_full[0])) == self.time[0] and f.time_periodic:
213-
self.ti = len(self.time)-2
213+
self.ti = len(self.time)-1
214214
periods -= 1
215215
if signdt == -1 and self.ti > 0:
216216
self.ti -= 1
217-
if self.ti >= len(self.time_full) - 2:
218-
self.ti = len(self.time_full) - 3
217+
if self.ti >= len(self.time_full) - 1:
218+
self.ti = len(self.time_full) - 2
219219

220-
self.time = self.time_full[self.ti:self.ti+3]
221-
self.tdim = 3
222-
if prev_time_indices is None or len(prev_time_indices) != 3 or len(prev_time_indices) != len(self.time):
220+
self.time = self.time_full[self.ti:self.ti+2]
221+
self.tdim = 2
222+
if prev_time_indices is None or len(prev_time_indices) != 2 or len(prev_time_indices) != len(self.time):
223223
self.update_status = 'first_updated'
224224
elif functools.reduce(lambda i, j: i and j, map(lambda m, k: m == k, self.time, prev_time_indices), True) and len(prev_time_indices) == len(self.time):
225225
self.update_status = 'not_updated'
226-
elif functools.reduce(lambda i, j: i and j, map(lambda m, k: m == k, self.time[:2], prev_time_indices[:2]), True) and len(prev_time_indices) == len(self.time):
226+
elif functools.reduce(lambda i, j: i and j, map(lambda m, k: m == k, self.time[:1], prev_time_indices[:1]), True) and len(prev_time_indices) == len(self.time):
227227
self.update_status = 'updated'
228228
else:
229229
self.update_status = 'first_updated'
230-
if signdt >= 0 and (self.ti < len(self.time_full)-3 or not f.allow_time_extrapolation):
231-
nextTime_loc = self.time[2] + periods*(self.time_full[-1]-self.time_full[0])
232-
elif signdt == -1 and (self.ti > 0 or not f.allow_time_extrapolation):
230+
if signdt >= 0 and (self.ti < len(self.time_full)-2 or not f.allow_time_extrapolation):
231+
nextTime_loc = self.time[1] + periods*(self.time_full[-1]-self.time_full[0])
232+
elif signdt < 0 and (self.ti > 0 or not f.allow_time_extrapolation):
233233
nextTime_loc = self.time[0] + periods*(self.time_full[-1]-self.time_full[0])
234234
return nextTime_loc
235235

tests/test_fieldset.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -521,10 +521,13 @@ def SampleUV2(particle, fieldset, time):
521521

522522
@pytest.mark.parametrize('mode', ['scipy', 'jit'])
523523
@pytest.mark.parametrize('time_periodic', [4*86400.0, False])
524+
@pytest.mark.parametrize('dt', [-3600, 3600])
524525
@pytest.mark.parametrize('chunksize', [False, 'auto', {'time': ('time_counter', 1), 'lat': ('y', 32), 'lon': ('x', 32)}])
525526
@pytest.mark.parametrize('with_GC', [False, True])
526527
@pytest.mark.skipif(sys.platform.startswith("win"), reason="skipping windows test as windows memory leaks (#787)")
527-
def test_from_netcdf_memory_containment(mode, time_periodic, chunksize, with_GC):
528+
def test_from_netcdf_memory_containment(mode, time_periodic, dt, chunksize, with_GC):
529+
if time_periodic and dt < 0:
530+
return True # time_periodic does not work in backward-time mode
528531
if chunksize == 'auto':
529532
dask.config.set({'array.chunk-size': '2MiB'})
530533
else:
@@ -575,7 +578,7 @@ def periodicBoundaryConditions(particle, fieldset, time):
575578
mem_0 = process.memory_info().rss
576579
mem_exhausted = False
577580
try:
578-
pset.execute(pset.Kernel(AdvectionRK4)+periodicBoundaryConditions, dt=delta(hours=1), runtime=delta(days=7), postIterationCallbacks=postProcessFuncs, callbackdt=delta(hours=12))
581+
pset.execute(pset.Kernel(AdvectionRK4)+periodicBoundaryConditions, dt=dt, runtime=delta(days=7), postIterationCallbacks=postProcessFuncs, callbackdt=delta(hours=12))
579582
except MemoryError:
580583
mem_exhausted = True
581584
mem_steps_np = np.array(perflog.memory_steps)
@@ -768,7 +771,7 @@ def DoNothing(particle, fieldset, time):
768771
assert np.allclose(fieldset.U.data, scale_fac*(zdim-1.)/zdim)
769772

770773

771-
@pytest.mark.parametrize('time2', [2, 7])
774+
@pytest.mark.parametrize('time2', [1, 7])
772775
def test_fieldset_initialisation_kernel_dask(time2, tmpdir, filename='test_parcels_defer_loading'):
773776
filepath = tmpdir.join(filename)
774777
data0, dims0 = generate_fieldset(3, 3, 4, 10)
@@ -789,7 +792,7 @@ class SampleParticle(JITParticle):
789792
pset = ParticleSet(fieldset, pclass=SampleParticle, time=[0, time2],
790793
lon=[0.5, 0.5], lat=[0.5, 0.5], depth=[0.5, 0.5])
791794

792-
if time2 > 2:
795+
if time2 > 1:
793796
failed = False
794797
try:
795798
pset.execute(SampleField, dt=0.)
@@ -905,3 +908,30 @@ def test_fieldset_from_data_gridtypes(xdim=20, ydim=10, zdim=4):
905908
pset.execute(AdvectionRK4, runtime=1, dt=.5)
906909
assert np.allclose(plon, pset.lon)
907910
assert np.allclose(plat, pset.lat)
911+
912+
913+
@pytest.mark.parametrize('mode', ['scipy', 'jit'])
914+
@pytest.mark.parametrize('direction', [1, -1])
915+
@pytest.mark.parametrize('time_extrapolation', [True, False])
916+
def test_deferredload_simplefield(mode, direction, time_extrapolation, tmpdir, tdim=10):
917+
filename = tmpdir.join("simplefield_deferredload.nc")
918+
data = np.zeros((tdim, 2, 2))
919+
for ti in range(tdim):
920+
data[ti, :, :] = ti if direction == 1 else tdim-ti-1
921+
ds = xr.Dataset({"U": (("t", "y", "x"), data), "V": (("t", "y", "x"), data)},
922+
coords={"x": [0, 1], "y": [0, 1], "t": np.arange(tdim)})
923+
ds.to_netcdf(filename)
924+
925+
fieldset = FieldSet.from_netcdf(filename, {'U': 'U', 'V': 'V'}, {'lon': 'x', 'lat': 'y', 'time': 't'},
926+
deferred_load=True, mesh='flat', allow_time_extrapolation=time_extrapolation)
927+
928+
class SamplingParticle(ptype[mode]):
929+
p = Variable('p')
930+
pset = ParticleSet(fieldset, SamplingParticle, lon=0.5, lat=0.5)
931+
932+
def SampleU(particle, fieldset, time):
933+
particle.p = fieldset.U[particle]
934+
935+
runtime = tdim*2 if time_extrapolation else None
936+
pset.execute(SampleU, dt=direction, runtime=runtime)
937+
assert pset.p == tdim-1 if time_extrapolation else tdim-2

0 commit comments

Comments
 (0)