Skip to content

Commit 9cf43ab

Browse files
committed
Move palette plot to da.xrs.plot() accessor
The .xrs accessor (registered on all DataArrays by xrspatial) now has a plot() method that checks for an embedded TIFF colormap in attrs. If present, it applies BoundaryNorm with the ListedColormap so that integer class indices map to the correct palette colors. da = read_geotiff('landcover.tif') da.xrs.plot() # palette colors applied automatically For non-palette DataArrays, falls through to the standard da.plot(). The old plot_geotiff() function is kept as a thin wrapper.
1 parent 0161d37 commit 9cf43ab

File tree

3 files changed

+71
-32
lines changed

3 files changed

+71
-32
lines changed

xrspatial/accessor.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,33 @@ class XrsSpatialDataArrayAccessor:
2121
def __init__(self, obj):
2222
self._obj = obj
2323

24+
# ---- Plot ----
25+
26+
def plot(self, **kwargs):
27+
"""Plot the DataArray, using an embedded TIFF colormap if present.
28+
29+
For palette/indexed-color GeoTIFFs (read via ``read_geotiff``),
30+
the TIFF's color table is applied automatically with correct
31+
normalization. For all other DataArrays, falls through to the
32+
standard ``da.plot()``.
33+
34+
Usage::
35+
36+
da = read_geotiff('landcover.tif')
37+
da.xrs.plot() # palette colors used automatically
38+
"""
39+
import numpy as np
40+
cmap = self._obj.attrs.get('cmap')
41+
if cmap is not None and 'cmap' not in kwargs:
42+
from matplotlib.colors import BoundaryNorm
43+
n_colors = len(cmap.colors)
44+
boundaries = np.arange(n_colors + 1) - 0.5
45+
norm = BoundaryNorm(boundaries, n_colors)
46+
kwargs.setdefault('cmap', cmap)
47+
kwargs.setdefault('norm', norm)
48+
kwargs.setdefault('add_colorbar', True)
49+
return self._obj.plot(**kwargs)
50+
2451
# ---- Surface ----
2552

2653
def slope(self, **kwargs):

xrspatial/geotiff/__init__.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
from ._reader import read_to_array
2121
from ._writer import write
2222

23-
__all__ = ['read_geotiff', 'write_geotiff', 'open_cog', 'read_geotiff_dask',
24-
'plot_geotiff']
23+
__all__ = ['read_geotiff', 'write_geotiff', 'open_cog', 'read_geotiff_dask']
2524

2625

2726
def _geo_to_coords(geo_info, height: int, width: int) -> dict:
@@ -381,30 +380,8 @@ def _read():
381380

382381

383382
def plot_geotiff(da: xr.DataArray, **kwargs):
384-
"""Plot a DataArray read from a GeoTIFF, using its embedded colormap if present.
383+
"""Plot a DataArray using its embedded colormap if present.
385384
386-
For palette/indexed-color TIFFs, the TIFF's color table is used
387-
automatically. For other TIFFs, falls through to xarray's default plot.
388-
389-
Parameters
390-
----------
391-
da : xr.DataArray
392-
DataArray from read_geotiff.
393-
**kwargs
394-
Additional keyword arguments passed to da.plot().
395-
396-
Returns
397-
-------
398-
matplotlib artist (from da.plot())
385+
Deprecated: use ``da.xrs.plot()`` instead.
399386
"""
400-
cmap = da.attrs.get('cmap')
401-
if cmap is not None and 'cmap' not in kwargs:
402-
from matplotlib.colors import BoundaryNorm
403-
n_colors = len(cmap.colors)
404-
# Build a BoundaryNorm that maps integer index i to palette[i]
405-
boundaries = np.arange(n_colors + 1) - 0.5
406-
norm = BoundaryNorm(boundaries, n_colors)
407-
kwargs.setdefault('cmap', cmap)
408-
kwargs.setdefault('norm', norm)
409-
kwargs.setdefault('add_colorbar', True)
410-
return da.plot(**kwargs)
387+
return da.xrs.plot(**kwargs)

xrspatial/geotiff/tests/test_features.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -920,11 +920,11 @@ def test_palette_cmap_works_with_plot(self, tmp_path):
920920
assert cmap(0)[:3] == pytest.approx((1.0, 0.0, 0.0), abs=0.01)
921921
assert cmap(1 / 255)[:3] == pytest.approx((0.0, 1.0, 0.0), abs=0.01)
922922

923-
def test_plot_geotiff_with_palette(self, tmp_path):
924-
"""plot_geotiff() uses the embedded colormap."""
923+
def test_xrs_plot_with_palette(self, tmp_path):
924+
"""da.xrs.plot() uses the embedded colormap."""
925925
import matplotlib
926-
matplotlib.use('Agg') # non-interactive backend for tests
927-
from xrspatial.geotiff import plot_geotiff
926+
matplotlib.use('Agg')
927+
import xrspatial.accessor # register .xrs accessor
928928

929929
palette = [
930930
(65535, 0, 0),
@@ -941,7 +941,42 @@ def test_plot_geotiff_with_palette(self, tmp_path):
941941
f.write(tiff_data)
942942

943943
da = read_geotiff(path)
944-
# Should not raise
944+
artist = da.xrs.plot()
945+
assert artist is not None
946+
import matplotlib.pyplot as plt
947+
plt.close('all')
948+
949+
def test_xrs_plot_no_palette(self, tmp_path):
950+
"""da.xrs.plot() falls through to normal plot for non-palette data."""
951+
import matplotlib
952+
matplotlib.use('Agg')
953+
import xrspatial.accessor
954+
955+
arr = np.random.RandomState(42).rand(4, 4).astype(np.float32)
956+
path = str(tmp_path / 'no_palette.tif')
957+
write(arr, path, compression='none', tiled=False)
958+
959+
da = read_geotiff(path)
960+
artist = da.xrs.plot()
961+
assert artist is not None
962+
import matplotlib.pyplot as plt
963+
plt.close('all')
964+
965+
def test_plot_geotiff_deprecated(self, tmp_path):
966+
"""plot_geotiff still works as deprecated wrapper."""
967+
import matplotlib
968+
matplotlib.use('Agg')
969+
import xrspatial.accessor
970+
from xrspatial.geotiff import plot_geotiff
971+
972+
palette = [(65535, 0, 0), (0, 65535, 0)] + [(0, 0, 0)] * 254
973+
pixels = np.array([[0, 1], [1, 0]], dtype=np.uint8)
974+
tiff_data = _make_palette_tiff(2, 2, 8, pixels, palette)
975+
path = str(tmp_path / 'deprecated.tif')
976+
with open(path, 'wb') as f:
977+
f.write(tiff_data)
978+
979+
da = read_geotiff(path)
945980
artist = plot_geotiff(da)
946981
assert artist is not None
947982
import matplotlib.pyplot as plt

0 commit comments

Comments
 (0)