Skip to content

Commit dd89755

Browse files
fix: address review comments from PR #124 (utils cleanup, accessor behavior, tests) (#130)
* test: cover utils helpers and xsg accessor; fix deprecations - Add tests for format_bytes, asdatetime, compute_2d_subset_mask, and utils.assign_ugrid_topology forwarding with DeprecationWarning - Add accessor tests for unknown-grid warning and None subset behavior - Fix warnings.warn argument order in assign_ugrid_topology; correct message - Coerce polygon coords via np.asarray in normalize_polygon_x_coords - Define zarr__version__ when fsspec/zarr import fails (test collection) Made-with: Cursor * fix: address PR #124 review (subset_vars, utils cleanup, tests) - Remove deprecated assign_ugrid_topology from utils (use grids.ugrid only) - Raise ValueError from xsg.subset_vars when no grid is recognized - test_utils: single utils import style; stronger mask checks; list polygon test - test_accessor: expect ValueError for subset_vars without grid - test_sgrid: drop zarr>=3 skip for online kerchunk test (per review) Made-with: Cursor
1 parent 3150e3f commit dd89755

5 files changed

Lines changed: 132 additions & 38 deletions

File tree

tests/test_accessor.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import numpy as np
2+
import pytest
3+
import xarray as xr
4+
5+
import xarray_subset_grid.accessor # noqa: F401 -- register accessor
6+
7+
8+
def test_accessor_warns_when_no_grid_recognized():
9+
ds = xr.Dataset()
10+
with pytest.warns(UserWarning, match="no grid type"):
11+
accessor = ds.xsg
12+
assert accessor.grid is None
13+
14+
15+
def test_subset_polygon_and_bbox_return_none_without_grid():
16+
ds = xr.Dataset()
17+
poly = np.array(
18+
[
19+
[-72.0, 41.0],
20+
[-70.0, 41.0],
21+
[-71.0, 39.0],
22+
[-72.0, 41.0],
23+
]
24+
)
25+
with pytest.warns(UserWarning, match="no grid type"):
26+
assert ds.xsg.subset_polygon(poly) is None
27+
assert ds.xsg.subset_bbox((-72, 39, -70, 41)) is None
28+
29+
30+
def test_subset_vars_raises_without_grid():
31+
ds = xr.Dataset({"a": (("x",), [1, 2, 3])})
32+
with pytest.warns(UserWarning, match="no grid type"):
33+
with pytest.raises(ValueError, match="subset_vars requires a recognized grid"):
34+
ds.xsg.subset_vars(["a"])
35+
36+
37+
def test_has_vertical_levels_false_without_grid():
38+
ds = xr.Dataset()
39+
with pytest.warns(UserWarning, match="no grid type"):
40+
assert ds.xsg.has_vertical_levels is False

tests/test_grids/test_sgrid.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010
# open dataset as zarr object using fsspec reference file system and xarray
1111
try:
1212
import fsspec
13-
import zarr
14-
15-
zarr__version__ = int(zarr.__version__.split(".")[0])
1613
except ImportError:
1714
fsspec = None
1815

@@ -51,9 +48,6 @@ def test_grid_topology_location_parse():
5148
}
5249

5350

54-
@pytest.mark.skipif(
55-
zarr__version__ >= 3, reason="zarr3.0.8 doesn't support FSpec AWS (it might soon)"
56-
)
5751
@pytest.mark.online
5852
def test_polygon_subset():
5953
"""

tests/test_utils.py

Lines changed: 83 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import os
2+
from datetime import datetime
23

4+
import cftime
35
import numpy as np
46
import pytest
7+
import xarray as xr
58

6-
from xarray_subset_grid.utils import (
7-
normalize_bbox_x_coords,
8-
normalize_polygon_x_coords,
9-
ray_tracing_numpy,
10-
)
9+
import xarray_subset_grid.utils as xsg_utils
1110

1211
# normalize_polygon_x_coords tests.
1312

@@ -69,12 +68,7 @@ def get_test_file_dir():
6968
)
7069
def test_normalize_x_coords(lons, poly, norm_poly):
7170
lons = np.array(lons)
72-
normalized_polygon = normalize_polygon_x_coords(lons, np.array(poly))
73-
print(f"{lons=}")
74-
print(f"{poly=}")
75-
print(f"{norm_poly=}")
76-
print(f"{normalized_polygon=}")
77-
71+
normalized_polygon = xsg_utils.normalize_polygon_x_coords(lons, np.array(poly))
7872
assert np.allclose(normalized_polygon, norm_poly)
7973

8074

@@ -97,7 +91,7 @@ def test_normalize_x_coords(lons, poly, norm_poly):
9791
)
9892
def test_normalize_x_coords_bbox(lons, bbox, norm_bbox):
9993
lons = np.array(lons)
100-
normalized_polygon = normalize_bbox_x_coords(lons, bbox)
94+
normalized_polygon = xsg_utils.normalize_bbox_x_coords(lons, bbox)
10195
assert np.allclose(normalized_polygon, norm_bbox)
10296

10397

@@ -122,6 +116,82 @@ def test_ray_tracing_numpy():
122116
]
123117
)
124118

125-
result = ray_tracing_numpy(points[:, 0], points[:, 1], poly)
119+
result = xsg_utils.ray_tracing_numpy(points[:, 0], points[:, 1], poly)
126120

127121
assert np.array_equal(result, [False, True, False])
122+
123+
124+
@pytest.mark.parametrize(
125+
"num, unit",
126+
[
127+
(512, "bytes"),
128+
(2048, "KB"),
129+
(3 * 1024**2, "MB"),
130+
],
131+
)
132+
def test_format_bytes(num, unit):
133+
assert unit in xsg_utils.format_bytes(num)
134+
135+
136+
def test_asdatetime_none():
137+
assert xsg_utils.asdatetime(None) is None
138+
139+
140+
def test_asdatetime_datetime_passthrough():
141+
dt = datetime(2020, 6, 15, 12, 30, 0)
142+
assert xsg_utils.asdatetime(dt) is dt
143+
144+
145+
def test_asdatetime_cftime_passthrough():
146+
dt = cftime.datetime(2020, 6, 15, 12)
147+
assert xsg_utils.asdatetime(dt) is dt
148+
149+
150+
def test_asdatetime_parse_string():
151+
dt = xsg_utils.asdatetime("2020-06-15T12:30:00")
152+
assert dt.year == 2020 and dt.month == 6 and dt.day == 15
153+
154+
155+
def test_compute_2d_subset_mask_all_inside():
156+
ny, nx = 5, 5
157+
lat = np.linspace(40.0, 44.0, ny)
158+
lon = np.linspace(-74.0, -70.0, nx)
159+
lat2d, lon2d = np.meshgrid(lat, lon, indexing="ij")
160+
lat_da = xr.DataArray(lat2d, dims=("y", "x"))
161+
lon_da = xr.DataArray(lon2d, dims=("y", "x"))
162+
poly = np.array([(-75.0, 39.0), (-69.0, 39.0), (-69.0, 45.0), (-75.0, 45.0)])
163+
mask = xsg_utils.compute_2d_subset_mask(lat_da, lon_da, poly)
164+
assert mask.dims == ("y", "x")
165+
assert mask.all()
166+
167+
168+
def test_compute_2d_subset_mask_partial():
169+
# Include explicit lon/lat nodes inside the polygon so the mask can be checked at a
170+
# non-boundary grid point (ray-casting is ambiguous on polygon edges).
171+
lat = np.array([40.0, 40.5, 41.0, 43.0, 46.0])
172+
lon = np.array([-74.5, -73.75, -73.0, -71.0, -68.0])
173+
lat2d, lon2d = np.meshgrid(lat, lon, indexing="ij")
174+
lat_da = xr.DataArray(lat2d, dims=("y", "x"))
175+
lon_da = xr.DataArray(lon2d, dims=("y", "x"))
176+
poly = np.array([(-74.5, 40.0), (-73.0, 40.0), (-73.0, 41.0), (-74.5, 41.0)])
177+
mask = xsg_utils.compute_2d_subset_mask(lat_da, lon_da, poly)
178+
assert mask.dims == ("y", "x")
179+
assert mask.any()
180+
assert not mask.all()
181+
i_inside = int(np.where(lat == 40.5)[0][0])
182+
j_inside = int(np.where(lon == -73.75)[0][0])
183+
assert mask.values[i_inside, j_inside]
184+
assert not mask.values[-1, -1]
185+
186+
187+
def test_compute_2d_subset_mask_list_polygon_coerced():
188+
"""list/tuple polygon vertices are accepted (coerced via normalize_polygon_x_coords)."""
189+
ny, nx = 5, 5
190+
lat = np.linspace(40.0, 44.0, ny)
191+
lon = np.linspace(-74.0, -70.0, nx)
192+
lat2d, lon2d = np.meshgrid(lat, lon, indexing="ij")
193+
lat_da = xr.DataArray(lat2d, dims=("y", "x"))
194+
lon_da = xr.DataArray(lon2d, dims=("y", "x"))
195+
poly = [(-75.0, 39.0), (-69.0, 39.0), (-69.0, 45.0), (-75.0, 45.0)]
196+
mask = xsg_utils.compute_2d_subset_mask(lat_da, lon_da, poly)
197+
assert mask.all()

xarray_subset_grid/accessor.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,12 @@ def subset_vars(self, vars: list[str]) -> xr.Dataset:
112112
:param vars: The variables to keep
113113
:return: The subsetted dataset
114114
"""
115-
if self._grid:
116-
return self._grid.subset_vars(self._ds, vars)
117-
return self._ds
115+
if not self._grid:
116+
raise ValueError(
117+
"subset_vars requires a recognized grid; this dataset has no grid type "
118+
"that xarray-subset-grid can use."
119+
)
120+
return self._grid.subset_vars(self._ds, vars)
118121

119122
@property
120123
def has_vertical_levels(self) -> bool:

xarray_subset_grid/utils.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import warnings
21
from datetime import datetime
32

43
import cf_xarray # noqa
@@ -19,12 +18,14 @@ def normalize_polygon_x_coords(x, poly):
1918
If the x coords are between 0 and 180 (i.e. both will work), the polygon
2019
is not changed.
2120
22-
NOTE: polygon is normalized in place!
21+
NOTE: ``poly`` is normalized in place when it is already an ndarray;
22+
a copy is made when ``poly`` is a sequence.
2323
2424
Args:
2525
x (np.array): x-coordinates of the vertices
2626
poly (np.array): polygon vertices
2727
"""
28+
poly = np.asarray(poly)
2829
x_min, x_max = x.min(), x.max()
2930

3031
poly_x = poly[:, 0]
@@ -93,20 +94,6 @@ def ray_tracing_numpy(x, y, poly):
9394
return inside
9495

9596

96-
# This is defined in ugrid.py
97-
# this placeholder for backwards compatibility for a brief period
98-
def assign_ugrid_topology(*args, **kwargs):
99-
warnings.warn(
100-
DeprecationWarning,
101-
"The function `assign_grid_topology` has been moved to the "
102-
"`grids.ugrid` module. It will not be able to be called from "
103-
"the utils `module` in the future.",
104-
)
105-
from .grids.ugrid import assign_ugrid_topology
106-
107-
return assign_ugrid_topology(*args, **kwargs)
108-
109-
11097
def format_bytes(num):
11198
"""This function will convert bytes to MB....
11299

0 commit comments

Comments
 (0)