Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 6 additions & 46 deletions parcels/fieldset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import functools

import numpy as np
import uxarray as ux
import xarray as xr

from parcels._typing import Mesh
Expand Down Expand Up @@ -41,30 +40,11 @@

"""

def __init__(self, datasets: list[Field | VectorField]):
self.datasets = datasets
def __init__(self, fields: list[Field | VectorField]):
# TODO Nick : Enforce fields to be list of Field or VectorField objects
self.fields = {f.name: f for f in fields}

self._fieldnames = []
time_origin = None
# Create pointers to each (Ux)DataArray
for ds in datasets:
for field in ds.data_vars:
if type(ds[field]) is ux.UxDataArray:
self.add_field(Field(field, ds[field], grid=ds[field].uxgrid), field)
else:
self.add_field(Field(field, ds[field]), field)
self._fieldnames.append(field)

if "time" in ds.coords:
if time_origin is None:
time_origin = ds.time.min().data
else:
time_origin = min(time_origin, ds.time.min().data)
else:
time_origin = 0.0

self.time_origin = time_origin
self._add_UVfield()
# TODO : Nick : Add _getattr_ magic method to allow access to fields by name

@property
def time_interval(self):
Expand Down Expand Up @@ -111,7 +91,7 @@

@property
def gridset_size(self):
return len(self._fieldnames)
return len(self.fields)

Check warning on line 94 in parcels/fieldset.py

View check run for this annotation

Codecov / codecov/patch

parcels/fieldset.py#L94

Added line #L94 was not covered by tests

def add_field(self, field: Field, name: str | None = None):
"""Add a :class:`parcels.field.Field` object to the FieldSet.
Expand Down Expand Up @@ -140,8 +120,7 @@
if hasattr(self, name): # check if Field with same name already exists when adding new Field
raise RuntimeError(f"FieldSet already has a Field with name '{name}'")
else:
setattr(self, name, field)
Comment thread
VeckoTheGecko marked this conversation as resolved.
self._fieldnames.append(name)
self.fields[name] = field

Check warning on line 123 in parcels/fieldset.py

View check run for this annotation

Codecov / codecov/patch

parcels/fieldset.py#L123

Added line #L123 was not covered by tests

def add_constant_field(self, name: str, value, mesh: Mesh = "flat"):
"""Wrapper function to add a Field that is constant in space,
Expand Down Expand Up @@ -179,19 +158,6 @@
)
)

def add_vector_field(self, vfield):
"""Add a :class:`parcels.field.VectorField` object to the FieldSet.

Parameters
----------
vfield : parcels.VectorField
class:`parcels.FieldSet.VectorField` object to be added
"""
setattr(self, vfield.name, vfield)
for v in vfield.__dict__.values():
if isinstance(v, Field) and (v not in self.get_fields()):
self.add_field(v)

def get_fields(self) -> list[Field | VectorField]:
"""Returns a list of all the :class:`parcels.field.Field` and :class:`parcels.field.VectorField`
objects associated with this FieldSet.
Expand All @@ -203,12 +169,6 @@
fields.append(v)
return fields

def _add_UVfield(self):
if not hasattr(self, "UV") and hasattr(self, "U") and hasattr(self, "V"):
self.add_vector_field(VectorField("UV", self.U, self.V))
if not hasattr(self, "UVW") and hasattr(self, "W"):
self.add_vector_field(VectorField("UVW", self.U, self.V, self.W))

def add_constant(self, name, value):
"""Add a constant to the FieldSet. Note that all constants are
stored as 32-bit floats.
Expand Down
4 changes: 1 addition & 3 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,6 @@
self._interaction_kernel = None

self.fieldset = fieldset
self.fieldset._check_complete()
self.time_origin = fieldset.time_origin
self._pclass = pclass

# ==== first: create a new subclass of the pclass that includes the required variables ==== #
Expand Down Expand Up @@ -962,7 +960,7 @@
if runtime is not None and endtime is not None:
raise RuntimeError("Only one of (endtime, runtime) can be specified")

mintime, maxtime = self.fieldset.dimrange("time")
mintime, maxtime = self.fieldset.dimrange("time") # TODO : change to fieldset.time_interval

Check warning on line 963 in parcels/particleset.py

View check run for this annotation

Codecov / codecov/patch

parcels/particleset.py#L963

Added line #L963 was not covered by tests

default_release_time = mintime if dt >= 0 else maxtime
if np.any(np.isnan(self.particledata.data["time"])):
Expand Down
66 changes: 49 additions & 17 deletions tests/v4/test_uxarray_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import uxarray as ux

from parcels import (
Field,
FieldSet,
Particle,
ParticleSet,
UXPiecewiseConstantFace,
UXPiecewiseLinearNode,
VectorField,
download_example_dataset,
)

Expand All @@ -26,33 +28,63 @@
return ds


def test_fesom_fieldset(ds_fesom_channel):
fieldset = FieldSet([ds_fesom_channel])
@pytest.fixture
def uv_fesom_channel(ds_fesom_channel) -> VectorField:
UV = VectorField(
name="UV",
U=Field(name="U", data=ds_fesom_channel.U, grid=ds_fesom_channel.uxgrid, interp_method=UXPiecewiseConstantFace),
V=Field(name="V", data=ds_fesom_channel.V, grid=ds_fesom_channel.uxgrid, interp_method=UXPiecewiseConstantFace),
)
return UV


@pytest.fixture
def uvw_fesom_channel(ds_fesom_channel) -> VectorField:
UVW = VectorField(

Check warning on line 43 in tests/v4/test_uxarray_fieldset.py

View check run for this annotation

Codecov / codecov/patch

tests/v4/test_uxarray_fieldset.py#L43

Added line #L43 was not covered by tests
name="UVW",
U=Field(name="U", data=ds_fesom_channel.U, grid=ds_fesom_channel.uxgrid, interp_method=UXPiecewiseConstantFace),
V=Field(name="V", data=ds_fesom_channel.V, grid=ds_fesom_channel.uxgrid, interp_method=UXPiecewiseConstantFace),
W=Field(name="W", data=ds_fesom_channel.W, grid=ds_fesom_channel.uxgrid, interp_method=UXPiecewiseLinearNode),
)
return UVW

Check warning on line 49 in tests/v4/test_uxarray_fieldset.py

View check run for this annotation

Codecov / codecov/patch

tests/v4/test_uxarray_fieldset.py#L49

Added line #L49 was not covered by tests


def test_fesom_fieldset(ds_fesom_channel, uv_fesom_channel):
fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V])
# Check that the fieldset has the expected properties
assert fieldset.datasets[0] == ds_fesom_channel
assert (fieldset.fields["U"] == ds_fesom_channel.U).all()
assert (fieldset.fields["V"] == ds_fesom_channel.V).all()


def test_fesom_in_particleset(ds_fesom_channel):
fieldset = FieldSet([ds_fesom_channel])
@pytest.mark.skip(reason="ParticleSet.__init__ needs major refactoring")
def test_fesom_in_particleset(ds_fesom_channel, uv_fesom_channel):
fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V])

Check warning on line 61 in tests/v4/test_uxarray_fieldset.py

View check run for this annotation

Codecov / codecov/patch

tests/v4/test_uxarray_fieldset.py#L61

Added line #L61 was not covered by tests
# Check that the fieldset has the expected properties
assert fieldset.datasets[0] == ds_fesom_channel
assert (fieldset.fields["U"] == ds_fesom_channel.U).all()
assert (fieldset.fields["V"] == ds_fesom_channel.V).all()

Check warning on line 64 in tests/v4/test_uxarray_fieldset.py

View check run for this annotation

Codecov / codecov/patch

tests/v4/test_uxarray_fieldset.py#L63-L64

Added lines #L63 - L64 were not covered by tests
pset = ParticleSet(fieldset, pclass=Particle)
assert pset.fieldset == fieldset


def test_set_interp_methods(ds_fesom_channel):
fieldset = FieldSet([ds_fesom_channel])
def test_set_interp_methods(ds_fesom_channel, uv_fesom_channel):
fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V])
# Check that the fieldset has the expected properties
assert (fieldset.fields["U"] == ds_fesom_channel.U).all()
assert (fieldset.fields["V"] == ds_fesom_channel.V).all()

# Set the interpolation method for each field
fieldset.U.interp_method = UXPiecewiseConstantFace
fieldset.V.interp_method = UXPiecewiseConstantFace
fieldset.W.interp_method = UXPiecewiseLinearNode
fieldset.fields["U"].interp_method = UXPiecewiseConstantFace
fieldset.fields["V"].interp_method = UXPiecewiseConstantFace


def test_fesom_channel(ds_fesom_channel):
fieldset = FieldSet([ds_fesom_channel])
# Set the interpolation method for each field
fieldset.U.interp_method = UXPiecewiseConstantFace
fieldset.V.interp_method = UXPiecewiseConstantFace
fieldset.W.interp_method = UXPiecewiseLinearNode
@pytest.mark.skip(reason="ParticleSet.__init__ needs major refactoring")
def test_fesom_channel(ds_fesom_channel, uvw_fesom_channel):
fieldset = FieldSet([uvw_fesom_channel, uvw_fesom_channel.U, uvw_fesom_channel.V, uvw_fesom_channel.W])

Check warning on line 82 in tests/v4/test_uxarray_fieldset.py

View check run for this annotation

Codecov / codecov/patch

tests/v4/test_uxarray_fieldset.py#L82

Added line #L82 was not covered by tests

# Check that the fieldset has the expected properties
assert (fieldset.fields["U"] == ds_fesom_channel.U).all()
assert (fieldset.fields["V"] == ds_fesom_channel.V).all()
assert (fieldset.fields["W"] == ds_fesom_channel.W).all()

Check warning on line 87 in tests/v4/test_uxarray_fieldset.py

View check run for this annotation

Codecov / codecov/patch

tests/v4/test_uxarray_fieldset.py#L85-L87

Added lines #L85 - L87 were not covered by tests

pset = ParticleSet(fieldset, pclass=Particle)
pset.execute(endtime=timedelta(days=1), dt=timedelta(hours=1))
Loading