Skip to content

Commit c973693

Browse files
Merge remote-tracking branch 'origin/v4-dev' into 2031-uxgrid-vertical-search
2 parents ddfa6ba + b251d87 commit c973693

13 files changed

Lines changed: 39 additions & 202 deletions

File tree

docs/examples/example_globcurrent.py

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -30,28 +30,6 @@ def set_globcurrent_fieldset(
3030
)
3131

3232

33-
@pytest.mark.v4remove(
34-
reason="indices keryword is not supported in v4. Subsetting should be done on xarray level."
35-
)
36-
@pytest.mark.parametrize(
37-
"use_xarray", [True, pytest.param(False, marks=pytest.mark.xfail)]
38-
)
39-
def test_globcurrent_fieldset(use_xarray):
40-
fieldset = set_globcurrent_fieldset()
41-
assert fieldset.U.lon.size == 81
42-
assert fieldset.U.lat.size == 41
43-
assert fieldset.V.lon.size == 81
44-
assert fieldset.V.lat.size == 41
45-
46-
if not use_xarray:
47-
indices = {"lon": [5], "lat": range(20, 30)}
48-
fieldsetsub = set_globcurrent_fieldset(indices=indices)
49-
assert np.allclose(fieldsetsub.U.lon, fieldset.U.lon[indices["lon"]])
50-
assert np.allclose(fieldsetsub.U.lat, fieldset.U.lat[indices["lat"]])
51-
assert np.allclose(fieldsetsub.V.lon, fieldset.V.lon[indices["lon"]])
52-
assert np.allclose(fieldsetsub.V.lat, fieldset.V.lat[indices["lat"]])
53-
54-
5533
@pytest.mark.parametrize(
5634
"dt, lonstart, latstart", [(3600.0, 25, -35), (-3600.0, 20, -39)]
5735
)
@@ -103,55 +81,6 @@ def test_globcurrent_particles():
10381
assert abs(pset[0].lat - -35.3) < 1
10482

10583

106-
@pytest.mark.v4remove
107-
@pytest.mark.xfail(
108-
reason="Can't patch metadata without using xarray. v4 will natively use xarray anyway. GH1919."
109-
)
110-
@pytest.mark.parametrize("dt", [-300, 300])
111-
def test_globcurrent_xarray_vs_netcdf(dt):
112-
fieldsetNetcdf = set_globcurrent_fieldset(use_xarray=False)
113-
fieldsetxarray = set_globcurrent_fieldset(use_xarray=True)
114-
lonstart, latstart, runtime = (25, -35, timedelta(days=7))
115-
116-
psetN = parcels.ParticleSet(
117-
fieldsetNetcdf, pclass=parcels.Particle, lon=lonstart, lat=latstart
118-
)
119-
psetN.execute(parcels.AdvectionRK4, runtime=runtime, dt=dt)
120-
121-
psetX = parcels.ParticleSet(
122-
fieldsetxarray, pclass=parcels.Particle, lon=lonstart, lat=latstart
123-
)
124-
psetX.execute(parcels.AdvectionRK4, runtime=runtime, dt=dt)
125-
126-
assert np.allclose(psetN[0].lon, psetX[0].lon)
127-
assert np.allclose(psetN[0].lat, psetX[0].lat)
128-
129-
130-
@pytest.mark.v4remove
131-
@pytest.mark.xfail(
132-
reason="Timeslices will be removed in v4, as users will be able to use xarray directly."
133-
)
134-
@pytest.mark.parametrize("dt", [-300, 300])
135-
def test_globcurrent_netcdf_timestamps(dt):
136-
fieldsetNetcdf = set_globcurrent_fieldset()
137-
timestamps = fieldsetNetcdf.U.grid.timeslices
138-
fieldsetTimestamps = set_globcurrent_fieldset(timestamps=timestamps)
139-
lonstart, latstart, runtime = (25, -35, timedelta(days=7))
140-
141-
psetN = parcels.ParticleSet(
142-
fieldsetNetcdf, pclass=parcels.Particle, lon=lonstart, lat=latstart
143-
)
144-
psetN.execute(parcels.AdvectionRK4, runtime=runtime, dt=dt)
145-
146-
psetT = parcels.ParticleSet(
147-
fieldsetTimestamps, pclass=parcels.Particle, lon=lonstart, lat=latstart
148-
)
149-
psetT.execute(parcels.AdvectionRK4, runtime=runtime, dt=dt)
150-
151-
assert np.allclose(psetN.lon[0], psetT.lon[0])
152-
assert np.allclose(psetN.lat[0], psetT.lat[0])
153-
154-
15584
def test__particles_init_time():
15685
fieldset = set_globcurrent_fieldset()
15786

parcels/_datasets/structured/generic.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
__all__ = ["T", "X", "Y", "Z", "datasets"]
99

10+
TIME = xr.date_range("2000", "2001", T)
11+
1012

1113
def _rotated_curvilinear_grid():
1214
XG = np.arange(X)
@@ -44,7 +46,7 @@ def _rotated_curvilinear_grid():
4446
{"axis": "Z"},
4547
),
4648
"depth": (["ZG"], np.arange(Z), {"axis": "Z"}),
47-
"time": (["time"], xr.date_range("2000", "2001", T), {"axis": "T"}),
49+
"time": (["time"], TIME, {"axis": "T"}),
4850
"lon": (
4951
["YG", "XG"],
5052
LON,
@@ -119,7 +121,7 @@ def _unrolled_cone_curvilinear_grid():
119121
{"axis": "Z"},
120122
),
121123
"depth": (["ZG"], np.arange(Z), {"axis": "Z"}),
122-
"time": (["time"], xr.date_range("2000", "2001", T), {"axis": "T"}),
124+
"time": (["time"], TIME, {"axis": "T"}),
123125
"lon": (
124126
["YG", "XG"],
125127
LON,
@@ -175,7 +177,7 @@ def _unrolled_cone_curvilinear_grid():
175177
"lon": (["XG"], 2 * np.pi / X * np.arange(0, X)),
176178
"lat": (["YG"], 2 * np.pi / (Y) * np.arange(0, Y)),
177179
"depth": (["ZG"], np.arange(Z)),
178-
"time": (["time"], xr.date_range("2000", "2001", T), {"axis": "T"}),
180+
"time": (["time"], TIME, {"axis": "T"}),
179181
},
180182
),
181183
"2d_left_unrolled_cone": _unrolled_cone_curvilinear_grid(),

parcels/_datasets/unstructured/generic.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import math
2-
from datetime import datetime, timedelta
32

43
import numpy as np
5-
import pandas as pd
64
import uxarray as ux
5+
import xarray as xr
76

87
__all__ = ["Nx", "datasets"]
98

9+
T = 13
1010
Nx = 20
1111
vmax = 1.0
1212
delta = 0.1
13+
TIME = xr.date_range("2000", "2001", T)
1314

1415

1516
def _stommel_gyre_delaunay():
@@ -57,7 +58,7 @@ def _stommel_gyre_delaunay():
5758
uxgrid=uxgrid,
5859
dims=["time", "nz1", "n_node"],
5960
coords=dict(
60-
time=(["time"], pd.to_datetime(["2000-01-01"])),
61+
time=(["time"], [TIME[0]]),
6162
nz1=(["nz1"], [0]),
6263
),
6364
attrs=dict(
@@ -70,7 +71,7 @@ def _stommel_gyre_delaunay():
7071
uxgrid=uxgrid,
7172
dims=["time", "nz1", "n_node"],
7273
coords=dict(
73-
time=(["time"], pd.to_datetime(["2000-01-01"])),
74+
time=(["time"], [TIME[0]]),
7475
nz1=(["nz1"], [0]),
7576
),
7677
attrs=dict(
@@ -83,7 +84,7 @@ def _stommel_gyre_delaunay():
8384
uxgrid=uxgrid,
8485
dims=["time", "nz1", "n_node"],
8586
coords=dict(
86-
time=(["time"], pd.to_datetime(["2000-01-01"])),
87+
time=(["time"], [TIME[0]]),
8788
nz1=(["nz1"], [0]),
8889
),
8990
attrs=dict(description="pressure", units="N/m^2", location="node", mesh="delaunay", Conventions="UGRID-1.0"),
@@ -108,8 +109,6 @@ def _fesom2_square_delaunay_uniform_z_coordinate():
108109
zc = 0.5 * (zf[:-1] + zf[1:]) # Vertical element centers
109110
nz = zf.size
110111
nz1 = zc.size
111-
num_days = 5
112-
date_array = [datetime(2000, 1, 1) + timedelta(days=i) for i in range(num_days)]
113112

114113
# mask any point on one of the boundaries
115114
mask = (
@@ -127,25 +126,23 @@ def _fesom2_square_delaunay_uniform_z_coordinate():
127126

128127
# Define arrays U (zonal), V (meridional) and P (sea surface height)
129128
U = np.ones(
130-
(num_days, nz1, uxgrid.n_face), dtype=np.float64
129+
(T, nz1, uxgrid.n_face), dtype=np.float64
131130
) # Lateral velocity is on the element centers and face centers
132131
V = np.ones(
133-
(num_days, nz1, uxgrid.n_face), dtype=np.float64
132+
(T, nz1, uxgrid.n_face), dtype=np.float64
134133
) # Lateral velocity is on the element centers and face centers
135134
W = np.zeros(
136-
(num_days, nz, uxgrid.n_node), dtype=np.float64
135+
(T, nz, uxgrid.n_node), dtype=np.float64
137136
) # Vertical velocity is on the element faces and face vertices
138-
P = np.ones(
139-
(num_days, nz1, uxgrid.n_node), dtype=np.float64
140-
) # Pressure is on the element centers and face vertices
137+
P = np.ones((T, nz1, uxgrid.n_node), dtype=np.float64) # Pressure is on the element centers and face vertices
141138

142139
u = ux.UxDataArray(
143140
data=U,
144141
name="U",
145142
uxgrid=uxgrid,
146143
dims=["time", "nz1", "n_face"],
147144
coords=dict(
148-
time=(["time"], date_array),
145+
time=(["time"], TIME),
149146
nz1=(["nz1"], zc),
150147
),
151148
attrs=dict(
@@ -158,7 +155,7 @@ def _fesom2_square_delaunay_uniform_z_coordinate():
158155
uxgrid=uxgrid,
159156
dims=["time", "nz1", "n_face"],
160157
coords=dict(
161-
time=(["time"], date_array),
158+
time=(["time"], TIME),
162159
nz1=(["nz1"], zc),
163160
),
164161
attrs=dict(
@@ -171,7 +168,7 @@ def _fesom2_square_delaunay_uniform_z_coordinate():
171168
uxgrid=uxgrid,
172169
dims=["time", "nz", "n_node"],
173170
coords=dict(
174-
time=(["time"], date_array),
171+
time=(["time"], TIME),
175172
nz=(["nz"], zf),
176173
),
177174
attrs=dict(
@@ -184,7 +181,7 @@ def _fesom2_square_delaunay_uniform_z_coordinate():
184181
uxgrid=uxgrid,
185182
dims=["time", "nz1", "n_node"],
186183
coords=dict(
187-
time=(["time"], date_array),
184+
time=(["time"], TIME),
188185
nz1=(["nz1"], zc),
189186
),
190187
attrs=dict(description="pressure", units="N/m^2", location="node", mesh="delaunay", Conventions="UGRID-1.0"),

parcels/_reprs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from parcels import Field, FieldSet, ParticleSet
1010

1111

12-
def field_repr(field: Field) -> str:
12+
def field_repr(field: Field) -> str: # TODO v4: Rework or remove entirely
1313
"""Return a pretty repr for Field"""
1414
out = f"""<{type(field).__name__}>
1515
name : {field.name!r}
@@ -63,7 +63,7 @@ def particleset_repr(pset: ParticleSet) -> str:
6363
return textwrap.dedent(out).strip()
6464

6565

66-
def fieldset_repr(fieldset: FieldSet) -> str:
66+
def fieldset_repr(fieldset: FieldSet) -> str: # TODO v4: Rework or remove entirely
6767
"""Return a pretty repr for FieldSet"""
6868
fields_repr = "\n".join([repr(f) for f in fieldset.get_fields()])
6969

parcels/field.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from parcels._core.utils.time import TimeInterval
1414
from parcels._core.utils.unstructured import get_vertical_location_from_dims
15-
from parcels._reprs import default_repr, field_repr
15+
from parcels._reprs import default_repr
1616
from parcels._typing import (
1717
Mesh,
1818
VectorType,
@@ -212,9 +212,6 @@ def __init__(
212212
if "time" not in self.data.dims:
213213
raise ValueError("Field is missing a 'time' dimension. ")
214214

215-
def __repr__(self):
216-
return field_repr(self)
217-
218215
@property
219216
def units(self):
220217
return self._units

parcels/fieldset.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from parcels._core.utils.time import get_datetime_type_calendar
1111
from parcels._core.utils.time import is_compatible as datetime_is_compatible
12-
from parcels._reprs import fieldset_repr
1312
from parcels._typing import Mesh
1413
from parcels.field import Field, VectorField
1514
from parcels.v4.grid import Grid
@@ -79,9 +78,6 @@ def time_interval(self):
7978
time_intervals = (t for t in time_intervals if t is not None)
8079
return functools.reduce(lambda x, y: x.intersection(y), time_intervals)
8180

82-
def __repr__(self):
83-
return fieldset_repr(self)
84-
8581
def dimrange(self, dim):
8682
"""Returns maximum value of a dimension (lon, lat, depth or time)
8783
on 'left' side and minimum value on 'right' side for all grids

parcels/v4/gridadapter.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,17 @@ def depth(self):
7070
return self.grid._ds["depth"].values
7171

7272
@property
73-
def time(self):
73+
def _datetimes(self):
7474
try:
7575
axis = self.grid.axes["T"]
7676
except KeyError:
7777
return np.zeros(1)
7878
return get_time(axis)
7979

80+
@property
81+
def time(self):
82+
return self._datetimes.astype(np.float64) / 1e9
83+
8084
@property
8185
def xdim(self):
8286
return get_dimensionality(self.grid.axes.get("X"))
@@ -95,7 +99,7 @@ def tdim(self):
9599

96100
@property
97101
def time_origin(self):
98-
return TimeConverter(self.time[0])
102+
return TimeConverter(self._datetimes[0])
99103

100104
@property
101105
def _z4d(self) -> Literal[0, 1]:

tests/test_field.py

Lines changed: 0 additions & 70 deletions
This file was deleted.

0 commit comments

Comments
 (0)