Skip to content

Commit 3db3eb2

Browse files
committed
fix: address PR ioos#124 review (subset_vars, utils cleanup, tests)
- Remove deprecated assign_ugrid_topology from utils (use grids.ugrid only) - Raise ValueError from xsg.subset_vars when no grid is recognized - test_utils: single utils import style; stronger mask checks; list polygon test - test_accessor: expect ValueError for subset_vars without grid - test_sgrid: drop zarr>=3 skip for online kerchunk test (per review) Made-with: Cursor
1 parent bda654b commit 3db3eb2

5 files changed

Lines changed: 44 additions & 72 deletions

File tree

tests/test_accessor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,11 @@ def test_subset_polygon_and_bbox_return_none_without_grid():
2727
assert ds.xsg.subset_bbox((-72, 39, -70, 41)) is None
2828

2929

30-
def test_subset_vars_passthrough_without_grid():
30+
def test_subset_vars_raises_without_grid():
3131
ds = xr.Dataset({"a": (("x",), [1, 2, 3])})
3232
with pytest.warns(UserWarning, match="no grid type"):
33-
out = ds.xsg.subset_vars(["a"])
34-
# Without a recognized grid, subset_vars returns the dataset unchanged.
35-
assert "a" in out.data_vars
33+
with pytest.raises(ValueError, match="subset_vars requires a recognized grid"):
34+
ds.xsg.subset_vars(["a"])
3635

3736

3837
def test_has_vertical_levels_false_without_grid():

tests/test_grids/test_sgrid.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,8 @@
88
from xarray_subset_grid.grids.sgrid import _get_location_info_from_topology
99

1010
# open dataset as zarr object using fsspec reference file system and xarray
11-
zarr__version__ = 0
1211
try:
1312
import fsspec
14-
import zarr
15-
16-
zarr__version__ = int(zarr.__version__.split(".")[0])
1713
except ImportError:
1814
fsspec = None
1915

@@ -52,9 +48,6 @@ def test_grid_topology_location_parse():
5248
}
5349

5450

55-
@pytest.mark.skipif(
56-
zarr__version__ >= 3, reason="zarr3.0.8 doesn't support FSpec AWS (it might soon)"
57-
)
5851
@pytest.mark.online
5952
def test_polygon_subset():
6053
"""

tests/test_utils.py

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,7 @@
66
import pytest
77
import xarray as xr
88

9-
from tests.conftest import EXAMPLE_DATA
10-
from xarray_subset_grid import utils as xsg_utils
11-
from xarray_subset_grid.utils import (
12-
asdatetime,
13-
compute_2d_subset_mask,
14-
format_bytes,
15-
normalize_bbox_x_coords,
16-
normalize_polygon_x_coords,
17-
ray_tracing_numpy,
18-
)
9+
import xarray_subset_grid.utils as xsg_utils
1910

2011
# normalize_polygon_x_coords tests.
2112

@@ -77,12 +68,7 @@ def get_test_file_dir():
7768
)
7869
def test_normalize_x_coords(lons, poly, norm_poly):
7970
lons = np.array(lons)
80-
normalized_polygon = normalize_polygon_x_coords(lons, np.array(poly))
81-
print(f"{lons=}")
82-
print(f"{poly=}")
83-
print(f"{norm_poly=}")
84-
print(f"{normalized_polygon=}")
85-
71+
normalized_polygon = xsg_utils.normalize_polygon_x_coords(lons, np.array(poly))
8672
assert np.allclose(normalized_polygon, norm_poly)
8773

8874

@@ -105,7 +91,7 @@ def test_normalize_x_coords(lons, poly, norm_poly):
10591
)
10692
def test_normalize_x_coords_bbox(lons, bbox, norm_bbox):
10793
lons = np.array(lons)
108-
normalized_polygon = normalize_bbox_x_coords(lons, bbox)
94+
normalized_polygon = xsg_utils.normalize_bbox_x_coords(lons, bbox)
10995
assert np.allclose(normalized_polygon, norm_bbox)
11096

11197

@@ -130,7 +116,7 @@ def test_ray_tracing_numpy():
130116
]
131117
)
132118

133-
result = ray_tracing_numpy(points[:, 0], points[:, 1], poly)
119+
result = xsg_utils.ray_tracing_numpy(points[:, 0], points[:, 1], poly)
134120

135121
assert np.array_equal(result, [False, True, False])
136122

@@ -144,25 +130,25 @@ def test_ray_tracing_numpy():
144130
],
145131
)
146132
def test_format_bytes(num, unit):
147-
assert unit in format_bytes(num)
133+
assert unit in xsg_utils.format_bytes(num)
148134

149135

150136
def test_asdatetime_none():
151-
assert asdatetime(None) is None
137+
assert xsg_utils.asdatetime(None) is None
152138

153139

154140
def test_asdatetime_datetime_passthrough():
155141
dt = datetime(2020, 6, 15, 12, 30, 0)
156-
assert asdatetime(dt) is dt
142+
assert xsg_utils.asdatetime(dt) is dt
157143

158144

159145
def test_asdatetime_cftime_passthrough():
160146
dt = cftime.datetime(2020, 6, 15, 12)
161-
assert asdatetime(dt) is dt
147+
assert xsg_utils.asdatetime(dt) is dt
162148

163149

164150
def test_asdatetime_parse_string():
165-
dt = asdatetime("2020-06-15T12:30:00")
151+
dt = xsg_utils.asdatetime("2020-06-15T12:30:00")
166152
assert dt.year == 2020 and dt.month == 6 and dt.day == 15
167153

168154

@@ -174,31 +160,38 @@ def test_compute_2d_subset_mask_all_inside():
174160
lat_da = xr.DataArray(lat2d, dims=("y", "x"))
175161
lon_da = xr.DataArray(lon2d, dims=("y", "x"))
176162
poly = np.array([(-75.0, 39.0), (-69.0, 39.0), (-69.0, 45.0), (-75.0, 45.0)])
177-
mask = compute_2d_subset_mask(lat_da, lon_da, poly)
163+
mask = xsg_utils.compute_2d_subset_mask(lat_da, lon_da, poly)
178164
assert mask.dims == ("y", "x")
179-
assert bool(mask.all())
165+
assert mask.all()
180166

181167

182168
def test_compute_2d_subset_mask_partial():
183-
ny, nx = 7, 7
184-
lat = np.linspace(40.0, 46.0, ny)
185-
lon = np.linspace(-74.0, -68.0, nx)
169+
# Include explicit lon/lat nodes inside the polygon so the mask can be checked at a
170+
# non-boundary grid point (ray-casting is ambiguous on polygon edges).
171+
lat = np.array([40.0, 40.5, 41.0, 43.0, 46.0])
172+
lon = np.array([-74.5, -73.75, -73.0, -71.0, -68.0])
186173
lat2d, lon2d = np.meshgrid(lat, lon, indexing="ij")
187174
lat_da = xr.DataArray(lat2d, dims=("y", "x"))
188175
lon_da = xr.DataArray(lon2d, dims=("y", "x"))
189-
# Small polygon over the south-west corner only
190176
poly = np.array([(-74.5, 40.0), (-73.0, 40.0), (-73.0, 41.0), (-74.5, 41.0)])
191-
mask = compute_2d_subset_mask(lat_da, lon_da, poly)
177+
mask = xsg_utils.compute_2d_subset_mask(lat_da, lon_da, poly)
192178
assert mask.dims == ("y", "x")
193-
assert bool(mask.any())
194-
assert not bool(mask.all())
195-
196-
197-
def test_assign_ugrid_topology_utils_deprecation_wrapper():
198-
nc = EXAMPLE_DATA / "SFBOFS_subset1.nc"
199-
if not nc.is_file():
200-
pytest.skip("example NetCDF not present")
201-
ds = xr.open_dataset(nc)
202-
with pytest.warns(DeprecationWarning, match="assign_ugrid_topology"):
203-
ds2 = xsg_utils.assign_ugrid_topology(ds, face_node_connectivity="nv")
204-
assert "mesh" in ds2.variables
179+
assert mask.any()
180+
assert not mask.all()
181+
i_inside = int(np.where(lat == 40.5)[0][0])
182+
j_inside = int(np.where(lon == -73.75)[0][0])
183+
assert mask.values[i_inside, j_inside]
184+
assert not mask.values[-1, -1]
185+
186+
187+
def test_compute_2d_subset_mask_list_polygon_coerced():
188+
"""list/tuple polygon vertices are accepted (coerced via normalize_polygon_x_coords)."""
189+
ny, nx = 5, 5
190+
lat = np.linspace(40.0, 44.0, ny)
191+
lon = np.linspace(-74.0, -70.0, nx)
192+
lat2d, lon2d = np.meshgrid(lat, lon, indexing="ij")
193+
lat_da = xr.DataArray(lat2d, dims=("y", "x"))
194+
lon_da = xr.DataArray(lon2d, dims=("y", "x"))
195+
poly = [(-75.0, 39.0), (-69.0, 39.0), (-69.0, 45.0), (-75.0, 45.0)]
196+
mask = xsg_utils.compute_2d_subset_mask(lat_da, lon_da, poly)
197+
assert mask.all()

xarray_subset_grid/accessor.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,12 @@ def subset_vars(self, vars: list[str]) -> xr.Dataset:
112112
:param vars: The variables to keep
113113
:return: The subsetted dataset
114114
"""
115-
if self._grid:
116-
return self._grid.subset_vars(self._ds, vars)
117-
return self._ds
115+
if not self._grid:
116+
raise ValueError(
117+
"subset_vars requires a recognized grid; this dataset has no grid type "
118+
"that xarray-subset-grid can use."
119+
)
120+
return self._grid.subset_vars(self._ds, vars)
118121

119122
@property
120123
def has_vertical_levels(self) -> bool:

xarray_subset_grid/utils.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import warnings
21
from datetime import datetime
32

43
import cf_xarray # noqa
@@ -95,21 +94,6 @@ def ray_tracing_numpy(x, y, poly):
9594
return inside
9695

9796

98-
# This is defined in ugrid.py
99-
# this placeholder for backwards compatibility for a brief period
100-
def assign_ugrid_topology(*args, **kwargs):
101-
warnings.warn(
102-
"The function `assign_ugrid_topology` has been moved to the "
103-
"`grids.ugrid` module. It will not be able to be called from "
104-
"the utils `module` in the future.",
105-
DeprecationWarning,
106-
stacklevel=2,
107-
)
108-
from .grids.ugrid import assign_ugrid_topology
109-
110-
return assign_ugrid_topology(*args, **kwargs)
111-
112-
11397
def format_bytes(num):
11498
"""This function will convert bytes to MB....
11599

0 commit comments

Comments
 (0)