diff --git a/.gitignore b/.gitignore index 45b79435b..6f07f7d82 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ out-* *.pyc **/*.zarr/* .DS_Store +*.parquet .vscode .env diff --git a/src/parcels/_core/particlefile.py b/src/parcels/_core/particlefile.py index 788c6e572..c9cdfa62a 100644 --- a/src/parcels/_core/particlefile.py +++ b/src/parcels/_core/particlefile.py @@ -2,21 +2,20 @@ from __future__ import annotations -import os from datetime import datetime, timedelta from pathlib import Path -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Any, Literal import cftime import numpy as np -import xarray as xr -import zarr -from zarr.storage import DirectoryStore +import pyarrow as pa +import pyarrow.parquet as pq import parcels from parcels._core.particle import ParticleClass from parcels._core.utils.time import timedelta_to_float from parcels._reprs import particlefile_repr +from parcels._typing import PathLike if TYPE_CHECKING: from parcels._core.particle import Variable @@ -25,20 +24,19 @@ __all__ = ["ParticleFile"] -_DATATYPES_TO_FILL_VALUES = { - np.dtype(np.float16): np.nan, - np.dtype(np.float32): np.nan, - np.dtype(np.float64): np.nan, - np.dtype(np.bool_): np.iinfo(np.int8).max, - np.dtype(np.int8): np.iinfo(np.int8).max, - np.dtype(np.int16): np.iinfo(np.int16).max, - np.dtype(np.int32): np.iinfo(np.int32).max, - np.dtype(np.int64): np.iinfo(np.int64).min, - np.dtype(np.uint8): np.iinfo(np.uint8).max, - np.dtype(np.uint16): np.iinfo(np.uint16).max, - np.dtype(np.uint32): np.iinfo(np.uint32).max, - np.dtype(np.uint64): np.iinfo(np.uint64).max, -} + +def _get_schema(particle: parcels.ParticleClass, file_metadata: dict[Any, Any]) -> pa.Schema: + return pa.schema( + [ + pa.field( + v.name, + pa.from_numpy_dtype(v.dtype), + metadata=v.attrs, + ) + for v in _get_vars_to_write(particle) + ], + metadata=file_metadata.copy(), + ) class ParticleFile: @@ -54,10 +52,6 @@ class ParticleFile: Interval which dictates the update frequency of file output while ParticleFile is given as an argument of ParticleSet.execute() It is either a numpy.timedelta64, a datimetime.timedelta object or a positive float (in seconds). - chunks : - Tuple (trajs, obs) to control the size of chunks in the zarr output. - create_new_zarrfile : bool - Whether to create a new file. Default is True Returns ------- @@ -65,34 +59,36 @@ class ParticleFile: ParticleFile object that can be used to write particle data to file """ - def __init__(self, store, outputdt, chunks=None, create_new_zarrfile=True): + def __init__(self, path: PathLike, outputdt): if not isinstance(outputdt, (np.timedelta64, timedelta, float)): raise ValueError( f"Expected outputdt to be a np.timedelta64, datetime.timedelta or float (in seconds), got {type(outputdt)}" ) outputdt = timedelta_to_float(outputdt) + path = Path(path) + + if path.suffix != ".parquet": + raise ValueError( + f"ParticleFile data is stored in Parquet files - file extension must be '.parquet'. Got {path.suffix=!r}." + ) if outputdt <= 0: raise ValueError(f"outputdt must be positive/non-zero. Got {outputdt=!r}") self._outputdt = outputdt - _assert_valid_chunks_tuple(chunks) - self._chunks = chunks + self._path = path # TODO v4: Consider https://arrow.apache.org/docs/python/getstarted.html#working-with-large-data - though a significant question becomes how to partition, perhaps using a particle variable "partition"? + self._writer: pq.ParquetWriter | None = None + if path.exists(): + # TODO: Add logic for recovering/appending to existing parquet file + raise ValueError(f"{path=!r} already exists. Either delete this file or use a path that doesn't exist.") + if not path.parent.exists(): + raise ValueError(f"Folder location for {path=!r} does not exist. Create the folder location first.") + self._maxids = 0 self._pids_written = {} - self.metadata = {} - self._create_new_zarrfile = create_new_zarrfile - - if not isinstance(store, zarr.storage.Store): - store = _get_store_from_pathlike(store) - - self._store = store - - # TODO v4: Enable once updating to zarr v3 - # if store.read_only: - # raise ValueError(f"Store {store} is read-only. Please provide a writable store.") + self.extra_metadata = {} # TODO v4: Add check that if create_new_zarrfile is False, the store already exists @@ -100,7 +96,7 @@ def __repr__(self) -> str: return particlefile_repr(self) def set_metadata(self, parcels_grid_mesh: Literal["spherical", "flat"]): - self.metadata.update( + self.extra_metadata.update( { "feature_type": "trajectory", "Conventions": "CF-1.6/CF-1.7", @@ -115,31 +111,8 @@ def outputdt(self): return self._outputdt @property - def chunks(self): - return self._chunks - - @property - def store(self): - return self._store - - @property - def create_new_zarrfile(self): - return self._create_new_zarrfile - - def _extend_zarr_dims(self, Z, store, dtype, axis): # noqa: N803 - if axis == 1: - a = np.full((Z.shape[0], self.chunks[1]), _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) - obs = zarr.group(store=store, overwrite=False)["obs"] - if len(obs) == Z.shape[1]: - obs.append(np.arange(self.chunks[1]) + obs[-1] + 1) - else: - extra_trajs = self._maxids - Z.shape[0] - if len(Z.shape) == 2: - a = np.full((extra_trajs, Z.shape[1]), _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) - else: - a = np.full((extra_trajs,), _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) - Z.append(a, axis=axis) - zarr.consolidate_metadata(store) + def path(self): + return self._path def write(self, pset: ParticleSet, time, indices=None): """Write all data from one time step to the zarr file, @@ -156,124 +129,35 @@ def write(self, pset: ParticleSet, time, indices=None): time_interval = pset.fieldset.time_interval particle_data = pset._data - self._write_particle_data( - particle_data=particle_data, pclass=pclass, time_interval=time_interval, time=time, indices=indices - ) + if self._writer is None: + assert not self.path.exists(), "If the file exists, the writer should already be set" + self._writer = pq.ParquetWriter(self.path, _get_schema(pclass, self.extra_metadata)) - def _write_particle_data(self, *, particle_data, pclass, time_interval, time, indices=None): - # if pset._data._ncount == 0: - # warnings.warn( - # f"ParticleSet is empty on writing as array at time {time:g}", - # RuntimeWarning, - # stacklevel=2, - # ) - # return if isinstance(time, (np.timedelta64, np.datetime64)): time = timedelta_to_float(time - time_interval.left) - nparticles = len(particle_data["trajectory"]) vars_to_write = _get_vars_to_write(pclass) if indices is None: indices_to_write = _to_write_particles(particle_data, time) else: indices_to_write = indices - if len(indices_to_write) == 0: - return - - pids = particle_data["trajectory"][indices_to_write] - to_add = sorted(set(pids) - set(self._pids_written.keys())) - for i, pid in enumerate(to_add): - self._pids_written[pid] = self._maxids + i - ids = np.array([self._pids_written[p] for p in pids], dtype=int) - self._maxids = len(self._pids_written) - - once_ids = np.where(particle_data["obs_written"][indices_to_write] == 0)[0] - if len(once_ids) > 0: - ids_once = ids[once_ids] - indices_to_write_once = indices_to_write[once_ids] - - store = self.store - if self.create_new_zarrfile: - if self.chunks is None: - self._chunks = (nparticles, 1) - if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]): - arrsize = (self._maxids, self.chunks[1]) - else: - arrsize = (len(ids), self.chunks[1]) - ds = xr.Dataset( - attrs=self.metadata, - coords={"trajectory": ("trajectory", pids), "obs": ("obs", np.arange(arrsize[1], dtype=np.int32))}, - ) - attrs = _create_variables_attribute_dict(pclass, time_interval) - obs = np.zeros((self._maxids), dtype=np.int32) - for var in vars_to_write: - if var.name not in ["trajectory"]: # because 'trajectory' is written as coordinate - if var.to_write == "once": - data = np.full( - (arrsize[0],), - _DATATYPES_TO_FILL_VALUES[var.dtype], - dtype=var.dtype, - ) - data[ids_once] = particle_data[var.name][indices_to_write_once] - dims = ["trajectory"] - else: - data = np.full(arrsize, _DATATYPES_TO_FILL_VALUES[var.dtype], dtype=var.dtype) - data[ids, 0] = particle_data[var.name][indices_to_write] - dims = ["trajectory", "obs"] - ds[var.name] = xr.DataArray(data=data, dims=dims, attrs=attrs[var.name]) - ds[var.name].encoding["chunks"] = self.chunks[0] if var.to_write == "once" else self.chunks - ds.to_zarr(store, mode="w") - self._create_new_zarrfile = False - else: - Z = zarr.group(store=store, overwrite=False) - obs = particle_data["obs_written"][indices_to_write] - for var in vars_to_write: - if self._maxids > Z[var.name].shape[0]: - self._extend_zarr_dims(Z[var.name], store, dtype=var.dtype, axis=0) - if var.to_write == "once": - if len(once_ids) > 0: - Z[var.name].vindex[ids_once] = particle_data[var.name][indices_to_write_once] - else: - if max(obs) >= Z[var.name].shape[1]: - self._extend_zarr_dims(Z[var.name], store, dtype=var.dtype, axis=1) - Z[var.name].vindex[ids, obs] = particle_data[var.name][indices_to_write] - - particle_data["obs_written"][indices_to_write] = obs + 1 - + self._writer.write_table( + pa.table({v.name: pa.array(particle_data[v.name][indices_to_write]) for v in vars_to_write}), + ) -def _get_store_from_pathlike(path: Path | str) -> DirectoryStore: - path = str(Path(path)) # Ensure valid path, and convert to string - extension = os.path.splitext(path)[1] - if extension != ".zarr": - raise ValueError(f"ParticleFile name must end with '.zarr' extension. Got path {path!r}.") + # if len(indices_to_write) == 0: # TODO: Remove this? + # return - return DirectoryStore(path) + def close(self): + if self._writer is not None: + self._writer.close() + self._writer = None def _get_vars_to_write(particle: ParticleClass) -> list[Variable]: return [v for v in particle.variables if v.to_write is not False] -def _create_variables_attribute_dict(particle: ParticleClass, time_interval: TimeInterval) -> dict: - """Creates the dictionary with variable attributes. - - Notes - ----- - For ParticleSet structures other than SoA, and structures where ID != index, this has to be overridden. - """ - attrs = {} - - vars = [var for var in particle.variables if var.to_write is not False] - for var in vars: - fill_value = {"_FillValue": _DATATYPES_TO_FILL_VALUES[var.dtype]} - - attrs[var.name] = {**var.attrs, **fill_value} - - attrs["time"].update(_get_calendar_and_units(time_interval)) - - return attrs - - def _to_write_particles(particle_data, time): """Return the Particles that need to be written at time: if particle.time is between time-dt/2 and time+dt (/2)""" return np.where( @@ -298,7 +182,7 @@ def _to_write_particles(particle_data, time): )[0] -def _get_calendar_and_units(time_interval: TimeInterval) -> dict[str, str]: +def _get_calendar_and_units(time_interval: TimeInterval) -> dict[str, str]: # TODO: Remove? calendar = None units = "seconds" if time_interval: @@ -315,16 +199,3 @@ def _get_calendar_and_units(time_interval: TimeInterval) -> dict[str, str]: attrs["calendar"] = calendar return attrs - - -def _assert_valid_chunks_tuple(chunks): - e = ValueError(f"chunks must be a tuple of integers with length 2, got {chunks=!r} instead.") - if chunks is None: - return - - if not isinstance(chunks, tuple): - raise e - if len(chunks) != 2: - raise e - if not all(isinstance(c, int) for c in chunks): - raise e diff --git a/src/parcels/_core/particleset.py b/src/parcels/_core/particleset.py index 5483ffbe4..f2e74112f 100644 --- a/src/parcels/_core/particleset.py +++ b/src/parcels/_core/particleset.py @@ -20,7 +20,7 @@ ) from parcels._core.warnings import ParticleSetWarning from parcels._logger import logger -from parcels._reprs import _format_zarr_output_location, particleset_repr +from parcels._reprs import particleset_repr __all__ = ["ParticleSet"] @@ -395,7 +395,7 @@ def execute( if output_file is not None: output_file.set_metadata(self.fieldset.gridset[0]._mesh) - output_file.metadata["parcels_kernels"] = self._kernel.funcname + output_file.extra_metadata["parcels_kernels"] = self._kernel.funcname dt, sign_dt = _convert_dt_to_float(dt) self._data["dt"][:] = dt @@ -415,7 +415,7 @@ def execute( # Set up pbar if output_file: - logger.info(f"Output files are stored in {_format_zarr_output_location(output_file.store)}") + logger.info(f"Output files are stored in {output_file.path}") if verbose_progress: pbar = tqdm(total=end_time - start_time, file=sys.stdout) @@ -451,6 +451,9 @@ def execute( time = next_time + if output_file is not None: + output_file.close() + if verbose_progress: pbar.close() diff --git a/src/parcels/_reprs.py b/src/parcels/_reprs.py index ad6d0cca2..d27eee379 100644 --- a/src/parcels/_reprs.py +++ b/src/parcels/_reprs.py @@ -7,7 +7,6 @@ import numpy as np import xarray as xr -from zarr.storage import DirectoryStore if TYPE_CHECKING: from parcels import Field, FieldSet, ParticleSet @@ -128,7 +127,7 @@ def timeinterval_repr(ti: Any) -> str: def particlefile_repr(pfile: Any) -> str: """Return a pretty repr for ParticleFile""" out = f"""<{type(pfile).__name__}> - store : {_format_zarr_output_location(pfile.store)} + path : {pfile.path} outputdt : {pfile.outputdt!r} chunks : {pfile.chunks!r} create_new_zarrfile : {pfile.create_new_zarrfile!r} @@ -178,11 +177,5 @@ def _format_list_items_multiline(items: list[str] | dict, level: int = 1, with_b return "\n".join([textwrap.indent(e, indentation_str) for e in entries]) -def _format_zarr_output_location(zarr_obj): - if isinstance(zarr_obj, DirectoryStore): - return zarr_obj.path - return repr(zarr_obj) - - def is_builtin_object(obj): return obj.__class__.__module__ == "builtins" diff --git a/tests-v3/test_advection.py b/tests-v3/test_advection.py index 3d8f06bac..bdd7a4221 100644 --- a/tests-v3/test_advection.py +++ b/tests-v3/test_advection.py @@ -1,5 +1,6 @@ import numpy as np import pytest +import pandas as pd import xarray as xr from parcels import ( @@ -79,7 +80,7 @@ def test_analyticalAgrid(): @pytest.mark.parametrize("v", [1, -0.3, 0, -1]) @pytest.mark.parametrize("w", [None, 1, -0.3, 0, -1]) @pytest.mark.parametrize("direction", [1, -1]) -def test_uniform_analytical(u, v, w, direction, tmp_zarrfile): +def test_uniform_analytical(u, v, w, direction, tmp_parquet): lon = np.arange(0, 15, dtype=np.float32) lat = np.arange(0, 15, dtype=np.float32) if w is not None: @@ -99,14 +100,14 @@ def test_uniform_analytical(u, v, w, direction, tmp_zarrfile): x0, y0, z0 = 6.1, 6.2, 20 pset = ParticleSet(fieldset, pclass=Particle, lon=x0, lat=y0, depth=z0) - outfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=1, chunks=(1, 1)) + outfile = pset.ParticleFile(name=tmp_parquet, outputdt=1, chunks=(1, 1)) pset.execute(AdvectionAnalytical, runtime=4, dt=direction, output_file=outfile) assert np.abs(pset.lon - x0 - pset.time * u) < 1e-6 assert np.abs(pset.lat - y0 - pset.time * v) < 1e-6 if w is not None: assert np.abs(pset.depth - z0 - pset.time * w) < 1e-4 - ds = xr.open_zarr(tmp_zarrfile) + ds = xr.open_zarr(tmp_parquet) times = (direction * ds["time"][:]).values.astype("timedelta64[s]")[0] timeref = np.arange(1, 5).astype("timedelta64[s]") assert np.allclose(times, timeref, atol=np.timedelta64(1, "ms")) diff --git a/tests-v3/test_fieldset_sampling.py b/tests-v3/test_fieldset_sampling.py index 291c27b88..176eedab1 100644 --- a/tests-v3/test_fieldset_sampling.py +++ b/tests-v3/test_fieldset_sampling.py @@ -3,6 +3,7 @@ from math import cos, pi import numpy as np +import pandas as pd import pytest import xarray as xr @@ -773,7 +774,7 @@ def test_multiple_grid_addlater_error(): assert fail -def test_fieldset_sampling_updating_order(tmp_zarrfile): +def test_fieldset_sampling_updating_order(tmp_parquet): def calc_p(t, y, x): return 10 * t + x + 0.2 * y @@ -805,10 +806,10 @@ def SampleP(particle, fieldset, time): # pragma: no cover kernels = [AdvectionRK4, SampleP] - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1) + pfile = pset.ParticleFile(tmp_parquet, outputdt=1) pset.execute(kernels, endtime=1, dt=1, output_file=pfile) - ds = xr.open_zarr(tmp_zarrfile) + ds = xr.open_zarr(tmp_parquet) for t in range(len(ds["obs"])): for i in range(len(ds["trajectory"])): assert np.isclose( diff --git a/tests-v3/test_particlesets.py b/tests-v3/test_particlesets.py index ed884f595..5c0f2495f 100644 --- a/tests-v3/test_particlesets.py +++ b/tests-v3/test_particlesets.py @@ -39,7 +39,7 @@ def test_pset_create_list_with_customvariable(fieldset): @pytest.mark.parametrize("restart", [True, False]) -def test_pset_create_fromparticlefile(fieldset, restart, tmp_zarrfile): +def test_pset_create_fromparticlefile(fieldset, restart, tmp_parquet): lon = np.linspace(0, 1, 10, dtype=np.float32) lat = np.linspace(1, 0, 10, dtype=np.float32) @@ -48,7 +48,7 @@ def test_pset_create_fromparticlefile(fieldset, restart, tmp_zarrfile): TestParticle = TestParticle.add_variable("p3", np.float64, to_write="once") pset = ParticleSet(fieldset, lon=lon, lat=lat, depth=[4] * len(lon), pclass=TestParticle, p3=np.arange(len(lon))) - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1) + pfile = pset.ParticleFile(tmp_parquet, outputdt=1) def Kernel(particle, fieldset, time): # pragma: no cover particle.p = 2.0 @@ -58,7 +58,7 @@ def Kernel(particle, fieldset, time): # pragma: no cover pset.execute(Kernel, runtime=2, dt=1, output_file=pfile) pset_new = ParticleSet.from_particlefile( - fieldset, pclass=TestParticle, filename=tmp_zarrfile, restart=restart, repeatdt=1 + fieldset, pclass=TestParticle, filename=tmp_parquet, restart=restart, repeatdt=1 ) for var in ["lon", "lat", "depth", "time", "p", "p2", "p3"]: diff --git a/tests/conftest.py b/tests/conftest.py index 82020c37e..0fd949880 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,6 @@ import pytest -from zarr.storage import MemoryStore - - -@pytest.fixture() -def tmp_zarrfile(tmp_path, request): - test_name = request.node.name - yield tmp_path / f"{test_name}-output.zarr" @pytest.fixture -def tmp_store(): - return MemoryStore() +def tmp_parquet(tmp_path): + return tmp_path / "tmp.parquet" diff --git a/tests/test_advection.py b/tests/test_advection.py index d8c6d2a45..95eca30f3 100644 --- a/tests/test_advection.py +++ b/tests/test_advection.py @@ -1,4 +1,5 @@ import numpy as np +import pandas as pd import pytest import xarray as xr @@ -60,7 +61,7 @@ def test_advection_zonal(mesh, npart=10): np.testing.assert_allclose(pset.lat, startlat, atol=1e-5) -def test_advection_zonal_with_particlefile(tmp_store): +def test_advection_zonal_with_particlefile(tmp_parquet): """Particles at high latitude move geographically faster due to the pole correction.""" npart = 10 ds = simple_UV_dataset(mesh="flat") @@ -68,12 +69,13 @@ def test_advection_zonal_with_particlefile(tmp_store): fieldset = FieldSet.from_sgrid_conventions(ds, mesh="flat") pset = ParticleSet(fieldset, lon=np.zeros(npart) + 20.0, lat=np.linspace(0, 80, npart)) - pfile = ParticleFile(tmp_store, outputdt=np.timedelta64(30, "m")) + pfile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(30, "m")) pset.execute(AdvectionRK4, runtime=np.timedelta64(2, "h"), dt=np.timedelta64(15, "m"), output_file=pfile) assert (np.diff(pset.lon) < 1.0e-4).all() - ds = xr.open_zarr(tmp_store) - np.testing.assert_allclose(ds.isel(obs=-1).lon.values, pset.lon) + df = pd.read_parquet(tmp_parquet) + final_time = df["time"].max() + np.testing.assert_allclose(df[df["time"] == final_time]["lon"].values, pset.lon, atol=1e-5) def periodicBC(particles, fieldset): diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index b2b05d33f..6eeef20a6 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -3,6 +3,7 @@ import cf_xarray # noqa: F401 import cftime import numpy as np +import pandas as pd import pytest import xarray as xr @@ -95,7 +96,7 @@ def test_fieldset_gridset(fieldset): assert len(fieldset.gridset) == 2 -def test_fieldset_no_UV(tmp_zarrfile): +def test_fieldset_no_UV(tmp_parquet): grid = XGrid.from_dataset(ds, mesh="flat") fieldset = FieldSet([Field("P", ds["U_A_grid"], grid, interp_method=XLinear)]) @@ -103,11 +104,11 @@ def SampleP(particles, fieldset): particles.dlon += fieldset.P[particles] pset = ParticleSet(fieldset, lon=0, lat=0) - ofile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) + ofile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) pset.execute(SampleP, runtime=np.timedelta64(1, "s"), dt=np.timedelta64(1, "s"), output_file=ofile) - ds_out = xr.open_zarr(tmp_zarrfile) - assert ds_out["lon"].shape == (1, 2) + df = pd.read_parquet(tmp_parquet) + assert len(df["lon"]) == 2 @pytest.mark.parametrize("ds", [pytest.param(ds, id=k) for k, ds in datasets_structured.items()]) diff --git a/tests/test_particlefile.py b/tests/test_particlefile.py index 84cb90ffa..f2a3c4553 100755 --- a/tests/test_particlefile.py +++ b/tests/test_particlefile.py @@ -4,9 +4,11 @@ from datetime import datetime, timedelta import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq import pytest import xarray as xr -from zarr.storage import MemoryStore import parcels.tutorial from parcels import ( @@ -20,8 +22,9 @@ VectorField, XGrid, ) -from parcels._core.particle import Particle, create_particle_data, get_default_particle -from parcels._core.utils.time import TimeInterval, timedelta_to_float +from parcels._core.particle import Particle, get_default_particle +from parcels._core.particlefile import _get_schema +from parcels._core.utils.time import timedelta_to_float from parcels._datasets.structured.generated import peninsula_dataset from parcels._datasets.structured.generic import datasets from parcels.convert import copernicusmarine_to_sgrid @@ -44,35 +47,17 @@ def fieldset() -> FieldSet: # TODO v4: Move into a `conftest.py` file and remov ) -def test_metadata(fieldset, tmp_zarrfile): +def test_metadata(fieldset, tmp_parquet): pset = ParticleSet(fieldset, pclass=Particle, lon=0, lat=0) - ofile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) + ofile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) pset.execute(DoNothing, runtime=np.timedelta64(1, "s"), dt=np.timedelta64(1, "s"), output_file=ofile) - ds = xr.open_zarr(tmp_zarrfile) - assert ds.attrs["parcels_kernels"].lower() == "DoNothing".lower() + tab = pq.read_table(tmp_parquet) + assert tab.schema.metadata[b"parcels_kernels"].decode().lower() == "DoNothing".lower() -def test_pfile_array_write_zarr_memorystore(fieldset): - """Check that writing to a Zarr MemoryStore works.""" - npart = 10 - zarr_store = MemoryStore() - pset = ParticleSet( - fieldset, - pclass=Particle, - lon=np.linspace(0, 1, npart), - lat=0.5 * np.ones(npart), - time=fieldset.time_interval.left, - ) - pfile = ParticleFile(zarr_store, outputdt=np.timedelta64(1, "s")) - pfile.write(pset, time=fieldset.time_interval.left) - - ds = xr.open_zarr(zarr_store) - assert ds.sizes["trajectory"] == npart - - -def test_write_fieldset_without_time(tmp_zarrfile): +def test_write_fieldset_without_time(tmp_parquet): ds = peninsula_dataset() # DataSet without time assert "time" not in ds.dims grid = XGrid.from_dataset(ds, mesh="flat") @@ -80,14 +65,17 @@ def test_write_fieldset_without_time(tmp_zarrfile): pset = ParticleSet(fieldset, pclass=Particle, lon=0, lat=0) - ofile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) + ofile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) pset.execute(DoNothing, runtime=np.timedelta64(1, "s"), dt=np.timedelta64(1, "s"), output_file=ofile) - ds = xr.open_zarr(tmp_zarrfile) - assert ds.time.values[0, 1] == np.timedelta64(1, "s") + df = pd.read_parquet(tmp_parquet) + pytest.skip("# TODO: Need to figure out how times work with parquet output (#2386)") + assert df["time"][1] == np.timedelta64(1, "s") -def test_pfile_array_remove_particles(fieldset, tmp_zarrfile): +@pytest.mark.skip("Keep or remove? Introduced in 5d7dd6bba800baa0fe4bd38edfc17ca3e310062b ") +def test_pfile_array_remove_particles(fieldset, tmp_parquet): + """If a particle from the middle of a particleset is removed, that writing doesn't crash""" npart = 10 pset = ParticleSet( fieldset, @@ -96,20 +84,19 @@ def test_pfile_array_remove_particles(fieldset, tmp_zarrfile): lat=0.5 * np.ones(npart), time=fieldset.time_interval.left, ) - pfile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) + pfile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) pset._data["time"][:] = 0 pfile.write(pset, time=fieldset.time_interval.left) pset.remove_indices(3) new_time = 86400 # s in a day pset._data["time"][:] = new_time pfile.write(pset, new_time) - ds = xr.open_zarr(tmp_zarrfile) + ds = xr.open_zarr(tmp_parquet) timearr = ds["time"][:] assert (np.isnat(timearr[3, 1])) and (np.isfinite(timearr[3, 0])) -@pytest.mark.parametrize("chunks_obs", [1, None]) -def test_pfile_array_remove_all_particles(fieldset, chunks_obs, tmp_zarrfile): +def test_pfile_array_remove_all_particles(fieldset, tmp_parquet): npart = 10 pset = ParticleSet( fieldset, @@ -118,39 +105,20 @@ def test_pfile_array_remove_all_particles(fieldset, chunks_obs, tmp_zarrfile): lat=0.5 * np.ones(npart), time=fieldset.time_interval.left, ) - chunks = (npart, chunks_obs) if chunks_obs else None - pfile = ParticleFile(tmp_zarrfile, chunks=chunks, outputdt=np.timedelta64(1, "s")) + pfile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) pfile.write(pset, time=0) for _ in range(npart): pset.remove_indices(-1) pfile.write(pset, fieldset.time_interval.left + np.timedelta64(1, "D")) pfile.write(pset, fieldset.time_interval.left + np.timedelta64(2, "D")) + pfile.close() - ds = xr.open_zarr(tmp_zarrfile) - np.testing.assert_allclose(ds["time"][:, 0] - fieldset.time_interval.left, np.timedelta64(0, "s")) - if chunks_obs is not None: - assert ds["time"][:].shape == chunks - else: - assert ds["time"][:].shape[0] == npart - assert np.all(np.isnan(ds["time"][:, 1:])) - - -def test_variable_write_double(fieldset, tmp_zarrfile): - def Update_lon(particles, fieldset): # pragma: no cover - particles.dlon += 0.1 - - dt = np.timedelta64(1, "s") - particle = get_default_particle(np.float64) - pset = ParticleSet(fieldset, pclass=particle, lon=[0], lat=[0]) - ofile = ParticleFile(tmp_zarrfile, outputdt=dt) - pset.execute(Update_lon, runtime=np.timedelta64(10, "s"), dt=dt, output_file=ofile) + df = pd.read_parquet(tmp_parquet) + # np.testing.assert_allclose(ds["time"][:, 0] - fieldset.time_interval.left, np.timedelta64(0, "s")) # TODO: Need to figure out how times work with parquet output (#2386) + assert df["trajectory"].nunique() == npart - ds = xr.open_zarr(tmp_zarrfile) - lons = ds["lon"][:] - assert isinstance(lons.values[0, 0], np.float64) - -def test_write_dtypes_pfile(fieldset, tmp_zarrfile): +def test_write_dtypes_pfile(fieldset, tmp_parquet): dtypes = [ np.float32, np.float64, @@ -169,14 +137,13 @@ def test_write_dtypes_pfile(fieldset, tmp_zarrfile): MyParticle = Particle.add_variable(extra_vars) pset = ParticleSet(fieldset, pclass=MyParticle, lon=0, lat=0, time=fieldset.time_interval.left) - pfile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) + pfile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) pfile.write(pset, time=fieldset.time_interval.left) + pfile.close() - ds = xr.open_zarr( - tmp_zarrfile, mask_and_scale=False - ) # Note masking issue at https://stackoverflow.com/questions/68460507/xarray-loading-int-data-as-float + tab = pq.read_table(tmp_parquet) for d in dtypes: - assert ds[f"v_{d.__name__}"].dtype == d + assert tab[f"v_{d.__name__}"].type == pa.from_numpy_dtype(d) def test_variable_written_once(): @@ -187,7 +154,7 @@ def test_variable_written_once(): @pytest.mark.skip(reason="Pending ParticleFile refactor; see issue #2386") @pytest.mark.parametrize("dt", [-np.timedelta64(1, "s"), np.timedelta64(1, "s")]) @pytest.mark.parametrize("maxvar", [2, 4, 10]) -def test_pset_repeated_release_delayed_adding_deleting(fieldset, tmp_zarrfile, dt, maxvar): +def test_pset_repeated_release_delayed_adding_deleting(fieldset, tmp_parquet, dt, maxvar): """Tests that if particles are released and deleted based on age that resulting output file is correct.""" npart = 10 fieldset.add_constant("maxvar", maxvar) @@ -203,7 +170,7 @@ def test_pset_repeated_release_delayed_adding_deleting(fieldset, tmp_zarrfile, d pclass=MyParticle, time=fieldset.time_interval.left + [np.timedelta64(i + 1, "s") for i in range(npart)], ) - pfile = ParticleFile(tmp_zarrfile, outputdt=abs(dt), chunks=(1, 1)) + pfile = ParticleFile(tmp_parquet, outputdt=abs(dt)) def IncrLon(particles, fieldset): # pragma: no cover particles.sample_var += 1.0 @@ -216,19 +183,19 @@ def IncrLon(particles, fieldset): # pragma: no cover for _ in range(npart): pset.execute(IncrLon, dt=dt, runtime=np.timedelta64(1, "s"), output_file=pfile) - ds = xr.open_zarr(tmp_zarrfile) + ds = xr.open_zarr(tmp_parquet) samplevar = ds["sample_var"][:] assert samplevar.shape == (npart, min(maxvar, npart + 1)) # test whether samplevar[:, k] = k for k in range(samplevar.shape[1]): assert np.allclose([p for p in samplevar[:, k] if np.isfinite(p)], k + 1) - filesize = os.path.getsize(str(tmp_zarrfile)) + filesize = os.path.getsize(str(tmp_parquet)) assert filesize < 1024 * 65 # test that chunking leads to filesize less than 65KB -def test_file_warnings(fieldset, tmp_zarrfile): +def test_file_warnings(fieldset, tmp_parquet): pset = ParticleSet(fieldset, lon=[0, 0], lat=[0, 0], time=[np.timedelta64(0, "s"), np.timedelta64(1, "s")]) - pfile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(2, "s")) + pfile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(2, "s")) with pytest.warns(ParticleSetWarning, match="Some of the particles have a start time difference.*"): pset.execute(AdvectionRK4, runtime=3, dt=1, output_file=pfile) @@ -244,32 +211,33 @@ def test_file_warnings(fieldset, tmp_zarrfile): (-np.timedelta64(5, "s"), pytest.raises(ValueError)), ], ) -def test_outputdt_types(outputdt, expectation, tmp_zarrfile): +def test_outputdt_types(outputdt, expectation, tmp_parquet): with expectation: - pfile = ParticleFile(tmp_zarrfile, outputdt=outputdt) + pfile = ParticleFile(tmp_parquet, outputdt=outputdt) assert pfile.outputdt == timedelta_to_float(outputdt) -def test_write_timebackward(fieldset, tmp_zarrfile): +def test_write_timebackward(fieldset, tmp_parquet): release_time = fieldset.time_interval.left + [np.timedelta64(i + 1, "s") for i in range(3)] pset = ParticleSet(fieldset, lat=[0, 1, 2], lon=[0, 0, 0], time=release_time) - pfile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) + pfile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) pset.execute(DoNothing, runtime=np.timedelta64(3, "s"), dt=-np.timedelta64(1, "s"), output_file=pfile) - ds = xr.open_zarr(tmp_zarrfile) - trajs = ds["trajectory"][:] - - output_time = ds["time"][:].values + df = pd.read_parquet(tmp_parquet) - assert trajs.values.dtype == "int64" - assert np.all(np.diff(trajs.values) < 0) # all particles written in order of release - doutput_time = np.diff(output_time, axis=1) - assert np.all(doutput_time[~np.isnan(doutput_time)] < 0) # all times written in decreasing order + assert df["trajectory"].dtype == "int64" + assert bool( + df.groupby("trajectory") + .apply( + lambda x: (np.diff(x["time"]) < 0).all() # for each particle - set True if it has decreasing time + ) + .all() # ensure for all particles + ) @pytest.mark.xfail @pytest.mark.v4alpha -def test_write_xiyi(fieldset, tmp_zarrfile): +def test_write_xiyi(fieldset, tmp_parquet): fieldset.U.data[:] = 1 # set a non-zero zonal velocity fieldset.add_field( Field(name="P", data=np.zeros((3, 20)), lon=np.linspace(0, 1, 20), lat=[-2, 0, 2], interp_method=XLinear) @@ -300,10 +268,10 @@ def SampleP(particles, fieldset): # pragma: no cover _ = fieldset.P[particles] # To trigger sampling of the P field pset = ParticleSet(fieldset, pclass=XiYiParticle, lon=[0, 0.2], lat=[0.2, 1]) - pfile = ParticleFile(tmp_zarrfile, outputdt=dt) + pfile = ParticleFile(tmp_parquet, outputdt=dt) pset.execute([SampleP, Get_XiYi, AdvectionRK4], endtime=10 * dt, dt=dt, output_file=pfile) - ds = xr.open_zarr(tmp_zarrfile) + ds = xr.open_zarr(tmp_parquet) pxi0 = ds["pxi0"][:].values.astype(np.int32) pxi1 = ds["pxi1"][:].values.astype(np.int32) lons = ds["lon"][:].values @@ -323,7 +291,7 @@ def SampleP(particles, fieldset): # pragma: no cover @pytest.mark.parametrize("outputdt", [np.timedelta64(1, "s"), np.timedelta64(2, "s"), np.timedelta64(3, "s")]) -def test_time_is_age(fieldset, tmp_zarrfile, outputdt): +def test_time_is_age(fieldset, tmp_parquet, outputdt): # Test that particle age is same as time - initial_time npart = 10 @@ -334,11 +302,12 @@ def IncreaseAge(particles, fieldset): # pragma: no cover time = fieldset.time_interval.left + np.arange(npart) * np.timedelta64(1, "s") pset = ParticleSet(fieldset, pclass=AgeParticle, lon=npart * [0], lat=npart * [0], time=time) - ofile = ParticleFile(tmp_zarrfile, outputdt=outputdt) + ofile = ParticleFile(tmp_parquet, outputdt=outputdt) pset.execute(IncreaseAge, runtime=np.timedelta64(npart * 2, "s"), dt=np.timedelta64(1, "s"), output_file=ofile) - ds = xr.open_zarr(tmp_zarrfile) + pytest.skip("# TODO: Need to figure out how times work with parquet output (#2386)") + ds = xr.open_zarr(tmp_parquet) age = ds["age"][:].values.astype("timedelta64[s]") ds_timediff = np.zeros_like(age) for i in range(npart): @@ -346,7 +315,7 @@ def IncreaseAge(particles, fieldset): # pragma: no cover np.testing.assert_equal(age, ds_timediff) -def test_reset_dt(fieldset, tmp_zarrfile): +def test_reset_dt(fieldset, tmp_parquet): # Assert that p.dt gets reset when a write_time is not a multiple of dt # for p.dt=0.02 to reach outputdt=0.05 and endtime=0.1, the steps should be [0.2, 0.2, 0.1, 0.2, 0.2, 0.1], resulting in 6 kernel executions dt = np.timedelta64(20, "s") @@ -356,13 +325,13 @@ def Update_lon(particles, fieldset): # pragma: no cover particle = get_default_particle(np.float64) pset = ParticleSet(fieldset, pclass=particle, lon=[0], lat=[0]) - ofile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(50, "s")) + ofile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(50, "s")) pset.execute(Update_lon, runtime=5 * dt, dt=dt, output_file=ofile) assert np.allclose(pset.lon, 0.6) -def test_correct_misaligned_outputdt_dt(fieldset, tmp_zarrfile): +def test_correct_misaligned_outputdt_dt(fieldset, tmp_parquet): """Testing that outputdt does not need to be a multiple of dt.""" def Update_lon(particles, fieldset): # pragma: no cover @@ -370,12 +339,13 @@ def Update_lon(particles, fieldset): # pragma: no cover particle = get_default_particle(np.float64) pset = ParticleSet(fieldset, pclass=particle, lon=[0], lat=[0]) - ofile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(3, "s")) + ofile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(3, "s")) pset.execute(Update_lon, runtime=np.timedelta64(11, "s"), dt=np.timedelta64(2, "s"), output_file=ofile) - ds = xr.open_zarr(tmp_zarrfile) - assert np.allclose(ds.lon.values, [0, 3, 6, 9]) - assert np.allclose(timedelta_to_float(ds.time.values - ds.time.values[0, 0]), [0, 3, 6, 9]) + df = pd.read_parquet(tmp_parquet) + assert np.allclose(df["lon"].values, [0, 3, 6, 9]) + pytest.skip("# TODO: Need to figure out how times work with parquet output (#2386)") + assert np.allclose(timedelta_to_float(df.time.values - df.time.values[0, 0]), [0, 3, 6, 9]) def setup_pset_execute(*, fieldset: FieldSet, outputdt: timedelta, execute_kwargs, particle_class=Particle): @@ -389,13 +359,13 @@ def setup_pset_execute(*, fieldset: FieldSet, outputdt: timedelta, execute_kwarg ) with tempfile.TemporaryDirectory() as dir: - name = f"{dir}/test.zarr" + name = f"{dir}/tmp.parquet" output_file = ParticleFile(name, outputdt=outputdt) pset.execute(DoNothing, output_file=output_file, **execute_kwargs) - ds = xr.open_zarr(name).load() + df = pd.read_parquet(name) - return ds + return df def test_pset_execute_outputdt_forwards(fieldset): @@ -404,9 +374,9 @@ def test_pset_execute_outputdt_forwards(fieldset): runtime = timedelta(hours=5) dt = timedelta(minutes=5) - ds = setup_pset_execute(fieldset=fieldset, outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt)) - - assert np.all(ds.isel(trajectory=0).time.diff(dim="obs").values == np.timedelta64(outputdt)) + df = setup_pset_execute(fieldset=fieldset, outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt)) # noqa: F841 + pytest.skip("# TODO: Need to figure out how times work with parquet output (#2386)") + assert np.all(ds.isel(trajectory=0).time.diff(dim="obs").values == np.timedelta64(outputdt)) # noqa: F821 def test_pset_execute_output_time_forwards(fieldset): @@ -415,11 +385,11 @@ def test_pset_execute_output_time_forwards(fieldset): runtime = np.timedelta64(5, "h") dt = np.timedelta64(5, "m") - ds = setup_pset_execute(fieldset=fieldset, outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt)) - + df = setup_pset_execute(fieldset=fieldset, outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt)) # noqa: F841 + pytest.skip("# TODO: Need to figure out how times work with parquet output (#2386)") assert ( - ds.time[0, 0].values == fieldset.time_interval.left - and ds.time[0, -1].values == fieldset.time_interval.left + runtime + ds.time[0, 0].values == fieldset.time_interval.left # noqa: F821 + and ds.time[0, -1].values == fieldset.time_interval.left + runtime # noqa: F821 ) @@ -429,8 +399,9 @@ def test_pset_execute_outputdt_backwards(fieldset): runtime = timedelta(days=2) dt = -timedelta(minutes=5) - ds = setup_pset_execute(fieldset=fieldset, outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt)) - file_outputdt = ds.isel(trajectory=0).time.diff(dim="obs").values + df = setup_pset_execute(fieldset=fieldset, outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt)) # noqa: F841 + pytest.skip("# TODO: Need to figure out how times work with parquet output (#2386)") + file_outputdt = ds.isel(trajectory=0).time.diff(dim="obs").values # noqa: F821 assert np.all(file_outputdt == np.timedelta64(-outputdt)) @@ -448,61 +419,27 @@ def test_pset_execute_outputdt_backwards_fieldset_timevarying(): ds_fset = copernicusmarine_to_sgrid(fields=fields) fieldset = FieldSet.from_sgrid_conventions(ds_fset) - ds = setup_pset_execute(outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt), fieldset=fieldset) - file_outputdt = ds.isel(trajectory=0).time.diff(dim="obs").values + df = setup_pset_execute(outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt), fieldset=fieldset) # noqa: F841 + pytest.skip("# TODO: Need to figure out how times work with parquet output (#2386)") + file_outputdt = ds.isel(trajectory=0).time.diff(dim="obs").values # noqa: F821 assert np.all(file_outputdt == np.timedelta64(-outputdt)), (file_outputdt, np.timedelta64(-outputdt)) -def test_particlefile_init(tmp_store): - ParticleFile(tmp_store, outputdt=np.timedelta64(1, "s"), chunks=(1, 3)) +def test_particlefile_init(tmp_parquet): + ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) -@pytest.mark.parametrize("name", ["store", "outputdt", "chunks", "create_new_zarrfile"]) -def test_particlefile_readonly_attrs(tmp_store, name): - pfile = ParticleFile(tmp_store, outputdt=np.timedelta64(1, "s"), chunks=(1, 3)) +@pytest.mark.parametrize("name", ["path", "outputdt"]) +def test_particlefile_readonly_attrs(tmp_parquet, name): + pfile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) with pytest.raises(AttributeError, match="property .* of 'ParticleFile' object has no setter"): setattr(pfile, name, "something") -def test_particlefile_init_invalid(tmp_store): # TODO: Add test for read only store - with pytest.raises(ValueError, match="chunks must be a tuple"): - ParticleFile(tmp_store, outputdt=np.timedelta64(1, "s"), chunks=1) - - -def test_particlefile_write_particle_data(tmp_store): - nparticles = 100 - - pfile = ParticleFile(tmp_store, outputdt=np.timedelta64(1, "s"), chunks=(nparticles, 40)) - pclass = Particle - - left, right = np.datetime64("2019-05-30T12:00:00.000000000", "ns"), np.datetime64("2020-01-02", "ns") - time_interval = TimeInterval(left=left, right=right) - - initial_lon = np.linspace(0, 1, nparticles) - data = create_particle_data( - pclass=pclass, - nparticles=nparticles, - ngrids=4, - time_interval=time_interval, - initial={ - "time": np.full(nparticles, fill_value=0), - "lon": initial_lon, - "dt": np.full(nparticles, fill_value=1.0), - "trajectory": np.arange(nparticles), - }, - ) - np.testing.assert_array_equal(data["time"], 0) - pfile._write_particle_data( - particle_data=data, - pclass=pclass, - time_interval=time_interval, - time=left, - ) - ds = xr.open_zarr(tmp_store) - assert ds.time.dtype == "datetime64[ns]" - np.testing.assert_equal(ds["time"].isel(obs=0).values, left) - assert ds.sizes["trajectory"] == nparticles - np.testing.assert_allclose(ds["lon"].isel(obs=0).values, initial_lon) +def test_particlefile_init_invalid(tmp_path): + path = tmp_path / "file.not-parquet" + with pytest.raises(ValueError, match="file extension must be '.parquet'"): + ParticleFile(path, outputdt=np.timedelta64(1, "s")) def test_pfile_write_custom_particle(): @@ -514,19 +451,19 @@ def test_pfile_write_custom_particle(): @pytest.mark.xfail( reason="set_variable_write_status should be removed - with Particle writing defined on the particle level. GH2186" ) -def test_pfile_set_towrite_False(fieldset, tmp_zarrfile): +def test_pfile_set_towrite_False(fieldset, tmp_parquet): npart = 10 pset = ParticleSet(fieldset, pclass=Particle, lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart)) pset.set_variable_write_status("z", False) pset.set_variable_write_status("lat", False) - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1) + pfile = pset.ParticleFile(tmp_parquet, outputdt=1) def Update_lon(particles, fieldset): # pragma: no cover particles.dlon += 0.1 pset.execute(Update_lon, runtime=10, output_file=pfile) - ds = xr.open_zarr(tmp_zarrfile) + ds = xr.open_zarr(tmp_parquet) assert "time" in ds assert "z" not in ds assert "lat" not in ds @@ -535,3 +472,47 @@ def Update_lon(particles, fieldset): # pragma: no cover # For pytest purposes, we need to reset to original status pset.set_variable_write_status("z", True) pset.set_variable_write_status("lat", True) + + +@pytest.mark.parametrize( + "particle", + [ + Particle, + parcels.ParticleClass( + variables=[ + Variable( + "lon", + dtype=np.float32, + attrs={"standard_name": "longitude", "units": "degrees_east", "axis": "X"}, + ), + Variable( + "lat", + dtype=np.float32, + attrs={"standard_name": "latitude", "units": "degrees_north", "axis": "Y"}, + ), + Variable( + "z", + dtype=np.float32, + attrs={"standard_name": "vertical coordinate", "units": "m", "positive": "down"}, + ), + ] + ), + ], +) +def test_particle_schema(particle): + s = _get_schema(particle, {}) + + written_variables = [v for v in particle.variables if v.to_write] + + assert len(s.names) == len(written_variables), ( + "Number of particles in the output schema should be the same as the writable variables in the ParticleClass object." + ) + + for variable, pyarrow_field in zip( + written_variables, + s, + strict=False, + ): + assert variable.name == pyarrow_field.name + assert variable.attrs == {k.decode(): v.decode() for k, v in pyarrow_field.metadata.items()} + assert pa.from_numpy_dtype(variable.dtype) == pyarrow_field.type diff --git a/tests/test_uxadvection.py b/tests/test_uxadvection.py index 3f27536f8..d3db9aecd 100644 --- a/tests/test_uxadvection.py +++ b/tests/test_uxadvection.py @@ -1,6 +1,6 @@ import numpy as np +import pandas as pd import pytest -import xarray as xr import parcels from parcels._datasets.unstructured.generic import datasets as datasets_unstructured @@ -12,17 +12,17 @@ @pytest.mark.parametrize("integrator", [AdvectionEE, AdvectionRK2, AdvectionRK4]) -def test_ux_constant_flow_face_centered_2D(integrator, tmp_zarrfile): +def test_ux_constant_flow_face_centered_2D(integrator, tmp_parquet): ds = datasets_unstructured["ux_constant_flow_face_centered_2D"] T = np.timedelta64(3600, "s") dt = np.timedelta64(300, "s") fieldset = parcels.FieldSet.from_ugrid_conventions(ds, mesh="flat") pset = parcels.ParticleSet(fieldset, lon=[5.0], lat=[5.0]) - pfile = parcels.ParticleFile(store=tmp_zarrfile, outputdt=dt) + pfile = parcels.ParticleFile(path=tmp_parquet, outputdt=dt) pset.execute(integrator, runtime=T, dt=dt, output_file=pfile, verbose_progress=False) expected_lon = 8.6 np.testing.assert_allclose(pset.lon, expected_lon, atol=1e-5) - ds_out = xr.open_zarr(tmp_zarrfile) - np.testing.assert_allclose(ds_out["lon"][:, -1], expected_lon, atol=1e-5) + df = pd.read_parquet(tmp_parquet) + np.testing.assert_allclose(df["lon"].iloc[-1], expected_lon, atol=1e-5)