Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ v0.17.0 (2025-12-16)

New Features
^^^^^^^^^^^^
* Added support for a data mask in `subset_gridpoint` where the subsetted grid points but be within the mask (True) (#493).
* Allows choice between using true world distance (`distance`) or nearest neighbour based on lat, lon (`geographic`) methods for subsetting both regular and irregular grids (#493).
* Added an `engine` argument to `Grid.ds.to_netcdf()` to allow users to specify the engine used for writing NetCDF files (#439).
* Coding conventions have been updated to use Python 3.10+ features (#439).
* `Weights` will now use `post_mask_source='domain_edge'` introduced in `xesmf` version 0.9 when remapping a regional grid via nearest-neighbour to avoid extrapolation beyond the source domain (#447).
Expand All @@ -25,6 +27,7 @@ Bug Fixes

Breaking Changes
^^^^^^^^^^^^^^^^
* Default method for `subset_gridpoint` using regular lat,lon grids is now `distance` instead of previously employing the equivalent of the new `geographic` method (#493).
* Support for Python 3.10 has been dropped. `numpy >=1.26` is the new minimum supported version (#469).
* `Grid.detect_extent()` now returns a tuple `(lon_extent, lat_extent)` instead of only `lon_extent` (#447).
* `Grid.extent` now represents the combined lon/lat extent: `"global"` if both are global; otherwise `"regional"`. The new `Grid.extent_lon` and `Grid.extent_lat` attributes provide axis-specific extent information (#447).
Expand Down
137 changes: 109 additions & 28 deletions clisops/core/subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pyproj import Geod
from pyproj.crs import CRS
from pyproj.exceptions import CRSError
from scipy.spatial import KDTree
from shapely.geometry import LineString, MultiPolygon, Point, Polygon
from shapely.ops import split, unary_union
from xarray.core import indexing
Expand Down Expand Up @@ -1515,12 +1516,14 @@ def subset_gridpoint(
da: xarray.DataArray | xarray.Dataset,
lon: float | Sequence[float] | xarray.DataArray | None = None,
lat: float | Sequence[float] | xarray.DataArray | None = None,
method: str | None = "distance",
start_date: str | None = None,
end_date: str | None = None,
first_level: float | int | None = None,
last_level: float | int | None = None,
tolerance: float | None = None,
add_distance: bool = False,
mask: np.ndarray | xarray.DataArray | None = None,
) -> xarray.DataArray | xarray.Dataset:
"""
Extract one or more of the nearest gridpoint(s) from datarray based on lat lon coordinate(s).
Expand All @@ -1538,6 +1541,10 @@ def subset_gridpoint(
Longitude coordinate(s). Must be of the same length as lat.
lat : float, Sequence[float], xarray.DataArray, optional
Latitude coordinate(s). Must be of the same length as lon.
method : str, optional
Method to use for finding the nearest grid point. Options are "geographic" (default) and "distance";
"geographic" uses longitude and latitude coordinates directly while "distance" calculates distance
on the Earth's surface.
start_date : str, optional
Start date of the subset.
Date string format -- can be year ("%Y"), year-month ("%Y-%m") or year-month-day("%Y-%m-%d").
Expand All @@ -1558,6 +1565,9 @@ def subset_gridpoint(
Masks values if the distance to the nearest gridpoint is larger than tolerance in meters.
add_distance : bool
Whether to add a new coordinate "distance" to the output DataArray or Dataset.
mask : bool
2d boolean array with the same spatial dimensions as da, where True values indicate valid
grid points to be considered for subsetting.

Returns
-------
Expand All @@ -1580,47 +1590,105 @@ def subset_gridpoint(
ds = xr.open_mfdataset([path_to_tasmax_file, path_to_tasmin_file])
dsSub = subset_gridpoint(ds, lon=-75, lat=45)
"""
if lat is None or lon is None:
raise ValueError("Insufficient coordinates provided to locate grid point(s).")

ptdim = lat.dims[0]

lon_name = lon.name or "lon"
lat_name = lat.name or "lat"
def _subset_gridpoint_mask(
da: xarray.DataArray | xarray.Dataset,
lon: float | Sequence[float] | xarray.DataArray | None = None,
lat: float | Sequence[float] | xarray.DataArray | None = None,
mask: np.ndarray | xarray.DataArray | None = None,
ptdim: str | None = None,
) -> xarray.DataArray | xarray.Dataset:

# make sure input data has 'lon' and 'lat'(dims, coordinates, or data_vars)
if hasattr(da, lon_name) and hasattr(da, lat_name):
dims = list(da.dims)

dists = None
lon_name = lon.name or "lon"
lat_name = lat.name or "lat"
# if 'lon' and 'lat' are present as data dimensions use the .sel method.
if lat_name in dims and lon_name in dims:
dims_flag = lat_name in dims and lon_name in dims
if method == "geographic" and mask is None and dims_flag:
da = da.sel(lat=lat, lon=lon, method="nearest")

if tolerance is not None or add_distance:
# Calculate the geodesic distance between grid points and the point of interest.
dist = distance(da, lon=lon, lat=lat)
if add_distance or tolerance is not None:
dists = distance(da, lon=lon, lat=lat)
elif method == "geographic" and (mask is not None or not dims_flag):
# 1. Apply mask and extract coordinates of valid points
# Create a 2D array of (lon, lat) pairs from the valid data
if len(da[lat_name].dims) == 1 and len(da[lon_name].dims) == 1:
lon_grid, lat_grid = xarray.broadcast(da[lon_name], da[lat_name])
elif len(da[lat_name].dims) == 2 and len(da[lon_name].dims) == 2:
lon_grid = da[lon_name]
lat_grid = da[lat_name]
else:
dist = None
raise ValueError("Latitude and longitude coordinates must be either 1D or 2D arrays.")

else:
# Calculate the geodesic distance between grid points and the point of interest.
dist = distance(da, lon=lon, lat=lat)
pts = []
dists = []
dim0, dim1 = lat_grid.dims
if mask is not None:
lon_grid = lon_grid.where(mask)
lat_grid = lat_grid.where(mask)

# Flatten and remove NaNs (points that didn't pass the mask)
v_lons = lon_grid.values.ravel()
v_lats = lat_grid.values.ravel()
valid_idx = ~np.isnan(v_lons)

coords_pool = np.column_stack((v_lons[valid_idx], v_lats[valid_idx]))

# 2. Build the KD Tree
tree = KDTree(coords_pool)
Comment thread
tlogan2000 marked this conversation as resolved.

# 3. Batch Query for all sites
# target_lons/lats should be lists or 1D arrays of equal length
targets = np.column_stack((lon, lat))
distances, indices = tree.query(targets)

# 4. Extract the closest valid coordinates
nearest_coords = coords_pool[indices]

# 5. Convert result to DataArrays for xarray indexing
nearest_lons = nearest_coords[:, 0]
nearest_lats = nearest_coords[:, 1]

# 6. subset ds
idx = np.where((np.isin(lon_grid, nearest_lons)) & (np.isin(lat_grid, nearest_lats)))

da = da.isel({dim0: xarray.DataArray(idx[0], dims=ptdim), dim1: xarray.DataArray(idx[1], dims=ptdim)})

if add_distance is not None or tolerance is not None:
# Calculate the geodesic distance between grid points and the point of interest.
dists = distance(da, lon=lon, lat=lat, mask=mask)
elif method == "distance":
dist = distance(da, lon=lon, lat=lat, mask=mask)
args = {xydim: [] for xydim in dist.dims if xydim != ptdim}
for site in dist[ptdim]:
# Find the indices for the closest point
distances = dist.sel({ptdim: site})
inds = np.unravel_index(np.nanargmin(distances), distances.shape)
for xydim, ind in zip(dist.dims, inds, strict=False):
args[xydim].append(ind)
for xydim in args.keys():
args[xydim] = xarray.DataArray(args[xydim], dims=ptdim)
da = da.isel(**args)
if add_distance or tolerance is not None:
dists = dist.isel(**args)
else:
raise ValueError(f"Method {method} not recognized. Use 'geographic' or 'distance'.")

return da, dists

if lat is None or lon is None:
raise ValueError("Insufficient coordinates provided to locate grid point(s).")

ptdim = lat.dims[0]

lon_name = lon.name or "lon"
lat_name = lat.name or "lat"

# make sure input data has 'lon' and 'lat'(dims, coordinates, or data_vars)
if hasattr(da, lon_name) and hasattr(da, lat_name):
da, dist = _subset_gridpoint_mask(da=da, lat=lat, lon=lon, mask=mask, ptdim=ptdim)

# Select data from closest point
args = {xydim: ind for xydim, ind in zip(dist.dims, inds, strict=False)}
pts.append(da.isel(**args))
dists.append(dist.isel(**args))
da = xarray.concat(pts, dim=ptdim)
dist = xarray.concat(dists, dim=ptdim)
else:
raise (
Exception(
ValueError(
f'{subset_gridpoint.__name__} requires input data with "lon" and "lat" coordinates or data variables.'
)
)
Expand Down Expand Up @@ -1905,6 +1973,7 @@ def distance(
*,
lon: float | Sequence[float] | xarray.DataArray,
lat: float | Sequence[float] | xarray.DataArray,
mask: np.ndarray | xarray.DataArray | None = None,
) -> xarray.DataArray | xarray.Dataset:
"""
Return distance to a point in meters.
Expand All @@ -1917,6 +1986,9 @@ def distance(
Longitude coordinate.
lat : float, sequence of floats, or xarray.DataArray
Latitude coordinate.
mask : np.ndarray or xarray.DataArray, optional
2d boolean array with the same spatial dimensions as da,
where True values indicate valid grid points to be considered for distance calculation. Optional.

Returns
-------
Expand All @@ -1943,9 +2015,18 @@ def distance(
def _func(lons, lats, lon, lat):
return g.inv(lons, lats, lon, lat)[2]

if len(da.lon.dims) == 1 and len(da.lat.dims) == 1:
lon_grid, lat_grid = xarray.broadcast(da.lon, da.lat)
else:
lon_grid = da.lon
lat_grid = da.lat
if mask is not None:
lon_grid = lon_grid.where(mask)
lat_grid = lat_grid.where(mask)

out = xarray.apply_ufunc(
_func,
*xarray.broadcast(da.lon.load(), da.lat.load(), lon, lat),
*xarray.broadcast(lon_grid.load(), lat_grid.load(), lon, lat),
input_core_dims=[[ptdim]] * 4,
output_core_dims=[[ptdim]],
)
Expand Down
75 changes: 62 additions & 13 deletions tests/test_core_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,50 +158,54 @@ class TestSubsetGridPoint:
nc_tasmin_file = "NRCANdaily/nrcan_canada_daily_tasmin_1990.nc"
nc_2dlonlat = "CRCM5/tasmax_bby_198406_se.nc"

def test_time_simple(self, nimbus):
@pytest.mark.parametrize("method", ["distance", "geographic"])
def test_time_simple(self, method, nimbus):
da = xr.open_dataset(nimbus.fetch(self.nc_poslons)).tas
da = da.assign_coords(lon=(da.lon - 360))
lon = -72.4
lat = 46.1
yr_st = "2050"
yr_ed = "2059"

out = subset.subset_gridpoint(da, lon=lon, lat=lat, start_date=yr_st, end_date=yr_ed)
out = subset.subset_gridpoint(da, lon=lon, lat=lat, method=method, start_date=yr_st, end_date=yr_ed)
np.testing.assert_almost_equal(out.lon, lon, 1)
np.testing.assert_almost_equal(out.lat, lat, 1)
np.testing.assert_array_equal(len(np.unique(out.time.dt.year)), 10)
np.testing.assert_array_equal(out.time.dt.year.max(), int(yr_ed))
np.testing.assert_array_equal(out.time.dt.year.min(), int(yr_st))

def test_dataset(self, nimbus):
@pytest.mark.parametrize("method", ["distance", "geographic"])
def test_dataset(self, method, nimbus):
da = xr.open_mfdataset(
[nimbus.fetch(self.nc_tasmax_file), nimbus.fetch(self.nc_tasmin_file)],
combine="by_coords",
compat="no_conflicts",
)
lon = -72.4
lat = 46.1
out = subset.subset_gridpoint(da, lon=lon, lat=lat)
out = subset.subset_gridpoint(da, lon=lon, lat=lat, method=method)
np.testing.assert_almost_equal(out.lon, lon, 1)
np.testing.assert_almost_equal(out.lat, lat, 1)
np.testing.assert_array_equal(out.tasmin.shape, out.tasmax.shape)

@pytest.mark.parametrize("lon,lat", [([-72.4], [46.1]), ([-67.4, -67.3], [43.1, 46.1])])
@pytest.mark.parametrize("add_distance", [True, False])
def test_simple(self, lat, lon, add_distance, nimbus):
@pytest.mark.parametrize("method", ["distance", "geographic"])
def test_simple(self, lat, lon, add_distance, method, nimbus):
da = xr.open_dataset(nimbus.fetch(self.nc_tasmax_file)).tasmax

out = subset.subset_gridpoint(da, lon=lon, lat=lat, add_distance=add_distance)
out = subset.subset_gridpoint(da, lon=lon, lat=lat, add_distance=add_distance, method=method)
np.testing.assert_almost_equal(out.lon, lon, 1)
np.testing.assert_almost_equal(out.lat, lat, 1)

assert ("site" in out.dims) ^ (len(lat) == 1)
assert ("distance" in out.coords) ^ (not add_distance)

def test_irregular(self, nimbus):
@pytest.mark.parametrize("method", ["distance", "geographic"])
def test_irregular(self, method, nimbus):
da = xr.open_dataset(nimbus.fetch(self.nc_2dlonlat)).tasmax
lon = -72.4
lat = 46.1
out = subset.subset_gridpoint(da, lon=lon, lat=lat)
out = subset.subset_gridpoint(da, lon=lon, lat=lat, method=method)
np.testing.assert_almost_equal(out.lon, lon, 1)
np.testing.assert_almost_equal(out.lat, lat, 1)
assert "site" not in out.dims
Expand Down Expand Up @@ -235,7 +239,7 @@ def test_irregular(self, nimbus):
out1 = subset.subset_gridpoint(daT, lon=lon, lat=lat)
np.testing.assert_almost_equal(out1.lon, lon, 1)
np.testing.assert_almost_equal(out1.lat, lat, 1)
np.testing.assert_array_equal(out, out1)
np.testing.assert_array_equal(out, out1.transpose(*out.dims))

# Dataset with tasmax, lon and lat as data variables (i.e. lon, lat not coords of tasmax)
daT1 = xr.DataArray(np.transpose(da1.values), dims=dims)
Expand All @@ -250,11 +254,11 @@ def test_irregular(self, nimbus):
out2 = subset.subset_gridpoint(dsT, lon=lon, lat=lat)
np.testing.assert_almost_equal(out2.lon, lon, 1)
np.testing.assert_almost_equal(out2.lat, lat, 1)
np.testing.assert_array_equal(out, out2.tasmax)
np.testing.assert_array_equal(out, out2.tasmax.transpose(*out.dims))

# Dataset with lon and lat as 1D arrays
lon = -60
lat = -45
lat = 45
da = xr.DataArray(
np.random.rand(5, 4),
dims=("time", "site"),
Expand All @@ -272,6 +276,50 @@ def test_irregular(self, nimbus):
np.testing.assert_almost_equal(gp.lat, lat)
assert gp.site == 0

# extracting two points close together should give a duplicate point in the output
# extract the same grid cell for two 'sites'
lon = [-60, -59]
lat = [-45, 45]
gp = subset.subset_gridpoint(ds, lon=lon, lat=lat)
# 'site' dim already in input da so output has '_site' dim
np.testing.assert_array_equal(gp.da.isel(_site=0), gp.da.isel(_site=1))
assert len(gp._site) == 2 and len(gp.lon) == 2 and len(gp.lat) == 2
assert len(np.unique(gp._site)) == 2 and len(np.unique(gp.lon)) == 1 and len(np.unique(gp.lat)) == 1

@pytest.mark.parametrize("method", ["distance", "geographic"])
def test_masked(self, method, nimbus):
da = xr.open_dataset(nimbus.fetch(self.nc_tasmax_file)).tasmax
# mask where there is valid data
mask = ~np.isnan(da.isel(time=0))
# lat lon close to coastline where there are masked gridcells in the dataset
# Halifax harbor
lon = -63.48131910815178
lat = 44.56206467361616
out = subset.subset_gridpoint(da, lon=lon, lat=lat, method=method)
out_mask = subset.subset_gridpoint(da, lon=lon, lat=lat, method=method, mask=mask)
assert out.isnull().all()
assert out_mask.notnull().all()
np.testing.assert_almost_equal(out_mask.mean(), 284.91546631)

@pytest.mark.parametrize("method", ["distance", "geographic"])
def test_masked_irregular(self, method, nimbus, clisops_test_data):
da = xr.open_dataset(nimbus.fetch(self.nc_2dlonlat)).tasmax
regions = gpd.read_file(clisops_test_data["multi_regions_geojson"])
reg_mask = subset.create_mask(x_dim=da.lon, y_dim=da.lat, poly=regions)
da = da.where(reg_mask == 0) # Quebec only

# mask where there is valid data
mask = ~np.isnan(da.isel(time=0))
# lat lon close to coastline where there are masked gridcells in the dataset
# Halifax harbor
lon = -63.009725545080705
lat = 48.25160814508184
out = subset.subset_gridpoint(da, lon=lon, lat=lat, method=method)
out_mask = subset.subset_gridpoint(da, lon=lon, lat=lat, method=method, mask=mask)
assert out.isnull().all()
assert out_mask.notnull().all()
np.testing.assert_almost_equal(out_mask.mean(), 285.34141642)

def test_positive_lons(self, nimbus):
da = xr.open_dataset(nimbus.fetch(self.nc_poslons)).tas
lon = -72.4
Expand Down Expand Up @@ -300,7 +348,8 @@ def test_tolerance(self, nimbus):
out = subset.subset_gridpoint(da, lon=lon, lat=lat, tolerance=1)
assert out.isnull().all()

subset.subset_gridpoint(da, lon=lon, lat=lat, tolerance=1e5)
out = subset.subset_gridpoint(da, lon=lon, lat=lat, tolerance=1e5)
assert out.notnull().all()


class TestSubsetBbox:
Expand Down