Skip to content

Commit 82fcb0c

Browse files
Merge branch 'v4-dev' into remove-time-extrapolation-option
2 parents d2417f7 + cc49c60 commit 82fcb0c

13 files changed

Lines changed: 296 additions & 141 deletions

File tree

parcels/_core/utils/time.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import TypeVar
55

66
import cftime
7+
import numpy as np
78

89
T = TypeVar("T", datetime, cftime.datetime)
910

@@ -24,10 +25,10 @@ class TimeInterval:
2425
"""
2526

2627
def __init__(self, left: T, right: T) -> None:
27-
if not isinstance(left, (datetime, cftime.datetime)):
28-
raise ValueError(f"Expected left to be a datetime or cftime.datetime, got {type(left)}.")
29-
if not isinstance(right, (datetime, cftime.datetime)):
30-
raise ValueError(f"Expected right to be a datetime or cftime.datetime, got {type(right)}.")
28+
if not isinstance(left, (datetime, cftime.datetime, np.datetime64)):
29+
raise ValueError(f"Expected right to be a datetime, cftime.datetime, or np.datetime64. Got {type(left)}.")
30+
if not isinstance(right, (datetime, cftime.datetime, np.datetime64)):
31+
raise ValueError(f"Expected right to be a datetime, cftime.datetime, or np.datetime64. Got {type(right)}.")
3132
if left >= right:
3233
raise ValueError(f"Expected left to be strictly less than right, got left={left} and right={right}.")
3334
if not is_compatible(left, right):

parcels/_datasets/structured/generic.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
__all__ = ["N", "T", "datasets"]
77

88
N = 30
9-
T = 10
9+
T = 13
1010

1111

1212
def _rotated_curvilinear_grid():
@@ -22,8 +22,12 @@ def _rotated_curvilinear_grid():
2222

2323
return xr.Dataset(
2424
{
25-
"data_g": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
26-
"data_c": (["ZC", "YC", "XC"], np.random.rand(3 * N, 2 * N, N)),
25+
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
26+
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, 3 * N, 2 * N, N)),
27+
"U (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
28+
"V (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
29+
"U (C grid)": (["time", "ZG", "YC", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
30+
"V (C grid)": (["time", "ZG", "YG", "XC"], np.random.rand(T, 3 * N, 2 * N, N)),
2731
},
2832
coords={
2933
"XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}),
@@ -41,7 +45,7 @@ def _rotated_curvilinear_grid():
4145
{"axis": "Z"},
4246
),
4347
"depth": (["ZG"], np.arange(3 * N), {"axis": "Z"}),
44-
"time": (["time"], np.arange(T), {"axis": "T"}),
48+
"time": (["time"], xr.date_range("2000", "2001", T), {"axis": "T"}),
4549
"lon": (
4650
["YG", "XG"],
4751
LON,
@@ -93,8 +97,12 @@ def _unrolled_cone_curvilinear_grid():
9397

9498
return xr.Dataset(
9599
{
96-
"data_g": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
97-
"data_c": (["ZC", "YC", "XC"], np.random.rand(3 * N, 2 * N, N)),
100+
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
101+
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, 3 * N, 2 * N, N)),
102+
"U (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
103+
"V (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
104+
"U (C grid)": (["time", "ZG", "YC", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
105+
"V (C grid)": (["time", "ZG", "YG", "XC"], np.random.rand(T, 3 * N, 2 * N, N)),
98106
},
99107
coords={
100108
"XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}),
@@ -112,7 +120,7 @@ def _unrolled_cone_curvilinear_grid():
112120
{"axis": "Z"},
113121
),
114122
"depth": (["ZG"], np.arange(3 * N), {"axis": "Z"}),
115-
"time": (["time"], np.arange(T), {"axis": "T"}),
123+
"time": (["time"], xr.date_range("2000", "2001", T), {"axis": "T"}),
116124
"lon": (
117125
["YG", "XG"],
118126
LON,
@@ -133,6 +141,10 @@ def _unrolled_cone_curvilinear_grid():
133141
{
134142
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
135143
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, 3 * N, 2 * N, N)),
144+
"U (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
145+
"V (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
146+
"U (C grid)": (["time", "ZG", "YC", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
147+
"V (C grid)": (["time", "ZG", "YG", "XC"], np.random.rand(T, 3 * N, 2 * N, N)),
136148
},
137149
coords={
138150
"XG": (
@@ -164,7 +176,7 @@ def _unrolled_cone_curvilinear_grid():
164176
"lon": (["XG"], 2 * np.pi / N * np.arange(0, N)),
165177
"lat": (["YG"], 2 * np.pi / (2 * N) * np.arange(0, 2 * N)),
166178
"depth": (["ZG"], np.arange(3 * N)),
167-
"time": (["time"], np.arange(T), {"axis": "T"}),
179+
"time": (["time"], xr.date_range("2000", "2001", T), {"axis": "T"}),
168180
},
169181
),
170182
"2d_left_unrolled_cone": _unrolled_cone_curvilinear_grid(),

parcels/_reprs.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""Parcels reprs"""
2+
3+
from __future__ import annotations
4+
5+
import textwrap
6+
from typing import TYPE_CHECKING, Any
7+
8+
if TYPE_CHECKING:
9+
from parcels import Field, FieldSet, ParticleSet
10+
11+
12+
def field_repr(field: Field) -> str:
13+
"""Return a pretty repr for Field"""
14+
out = f"""<{type(field).__name__}>
15+
name : {field.name!r}
16+
data : {field.data!r}
17+
extrapolate time: {field.allow_time_extrapolation!r}
18+
"""
19+
return textwrap.dedent(out).strip()
20+
21+
22+
def _format_list_items_multiline(items: list[str], level: int = 1) -> str:
23+
"""Given a list of strings, formats them across multiple lines.
24+
25+
Uses indentation levels of 4 spaces provided by ``level``.
26+
27+
Example
28+
-------
29+
>>> output = _format_list_items_multiline(["item1", "item2", "item3"], 4)
30+
>>> f"my_items: {output}"
31+
my_items: [
32+
item1,
33+
item2,
34+
item3,
35+
]
36+
"""
37+
if len(items) == 0:
38+
return "[]"
39+
40+
assert level >= 1, "Indentation level >=1 supported"
41+
indentation_str = level * 4 * " "
42+
indentation_str_end = (level - 1) * 4 * " "
43+
44+
items_str = ",\n".join([textwrap.indent(i, indentation_str) for i in items])
45+
return f"[\n{items_str}\n{indentation_str_end}]"
46+
47+
48+
def particleset_repr(pset: ParticleSet) -> str:
49+
"""Return a pretty repr for ParticleSet"""
50+
if len(pset) < 10:
51+
particles = [repr(p) for p in pset]
52+
else:
53+
particles = [repr(pset[i]) for i in range(7)] + ["..."]
54+
55+
out = f"""<{type(pset).__name__}>
56+
fieldset :
57+
{textwrap.indent(repr(pset.fieldset), " " * 8)}
58+
pclass : {pset.pclass}
59+
repeatdt : {pset.repeatdt}
60+
# particles: {len(pset)}
61+
particles : {_format_list_items_multiline(particles, level=2)}
62+
"""
63+
return textwrap.dedent(out).strip()
64+
65+
66+
def fieldset_repr(fieldset: FieldSet) -> str:
67+
"""Return a pretty repr for FieldSet"""
68+
fields_repr = "\n".join([repr(f) for f in fieldset.get_fields()])
69+
70+
out = f"""<{type(fieldset).__name__}>
71+
fields:
72+
{textwrap.indent(fields_repr, 8 * " ")}
73+
"""
74+
return textwrap.dedent(out).strip()
75+
76+
77+
def default_repr(obj: Any):
78+
if is_builtin_object(obj):
79+
return repr(obj)
80+
return object.__repr__(obj)
81+
82+
83+
def is_builtin_object(obj):
84+
return obj.__class__.__module__ == "builtins"

parcels/field.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
1+
from __future__ import annotations
2+
13
import inspect
24
import warnings
35
from collections.abc import Callable
46
from datetime import datetime
57
from enum import IntEnum
6-
from typing import TYPE_CHECKING
78

89
import numpy as np
910
import uxarray as ux
1011
import xarray as xr
1112
from uxarray.grid.neighbors import _barycentric_coordinates
1213

14+
from parcels._core.utils.time import TimeInterval
1315
from parcels._core.utils.unstructured import get_vertical_location_from_dims
16+
from parcels._reprs import default_repr, field_repr
1417
from parcels._typing import (
1518
Mesh,
1619
VectorType,
1720
assert_valid_mesh,
1821
)
19-
from parcels.tools._helpers import default_repr, field_repr
2022
from parcels.tools.converters import (
2123
UnitConverter,
2224
unitconverters_map,
@@ -33,9 +35,6 @@
3335

3436
from ._index_search import _search_indices_rectilinear, _search_time_index
3537

36-
if TYPE_CHECKING:
37-
pass
38-
3938
__all__ = ["Field", "GridType", "VectorField"]
4039

4140

@@ -166,6 +165,7 @@ def __init__(
166165
self.name = name
167166
self.data = data
168167
self.grid = grid
168+
self.time_interval = get_time_interval(data)
169169

170170
# For compatibility with parts of the codebase that rely on v3 definition of Grid.
171171
# Should be worked to be removed in v4
@@ -184,7 +184,6 @@ def __init__(
184184
e.add_note(f"Error validating field {name!r}.")
185185
raise e
186186

187-
self._parent_mesh = data.attrs["mesh"]
188187
self._mesh_type = mesh_type
189188

190189
# Setting the interpolation method dynamically
@@ -665,3 +664,10 @@ def _assert_compatible_combination(data: xr.DataArray | ux.UxDataArray, grid: ux
665664
raise ValueError(
666665
f"Incompatible data-grid combination. Data is a xarray.DataArray, expected `grid` to be a parcels Grid object, got {type(grid)}."
667666
)
667+
668+
669+
def get_time_interval(data: xr.DataArray | ux.UxDataArray) -> TimeInterval | None:
670+
if "time" not in data.dims:
671+
return None
672+
673+
return TimeInterval(data.time.values[0], data.time.values[-1])

parcels/fieldset.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import numpy as np
44
import xarray as xr
55

6+
from parcels._reprs import fieldset_repr
67
from parcels._typing import Mesh
78
from parcels.field import Field, VectorField
8-
from parcels.tools._helpers import fieldset_repr
9+
from parcels.v4.grid import Grid
910

1011
__all__ = ["FieldSet"]
1112

@@ -41,10 +42,21 @@ class FieldSet:
4142
"""
4243

4344
def __init__(self, fields: list[Field | VectorField]):
44-
# TODO Nick : Enforce fields to be list of Field or VectorField objects
45+
for field in fields:
46+
if not isinstance(field, (Field, VectorField)):
47+
raise ValueError(f"Expected `field` to be a Field or VectorField object. Got {field}")
48+
4549
self.fields = {f.name: f for f in fields}
50+
self.constants = {}
4651

47-
# TODO : Nick : Add _getattr_ magic method to allow access to fields by name
52+
def __getattr__(self, name):
53+
"""Get the field by name. If the field is not found, check if it's a constant."""
54+
if name in self.fields:
55+
return self.fields[name]
56+
elif name in self.constants:
57+
return self.constants[name]
58+
else:
59+
raise AttributeError(f"FieldSet has no attribute '{name}'")
4860

4961
@property
5062
def time_interval(self):
@@ -112,6 +124,9 @@ def add_field(self, field: Field, name: str | None = None):
112124
* `Unit converters <../examples/tutorial_unitconverters.ipynb>`__ (Default value = None)
113125
114126
"""
127+
if not isinstance(field, (Field, VectorField)):
128+
raise ValueError(f"Expected `field` to be a Field or VectorField object. Got {type(field)}")
129+
115130
name = field.name if name is None else name
116131

117132
if name in self.fields:
@@ -137,19 +152,24 @@ def add_constant_field(self, name: str, value, mesh: Mesh = "flat"):
137152
correction for zonal velocity U near the poles.
138153
2. flat: No conversion, lat/lon are assumed to be in m.
139154
"""
140-
time = 0.0
141-
values = np.full((1, 1, 1, 1), value)
142-
data = xr.DataArray(
143-
data=values,
144-
name=name,
145-
dims="null",
146-
coords=[time, [0], [0], [0]],
147-
attrs=dict(description="null", units="null", location="node", mesh="constant", mesh_type=mesh),
155+
da = xr.DataArray(
156+
data=np.full((1, 1, 1, 1), value),
157+
dims=["T", "ZG", "YG", "XG"],
158+
coords={
159+
"ZG": (["ZG"], np.arange(1), {"axis": "Z"}),
160+
"YG": (["YG"], np.arange(1), {"axis": "Y"}),
161+
"XG": (["XG"], np.arange(1), {"axis": "X"}),
162+
"lon": (["XG"], np.arange(1), {"axis": "X"}),
163+
"lat": (["YG"], np.arange(1), {"axis": "Y"}),
164+
"depth": (["ZG"], np.arange(1), {"axis": "Z"}),
165+
},
148166
)
167+
grid = Grid(da)
149168
self.add_field(
150169
Field(
151170
name,
152-
data,
171+
da,
172+
grid,
153173
interp_method=None, # TODO : Need to define an interpolation method for constants
154174
)
155175
)
@@ -184,7 +204,10 @@ def add_constant(self, name, value):
184204
`Diffusion <../examples/tutorial_diffusion.ipynb>`__
185205
`Periodic boundaries <../examples/tutorial_periodic_boundaries.ipynb>`__
186206
"""
187-
setattr(self, name, value)
207+
if name in self.constants:
208+
raise ValueError(f"FieldSet already has a constant with name '{name}'")
209+
210+
self.constants[name] = np.float32(value)
188211

189212
# def computeTimeChunk(self, time=0.0, dt=1):
190213
# """Load a chunk of three data time steps into the FieldSet.

parcels/particlefile.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111
import parcels
1212
from parcels._compat import MPI
13-
from parcels.tools._helpers import default_repr, timedelta_to_float
13+
from parcels._reprs import default_repr
14+
from parcels.tools._helpers import timedelta_to_float
1415
from parcels.tools.warnings import FileWarning
1516

1617
__all__ = ["ParticleFile"]

parcels/particleset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from tqdm import tqdm
1212

1313
from parcels._compat import MPI
14+
from parcels._reprs import particleset_repr
1415
from parcels.application_kernels.advection import AdvectionRK4
1516
from parcels.field import Field
1617
from parcels.grid import GridType
@@ -25,7 +26,7 @@
2526
from parcels.particle import Particle, Variable
2627
from parcels.particledata import ParticleData, ParticleDataIterator
2728
from parcels.particlefile import ParticleFile
28-
from parcels.tools._helpers import particleset_repr, timedelta_to_float
29+
from parcels.tools._helpers import timedelta_to_float
2930
from parcels.tools.converters import _get_cftime_calendars, convert_to_flat_array
3031
from parcels.tools.loggers import logger
3132
from parcels.tools.statuscodes import StatusCode

0 commit comments

Comments
 (0)