Skip to content
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ out-*
*.pyc
**/*.zarr/*
.DS_Store
*.parquet

.vscode
.env
Expand Down
227 changes: 49 additions & 178 deletions src/parcels/_core/particlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,20 @@

from __future__ import annotations

import os
from datetime import datetime, timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Any, Literal

import cftime
import numpy as np
import xarray as xr
import zarr
from zarr.storage import DirectoryStore
import pyarrow as pa
import pyarrow.parquet as pq

import parcels
from parcels._core.particle import ParticleClass
from parcels._core.utils.time import timedelta_to_float
from parcels._reprs import particlefile_repr
from parcels._typing import PathLike

if TYPE_CHECKING:
from parcels._core.particle import Variable
Expand All @@ -25,20 +24,19 @@

__all__ = ["ParticleFile"]

_DATATYPES_TO_FILL_VALUES = {
np.dtype(np.float16): np.nan,
np.dtype(np.float32): np.nan,
np.dtype(np.float64): np.nan,
np.dtype(np.bool_): np.iinfo(np.int8).max,
np.dtype(np.int8): np.iinfo(np.int8).max,
np.dtype(np.int16): np.iinfo(np.int16).max,
np.dtype(np.int32): np.iinfo(np.int32).max,
np.dtype(np.int64): np.iinfo(np.int64).min,
np.dtype(np.uint8): np.iinfo(np.uint8).max,
np.dtype(np.uint16): np.iinfo(np.uint16).max,
np.dtype(np.uint32): np.iinfo(np.uint32).max,
np.dtype(np.uint64): np.iinfo(np.uint64).max,
}

def _get_schema(particle: parcels.ParticleClass, file_metadata: dict[Any, Any]) -> pa.Schema:
return pa.schema(
[
pa.field(
v.name,
pa.from_numpy_dtype(v.dtype),
metadata=v.attrs,
)
for v in _get_vars_to_write(particle)
],
metadata=file_metadata.copy(),
)


class ParticleFile:
Expand All @@ -54,53 +52,51 @@ class ParticleFile:
Interval which dictates the update frequency of file output
while ParticleFile is given as an argument of ParticleSet.execute()
It is either a numpy.timedelta64, a datimetime.timedelta object or a positive float (in seconds).
chunks :
Tuple (trajs, obs) to control the size of chunks in the zarr output.
create_new_zarrfile : bool
Whether to create a new file. Default is True

Returns
-------
ParticleFile
ParticleFile object that can be used to write particle data to file
"""

def __init__(self, store, outputdt, chunks=None, create_new_zarrfile=True):
def __init__(self, path: PathLike, outputdt):
if not isinstance(outputdt, (np.timedelta64, timedelta, float)):
raise ValueError(
f"Expected outputdt to be a np.timedelta64, datetime.timedelta or float (in seconds), got {type(outputdt)}"
)

outputdt = timedelta_to_float(outputdt)
path = Path(path)

if path.suffix != ".parquet":
raise ValueError(
f"ParticleFile data is stored in Parquet files - file extension must be '.parquet'. Got {path.suffix=!r}."
)

if outputdt <= 0:
raise ValueError(f"outputdt must be positive/non-zero. Got {outputdt=!r}")

self._outputdt = outputdt

_assert_valid_chunks_tuple(chunks)
self._chunks = chunks
self._path = path # TODO v4: Consider https://arrow.apache.org/docs/python/getstarted.html#working-with-large-data - though a significant question becomes how to partition, perhaps using a particle variable "partition"?
self._writer: pq.ParquetWriter | None = None
if path.exists():
# TODO: Add logic for recovering/appending to existing parquet file
raise ValueError(f"{path=!r} already exists. Either delete this file or use a path that doesn't exist.")
if not path.parent.exists():
raise ValueError(f"Folder location for {path=!r} does not exist. Create the folder location first.")

self._maxids = 0
self._pids_written = {}
self.metadata = {}
self._create_new_zarrfile = create_new_zarrfile

if not isinstance(store, zarr.storage.Store):
store = _get_store_from_pathlike(store)

self._store = store

# TODO v4: Enable once updating to zarr v3
# if store.read_only:
# raise ValueError(f"Store {store} is read-only. Please provide a writable store.")
self.extra_metadata = {}

# TODO v4: Add check that if create_new_zarrfile is False, the store already exists

def __repr__(self) -> str:
return particlefile_repr(self)

def set_metadata(self, parcels_grid_mesh: Literal["spherical", "flat"]):
self.metadata.update(
self.extra_metadata.update(
{
"feature_type": "trajectory",
"Conventions": "CF-1.6/CF-1.7",
Expand All @@ -115,31 +111,8 @@ def outputdt(self):
return self._outputdt

@property
def chunks(self):
return self._chunks

@property
def store(self):
return self._store

@property
def create_new_zarrfile(self):
return self._create_new_zarrfile

def _extend_zarr_dims(self, Z, store, dtype, axis): # noqa: N803
if axis == 1:
a = np.full((Z.shape[0], self.chunks[1]), _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype)
obs = zarr.group(store=store, overwrite=False)["obs"]
if len(obs) == Z.shape[1]:
obs.append(np.arange(self.chunks[1]) + obs[-1] + 1)
else:
extra_trajs = self._maxids - Z.shape[0]
if len(Z.shape) == 2:
a = np.full((extra_trajs, Z.shape[1]), _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype)
else:
a = np.full((extra_trajs,), _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype)
Z.append(a, axis=axis)
zarr.consolidate_metadata(store)
def path(self):
return self._path

def write(self, pset: ParticleSet, time, indices=None):
"""Write all data from one time step to the zarr file,
Expand All @@ -156,124 +129,35 @@ def write(self, pset: ParticleSet, time, indices=None):
time_interval = pset.fieldset.time_interval
particle_data = pset._data

self._write_particle_data(
particle_data=particle_data, pclass=pclass, time_interval=time_interval, time=time, indices=indices
)
if self._writer is None:
assert not self.path.exists(), "If the file exists, the writer should already be set"
self._writer = pq.ParquetWriter(self.path, _get_schema(pclass, self.extra_metadata))

def _write_particle_data(self, *, particle_data, pclass, time_interval, time, indices=None):
# if pset._data._ncount == 0:
# warnings.warn(
# f"ParticleSet is empty on writing as array at time {time:g}",
# RuntimeWarning,
# stacklevel=2,
# )
# return
if isinstance(time, (np.timedelta64, np.datetime64)):
time = timedelta_to_float(time - time_interval.left)
nparticles = len(particle_data["trajectory"])
vars_to_write = _get_vars_to_write(pclass)
if indices is None:
indices_to_write = _to_write_particles(particle_data, time)
else:
indices_to_write = indices

if len(indices_to_write) == 0:
return

pids = particle_data["trajectory"][indices_to_write]
to_add = sorted(set(pids) - set(self._pids_written.keys()))
for i, pid in enumerate(to_add):
self._pids_written[pid] = self._maxids + i
ids = np.array([self._pids_written[p] for p in pids], dtype=int)
self._maxids = len(self._pids_written)

once_ids = np.where(particle_data["obs_written"][indices_to_write] == 0)[0]
if len(once_ids) > 0:
ids_once = ids[once_ids]
indices_to_write_once = indices_to_write[once_ids]

store = self.store
if self.create_new_zarrfile:
if self.chunks is None:
self._chunks = (nparticles, 1)
if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]):
arrsize = (self._maxids, self.chunks[1])
else:
arrsize = (len(ids), self.chunks[1])
ds = xr.Dataset(
attrs=self.metadata,
coords={"trajectory": ("trajectory", pids), "obs": ("obs", np.arange(arrsize[1], dtype=np.int32))},
)
attrs = _create_variables_attribute_dict(pclass, time_interval)
obs = np.zeros((self._maxids), dtype=np.int32)
for var in vars_to_write:
if var.name not in ["trajectory"]: # because 'trajectory' is written as coordinate
if var.to_write == "once":
data = np.full(
(arrsize[0],),
_DATATYPES_TO_FILL_VALUES[var.dtype],
dtype=var.dtype,
)
data[ids_once] = particle_data[var.name][indices_to_write_once]
dims = ["trajectory"]
else:
data = np.full(arrsize, _DATATYPES_TO_FILL_VALUES[var.dtype], dtype=var.dtype)
data[ids, 0] = particle_data[var.name][indices_to_write]
dims = ["trajectory", "obs"]
ds[var.name] = xr.DataArray(data=data, dims=dims, attrs=attrs[var.name])
ds[var.name].encoding["chunks"] = self.chunks[0] if var.to_write == "once" else self.chunks
ds.to_zarr(store, mode="w")
self._create_new_zarrfile = False
else:
Z = zarr.group(store=store, overwrite=False)
obs = particle_data["obs_written"][indices_to_write]
for var in vars_to_write:
if self._maxids > Z[var.name].shape[0]:
self._extend_zarr_dims(Z[var.name], store, dtype=var.dtype, axis=0)
if var.to_write == "once":
if len(once_ids) > 0:
Z[var.name].vindex[ids_once] = particle_data[var.name][indices_to_write_once]
else:
if max(obs) >= Z[var.name].shape[1]:
self._extend_zarr_dims(Z[var.name], store, dtype=var.dtype, axis=1)
Z[var.name].vindex[ids, obs] = particle_data[var.name][indices_to_write]

particle_data["obs_written"][indices_to_write] = obs + 1

self._writer.write_table(
pa.table({v.name: pa.array(particle_data[v.name][indices_to_write]) for v in vars_to_write}),
)

def _get_store_from_pathlike(path: Path | str) -> DirectoryStore:
path = str(Path(path)) # Ensure valid path, and convert to string
extension = os.path.splitext(path)[1]
if extension != ".zarr":
raise ValueError(f"ParticleFile name must end with '.zarr' extension. Got path {path!r}.")
# if len(indices_to_write) == 0: # TODO: Remove this?
# return

return DirectoryStore(path)
def close(self):
if self._writer is not None:
self._writer.close()
self._writer = None


def _get_vars_to_write(particle: ParticleClass) -> list[Variable]:
return [v for v in particle.variables if v.to_write is not False]


def _create_variables_attribute_dict(particle: ParticleClass, time_interval: TimeInterval) -> dict:
"""Creates the dictionary with variable attributes.

Notes
-----
For ParticleSet structures other than SoA, and structures where ID != index, this has to be overridden.
"""
attrs = {}

vars = [var for var in particle.variables if var.to_write is not False]
for var in vars:
fill_value = {"_FillValue": _DATATYPES_TO_FILL_VALUES[var.dtype]}

attrs[var.name] = {**var.attrs, **fill_value}

attrs["time"].update(_get_calendar_and_units(time_interval))

return attrs


def _to_write_particles(particle_data, time):
"""Return the Particles that need to be written at time: if particle.time is between time-dt/2 and time+dt (/2)"""
return np.where(
Expand All @@ -298,7 +182,7 @@ def _to_write_particles(particle_data, time):
)[0]


def _get_calendar_and_units(time_interval: TimeInterval) -> dict[str, str]:
def _get_calendar_and_units(time_interval: TimeInterval) -> dict[str, str]: # TODO: Remove?
calendar = None
units = "seconds"
if time_interval:
Expand All @@ -315,16 +199,3 @@ def _get_calendar_and_units(time_interval: TimeInterval) -> dict[str, str]:
attrs["calendar"] = calendar

return attrs


def _assert_valid_chunks_tuple(chunks):
e = ValueError(f"chunks must be a tuple of integers with length 2, got {chunks=!r} instead.")
if chunks is None:
return

if not isinstance(chunks, tuple):
raise e
if len(chunks) != 2:
raise e
if not all(isinstance(c, int) for c in chunks):
raise e
9 changes: 6 additions & 3 deletions src/parcels/_core/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from parcels._core.warnings import ParticleSetWarning
from parcels._logger import logger
from parcels._reprs import _format_zarr_output_location, particleset_repr
from parcels._reprs import particleset_repr

__all__ = ["ParticleSet"]

Expand Down Expand Up @@ -395,7 +395,7 @@ def execute(

if output_file is not None:
output_file.set_metadata(self.fieldset.gridset[0]._mesh)
output_file.metadata["parcels_kernels"] = self._kernel.funcname
output_file.extra_metadata["parcels_kernels"] = self._kernel.funcname

dt, sign_dt = _convert_dt_to_float(dt)
self._data["dt"][:] = dt
Expand All @@ -415,7 +415,7 @@ def execute(

# Set up pbar
if output_file:
logger.info(f"Output files are stored in {_format_zarr_output_location(output_file.store)}")
logger.info(f"Output files are stored in {output_file.path}")

if verbose_progress:
pbar = tqdm(total=end_time - start_time, file=sys.stdout)
Expand Down Expand Up @@ -451,6 +451,9 @@ def execute(

time = next_time

if output_file is not None:
output_file.close()

if verbose_progress:
pbar.close()

Expand Down
Loading
Loading