Skip to content

Commit a6ec8d8

Browse files
committed
Add tests for .xrs.to_geotiff and .xrs.open_geotiff accessors (#1047)
1 parent 19f55d8 commit a6ec8d8

File tree

1 file changed

+166
-0
lines changed

1 file changed

+166
-0
lines changed
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
"""Tests for .xrs.to_geotiff() and .xrs.open_geotiff() accessor methods."""
2+
from __future__ import annotations
3+
4+
import numpy as np
5+
import pytest
6+
import xarray as xr
7+
8+
import xrspatial # noqa: F401 -- registers .xrs accessor
9+
from xrspatial.geotiff import open_geotiff, to_geotiff
10+
11+
12+
# ---------------------------------------------------------------------------
13+
# Helpers
14+
# ---------------------------------------------------------------------------
15+
16+
def _make_da(height=8, width=10, crs=4326, name='elevation'):
17+
"""Build a georeferenced DataArray for testing."""
18+
arr = np.arange(height * width, dtype=np.float32).reshape(height, width)
19+
y = np.linspace(45.0, 44.0, height)
20+
x = np.linspace(-120.0, -119.0, width)
21+
return xr.DataArray(
22+
arr, dims=['y', 'x'],
23+
coords={'y': y, 'x': x},
24+
name=name,
25+
attrs={'crs': crs},
26+
)
27+
28+
29+
def _make_ds(height=8, width=10, crs=4326):
30+
"""Build a georeferenced Dataset for testing."""
31+
da = _make_da(height, width, crs, name='elevation')
32+
return xr.Dataset({'elevation': da})
33+
34+
35+
# ---------------------------------------------------------------------------
36+
# DataArray.xrs.to_geotiff
37+
# ---------------------------------------------------------------------------
38+
39+
class TestDataArrayToGeotiff:
40+
def test_round_trip(self, tmp_path):
41+
da = _make_da()
42+
path = str(tmp_path / 'test_1047_da_roundtrip.tif')
43+
da.xrs.to_geotiff(path, compression='none')
44+
45+
result = open_geotiff(path)
46+
np.testing.assert_array_equal(result.values, da.values)
47+
48+
def test_with_kwargs(self, tmp_path):
49+
da = _make_da()
50+
path = str(tmp_path / 'test_1047_da_kwargs.tif')
51+
da.xrs.to_geotiff(path, compression='deflate', tiled=True,
52+
tile_size=256)
53+
54+
result = open_geotiff(path)
55+
np.testing.assert_array_equal(result.values, da.values)
56+
57+
def test_preserves_crs(self, tmp_path):
58+
da = _make_da(crs=32610)
59+
path = str(tmp_path / 'test_1047_da_crs.tif')
60+
da.xrs.to_geotiff(path, compression='none')
61+
62+
result = open_geotiff(path)
63+
assert result.attrs.get('crs') == 32610
64+
65+
66+
# ---------------------------------------------------------------------------
67+
# Dataset.xrs.to_geotiff
68+
# ---------------------------------------------------------------------------
69+
70+
class TestDatasetToGeotiff:
71+
def test_round_trip(self, tmp_path):
72+
ds = _make_ds()
73+
path = str(tmp_path / 'test_1047_ds_roundtrip.tif')
74+
ds.xrs.to_geotiff(path, compression='none')
75+
76+
result = open_geotiff(path)
77+
np.testing.assert_array_equal(result.values, ds['elevation'].values)
78+
79+
def test_explicit_var(self, tmp_path):
80+
ds = _make_ds()
81+
ds['slope'] = ds['elevation'] * 2
82+
path = str(tmp_path / 'test_1047_ds_var.tif')
83+
ds.xrs.to_geotiff(path, var='slope', compression='none')
84+
85+
result = open_geotiff(path)
86+
np.testing.assert_array_equal(result.values, ds['slope'].values)
87+
88+
def test_no_yx_raises(self, tmp_path):
89+
ds = xr.Dataset({'vals': xr.DataArray(np.zeros(5), dims=['z'])})
90+
with pytest.raises(ValueError, match="no variable with 'y' and 'x'"):
91+
ds.xrs.to_geotiff(str(tmp_path / 'bad.tif'))
92+
93+
94+
# ---------------------------------------------------------------------------
95+
# Dataset.xrs.open_geotiff (spatially-windowed read)
96+
# ---------------------------------------------------------------------------
97+
98+
class TestDatasetOpenGeotiff:
99+
def test_windowed_read(self, tmp_path):
100+
"""Reading with a Dataset template should return a spatial subset."""
101+
# Write a 20x20 raster
102+
big = _make_da(height=20, width=20)
103+
big_path = str(tmp_path / 'test_1047_big.tif')
104+
to_geotiff(big, big_path, compression='none')
105+
106+
# Template dataset covers the center region
107+
y_sub = big.coords['y'].values[5:15]
108+
x_sub = big.coords['x'].values[5:15]
109+
template = xr.Dataset({
110+
'dummy': xr.DataArray(
111+
np.zeros((len(y_sub), len(x_sub))),
112+
dims=['y', 'x'],
113+
coords={'y': y_sub, 'x': x_sub},
114+
)
115+
})
116+
117+
result = template.xrs.open_geotiff(big_path)
118+
# Result should be smaller than the full raster
119+
assert result.shape[0] <= 20
120+
assert result.shape[1] <= 20
121+
# And at least as large as the template
122+
assert result.shape[0] >= len(y_sub)
123+
assert result.shape[1] >= len(x_sub)
124+
125+
def test_full_extent_returns_all(self, tmp_path):
126+
"""Template covering full extent should return the whole raster."""
127+
da = _make_da(height=8, width=10)
128+
path = str(tmp_path / 'test_1047_full.tif')
129+
to_geotiff(da, path, compression='none')
130+
131+
template = xr.Dataset({
132+
'dummy': xr.DataArray(
133+
np.zeros_like(da.values),
134+
dims=['y', 'x'],
135+
coords={'y': da.coords['y'].values,
136+
'x': da.coords['x'].values},
137+
)
138+
})
139+
result = template.xrs.open_geotiff(path)
140+
np.testing.assert_array_equal(result.values, da.values)
141+
142+
def test_no_coords_raises(self, tmp_path):
143+
da = _make_da()
144+
path = str(tmp_path / 'test_1047_nocoords.tif')
145+
to_geotiff(da, path, compression='none')
146+
147+
ds = xr.Dataset({'vals': xr.DataArray(np.zeros(5), dims=['z'])})
148+
with pytest.raises(ValueError, match="'y' and 'x' coordinates"):
149+
ds.xrs.open_geotiff(path)
150+
151+
def test_kwargs_forwarded(self, tmp_path):
152+
"""Extra kwargs like name= should be forwarded to open_geotiff."""
153+
da = _make_da(height=8, width=10)
154+
path = str(tmp_path / 'test_1047_kwargs.tif')
155+
to_geotiff(da, path, compression='none')
156+
157+
template = xr.Dataset({
158+
'dummy': xr.DataArray(
159+
np.zeros_like(da.values),
160+
dims=['y', 'x'],
161+
coords={'y': da.coords['y'].values,
162+
'x': da.coords['x'].values},
163+
)
164+
})
165+
result = template.xrs.open_geotiff(path, name='myname')
166+
assert result.name == 'myname'

0 commit comments

Comments
 (0)