1+ from __future__ import annotations
2+
13import numpy as np
24import uxarray as ux
35from uxarray .grid .neighbors import _barycentric_coordinates
46
57from parcels .field import FieldOutOfBoundError # Adjust import as necessary
68
9+ from .basegrid import BaseGrid
10+
711
8- class UxGrid (ux . grid . Grid ):
12+ class UxGrid (BaseGrid ):
913 """
1014 Extension of uxarray's Grid class that supports point-location search
1115 for interpolation on unstructured grids.
1216 """
1317
14- @classmethod
15- def from_uxgrid (cls , grid : ux .grid .Grid ) -> "UxGrid" :
16- """
17- Create a UxGrid instance from an existing uxarray Grid instance.
18-
19- Parameters
20- ----------
21- grid : uxarray.grid.Grid
22- A previously constructed uxarray Grid object.
23-
24- Returns
25- -------
26- UxGrid
27- A new UxGrid object with the same internal state.
28- """
29- if isinstance (grid , cls ):
30- return grid # Already an extended grid
31-
32- new = cls .__new__ (cls )
33- new .__dict__ .update (grid .__dict__ )
34- return new
18+ def __init__ (self , grid : ux .grid .Grid ) -> UxGrid :
19+ self .uxgrid = grid
3520
3621 def search (
3722 self , field , z : float , y : float , x : float , ei : int | None = None , search2D : bool = False
3823 ) -> tuple [np .ndarray , int ]:
39- """
40- Locate the unstructured grid face containing the point (x, y),
41- returning interpolation weights and a face-based encoded index.
42-
43- Parameters
44- ----------
45- field : parcels.Field
46- The field requesting the search. Used to access unravel_index(),
47- ravel_index(), and igrid metadata.
48- z : float
49- Vertical coordinate of the query point. Currently ignored.
50- y : float
51- Latitude of the query point.
52- x : float
53- Longitude of the query point.
54- ei : int, optional
55- Encoded index to test reuse of previous face. If valid, neighbors
56- of that face are also checked before falling back to global search.
57- search2D : bool, default=False
58- Ignored for now. Included for interface compatibility.
59-
60- Returns
61- -------
62- bcoords : np.ndarray
63- Barycentric coordinates of the point in the containing face.
64- ei : int
65- Encoded index (e.g., raveled face index) corresponding to the face found.
66-
67- Raises
68- ------
69- FieldOutOfBoundError
70- If no containing face is found within tolerance.
71- """
7224 tol = 1e-10
7325
7426 def try_face (fid ):
75- bcoords , err = self ._get_barycentric_coordinates (y , x , fid )
27+ bcoords , err = self .uxgrid . _get_barycentric_coordinates (y , x , fid )
7628 if (bcoords >= 0 ).all () and (bcoords <= 1 ).all () and err < tol :
7729 return bcoords , field .ravel_index (0 , 0 , fid ) # Z and time indices are 0 for now
7830 return None , None
@@ -84,15 +36,15 @@ def try_face(fid):
8436 return bcoords , ei_new
8537
8638 # Try neighbors of current face
87- for neighbor in self .face_face_connectivity [fi , :]:
39+ for neighbor in self .uxgrid . face_face_connectivity [fi , :]:
8840 if neighbor == - 1 :
8941 continue
9042 bcoords , ei_new = try_face (neighbor )
9143 if bcoords is not None :
9244 return bcoords , ei_new
9345
9446 # Global fallback using spatial hash
95- fi , bcoords = self .get_spatial_hash ().query ([[x , y ]])
47+ fi , bcoords = self .uxgrid . get_spatial_hash ().query ([[x , y ]])
9648 if fi == - 1 :
9749 raise FieldOutOfBoundError (z , y , x )
9850
@@ -101,12 +53,12 @@ def try_face(fid):
10153 def _get_barycentric_coordinates (self , y , x , fi ):
10254 """Checks if a point is inside a given face id on a UxGrid."""
10355 # Check if particle is in the same face, otherwise search again.
104- n_nodes = self .n_nodes_per_face [fi ].to_numpy ()
105- node_ids = self .face_node_connectivity [fi , 0 :n_nodes ]
56+ n_nodes = self .uxgrid . n_nodes_per_face [fi ].to_numpy ()
57+ node_ids = self .uxgrid . face_node_connectivity [fi , 0 :n_nodes ]
10658 nodes = np .column_stack (
10759 (
108- np .deg2rad (self .grid .node_lon [node_ids ].to_numpy ()),
109- np .deg2rad (self .grid .node_lat [node_ids ].to_numpy ()),
60+ np .deg2rad (self .uxgrid . grid .node_lon [node_ids ].to_numpy ()),
61+ np .deg2rad (self .uxgrid . grid .node_lat [node_ids ].to_numpy ()),
11062 )
11163 )
11264
@@ -116,57 +68,9 @@ def _get_barycentric_coordinates(self, y, x, fi):
11668 return bcoord , err
11769
11870 def ravel_index (self , zi , yi , xi ):
119- """Return the flat index of the given grid points.
120-
121- Parameters
122- ----------
123- zi : int
124- z index
125- yi : int
126- y index
127- xi : int
128- x index. When using an unstructured grid, this is the face index (fi)
129-
130- Returns
131- -------
132- int
133- flat index
134- """
135- return xi + self .n_face * zi
71+ return xi + self .uxgrid .n_face * zi
13672
13773 def unravel_index (self , ei ):
138- """Return the zi, yi, xi indices for a given flat index.
139- Only used when working with fields on a structured grid.
140-
141- Parameters
142- ----------
143- ei : int
144- The flat index to be unraveled.
145-
146- Returns
147- -------
148- zi : int
149- The z index.
150- yi : int
151- The y index.
152- xi : int
153- The x index.
154- """
155- zi = ei // self .n_face
156- fi = ei % self .n_face
74+ zi = ei // self .uxgrid .n_face
75+ fi = ei % self .uxgrid .n_face
15776 return zi , fi
158-
159-
160- def ensure_uxgrid (grid : ux .grid .Grid ) -> UxGrid :
161- """
162- Ensure a given uxarray grid is an instance of UxGrid.
163-
164- Parameters
165- ----------
166- grid : uxarray.grid.Grid
167-
168- Returns
169- -------
170- UxGrid
171- """
172- return UxGrid .from_uxgrid (grid )
0 commit comments