Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ New Features
^^^^^^^^^^^^
* Added an `engine` argument to `Grid.ds.to_netcdf()` to allow users to specify the engine used for writing NetCDF files (#439).
* Coding conventions have been updated to use Python 3.10+ features (#439).
* `core.subset.subset_gridpoint` will find nearest neighbours using a KDTree based on euclidean distance in lat/lon space instead of using great circle distances. The small loss in precision is compensated by a significant performance boost, especially for large grids and long point lists (#452).

Bug Fixes
^^^^^^^^^
Expand Down
80 changes: 39 additions & 41 deletions clisops/core/subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pyproj import Geod
from pyproj.crs import CRS
from pyproj.exceptions import CRSError
from scipy.spatial import KDTree
from shapely import vectorized
from shapely.geometry import LineString, MultiPolygon, Point, Polygon
from shapely.ops import split, unary_union
Expand Down Expand Up @@ -1526,8 +1527,7 @@ def subset_gridpoint(
Extract one or more of the nearest gridpoint(s) from datarray based on lat lon coordinate(s).

Return a subsetted data array (or Dataset) for the grid point(s) falling nearest the input longitude and latitude
coordinates. Optionally, subset the data array for years falling within provided date bounds.
Time series can optionally be subsetted by dates.
coordinates (as computed with a lat/lon euclidean distance). Time series can optionally be subsetted by dates.
If 1D sequences of coordinates are given, the gridpoints will be concatenated along the new dimension "site".

Parameters
Expand Down Expand Up @@ -1576,62 +1576,60 @@ def subset_gridpoint(
# Subset lat lon point
prSub = subset_gridpoint(ds.pr, lon=-75, lat=45)

# Subset multiple variables in a single dataset
ds = xr.open_mfdataset([path_to_tasmax_file, path_to_tasmin_file])
dsSub = subset_gridpoint(ds, lon=-75, lat=45)
# Drop locations where the closest gridpoint was too far (here 1000 km)
prSub = subset_gridpoint(ds.pr, lon=[-75, -60], lat=[45, 40], tolerance=1e6)
"""
if lat is None or lon is None:
raise ValueError("Insufficient coordinates provided to locate grid point(s).")

ptdim = lat.dims[0]
dist = None

lon_name = lon.name or "lon"
lat_name = lat.name or "lat"

srclon = get_coord_by_type(da, "longitude", ignore_aux_coords=False)
srclat = get_coord_by_type(da, "latitude", ignore_aux_coords=False)
# make sure input data has 'lon' and 'lat'(dims, coordinates, or data_vars)
if hasattr(da, lon_name) and hasattr(da, lat_name):
dims = list(da.dims)

# if 'lon' and 'lat' are present as data dimensions use the .sel method.
if lat_name in dims and lon_name in dims:
da = da.sel(lat=lat, lon=lon, method="nearest")

if tolerance is not None or add_distance:
# Calculate the geodesic distance between grid points and the point of interest.
dist = distance(da, lon=lon, lat=lat)
else:
dist = None

if srclon is not None and srclat is not None:
srclon = da[srclon]
srclat = da[srclat]
if srclon.ndim == 1 and srclat.ndim == 1 and srclon.dims != srclat.dims:
# lon and lat are 1D and don't share coords : rectilinear grid
da = da.sel({srclat.dims[0]: lat, srclon.dims[0]: lon}, method="nearest")
elif srclon.ndim == 2 and srclat.dims == srclon.dims:
# lon and lat are 2D and share coords : curvilinear grid
pts = np.vstack([srclon.values.flatten(), srclat.values.flatten()]).T
# The input is a grid, so already well-behaved, no need for the precision-improving features of KDTree
tree = KDTree(pts, compact_nodes=False, balanced_tree=False)
_, idxs = tree.query(np.vstack([lon.values, lat.values]).T)
iY, iX = np.unravel_index(idxs, shape=da.lon.shape)
iY = lon.copy(data=iY)
iX = lon.copy(data=iX)
da = da.isel({da.lon.dims[0]: iY, da.lon.dims[1]: iX})
elif srclon.ndim == 1 and srclat.ndim == 1 and srclon.dims == srclat.dims:
# lon and lat are 1D and share coords : list of points case
pts = np.vstack([srclon.values, srclat.values]).T
tree = KDTree(pts)
_, idxs = tree.query(np.vstack([lon.values, lat.values]).T)
idxs = lon.copy(data=idxs)
da = da.isel({srclon.dims[0]: idxs})
else:
# Calculate the geodesic distance between grid points and the point of interest.
dist = distance(da, lon=lon, lat=lat)
pts = []
dists = []
for site in dist[ptdim]:
# Find the indices for the closest point
inds = np.unravel_index(dist.sel({ptdim: site}).argmin(), dist.sel({ptdim: site}).shape)

# Select data from closest point
args = {xydim: ind for xydim, ind in zip(dist.dims, inds, strict=False)}
pts.append(da.isel(**args))
dists.append(dist.isel(**args))
da = xarray.concat(pts, dim=ptdim)
dist = xarray.concat(dists, dim=ptdim)
raise ValueError(f"Unrecognized coordinate type for longitude and latitude ({srclon.name}, {srclat.name})")
else:
raise (
Exception(
f'{subset_gridpoint.__name__} requires input data with "lon" and "lat" coordinates or data variables.'
)
)
raise ValueError("subset_gridpoint requires input data with longitude and latitude coordinates.")

if tolerance is not None or add_distance:
# Calculate the geodesic distance between grid points and the point of interest.
dist = distance(da, lon=lon, lat=lat)

if tolerance is not None and dist is not None:
if tolerance is not None:
da = da.where(dist < tolerance)

if add_distance:
da = da.assign_coords(distance=dist)

if len(lat) == 1:
da = da.squeeze(ptdim)
else:
da = da.transpose(..., ptdim)

if start_date or end_date:
da = subset_time(da, start_date=start_date, end_date=end_date)
Expand Down
6 changes: 6 additions & 0 deletions clisops/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ def is_latitude(coord: xr.DataArray | xr.Dataset) -> bool:
if hasattr(coord, "long_name") and coord.long_name == "latitude":
return True

if coord.name == "lat":
return True

return False


Expand Down Expand Up @@ -203,6 +206,9 @@ def is_longitude(coord: xr.DataArray | xr.Dataset) -> bool:
if hasattr(coord, "long_name") and coord.long_name == "longitude":
return True

if coord.name == "lon":
return True

return False


Expand Down
21 changes: 3 additions & 18 deletions tests/test_core_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def test_dataset(self, nimbus):
da = xr.open_mfdataset(
[nimbus.fetch(self.nc_tasmax_file), nimbus.fetch(self.nc_tasmin_file)],
combine="by_coords",
compat="override",
)
lon = -72.4
lat = 46.1
Expand Down Expand Up @@ -222,31 +223,15 @@ def test_irregular(self, nimbus):

# test_irregular transposed:
da1 = xr.open_dataset(nimbus.fetch(self.nc_2dlonlat)).tasmax
dims = list(da1.dims)
dims.reverse()
daT = xr.DataArray(np.transpose(da1.values), dims=dims)
for d in daT.dims:
args = dict()
args[d] = da1[d]
daT = daT.assign_coords(**args)
daT = daT.assign_coords(lon=(["rlon", "rlat"], np.transpose(da1.lon.values)))
daT = daT.assign_coords(lat=(["rlon", "rlat"], np.transpose(da1.lat.values)))
daT = da1.transpose(*list(reversed(da1.dims)))

out1 = subset.subset_gridpoint(daT, lon=lon, lat=lat)
np.testing.assert_almost_equal(out1.lon, lon, 1)
np.testing.assert_almost_equal(out1.lat, lat, 1)
np.testing.assert_array_equal(out, out1)

# Dataset with tasmax, lon and lat as data variables (i.e. lon, lat not coords of tasmax)
daT1 = xr.DataArray(np.transpose(da1.values), dims=dims)
for d in daT1.dims:
args = dict()
args[d] = da1[d]
daT1 = daT1.assign_coords(**args)
dsT = xr.Dataset(data_vars=None, coords=daT1.coords)
dsT["tasmax"] = daT1
dsT["lon"] = xr.DataArray(np.transpose(da1.lon.values), dims=["rlon", "rlat"])
dsT["lat"] = xr.DataArray(np.transpose(da1.lat.values), dims=["rlon", "rlat"])
dsT = daT.to_dataset().reset_coords(["lon", "lat"])
out2 = subset.subset_gridpoint(dsT, lon=lon, lat=lat)
np.testing.assert_almost_equal(out2.lon, lon, 1)
np.testing.assert_almost_equal(out2.lat, lat, 1)
Expand Down
Loading