Skip to content

Commit 4256765

Browse files
committed
Add plot() accessors with helpful defaults
Enhance DataArray .xrs.plot() with dask compute, default figsize, and equal aspect ratio. Add Dataset .xrs.plot() that grids all 2D variables into subplots with GeoTIFF colormap support.
1 parent bd3fb9a commit 4256765

File tree

1 file changed

+131
-12
lines changed

1 file changed

+131
-12
lines changed

xrspatial/accessor.py

Lines changed: 131 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,54 @@ def __init__(self, obj):
2424
# ---- Plot ----
2525

2626
def plot(self, **kwargs):
27-
"""Plot the DataArray, using an embedded TIFF colormap if present.
27+
"""Plot the DataArray with helpful defaults.
2828
29-
For palette/indexed-color GeoTIFFs (read via ``open_geotiff``),
30-
the TIFF's color table is applied automatically with correct
31-
normalization. For all other DataArrays, falls through to the
32-
standard ``da.plot()``.
29+
Computes dask arrays, applies embedded GeoTIFF colormaps,
30+
and sets equal aspect ratio.
3331
34-
Usage::
32+
Parameters
33+
----------
34+
**kwargs
35+
Passed to ``da.plot()``. Common extras: ``cmap``,
36+
``figsize``, ``ax``, ``add_colorbar``.
3537
36-
da = open_geotiff('landcover.tif')
37-
da.xrs.plot() # palette colors used automatically
38+
Returns
39+
-------
40+
matplotlib artist (from ``da.plot()``)
3841
"""
42+
import matplotlib.pyplot as plt
3943
import numpy as np
40-
cmap = self._obj.attrs.get('cmap')
44+
45+
da = self._obj
46+
47+
# Materialise dask arrays so matplotlib can render them.
48+
try:
49+
da = da.compute()
50+
except (AttributeError, TypeError):
51+
pass
52+
53+
# Use embedded GeoTIFF colormap when present.
54+
cmap = da.attrs.get('cmap')
4155
if cmap is not None and 'cmap' not in kwargs:
4256
from matplotlib.colors import BoundaryNorm
4357
n_colors = len(cmap.colors)
4458
boundaries = np.arange(n_colors + 1) - 0.5
45-
norm = BoundaryNorm(boundaries, n_colors)
4659
kwargs.setdefault('cmap', cmap)
47-
kwargs.setdefault('norm', norm)
60+
kwargs.setdefault('norm', BoundaryNorm(boundaries, n_colors))
4861
kwargs.setdefault('add_colorbar', True)
49-
return self._obj.plot(**kwargs)
62+
63+
# Create a figure with sensible size if none provided.
64+
if 'ax' not in kwargs:
65+
fig, ax = plt.subplots(
66+
figsize=kwargs.pop('figsize', (8, 6)),
67+
)
68+
kwargs['ax'] = ax
69+
70+
result = da.plot(**kwargs)
71+
72+
kwargs['ax'].set_aspect('equal')
73+
plt.tight_layout()
74+
return result
5075

5176
# ---- Surface ----
5277

@@ -522,6 +547,86 @@ class XrsSpatialDatasetAccessor:
522547
def __init__(self, obj):
523548
self._obj = obj
524549

550+
# ---- Plot ----
551+
552+
def plot(self, vars=None, cols=3, **kwargs):
553+
"""Plot 2D data variables as a grid of subplots.
554+
555+
Parameters
556+
----------
557+
vars : list of str, optional
558+
Variable names to plot. If None, plots all 2D variables.
559+
cols : int, default 3
560+
Maximum number of columns in the subplot grid.
561+
**kwargs
562+
Passed to each subplot's ``da.plot()``. Common extras:
563+
``cmap``, ``figsize``, ``add_colorbar``.
564+
565+
Returns
566+
-------
567+
numpy.ndarray of matplotlib.axes.Axes
568+
"""
569+
import math
570+
import matplotlib.pyplot as plt
571+
import numpy as np
572+
from matplotlib.colors import BoundaryNorm
573+
574+
ds = self._obj
575+
576+
# Collect 2D variables to plot.
577+
if vars is not None:
578+
names = [v for v in vars if v in ds.data_vars]
579+
else:
580+
names = [
581+
v for v in ds.data_vars
582+
if ds[v].ndim == 2
583+
]
584+
585+
if not names:
586+
raise ValueError("No 2D variables found to plot")
587+
588+
n = len(names)
589+
ncols = min(n, cols)
590+
nrows = math.ceil(n / ncols)
591+
592+
fig, axes = plt.subplots(
593+
nrows, ncols,
594+
figsize=kwargs.pop('figsize', (5 * ncols, 4 * nrows)),
595+
squeeze=False,
596+
)
597+
598+
for idx, name in enumerate(names):
599+
ax = axes[idx // ncols][idx % ncols]
600+
da = ds[name]
601+
602+
# Materialise dask arrays so matplotlib can render them.
603+
try:
604+
da = da.compute()
605+
except (AttributeError, TypeError):
606+
pass
607+
608+
# Use embedded GeoTIFF colormap when present.
609+
cmap = da.attrs.get('cmap')
610+
kw = dict(kwargs)
611+
if cmap is not None and 'cmap' not in kw:
612+
n_colors = len(cmap.colors)
613+
boundaries = np.arange(n_colors + 1) - 0.5
614+
kw.setdefault('cmap', cmap)
615+
kw.setdefault('norm', BoundaryNorm(boundaries, n_colors))
616+
kw.setdefault('add_colorbar', True)
617+
618+
kw.setdefault('ax', ax)
619+
da.plot(**kw)
620+
ax.set_title(name)
621+
ax.set_aspect('equal')
622+
623+
# Hide unused axes.
624+
for idx in range(n, nrows * ncols):
625+
axes[idx // ncols][idx % ncols].set_visible(False)
626+
627+
plt.tight_layout()
628+
return axes
629+
525630
# ---- Surface ----
526631

527632
def slope(self, **kwargs):
@@ -918,3 +1023,17 @@ def open_geotiff(self, source, **kwargs):
9181023
y_min, y_max, x_min, x_max)
9191024
kwargs.pop('window', None)
9201025
return open_geotiff(source, window=window, **kwargs)
1026+
1027+
# ---- Chunking ----
1028+
1029+
def rechunk_no_shuffle(self, **kwargs):
1030+
from .utils import rechunk_no_shuffle
1031+
return rechunk_no_shuffle(self._obj, **kwargs)
1032+
1033+
def fused_overlap(self, *stages, **kwargs):
1034+
from .utils import fused_overlap
1035+
return fused_overlap(self._obj, *stages, **kwargs)
1036+
1037+
def multi_overlap(self, func, n_outputs, **kwargs):
1038+
from .utils import multi_overlap
1039+
return multi_overlap(self._obj, func, n_outputs, **kwargs)

0 commit comments

Comments
 (0)