@@ -24,29 +24,54 @@ def __init__(self, obj):
2424 # ---- Plot ----
2525
2626 def plot (self , ** kwargs ):
27- """Plot the DataArray, using an embedded TIFF colormap if present .
27+ """Plot the DataArray with helpful defaults .
2828
29- For palette/indexed-color GeoTIFFs (read via ``open_geotiff``),
30- the TIFF's color table is applied automatically with correct
31- normalization. For all other DataArrays, falls through to the
32- standard ``da.plot()``.
29+ Computes dask arrays, applies embedded GeoTIFF colormaps,
30+ and sets equal aspect ratio.
3331
34- Usage::
32+ Parameters
33+ ----------
34+ **kwargs
35+ Passed to ``da.plot()``. Common extras: ``cmap``,
36+ ``figsize``, ``ax``, ``add_colorbar``.
3537
36- da = open_geotiff('landcover.tif')
37- da.xrs.plot() # palette colors used automatically
38+ Returns
39+ -------
40+ matplotlib artist (from ``da.plot()``)
3841 """
42+ import matplotlib .pyplot as plt
3943 import numpy as np
40- cmap = self ._obj .attrs .get ('cmap' )
44+
45+ da = self ._obj
46+
47+ # Materialise dask arrays so matplotlib can render them.
48+ try :
49+ da = da .compute ()
50+ except (AttributeError , TypeError ):
51+ pass
52+
53+ # Use embedded GeoTIFF colormap when present.
54+ cmap = da .attrs .get ('cmap' )
4155 if cmap is not None and 'cmap' not in kwargs :
4256 from matplotlib .colors import BoundaryNorm
4357 n_colors = len (cmap .colors )
4458 boundaries = np .arange (n_colors + 1 ) - 0.5
45- norm = BoundaryNorm (boundaries , n_colors )
4659 kwargs .setdefault ('cmap' , cmap )
47- kwargs .setdefault ('norm' , norm )
60+ kwargs .setdefault ('norm' , BoundaryNorm ( boundaries , n_colors ) )
4861 kwargs .setdefault ('add_colorbar' , True )
49- return self ._obj .plot (** kwargs )
62+
63+ # Create a figure with sensible size if none provided.
64+ if 'ax' not in kwargs :
65+ fig , ax = plt .subplots (
66+ figsize = kwargs .pop ('figsize' , (8 , 6 )),
67+ )
68+ kwargs ['ax' ] = ax
69+
70+ result = da .plot (** kwargs )
71+
72+ kwargs ['ax' ].set_aspect ('equal' )
73+ plt .tight_layout ()
74+ return result
5075
5176 # ---- Surface ----
5277
@@ -522,6 +547,86 @@ class XrsSpatialDatasetAccessor:
522547 def __init__ (self , obj ):
523548 self ._obj = obj
524549
550+ # ---- Plot ----
551+
552+ def plot (self , vars = None , cols = 3 , ** kwargs ):
553+ """Plot 2D data variables as a grid of subplots.
554+
555+ Parameters
556+ ----------
557+ vars : list of str, optional
558+ Variable names to plot. If None, plots all 2D variables.
559+ cols : int, default 3
560+ Maximum number of columns in the subplot grid.
561+ **kwargs
562+ Passed to each subplot's ``da.plot()``. Common extras:
563+ ``cmap``, ``figsize``, ``add_colorbar``.
564+
565+ Returns
566+ -------
567+ numpy.ndarray of matplotlib.axes.Axes
568+ """
569+ import math
570+ import matplotlib .pyplot as plt
571+ import numpy as np
572+ from matplotlib .colors import BoundaryNorm
573+
574+ ds = self ._obj
575+
576+ # Collect 2D variables to plot.
577+ if vars is not None :
578+ names = [v for v in vars if v in ds .data_vars ]
579+ else :
580+ names = [
581+ v for v in ds .data_vars
582+ if ds [v ].ndim == 2
583+ ]
584+
585+ if not names :
586+ raise ValueError ("No 2D variables found to plot" )
587+
588+ n = len (names )
589+ ncols = min (n , cols )
590+ nrows = math .ceil (n / ncols )
591+
592+ fig , axes = plt .subplots (
593+ nrows , ncols ,
594+ figsize = kwargs .pop ('figsize' , (5 * ncols , 4 * nrows )),
595+ squeeze = False ,
596+ )
597+
598+ for idx , name in enumerate (names ):
599+ ax = axes [idx // ncols ][idx % ncols ]
600+ da = ds [name ]
601+
602+ # Materialise dask arrays so matplotlib can render them.
603+ try :
604+ da = da .compute ()
605+ except (AttributeError , TypeError ):
606+ pass
607+
608+ # Use embedded GeoTIFF colormap when present.
609+ cmap = da .attrs .get ('cmap' )
610+ kw = dict (kwargs )
611+ if cmap is not None and 'cmap' not in kw :
612+ n_colors = len (cmap .colors )
613+ boundaries = np .arange (n_colors + 1 ) - 0.5
614+ kw .setdefault ('cmap' , cmap )
615+ kw .setdefault ('norm' , BoundaryNorm (boundaries , n_colors ))
616+ kw .setdefault ('add_colorbar' , True )
617+
618+ kw .setdefault ('ax' , ax )
619+ da .plot (** kw )
620+ ax .set_title (name )
621+ ax .set_aspect ('equal' )
622+
623+ # Hide unused axes.
624+ for idx in range (n , nrows * ncols ):
625+ axes [idx // ncols ][idx % ncols ].set_visible (False )
626+
627+ plt .tight_layout ()
628+ return axes
629+
525630 # ---- Surface ----
526631
527632 def slope (self , ** kwargs ):
@@ -918,3 +1023,17 @@ def open_geotiff(self, source, **kwargs):
9181023 y_min , y_max , x_min , x_max )
9191024 kwargs .pop ('window' , None )
9201025 return open_geotiff (source , window = window , ** kwargs )
1026+
1027+ # ---- Chunking ----
1028+
1029+ def rechunk_no_shuffle (self , ** kwargs ):
1030+ from .utils import rechunk_no_shuffle
1031+ return rechunk_no_shuffle (self ._obj , ** kwargs )
1032+
1033+ def fused_overlap (self , * stages , ** kwargs ):
1034+ from .utils import fused_overlap
1035+ return fused_overlap (self ._obj , * stages , ** kwargs )
1036+
1037+ def multi_overlap (self , func , n_outputs , ** kwargs ):
1038+ from .utils import multi_overlap
1039+ return multi_overlap (self ._obj , func , n_outputs , ** kwargs )
0 commit comments