|
7 | 7 | import xarray as xr |
8 | 8 |
|
9 | 9 | from parcels.grid import GridCode |
| 10 | +from parcels.grid import CurvilinearGrid |
10 | 11 | from parcels.kernel import Kernel |
11 | 12 | from parcels.particle import JITParticle |
12 | 13 | from parcels.particlefile import ParticleFile |
|
20 | 21 | from mpi4py import MPI |
21 | 22 | except: |
22 | 23 | MPI = None |
| 24 | +# == comment CK: prevents us from adding KDTree as 'mandatory' dependency == # |
| 25 | +try: |
| 26 | + from pykdtree.kdtree import KDTree |
| 27 | +except: |
| 28 | + KDTree = None |
23 | 29 |
|
24 | 30 | __all__ = ['ParticleSet', 'ParticleSetSOA'] |
25 | 31 |
|
@@ -187,6 +193,36 @@ def indexed_subset(self, indices): |
187 | 193 | return ParticleCollectionIteratorSOA(self._collection, |
188 | 194 | subset=indices) |
189 | 195 |
|
| 196 | + def populate_indices(self): |
| 197 | + """Pre-populate guesses of particle xi/yi indices using a kdtree. |
| 198 | +
|
| 199 | + This is only intended for curvilinear grids, where the initial index search |
| 200 | + may be quite expensive. |
| 201 | + """ |
| 202 | + |
| 203 | + if self.fieldset is None: |
| 204 | + # we need to be attached to a fieldset to have a valid |
| 205 | + # gridset to search for indices |
| 206 | + return |
| 207 | + |
| 208 | + if KDTree is None: |
| 209 | + return |
| 210 | + else: |
| 211 | + for i, grid in enumerate(self.fieldset.gridset.grids): |
| 212 | + if not isinstance(grid, CurvilinearGrid): |
| 213 | + continue |
| 214 | + |
| 215 | + tree_data = np.stack((grid.lon.flat, grid.lat.flat), axis=-1) |
| 216 | + tree = KDTree(tree_data) |
| 217 | + # stack all the particle positions for a single query |
| 218 | + pts = np.stack((self._collection.data['lon'], self._collection.data['lat']), axis=-1) |
| 219 | + # query datatype needs to match tree datatype |
| 220 | + _, idx = tree.query(pts.astype(tree_data.dtype)) |
| 221 | + yi, xi = np.unravel_index(idx, grid.lon.shape) |
| 222 | + |
| 223 | + self._collection.data['xi'][:, i] = xi |
| 224 | + self._collection.data['yi'][:, i] = yi |
| 225 | + |
190 | 226 | @property |
191 | 227 | def error_particles(self): |
192 | 228 | """Get an iterator over all particles that are in an error state. |
|
0 commit comments