Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 16 additions & 28 deletions parcels/fieldset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import functools

import numpy as np
import uxarray as ux
import xarray as xr

from parcels._typing import Mesh
Expand Down Expand Up @@ -41,30 +40,16 @@ class FieldSet:

"""

def __init__(self, datasets: list[Field | VectorField]):
self.datasets = datasets
def __init__(self, fields: list[Field | VectorField]):
# TODO Nick : Enforce fields to be list of Field or VectorField objects
self.fields = {f.name: f for f in fields}

self._fieldnames = []
time_origin = None
# Create pointers to each (Ux)DataArray
for ds in datasets:
for field in ds.data_vars:
if type(ds[field]) is ux.UxDataArray:
self.add_field(Field(field, ds[field], grid=ds[field].uxgrid), field)
else:
self.add_field(Field(field, ds[field]), field)
self._fieldnames.append(field)

if "time" in ds.coords:
if time_origin is None:
time_origin = ds.time.min().data
else:
time_origin = min(time_origin, ds.time.min().data)
else:
time_origin = 0.0

self.time_origin = time_origin
self._add_UVfield()
# Add components of vector fields as individual fields
for field in fields:
if isinstance(field, VectorField):
self.add_vector_field(field)

# TODO : Nick : Add _getattr_ magic method to allow access to fields by name

@property
def time_interval(self):
Expand Down Expand Up @@ -111,7 +96,7 @@ def dimrange(self, dim):

@property
def gridset_size(self):
return len(self._fieldnames)
return len(self.fields)

def add_field(self, field: Field, name: str | None = None):
"""Add a :class:`parcels.field.Field` object to the FieldSet.
Expand Down Expand Up @@ -140,8 +125,7 @@ def add_field(self, field: Field, name: str | None = None):
if hasattr(self, name): # check if Field with same name already exists when adding new Field
raise RuntimeError(f"FieldSet already has a Field with name '{name}'")
else:
setattr(self, name, field)
Comment thread
VeckoTheGecko marked this conversation as resolved.
self._fieldnames.append(name)
self.fields[name] = field

def add_constant_field(self, name: str, value, mesh: Mesh = "flat"):
"""Wrapper function to add a Field that is constant in space,
Expand Down Expand Up @@ -187,7 +171,11 @@ def add_vector_field(self, vfield):
vfield : parcels.VectorField
class:`parcels.FieldSet.VectorField` object to be added
"""
setattr(self, vfield.name, vfield)
# If the vector field is not already in the fieldset, add it
if vfield.name not in self.fields.keys():
self.fields[vfield.name] = vfield

# Add the vector field components as fields to the fieldset
for v in vfield.__dict__.values():
if isinstance(v, Field) and (v not in self.get_fields()):
self.add_field(v)
Expand Down
4 changes: 1 addition & 3 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,6 @@ def __init__(
self._interaction_kernel = None

self.fieldset = fieldset
self.fieldset._check_complete()
self.time_origin = fieldset.time_origin
self._pclass = pclass

# ==== first: create a new subclass of the pclass that includes the required variables ==== #
Expand Down Expand Up @@ -962,7 +960,7 @@ def execute(
if runtime is not None and endtime is not None:
raise RuntimeError("Only one of (endtime, runtime) can be specified")

mintime, maxtime = self.fieldset.dimrange("time")
mintime, maxtime = self.fieldset.dimrange("time") # TODO : change to fieldset.time_interval

default_release_time = mintime if dt >= 0 else maxtime
if np.any(np.isnan(self.particledata.data["time"])):
Expand Down
64 changes: 47 additions & 17 deletions tests/v4/test_uxarray_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import uxarray as ux

from parcels import (
Field,
FieldSet,
Particle,
ParticleSet,
UXPiecewiseConstantFace,
UXPiecewiseLinearNode,
VectorField,
download_example_dataset,
)

Expand All @@ -26,33 +28,61 @@ def ds_fesom_channel() -> ux.UxDataset:
return ds


def test_fesom_fieldset(ds_fesom_channel):
fieldset = FieldSet([ds_fesom_channel])
@pytest.fixture
def uv_fesom_channel(ds_fesom_channel) -> VectorField:
UV = VectorField(
name="UV",
U=Field(name="U", data=ds_fesom_channel.U, grid=ds_fesom_channel.uxgrid, interp_method=UXPiecewiseConstantFace),
V=Field(name="V", data=ds_fesom_channel.V, grid=ds_fesom_channel.uxgrid, interp_method=UXPiecewiseConstantFace),
)
return UV


@pytest.fixture
def uvw_fesom_channel(ds_fesom_channel) -> VectorField:
UVW = VectorField(
name="UVW",
U=Field(name="U", data=ds_fesom_channel.U, grid=ds_fesom_channel.uxgrid, interp_method=UXPiecewiseConstantFace),
V=Field(name="V", data=ds_fesom_channel.V, grid=ds_fesom_channel.uxgrid, interp_method=UXPiecewiseConstantFace),
W=Field(name="W", data=ds_fesom_channel.W, grid=ds_fesom_channel.uxgrid, interp_method=UXPiecewiseLinearNode),
)
return UVW


def test_fesom_fieldset(ds_fesom_channel, uv_fesom_channel):
fieldset = FieldSet([uv_fesom_channel])
# Check that the fieldset has the expected properties
assert fieldset.datasets[0] == ds_fesom_channel
assert (fieldset.fields["U"] == ds_fesom_channel.U).all()
assert (fieldset.fields["V"] == ds_fesom_channel.V).all()


def test_fesom_in_particleset(ds_fesom_channel):
fieldset = FieldSet([ds_fesom_channel])
def test_fesom_in_particleset(ds_fesom_channel, uv_fesom_channel):
fieldset = FieldSet([uv_fesom_channel])
# Check that the fieldset has the expected properties
assert fieldset.datasets[0] == ds_fesom_channel
assert (fieldset.fields["U"] == ds_fesom_channel.U).all()
assert (fieldset.fields["V"] == ds_fesom_channel.V).all()
pset = ParticleSet(fieldset, pclass=Particle)
assert pset.fieldset == fieldset


def test_set_interp_methods(ds_fesom_channel):
fieldset = FieldSet([ds_fesom_channel])
def test_set_interp_methods(ds_fesom_channel, uv_fesom_channel):
fieldset = FieldSet([uv_fesom_channel])
# Check that the fieldset has the expected properties
assert (fieldset.fields["U"] == ds_fesom_channel.U).all()
assert (fieldset.fields["V"] == ds_fesom_channel.V).all()

# Set the interpolation method for each field
fieldset.U.interp_method = UXPiecewiseConstantFace
fieldset.V.interp_method = UXPiecewiseConstantFace
fieldset.W.interp_method = UXPiecewiseLinearNode
fieldset.fields["U"].interp_method = UXPiecewiseConstantFace
fieldset.fields["V"].interp_method = UXPiecewiseConstantFace


def test_fesom_channel(ds_fesom_channel):
fieldset = FieldSet([ds_fesom_channel])
# Set the interpolation method for each field
fieldset.U.interp_method = UXPiecewiseConstantFace
fieldset.V.interp_method = UXPiecewiseConstantFace
fieldset.W.interp_method = UXPiecewiseLinearNode
def test_fesom_channel(ds_fesom_channel, uvw_fesom_channel):
fieldset = FieldSet([uvw_fesom_channel])

# Check that the fieldset has the expected properties
assert (fieldset.fields["U"] == ds_fesom_channel.U).all()
assert (fieldset.fields["V"] == ds_fesom_channel.V).all()
assert (fieldset.fields["W"] == ds_fesom_channel.W).all()

pset = ParticleSet(fieldset, pclass=Particle)
pset.execute(endtime=timedelta(days=1), dt=timedelta(hours=1))
Loading