Skip to content

Commit 6f90000

Browse files
Update UxGrid to use BaseGrid
1 parent 54bfed3 commit 6f90000

3 files changed

Lines changed: 22 additions & 121 deletions

File tree

parcels/basegrid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
class BaseGrid(ABC):
55
@abstractmethod
6-
def ravel_index(self, zi, yi, xi):
6+
def ravel_index(self, zi: int, yi: int, xi: int):
77
"""Return the flat index of the given grid points.
88
99
Parameters
@@ -23,7 +23,7 @@ def ravel_index(self, zi, yi, xi):
2323
...
2424

2525
@abstractmethod
26-
def unravel_index(self, ei):
26+
def unravel_index(self, ei: int):
2727
"""Return the zi, yi, xi indices for a given flat index.
2828
Only used when working with fields on a structured grid.
2929

parcels/field.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
FieldSamplingError,
3030
_raise_field_out_of_bound_error,
3131
)
32-
from parcels.uxgrid import UxGrid, ensure_uxgrid
32+
from parcels.uxgrid import UxGrid
3333
from parcels.v4.grid import Grid
3434
from parcels.v4.gridadapter import GridAdapter
3535

@@ -145,7 +145,7 @@ def __init__(
145145
self,
146146
name: str,
147147
data: xr.DataArray | ux.UxDataArray,
148-
grid: ux.Grid | UxGrid | Grid,
148+
grid: UxGrid | Grid,
149149
mesh_type: Mesh = "flat",
150150
interp_method: Callable | None = None,
151151
):
@@ -166,10 +166,7 @@ def __init__(
166166

167167
self.name = name
168168
self.data = data
169-
if isinstance(grid, ux.Grid):
170-
self.grid = ensure_uxgrid(grid)
171-
else:
172-
self.grid = grid
169+
self.grid = grid
173170

174171
try:
175172
self.time_interval = get_time_interval(data)

parcels/uxgrid.py

Lines changed: 17 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,30 @@
1+
from __future__ import annotations
2+
13
import numpy as np
24
import uxarray as ux
35
from uxarray.grid.neighbors import _barycentric_coordinates
46

57
from parcels.field import FieldOutOfBoundError # Adjust import as necessary
68

9+
from .basegrid import BaseGrid
10+
711

8-
class UxGrid(ux.grid.Grid):
12+
class UxGrid(BaseGrid):
913
"""
1014
Extension of uxarray's Grid class that supports point-location search
1115
for interpolation on unstructured grids.
1216
"""
1317

14-
@classmethod
15-
def from_uxgrid(cls, grid: ux.grid.Grid) -> "UxGrid":
16-
"""
17-
Create a UxGrid instance from an existing uxarray Grid instance.
18-
19-
Parameters
20-
----------
21-
grid : uxarray.grid.Grid
22-
A previously constructed uxarray Grid object.
23-
24-
Returns
25-
-------
26-
UxGrid
27-
A new UxGrid object with the same internal state.
28-
"""
29-
if isinstance(grid, cls):
30-
return grid # Already an extended grid
31-
32-
new = cls.__new__(cls)
33-
new.__dict__.update(grid.__dict__)
34-
return new
18+
def __init__(self, grid: ux.grid.Grid) -> UxGrid:
19+
self.uxgrid = grid
3520

3621
def search(
3722
self, field, z: float, y: float, x: float, ei: int | None = None, search2D: bool = False
3823
) -> tuple[np.ndarray, int]:
39-
"""
40-
Locate the unstructured grid face containing the point (x, y),
41-
returning interpolation weights and a face-based encoded index.
42-
43-
Parameters
44-
----------
45-
field : parcels.Field
46-
The field requesting the search. Used to access unravel_index(),
47-
ravel_index(), and igrid metadata.
48-
z : float
49-
Vertical coordinate of the query point. Currently ignored.
50-
y : float
51-
Latitude of the query point.
52-
x : float
53-
Longitude of the query point.
54-
ei : int, optional
55-
Encoded index to test reuse of previous face. If valid, neighbors
56-
of that face are also checked before falling back to global search.
57-
search2D : bool, default=False
58-
Ignored for now. Included for interface compatibility.
59-
60-
Returns
61-
-------
62-
bcoords : np.ndarray
63-
Barycentric coordinates of the point in the containing face.
64-
ei : int
65-
Encoded index (e.g., raveled face index) corresponding to the face found.
66-
67-
Raises
68-
------
69-
FieldOutOfBoundError
70-
If no containing face is found within tolerance.
71-
"""
7224
tol = 1e-10
7325

7426
def try_face(fid):
75-
bcoords, err = self._get_barycentric_coordinates(y, x, fid)
27+
bcoords, err = self.uxgrid._get_barycentric_coordinates(y, x, fid)
7628
if (bcoords >= 0).all() and (bcoords <= 1).all() and err < tol:
7729
return bcoords, field.ravel_index(0, 0, fid) # Z and time indices are 0 for now
7830
return None, None
@@ -84,15 +36,15 @@ def try_face(fid):
8436
return bcoords, ei_new
8537

8638
# Try neighbors of current face
87-
for neighbor in self.face_face_connectivity[fi, :]:
39+
for neighbor in self.uxgrid.face_face_connectivity[fi, :]:
8840
if neighbor == -1:
8941
continue
9042
bcoords, ei_new = try_face(neighbor)
9143
if bcoords is not None:
9244
return bcoords, ei_new
9345

9446
# Global fallback using spatial hash
95-
fi, bcoords = self.get_spatial_hash().query([[x, y]])
47+
fi, bcoords = self.uxgrid.get_spatial_hash().query([[x, y]])
9648
if fi == -1:
9749
raise FieldOutOfBoundError(z, y, x)
9850

@@ -101,12 +53,12 @@ def try_face(fid):
10153
def _get_barycentric_coordinates(self, y, x, fi):
10254
"""Checks if a point is inside a given face id on a UxGrid."""
10355
# Check if particle is in the same face, otherwise search again.
104-
n_nodes = self.n_nodes_per_face[fi].to_numpy()
105-
node_ids = self.face_node_connectivity[fi, 0:n_nodes]
56+
n_nodes = self.uxgrid.n_nodes_per_face[fi].to_numpy()
57+
node_ids = self.uxgrid.face_node_connectivity[fi, 0:n_nodes]
10658
nodes = np.column_stack(
10759
(
108-
np.deg2rad(self.grid.node_lon[node_ids].to_numpy()),
109-
np.deg2rad(self.grid.node_lat[node_ids].to_numpy()),
60+
np.deg2rad(self.uxgrid.grid.node_lon[node_ids].to_numpy()),
61+
np.deg2rad(self.uxgrid.grid.node_lat[node_ids].to_numpy()),
11062
)
11163
)
11264

@@ -116,57 +68,9 @@ def _get_barycentric_coordinates(self, y, x, fi):
11668
return bcoord, err
11769

11870
def ravel_index(self, zi, yi, xi):
119-
"""Return the flat index of the given grid points.
120-
121-
Parameters
122-
----------
123-
zi : int
124-
z index
125-
yi : int
126-
y index
127-
xi : int
128-
x index. When using an unstructured grid, this is the face index (fi)
129-
130-
Returns
131-
-------
132-
int
133-
flat index
134-
"""
135-
return xi + self.n_face * zi
71+
return xi + self.uxgrid.n_face * zi
13672

13773
def unravel_index(self, ei):
138-
"""Return the zi, yi, xi indices for a given flat index.
139-
Only used when working with fields on a structured grid.
140-
141-
Parameters
142-
----------
143-
ei : int
144-
The flat index to be unraveled.
145-
146-
Returns
147-
-------
148-
zi : int
149-
The z index.
150-
yi : int
151-
The y index.
152-
xi : int
153-
The x index.
154-
"""
155-
zi = ei // self.n_face
156-
fi = ei % self.n_face
74+
zi = ei // self.uxgrid.n_face
75+
fi = ei % self.uxgrid.n_face
15776
return zi, fi
158-
159-
160-
def ensure_uxgrid(grid: ux.grid.Grid) -> UxGrid:
161-
"""
162-
Ensure a given uxarray grid is an instance of UxGrid.
163-
164-
Parameters
165-
----------
166-
grid : uxarray.grid.Grid
167-
168-
Returns
169-
-------
170-
UxGrid
171-
"""
172-
return UxGrid.from_uxgrid(grid)

0 commit comments

Comments
 (0)