Skip to content

Commit cfaf93d

Browse files
committed
Add WKT/PROJ CRS support via pyproj
CRS can now be specified as WKT strings, PROJ strings, or EPSG integers. pyproj (lazy import) resolves between them: Read side: - crs_wkt attr is populated by resolving EPSG -> WKT via pyproj - Falls back gracefully if pyproj is not installed (EPSG still works) Write side: - crs= parameter on write_geotiff accepts int (EPSG), WKT string, or PROJ string. String inputs are resolved to EPSG via pyproj.CRS.from_user_input().to_epsg(). - DataArray with crs_wkt attr (no integer crs) is also handled: the WKT is resolved to EPSG for the GeoKeyDirectory. This means files with user-defined CRS no longer lose their spatial reference when round-tripped, as long as pyproj can resolve the WKT/PROJ string to an EPSG code. 5 new tests: WKT from EPSG, write with WKT string, write with PROJ string, crs_wkt attr round-trip, and no-CRS baseline.
1 parent 6601bcf commit cfaf93d

File tree

3 files changed

+125
-5
lines changed

3 files changed

+125
-5
lines changed

xrspatial/geotiff/__init__.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,20 @@
2323
__all__ = ['read_geotiff', 'write_geotiff', 'open_cog', 'read_geotiff_dask']
2424

2525

26+
def _wkt_to_epsg(wkt_or_proj: str) -> int | None:
27+
"""Try to extract an EPSG code from a WKT or PROJ string.
28+
29+
Returns None if pyproj is not installed or the string can't be parsed.
30+
"""
31+
try:
32+
from pyproj import CRS
33+
crs = CRS.from_user_input(wkt_or_proj)
34+
epsg = crs.to_epsg()
35+
return epsg
36+
except Exception:
37+
return None
38+
39+
2640
def _geo_to_coords(geo_info, height: int, width: int) -> dict:
2741
"""Build y/x coordinate arrays from GeoInfo.
2842
@@ -132,6 +146,8 @@ def read_geotiff(source: str, *, window=None,
132146
attrs = {}
133147
if geo_info.crs_epsg is not None:
134148
attrs['crs'] = geo_info.crs_epsg
149+
if geo_info.crs_wkt is not None:
150+
attrs['crs_wkt'] = geo_info.crs_wkt
135151
if geo_info.raster_type == RASTER_PIXEL_IS_POINT:
136152
attrs['raster_type'] = 'point'
137153

@@ -214,7 +230,7 @@ def read_geotiff(source: str, *, window=None,
214230

215231

216232
def write_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
217-
crs: int | None = None,
233+
crs: int | str | None = None,
218234
nodata=None,
219235
compression: str = 'deflate',
220236
tiled: bool = True,
@@ -231,8 +247,10 @@ def write_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
231247
2D raster data.
232248
path : str
233249
Output file path.
234-
crs : int or None
235-
EPSG code. If None and data is a DataArray, tries to read from attrs.
250+
crs : int, str, or None
251+
EPSG code (int), WKT string, or PROJ string. If None and data
252+
is a DataArray, tries to read from attrs ('crs' for EPSG,
253+
'crs_wkt' for WKT).
236254
nodata : float, int, or None
237255
NoData value.
238256
compression : str
@@ -252,18 +270,29 @@ def write_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
252270
'min', 'max', 'median', 'mode', or 'cubic'.
253271
"""
254272
geo_transform = None
255-
epsg = crs
273+
epsg = None
256274
raster_type = RASTER_PIXEL_IS_AREA
257275
x_res = None
258276
y_res = None
259277
res_unit = None
260278

279+
# Resolve crs argument: can be int (EPSG) or str (WKT/PROJ)
280+
if isinstance(crs, int):
281+
epsg = crs
282+
elif isinstance(crs, str):
283+
epsg = _wkt_to_epsg(crs) # try to extract EPSG from WKT/PROJ
284+
261285
if isinstance(data, xr.DataArray):
262286
arr = data.values
263287
if geo_transform is None:
264288
geo_transform = _coords_to_transform(data)
265-
if epsg is None:
289+
if epsg is None and crs is None:
266290
epsg = data.attrs.get('crs')
291+
if epsg is None:
292+
# Try resolving EPSG from a WKT string in attrs
293+
wkt = data.attrs.get('crs_wkt')
294+
if isinstance(wkt, str):
295+
epsg = _wkt_to_epsg(wkt)
267296
if nodata is None:
268297
nodata = data.attrs.get('nodata')
269298
if data.attrs.get('raster_type') == 'point':

xrspatial/geotiff/_geotags.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,24 @@ class GeoInfo:
107107
vertical_datum: int | None = None
108108
vertical_units: str | None = None
109109
vertical_units_code: int | None = None
110+
# WKT CRS string (resolved from EPSG via pyproj, or provided by caller)
111+
crs_wkt: str | None = None
110112
# Raw geokeys dict for anything else
111113
geokeys: dict[int, int | float | str] = field(default_factory=dict)
112114

113115

116+
def _epsg_to_wkt(epsg: int) -> str | None:
117+
"""Resolve an EPSG code to a WKT string using pyproj.
118+
119+
Returns None if pyproj is not installed or the code is unknown.
120+
"""
121+
try:
122+
from pyproj import CRS
123+
return CRS.from_epsg(epsg).to_wkt()
124+
except Exception:
125+
return None
126+
127+
114128
def _parse_geokeys(ifd: IFD, data: bytes | memoryview,
115129
byte_order: str) -> dict[int, int | float | str]:
116130
"""Parse the GeoKeyDirectory and resolve values from param tags.
@@ -385,6 +399,11 @@ def extract_geo_info(ifd: IFD, data: bytes | memoryview,
385399
b = raw_cmap[2 * n_colors + i] / 65535.0
386400
colormap.append((r, g, b, 1.0))
387401

402+
# Resolve EPSG -> WKT via pyproj if available
403+
crs_wkt = None
404+
if epsg is not None:
405+
crs_wkt = _epsg_to_wkt(epsg)
406+
388407
return GeoInfo(
389408
transform=transform,
390409
crs_epsg=epsg,
@@ -410,6 +429,7 @@ def extract_geo_info(ifd: IFD, data: bytes | memoryview,
410429
vertical_datum=vert_datum,
411430
vertical_units=vert_units_name,
412431
vertical_units_code=vert_units_code,
432+
crs_wkt=crs_wkt,
413433
geokeys=geokeys,
414434
)
415435

xrspatial/geotiff/tests/test_features.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,77 @@ def test_angular_unit_lookup(self):
337337
assert LINEAR_UNITS[9002] == 'foot'
338338
assert LINEAR_UNITS[9003] == 'us_survey_foot'
339339

340+
def test_crs_wkt_from_epsg(self, tmp_path):
341+
"""crs_wkt is resolved from EPSG via pyproj."""
342+
from xrspatial.geotiff._geotags import GeoTransform
343+
arr = np.ones((4, 4), dtype=np.float32)
344+
gt = GeoTransform(-120.0, 45.0, 0.001, -0.001)
345+
path = str(tmp_path / 'wkt.tif')
346+
write(arr, path, compression='none', tiled=False,
347+
geo_transform=gt, crs_epsg=4326)
348+
349+
da = read_geotiff(path)
350+
assert 'crs_wkt' in da.attrs
351+
wkt = da.attrs['crs_wkt']
352+
assert 'WGS 84' in wkt or '4326' in wkt
353+
354+
def test_write_with_wkt_string(self, tmp_path):
355+
"""crs= accepts a WKT string and resolves to EPSG."""
356+
arr = np.ones((4, 4), dtype=np.float32)
357+
wkt = ('GEOGCRS["WGS 84",DATUM["World Geodetic System 1984",'
358+
'ELLIPSOID["WGS 84",6378137,298.257223563]],'
359+
'CS[ellipsoidal,2],'
360+
'AXIS["geodetic latitude (Lat)",north],'
361+
'AXIS["geodetic longitude (Lon)",east],'
362+
'UNIT["degree",0.0174532925199433],'
363+
'ID["EPSG",4326]]')
364+
path = str(tmp_path / 'wkt_in.tif')
365+
write_geotiff(arr, path, crs=wkt, compression='none')
366+
367+
da = read_geotiff(path)
368+
assert da.attrs['crs'] == 4326
369+
370+
def test_write_with_proj_string(self, tmp_path):
371+
"""crs= accepts a PROJ string."""
372+
arr = np.ones((4, 4), dtype=np.float32)
373+
path = str(tmp_path / 'proj_in.tif')
374+
write_geotiff(arr, path, crs='+proj=utm +zone=18 +datum=NAD83',
375+
compression='none')
376+
377+
da = read_geotiff(path)
378+
# pyproj should resolve this to EPSG:26918
379+
assert da.attrs.get('crs') is not None
380+
381+
def test_crs_wkt_attr_round_trip(self, tmp_path):
382+
"""DataArray with crs_wkt attr (no int crs) round-trips."""
383+
wkt = ('GEOGCRS["WGS 84",DATUM["World Geodetic System 1984",'
384+
'ELLIPSOID["WGS 84",6378137,298.257223563]],'
385+
'CS[ellipsoidal,2],'
386+
'AXIS["geodetic latitude (Lat)",north],'
387+
'AXIS["geodetic longitude (Lon)",east],'
388+
'UNIT["degree",0.0174532925199433],'
389+
'ID["EPSG",4326]]')
390+
y = np.linspace(45.0, 44.0, 4)
391+
x = np.linspace(-120.0, -119.0, 4)
392+
da = xr.DataArray(np.ones((4, 4), dtype=np.float32),
393+
dims=['y', 'x'], coords={'y': y, 'x': x},
394+
attrs={'crs_wkt': wkt})
395+
path = str(tmp_path / 'wkt_rt.tif')
396+
write_geotiff(da, path, compression='none')
397+
398+
result = read_geotiff(path)
399+
assert result.attrs['crs'] == 4326
400+
assert 'crs_wkt' in result.attrs
401+
402+
def test_no_crs_no_wkt(self, tmp_path):
403+
"""File without CRS has no crs_wkt attr."""
404+
arr = np.ones((4, 4), dtype=np.float32)
405+
path = str(tmp_path / 'no_wkt.tif')
406+
write(arr, path, compression='none', tiled=False)
407+
408+
da = read_geotiff(path)
409+
assert 'crs_wkt' not in da.attrs
410+
340411

341412
# -----------------------------------------------------------------------
342413
# Resolution / DPI tags

0 commit comments

Comments
 (0)