Skip to content

Commit b23310e

Browse files
committed
Improve nearest-neighbor query perf
1 parent b8c4b74 commit b23310e

2 files changed

Lines changed: 97 additions & 9 deletions

File tree

pyresample/future/resamplers/nearest.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,22 +66,32 @@ def query_no_distance(target_lons, target_lats, valid_output_index,
6666
distance_upper_bound=radius,
6767
mask=mask)
6868

69+
if neighbours == 1:
70+
# Nearest-neighbor resampling only consumes one neighbor, so avoid
71+
# building and masking the full (rows, cols, neighbours) array shape.
72+
index_array = np.asarray(index_array, dtype=np.int64)
73+
index_array[index_array >= kdtree.n] = -1
74+
out = np.full(voi.size, -1, dtype=np.int64)
75+
out[voir] = index_array
76+
return out.reshape(voi.shape + (1,))
77+
6978
if index_array.ndim == 1:
7079
index_array = index_array[:, None]
7180

7281
# KDTree query returns out-of-bounds neighbors as `len(arr)`
7382
# which is an invalid index, we mask those out so -1 represents
7483
# invalid values
84+
#
7585
# voi is 2D (trows, tcols)
7686
# index_array is 2D (valid output pixels, neighbors)
7787
# there are as many Trues in voi as rows in index_array
78-
good_pixels = index_array < kdtree.n
79-
res_ia = np.empty(shape, dtype=int)
80-
mask = np.zeros(shape, dtype=bool)
81-
mask[voi, :] = good_pixels
82-
res_ia[mask] = index_array[good_pixels]
83-
res_ia[~mask] = -1
84-
return res_ia
88+
#
89+
# Write (valid_output_pixels, neighbours) index array into an output filled with
90+
# -1 and then overwrite out-of-bounds values in-place.
91+
out = np.full(shape, -1, dtype=np.int64)
92+
out[voi, :] = index_array
93+
out[out >= kdtree.n] = -1
94+
return out
8595

8696

8797
def _my_index(index_arr, vii, data_arr, vii_slices=None, ia_slices=None,
@@ -144,7 +154,12 @@ def _compute_radius_of_influence(self):
144154
logger.warning("Could not calculate destination definition "
145155
"resolution")
146156
dst_res = np.nan
147-
radius_of_influence = np.nanmax([src_res, dst_res])
157+
if np.isnan(src_res):
158+
radius_of_influence = dst_res
159+
elif np.isnan(dst_res):
160+
radius_of_influence = src_res
161+
else:
162+
radius_of_influence = max(src_res, dst_res)
148163
if np.isnan(radius_of_influence):
149164
logger.warning("Could not calculate radius_of_influence, falling "
150165
"back to 10000 meters. This may produce lower "
@@ -487,7 +502,9 @@ def _verify_input_object_type(self, data):
487502
"to dask arrays for computation and then converted back. To "
488503
"avoid this warning convert your numpy array before providing "
489504
"it to the resampler.", PerformanceWarning, stacklevel=3)
490-
data = data.copy()
505+
# Avoid copying the underlying ndarray; we only need a new wrapper
506+
# object so we can replace `.data` with a dask array.
507+
data = data.copy(deep=False)
491508
data.data = da.from_array(data.data, chunks="auto")
492509
return data
493510

pyresample/test/test_resamplers/test_nearest.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@
2424
import numpy as np
2525
import pytest
2626
import xarray as xr
27+
from pykdtree.kdtree import KDTree
2728
from pytest_lazy_fixtures import lf
2829

2930
from pyresample.future.geometry import AreaDefinition, SwathDefinition
3031
from pyresample.future.resamplers import KDTreeNearestXarrayResampler
32+
from pyresample.future.resamplers._transform_utils import lonlat2xyz
33+
from pyresample.future.resamplers.nearest import query_no_distance
3134
from pyresample.test.utils import assert_maximum_dask_computes, assert_warnings_contain, catch_warnings
3235
from pyresample.utils.errors import PerformanceWarning
3336

@@ -300,3 +303,71 @@ def test_inconsistent_input_shapes(self, src_geom, match, call_precompute,
300303
resampler.precompute(mask=data_2d_float32_xarray_dask.notnull())
301304
else:
302305
resampler.resample(data_2d_float32_xarray_dask)
306+
307+
308+
class TestQueryNoDistance:
309+
"""Tests for direct KDTree query index remapping."""
310+
311+
def test_unselected_and_oob_are_minus_one(self):
312+
voi = np.array([[True, False], [True, False]])
313+
tlons = np.array([[0.0, 0.0], [10.0, 0.0]], dtype=np.float64)
314+
tlats = np.zeros_like(tlons)
315+
316+
src_lons = np.array([0.0], dtype=np.float64)
317+
src_lats = np.array([0.0], dtype=np.float64)
318+
src_xyz = lonlat2xyz(src_lons, src_lats).astype(np.float64, copy=False)
319+
kdtree = KDTree(src_xyz)
320+
321+
res = query_no_distance(
322+
tlons,
323+
tlats,
324+
voi,
325+
neighbours=1,
326+
epsilon=0.0,
327+
radius=1.0, # meters; only exact match is within this ROI
328+
kdtree=kdtree,
329+
)
330+
331+
np.testing.assert_array_equal(res[..., 0], np.array([[0, -1], [-1, -1]]))
332+
333+
def test_forwards_filtered_source_mask(self):
334+
voi = np.array([[True]])
335+
336+
src_lons = np.array([[0.0, 0.0001], [0.0002, 0.0003]], dtype=np.float64)
337+
src_lats = np.zeros_like(src_lons)
338+
valid_input_index = np.array([[True, True], [True, False]])
339+
340+
src_xyz = lonlat2xyz(src_lons, src_lats).astype(np.float64, copy=False)
341+
kdtree = KDTree(src_xyz[valid_input_index.ravel()])
342+
343+
target_lons = np.array([[0.0]], dtype=np.float64)
344+
target_lats = np.array([[0.0]], dtype=np.float64)
345+
346+
res_unmasked = query_no_distance(
347+
target_lons,
348+
target_lats,
349+
voi,
350+
neighbours=1,
351+
epsilon=0.0,
352+
radius=1000.0,
353+
kdtree=kdtree,
354+
)
355+
356+
# Mask out the nearest source point (after valid_input_index filtering).
357+
source_mask = np.array([[True, False], [False, True]])
358+
res_masked = query_no_distance(
359+
target_lons,
360+
target_lats,
361+
voi,
362+
mask=source_mask,
363+
valid_input_index=valid_input_index,
364+
neighbours=1,
365+
epsilon=0.0,
366+
radius=1000.0,
367+
kdtree=kdtree,
368+
)
369+
370+
assert res_unmasked.shape == (1, 1, 1)
371+
assert res_masked.shape == (1, 1, 1)
372+
assert res_unmasked[0, 0, 0] == 0
373+
assert res_masked[0, 0, 0] == 1

0 commit comments

Comments
 (0)