Skip to content

Commit e77866f

Browse files
cleaning up _load_timesteps
1 parent d8e6def commit e77866f

3 files changed

Lines changed: 22 additions & 5 deletions

File tree

parcels/field.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,19 @@ def _check_velocitysampling(self):
242242
stacklevel=2,
243243
)
244244

245+
def _load_timesteps(self, time):
246+
"""Load the appropriate timesteps of a field."""
247+
ti = np.argmin(self.data_full.time.data <= time) - 1 # TODO also implement dt < 0
248+
if not hasattr(self, "data"):
249+
self.data = self.data_full.isel({"time": slice(ti, ti + 2)}).load()
250+
elif self.data_full.time.data[ti] == self.data.time.data[1]:
251+
self.data = xr.concat([self.data[1, :], self.data_full.isel({"time": ti + 1}).load()], dim="time")
252+
elif self.data_full.time.data[ti] != self.data.time.data[0]:
253+
self.data = self.data_full.isel({"time": slice(ti, ti + 2)}).load()
254+
assert (
255+
len(self.data.time) == 2
256+
), f"Field {self.name} has not been loaded correctly. Expected 2 timesteps, but got {len(self.data.time)}."
257+
245258
def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):
246259
"""Interpolate field values in space and time.
247260

parcels/fieldset.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,13 @@ def time_interval(self):
8282
return None
8383
return functools.reduce(lambda x, y: x.intersection(y), time_intervals)
8484

85+
def _load_timesteps(self, time):
86+
"""Load the appropriate timesteps of all fields in the fieldset."""
87+
for fldname in self.fields:
88+
field = self.fields[fldname]
89+
if isinstance(field, Field):
90+
field._load_timesteps(time)
91+
8592
def add_field(self, field: Field, name: str | None = None):
8693
"""Add a :class:`parcels.field.Field` object to the FieldSet.
8794

parcels/particleset.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -805,11 +805,8 @@ def execute(
805805

806806
time = start_time
807807
while sign_dt * (time - end_time) < 0:
808-
809-
for fld in [self.fieldset.U, self.fieldset.V]: # TODO generalise to all fields
810-
ti = np.argmin(fld.data_full.time.data <= self._data["time_nextloop"][0]) - 1 # TODO also implement dt < 0
811-
if not hasattr(fld, "data") or fld.data_full.time.data[ti] != fld.data.time.data[0]:
812-
fld.data = fld.data_full.isel({"time": slice(ti, ti + 2)}).load()
808+
# Load the appropriate timesteps of the fieldset
809+
self.fieldset._load_timesteps(self._data["time_nextloop"][0])
813810

814811
if sign_dt > 0:
815812
next_time = min(time + dt, end_time)

0 commit comments

Comments
 (0)