-
Notifications
You must be signed in to change notification settings - Fork 51
Expand file tree
/
Copy pathdataarray.py
More file actions
2024 lines (1684 loc) · 75.7 KB
/
Copy pathdataarray.py
File metadata and controls
2024 lines (1684 loc) · 75.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import warnings
from html import escape
from typing import TYPE_CHECKING, Any, Hashable, Literal, Mapping, Optional
from warnings import warn
import cartopy.crs as ccrs
import numpy as np
import xarray as xr
from cartopy.mpl.geoaxes import GeoAxes
from xarray.core import dtypes
from xarray.core.options import OPTIONS
from xarray.core.utils import UncachedAccessor
import uxarray
from uxarray.core.aggregation import _uxda_grid_aggregate
from uxarray.core.gradient import (
_calculate_edge_face_difference,
_calculate_edge_node_difference,
_compute_gradient,
)
from uxarray.core.utils import _map_dims_to_ugrid
from uxarray.core.zonal import (
_compute_conservative_zonal_mean_bands,
_compute_non_conservative_zonal_mean,
)
from uxarray.cross_sections import UxDataArrayCrossSectionAccessor
from uxarray.formatting_html import array_repr
from uxarray.grid import Grid
from uxarray.grid.dual import construct_dual
from uxarray.grid.validation import _check_duplicate_nodes_indices
from uxarray.io._healpix import get_zoom_from_cells
from uxarray.plot.accessor import UxDataArrayPlotAccessor
from uxarray.remap.accessor import RemapAccessor
from uxarray.subset import DataArraySubsetAccessor
if TYPE_CHECKING:
from uxarray.core.dataset import UxDataset
class UxDataArray(xr.DataArray):
"""Grid informed ``xarray.DataArray`` with an attached ``Grid`` accessor
and grid-specific functionality.
Parameters
----------
uxgrid : uxarray.Grid, optional
The `Grid` object that makes this array aware of the unstructured
grid topology it belongs to.
If `None`, it needs to be an instance of `uxarray.Grid`.
Other Parameters
----------------
*args:
Arguments for the ``xarray.DataArray`` class
**kwargs:
Keyword arguments for the ``xarray.DataArray`` class
Notes
-----
See `xarray.DataArray <https://docs.xarray.dev/en/stable/generated/xarray.DataArray.html>`__
for further information about DataArrays.
Grid-Aware Accessor Methods
---------------------------
The following methods return specialized accessors that preserve grid information:
- ``groupby``: Groups data by dimension/coordinate
- ``groupby_bins``: Groups data by bins
- ``resample``: Resamples timeseries data
- ``rolling``: Rolling window operations
- ``coarsen``: Coarsens data by integer factors
- ``weighted``: Weighted operations
- ``rolling_exp``: Exponentially weighted rolling (requires numbagg)
- ``cumulative``: Cumulative operations
All these methods work identically to xarray but maintain the uxgrid attribute
throughout operations.
"""
# expected instance attributes, required for subclassing with xarray (as of v0.13.0)
__slots__ = ("_uxgrid",)
def __init__(self, *args, uxgrid: Grid = None, **kwargs):
self._uxgrid = None
if uxgrid is not None and not isinstance(uxgrid, Grid):
raise RuntimeError(
"uxarray.UxDataArray.__init__: uxgrid can be either None or "
"an instance of the uxarray.Grid class"
)
else:
self.uxgrid = uxgrid
super().__init__(*args, **kwargs)
# declare various accessors
plot = UncachedAccessor(UxDataArrayPlotAccessor)
subset = UncachedAccessor(DataArraySubsetAccessor)
remap = UncachedAccessor(RemapAccessor)
cross_section = UncachedAccessor(UxDataArrayCrossSectionAccessor)
def _repr_html_(self) -> str:
if OPTIONS["display_style"] == "text":
return f"<pre>{escape(repr(self))}</pre>"
return array_repr(self)
@classmethod
def _construct_direct(cls, *args, **kwargs):
"""Override to make the result a ``uxarray.UxDataArray`` class."""
return cls(xr.DataArray._construct_direct(*args, **kwargs))
def _copy(self, **kwargs):
"""Override to make the result a complete instance of
``uxarray.UxDataArray``."""
copied = super()._copy(**kwargs)
deep = kwargs.get("deep", None)
if deep:
# Reinitialize the uxgrid assessor
copied.uxgrid = self.uxgrid.copy() # deep copy
else:
# Point to the existing uxgrid object
copied.uxgrid = self.uxgrid
return copied
def _replace(self, *args, **kwargs):
"""Override to make the result a complete instance of
``uxarray.UxDataArray``."""
da = super()._replace(*args, **kwargs)
if isinstance(da, UxDataArray):
da.uxgrid = self.uxgrid
else:
da = UxDataArray(da, uxgrid=self.uxgrid)
return da
@property
def uxgrid(self):
"""Linked ``Grid`` representing to the unstructured grid the data
resides on."""
return self._uxgrid
# a setter function
@uxgrid.setter
def uxgrid(self, ugrid_obj):
self._uxgrid = ugrid_obj
@property
def data_mapping(self):
"""Returns which unstructured grid a data variable is mapped to."""
if self._face_centered():
return "faces"
elif self._edge_centered():
return "edges"
elif self._node_centered():
return "nodes"
else:
return None
def to_geodataframe(
self,
periodic_elements: str | None = "exclude",
projection=None,
cache: bool | None = True,
override: bool | None = False,
engine: str | None = "spatialpandas",
exclude_antimeridian: bool | None = None,
**kwargs,
):
"""Constructs a ``GeoDataFrame`` consisting of polygons representing
the faces of the current ``Grid`` with a face-centered data variable
mapped to them.
Periodic polygons (i.e. those that cross the antimeridian) can be handled using the ``periodic_elements``
parameter. Setting ``periodic_elements='split'`` will split each periodic polygon along the antimeridian.
Setting ``periodic_elements='exclude'`` will exclude any periodic polygon from the computed GeoDataFrame.
Setting ``periodic_elements='ignore'`` will compute the GeoDataFrame assuming no corrections are needed, which
is best used for grids that do not initially include any periodic polygons.
Parameters
----------
periodic_elements : str, optional
Method for handling periodic elements. One of ['exclude', 'split', or 'ignore']:
- 'exclude': Periodic elements will be identified and excluded from the GeoDataFrame
- 'split': Periodic elements will be identified and split using the ``antimeridian`` package
- 'ignore': No processing will be applied to periodic elements.
projection: ccrs.Projection, optional
Geographic projection used to transform polygons. Only supported when periodic_elements is set to
'ignore' or 'exclude'
cache: bool, optional
Flag used to select whether to cache the computed GeoDataFrame
override: bool, optional
Flag used to select whether to ignore any cached GeoDataFrame
engine: str, optional
Selects what library to use for creating a GeoDataFrame. One of ['spatialpandas', 'geopandas']. Defaults
to spatialpandas
exclude_antimeridian: bool, optional
Flag used to select whether to exclude polygons that cross the antimeridian (Will be deprecated)
Returns
-------
gdf : spatialpandas.GeoDataFrame or geopandas.GeoDataFrame
The output ``GeoDataFrame`` with a filled out "geometry" column of polygons and a data column with the
same name as the ``UxDataArray`` (or named ``var`` if no name exists)
"""
if self.values.ndim > 1:
# data is multidimensional, must be a 1D slice
raise ValueError(
f"Data Variable must be 1-dimensional, with shape {self.uxgrid.n_face} "
f"for face-centered data."
)
if self.values.size == self.uxgrid.n_face:
gdf, non_nan_polygon_indices = self.uxgrid.to_geodataframe(
periodic_elements=periodic_elements,
projection=projection,
project=kwargs.get("project", True),
cache=cache,
override=override,
exclude_antimeridian=exclude_antimeridian,
return_non_nan_polygon_indices=True,
engine=engine,
)
if exclude_antimeridian is not None:
if exclude_antimeridian:
periodic_elements = "exclude"
else:
periodic_elements = "split"
# set a default variable name if the data array is not named
var_name = self.name if self.name is not None else "var"
if periodic_elements == "exclude":
# index data to ignore data mapped to periodic elements
_data = np.delete(
self.values,
self.uxgrid._gdf_cached_parameters["antimeridian_face_indices"],
axis=0,
)
else:
_data = self.values
if non_nan_polygon_indices is not None:
# index data to ignore NaN polygons
_data = _data[non_nan_polygon_indices]
gdf[var_name] = _data
elif self.values.size == self.uxgrid.n_node:
raise ValueError(
f"Data Variable with size {self.values.size} does not match the number of faces "
f"({self.uxgrid.n_face}. Current size matches the number of nodes. Consider running "
f"``UxDataArray.topological_mean(destination='face') to aggregate the data onto the faces."
)
elif self.values.size == self.uxgrid.n_edge:
raise ValueError(
f"Data Variable with size {self.values.size} does not match the number of faces "
f"({self.uxgrid.n_face}. Current size matches the number of edges."
)
else:
# data is not mapped to
raise ValueError(
f"Data Variable with size {self.values.size} does not match the number of faces "
f"({self.uxgrid.n_face}."
)
return gdf
def to_polycollection(
self,
periodic_elements: Optional[str] = "exclude",
projection: Optional[ccrs.Projection] = None,
return_indices: Optional[bool] = False,
cache: Optional[bool] = True,
override: Optional[bool] = False,
**kwargs,
):
"""Constructs a ``matplotlib.collections.PolyCollection``` consisting
of polygons representing the faces of the current ``UxDataArray`` with
a face-centered data variable mapped to them.
Parameters
----------
periodic_elements : str, optional
Method for handling periodic elements. One of ['exclude', 'split', or 'ignore']:
- 'exclude': Periodic elements will be identified and excluded from the GeoDataFrame
- 'split': Periodic elements will be identified and split using the ``antimeridian`` package
- 'ignore': No processing will be applied to periodic elements.
projection: ccrs.Projection
Cartopy geographic projection to use
return_indices: bool
Flag to indicate whether to return the indices of corrected polygons, if any exist
cache: bool
Flag to indicate whether to cache the computed PolyCollection
override: bool
Flag to indicate whether to override a cached PolyCollection, if it exists
"""
# data is multidimensional, must be a 1D slice
if self.values.ndim > 1:
raise ValueError(
f"Data Variable must be 1-dimensional, with shape {self.uxgrid.n_face} "
f"for face-centered data."
)
if self._face_centered():
poly_collection, corrected_to_original_faces = (
self.uxgrid.to_polycollection(
override=override,
cache=cache,
periodic_elements=periodic_elements,
return_indices=True,
projection=projection,
**kwargs,
)
)
if periodic_elements == "exclude":
# index data to ignore data mapped to periodic elements
_data = np.delete(
self.values,
self.uxgrid._poly_collection_cached_parameters[
"antimeridian_face_indices"
],
axis=0,
)
elif periodic_elements == "split":
_data = self.values[corrected_to_original_faces]
else:
_data = self.values
if (
self.uxgrid._poly_collection_cached_parameters[
"non_nan_polygon_indices"
]
is not None
):
# index data to ignore NaN polygons
_data = _data[
self.uxgrid._poly_collection_cached_parameters[
"non_nan_polygon_indices"
]
]
poly_collection.set_array(_data)
if return_indices:
return poly_collection, corrected_to_original_faces
else:
return poly_collection
else:
raise ValueError("Data variable must be face centered.")
def to_raster(
self,
ax: GeoAxes,
*,
pixel_ratio: float | None = None,
pixel_mapping: xr.DataArray | np.ndarray | None = None,
return_pixel_mapping: bool = False,
):
"""
Rasterizes a data variable stored on the faces of an unstructured grid onto the pixels of the provided Cartopy GeoAxes.
Parameters
----------
ax : GeoAxes
A Cartopy :class:`~cartopy.mpl.geoaxes.GeoAxes` onto which the data will be rasterized.
Each pixel in this axes will be sampled against the unstructured grid's face geometry.
pixel_ratio : float, default=1.0
A scaling factor to adjust the resolution of the rasterization.
A value greater than 1 increases the resolution (sharpens the image),
while a value less than 1 will result in a coarser rasterization.
The resolution also depends on what the figure's DPI setting is
prior to calling :meth:`to_raster`.
You can control DPI with the ``dpi`` keyword argument when creating the figure,
or by using :meth:`~matplotlib.figure.Figure.set_dpi` after creation.
pixel_mapping : xr.DataArray or array-like, optional
Precomputed mapping from pixels within the Cartopy GeoAxes boundary
to grid face indices (1-dimensional).
return_pixel_mapping : bool, default=False
If ``True``, the pixel mapping will be returned in addition to the raster,
and then you can pass it via the `pixel_mapping` parameter for future rasterizations
using the same or equivalent :attr:`uxgrid` and `ax`.
Note that this is also specific to the pixel ratio setting.
Returns
-------
raster : numpy.ndarray, shape (ny, nx)
Array of resampled data values corresponding to each pixel.
pixel_mapping : xr.DataArray, shape (n,)
If ``return_pixel_mapping=True``, the computed pixel mapping is returned
so that you can reuse it.
Axes and pixel ratio info are included as attributes.
Notes
-----
- This method currently employs a nearest-neighbor resampling approach. For every pixel in the GeoAxes,
it finds the face of the unstructured grid that contains the pixel's geographic coordinate and colors
that pixel with the face's data value.
- If a pixel does not intersect any face (i.e., lies outside the grid domain),
it will be left empty (transparent).
Examples
--------
>>> import cartopy.crs as ccrs
>>> import matplotlib.pyplot as plt
Create a :class:`~cartopy.mpl.geoaxes.GeoAxes` with a Robinson projection and global extent
>>> fig, ax = plt.subplots(subplot_kw={"projection": ccrs.Robinson()})
>>> ax.set_global()
Rasterize data onto the GeoAxes
>>> raster = uxds["psi"].to_raster(ax=ax)
Use :meth:`~cartopy.mpl.geoaxes.GeoAxes.imshow` to visualize the raster
>>> ax.imshow(raster, origin="lower", extent=ax.get_xlim() + ax.get_ylim())
"""
from uxarray.constants import INT_DTYPE
from uxarray.plot.matplotlib import (
_ensure_dimensions,
_nearest_neighbor_resample,
_RasterAxAttrs,
)
data = _ensure_dimensions(self)
if not isinstance(ax, GeoAxes):
raise TypeError("`ax` must be an instance of cartopy.mpl.geoaxes.GeoAxes")
pixel_ratio_set = pixel_ratio is not None
if not pixel_ratio_set:
pixel_ratio = 1.0
if pixel_mapping is not None:
input_ax_attrs = _RasterAxAttrs.from_ax(ax, pixel_ratio=pixel_ratio)
if isinstance(pixel_mapping, xr.DataArray):
pixel_ratio_input = pixel_ratio
pixel_ratio = pixel_mapping.attrs["pixel_ratio"]
if pixel_ratio_set and pixel_ratio_input != pixel_ratio:
warn(
"Pixel ratio mismatch: "
f"{pixel_ratio_input} passed but {pixel_ratio} in pixel_mapping. "
"Using the pixel_mapping attribute.",
stacklevel=2,
)
input_ax_attrs = _RasterAxAttrs.from_ax(ax, pixel_ratio=pixel_ratio)
pm_ax_attrs = _RasterAxAttrs.from_xr_attrs(pixel_mapping.attrs)
if input_ax_attrs != pm_ax_attrs:
raise ValueError(
"Pixel mapping incompatible with ax. "
+ input_ax_attrs._value_comparison_message(pm_ax_attrs)
)
pixel_mapping = np.asarray(pixel_mapping, dtype=INT_DTYPE)
else:
def _is_default_extent() -> bool:
return ax.get_autoscale_on()
if _is_default_extent():
try:
import cartopy.crs as ccrs
lon_min = float(self.uxgrid.node_lon.min(skipna=True).values)
lon_max = float(self.uxgrid.node_lon.max(skipna=True).values)
lat_min = float(self.uxgrid.node_lat.min(skipna=True).values)
lat_max = float(self.uxgrid.node_lat.max(skipna=True).values)
ax.set_extent(
(lon_min, lon_max, lat_min, lat_max),
crs=ccrs.PlateCarree(),
)
warn(
"Axes extent was default; auto-setting from grid lon/lat bounds for rasterization. "
"Set the extent explicitly to control this, e.g. via ax.set_global(), "
"ax.set_extent(...), or ax.set_xlim(...) + ax.set_ylim(...).",
stacklevel=2,
)
except Exception as e:
warn(
f"Failed to auto-set extent from grid bounds: {e}",
stacklevel=2,
)
input_ax_attrs = _RasterAxAttrs.from_ax(ax, pixel_ratio=pixel_ratio)
raster, pixel_mapping_np = _nearest_neighbor_resample(
data,
ax,
pixel_ratio=pixel_ratio,
pixel_mapping=pixel_mapping,
)
if return_pixel_mapping:
pixel_mapping_da = xr.DataArray(
pixel_mapping_np,
name="pixel_mapping",
dims=("n_pixel",),
attrs={
"long_name": "pixel_mapping",
"description": (
"Mapping from raster pixels within a Cartopy GeoAxes "
"to nearest grid face index."
),
**input_ax_attrs.to_xr_attrs(),
},
)
return raster, pixel_mapping_da
else:
return raster
def to_dataset(
self,
dim: Hashable = None,
*,
name: Hashable = None,
promote_attrs: bool = False,
) -> UxDataset:
"""Convert a ``UxDataArray`` to a ``UxDataset``.
Parameters
----------
dim : Hashable, optional
Name of the dimension on this array along which to split this array
into separate variables. If not provided, this array is converted
into a Dataset of one variable.
name : Hashable, optional
Name to substitute for this array's name. Only valid if ``dim`` is
not provided.
promote_attrs : bool, default: False
Set to True to shallow copy attrs of UxDataArray to returned UxDataset.
Returns
-------
uxds: UxDataSet
"""
xrds = super().to_dataset(dim=dim, name=name, promote_attrs=promote_attrs)
uxds = uxarray.core.dataset.UxDataset(xrds, uxgrid=self.uxgrid)
return uxds
def to_xarray(self):
return xr.DataArray(self)
def integrate(
self, quadrature_rule: str | None = "triangular", order: int | None = 4
) -> UxDataArray:
"""Computes the integral of a data variable.
Parameters
----------
quadrature_rule : str, optional
Quadrature rule to use. Defaults to "triangular".
order : int, optional
Order of quadrature rule. Defaults to 4.
Returns
-------
uxda : UxDataArray
UxDataArray containing the integrated data variable
Examples
--------
Open a UXarray dataset and compute the integral
>>> import uxarray as ux
>>> uxds = ux.open_dataset("grid.ug", "centroid_pressure_data_ug")
>>> integral = uxds["psi"].integrate()
"""
if self.values.shape[-1] == self.uxgrid.n_face:
face_areas = self.uxgrid.face_areas.values
# perform dot product between face areas and last dimension of data
integral = np.einsum("i,...i", face_areas, self.values)
elif self.values.shape[-1] == self.uxgrid.n_node:
raise ValueError("Integrating data mapped to each node not yet supported.")
elif self.values.shape[-1] == self.uxgrid.n_edge:
raise ValueError("Integrating data mapped to each edge not yet supported.")
else:
raise ValueError(
f"The final dimension of the data variable does not match the number of nodes, edges, "
f"or faces. Expected one of "
f"{self.uxgrid.n_node}, {self.uxgrid.n_edge}, or {self.uxgrid.n_face}, "
f"but received {self.values.shape[-1]}"
)
# construct a uxda with integrated quantity
uxda = UxDataArray(
integral, uxgrid=self.uxgrid, dims=self.dims[:-1], name=self.name
)
return uxda
def zonal_mean(self, lat=(-90, 90, 10), conservative: bool = False, **kwargs):
"""Compute non-conservative or conservative averages of a face-centered variable along lines of constant latitude or latitude bands.
A zonal mean in UXarray operates differently depending on the ``conservative`` flag:
- **Non-conservative**: Calculates the mean by sampling face values at specific latitude lines and weighting each contribution by the length of the line where each face intersects that latitude.
- **Conservative**: Preserves integral quantities by calculating the mean by sampling face values within latitude bands and weighting contributions by their area overlap with latitude bands.
Parameters
----------
lat : tuple, float, or array-like, default=(-90, 90, 10)
Latitude specification:
- tuple (start, end, step): For non-conservative, computes means at intervals of `step`.
For conservative, creates band edges via np.arange(start, end+step, step).
- float: Single latitude for non-conservative averaging
- array-like: For non-conservative, latitudes to sample. For conservative, band edges.
conservative : bool, default=False
If True, performs conservative (area-weighted) zonal averaging over latitude bands.
If False, performs non-conservative (intersection-weighted) averaging at latitude lines.
Returns
-------
UxDataArray
Contains zonal means with a new 'latitudes' dimension and corresponding coordinates.
Name will be original_name + '_zonal_mean' or 'zonal_mean' if unnamed.
Examples
--------
# Non-conservative averaging from -90° to 90° at 10° intervals by default
>>> uxds["var"].zonal_mean()
# Single latitude (non-conservative) over 30° latitude
>>> uxds["var"].zonal_mean(lat=30.0)
# Conservative averaging over latitude bands
>>> uxds["var"].zonal_mean(lat=(-60, 60, 10), conservative=True)
# Conservative with explicit band edges
>>> uxds["var"].zonal_mean(lat=[-90, -30, 0, 30, 90], conservative=True)
Notes
-----
Only supported for face-centered data variables.
Conservative averaging preserves integral quantities and is recommended for
physical analysis. Non-conservative averaging samples at latitude lines.
"""
if not self._face_centered():
raise ValueError(
"Zonal mean computations are currently only supported for face-centered data variables."
)
face_axis = self.dims.index("n_face")
if not conservative:
# Non-conservative (traditional) zonal averaging
if isinstance(lat, tuple):
start, end, step = lat
if step <= 0:
raise ValueError("Step size must be positive.")
if step < 0.1:
warnings.warn(
f"Very small step size ({step}°) may lead to performance issues...",
UserWarning,
stacklevel=2,
)
num_points = int(round((end - start) / step)) + 1
latitudes = np.linspace(start, end, num_points)
latitudes = np.clip(latitudes, -90, 90)
elif isinstance(lat, (float, int)):
latitudes = [lat]
elif isinstance(lat, (list, np.ndarray)):
latitudes = np.asarray(lat)
else:
raise ValueError(
"Invalid value for 'lat' provided. Must be a scalar, tuple (min_lat, max_lat, step), or array-like."
)
res = _compute_non_conservative_zonal_mean(
uxda=self, latitudes=latitudes, **kwargs
)
dims = list(self.dims)
dims[face_axis] = "latitudes"
# Assign coords from `self` to the result except one that corresponds to `dims[face_axis]`
new_coords = {
k: v
for k, v in self.coords.items()
if self.dims[face_axis] not in v.dims
}
# Add latitudes to the resulting coords
new_coords["latitudes"] = latitudes
return xr.DataArray(
res,
dims=dims,
coords=new_coords,
name=self.name + "_zonal_mean"
if self.name is not None
else "zonal_mean",
attrs={"zonal_mean": True, "conservative": False},
)
else:
# Conservative zonal averaging
if isinstance(lat, tuple):
start, end, step = lat
if step <= 0:
raise ValueError(
"Step size must be positive for conservative averaging."
)
if step < 0.1:
warnings.warn(
f"Very small step size ({step}°) may lead to performance issues...",
UserWarning,
stacklevel=2,
)
num_points = int(round((end - start) / step)) + 1
edges = np.linspace(start, end, num_points)
edges = np.clip(edges, -90, 90)
elif isinstance(lat, (list, np.ndarray)):
edges = np.asarray(lat, dtype=float)
else:
raise ValueError(
"For conservative averaging, 'lat' must be a tuple (start, end, step) or array-like band edges."
)
if edges.ndim != 1 or edges.size < 2:
raise ValueError("Band edges must be 1D with at least two values")
res = _compute_conservative_zonal_mean_bands(self, edges)
# Use band centers as coordinate values
centers = 0.5 * (edges[:-1] + edges[1:])
dims = list(self.dims)
dims[face_axis] = "latitudes"
# Assign coords from `self` to the result except one that corresponds to `dims[face_axis]`
new_coords = {
k: v
for k, v in self.coords.items()
if self.dims[face_axis] not in v.dims
}
# Add latitudes to the resulting coords
new_coords["latitudes"] = centers
return xr.DataArray(
res,
dims=dims,
coords=new_coords,
name=self.name + "_zonal_mean"
if self.name is not None
else "zonal_mean",
attrs={
"zonal_mean": True,
"conservative": True,
"lat_band_edges": edges,
},
)
def zonal_average(self, lat=(-90, 90, 10), conservative: bool = False, **kwargs):
"""Alias of zonal_mean; prefer `zonal_mean` for primary API."""
return self.zonal_mean(lat=lat, conservative=conservative, **kwargs)
def azimuthal_mean(
self,
center_coord,
outer_radius: int | float,
radius_step: int | float,
return_hit_counts: bool = False,
):
"""Compute averages along circles of constant great-circle distance from a point.
Parameters
----------
center_coord: tuple, list, ndarray
Longitude and latitude of the center of the bounding circle
outer_radius: scalar, int, float
The maximum radius, in great-circle degrees, at which the azimuthal mean will be computed.
radius_step: scalar, int, float
Means will be computed at intervals of `radius_step` on the interval [0, outer_radius]
return_hit_counts: bool, false
Indicates whether to return the number of hits at each radius
Returns
-------
azimuthal_mean: xr.DataArray
Contains a variable with a dimension 'radius' corresponding to the azimuthal average.
hit_counts: xr.DataArray
The number of hits at each radius
Examples
--------
# Range from 0° to 5° at 0.5° intervals, around the central point lon,lat=10,50
>>> az = uxds["var"].azimuthal_mean(
... center_coord=(10, 50), outer_radius=5.0, radius_step=0.5
... )
>>> az.plot(title="Azimuthal Mean")
Notes
-----
Only supported for face-centered data variables. Candidate faces are determined
using bounding circles - for radii = [r1, r2, r3, ...] faces whose centers lie at distance d,
r2 < d <= r3 are included in calculations for r3.
"""
from uxarray.grid.coordinates import _lonlat_rad_to_xyz
if not self._face_centered():
raise ValueError(
"Azimuthal mean computations are currently only supported for face-centered data variables."
)
if outer_radius <= 0:
raise ValueError("Radius must be a positive scalar.")
kdtree = self.uxgrid._get_scipy_kd_tree()
lon_deg, lat_deg = map(float, np.asarray(center_coord))
center_xyz = np.array(
_lonlat_rad_to_xyz(np.deg2rad(lon_deg), np.deg2rad(lat_deg))
)
radii_deg = np.arange(0.0, outer_radius + radius_step, radius_step, dtype=float)
radii_rad = np.deg2rad(radii_deg)
chord_radii = 2.0 * np.sin(radii_rad / 2.0)
faces_processed = np.array([], dtype=np.int_)
means = np.full(
(radii_deg.size, *self.to_xarray().isel(drop=True, n_face=0).shape), np.nan
)
hit_count = np.zeros_like(radii_deg, dtype=np.int_)
for ii, r_chord in enumerate(chord_radii):
# indices of faces within the bounding circle for this radius
within = np.array(
kdtree.query_ball_point(center_xyz, r_chord), dtype=np.int_
)
if within.size:
within.sort()
# include only the new ring: r_(i-1) < d <= r_i
faces_in_bin = np.setdiff1d(within, faces_processed, assume_unique=True)
hit_count[ii] = faces_in_bin.size
if hit_count[ii] == 0:
continue
faces_processed = within # cumulative set for next iteration
tpose = self.isel(n_face=faces_in_bin).transpose(..., "n_face")
means[ii, ...] = tpose.weighted_mean().data
# swap the leading 'radius' axis into the former n_face position
face_axis = self.dims.index("n_face")
dims = list(self.dims)
dims[face_axis] = "radius"
means = np.moveaxis(means, 0, face_axis)
hit_count = xr.DataArray(
data=hit_count, dims="radius", coords={"radius": radii_deg}
)
# Assign coords from `self` to the result except one that corresponds to `dims[face_axis]`
new_coords = {
k: v for k, v in self.coords.items() if self.dims[face_axis] not in v.dims
}
# Add radii_deg to the resulting coords
new_coords["radius"] = radii_deg
uxda = xr.DataArray(
means,
dims=dims,
coords=new_coords,
name=self.name + "_azimuthal_mean"
if self.name is not None
else "azimuthal_mean",
attrs={
"azimuthal_mean": True,
"center_lon": lon_deg,
"center_lat": lat_deg,
"radius_units": "degrees",
},
)
if return_hit_counts:
return uxda, hit_count
else:
return uxda
azimuthal_average = azimuthal_mean
def weighted_mean(self, weights=None):
"""Computes a weighted mean.
This function calculates the weighted mean of a variable,
using the specified `weights`. If no weights are provided, it will automatically select
appropriate weights based on whether the variable is face-centered or edge-centered. If
the variable is neither face nor edge-centered a warning is raised, and an unweighted mean is computed instead.
Parameters
----------
weights : np.ndarray or None, optional
The weights to use for the weighted mean calculation. If `None`, the function will
determine weights based on the variable's association:
- For face-centered variables: uses `self.uxgrid.face_areas.data`
- For edge-centered variables: uses `self.uxgrid.edge_node_distances.data`
If the variable is neither face-centered nor edge-centered, a warning is raised, and
an unweighted mean is computed instead. User-defined weights should match the shape
of the data variable's last dimension.
Returns
-------
UxDataArray
A new `UxDataArray` object representing the weighted mean of the input variable. The
result is attached to the same `uxgrid` attribute as the original variable.
Example
-------
>>> weighted_mean = uxds["t2m"].weighted_mean()
Raises
------
AssertionError
If user-defined `weights` are provided and the shape of `weights` does not match
the shape of the data variable's last dimension.
Warnings
--------
UserWarning
Raised when attempting to compute a weighted mean on a variable without associated
weights. An unweighted mean will be computed in this case.
Notes
-----
- The weighted mean is computed along the last dimension of the data variable, which is
assumed to be the geometry dimension (e.g., faces, edges, or nodes).
"""
if weights is None:
if self._face_centered():
weights = self.uxgrid.face_areas.data
elif self._edge_centered():
weights = self.uxgrid.edge_node_distances.data
else:
warnings.warn(
"Attempting to perform a weighted mean calculation on a variable that does not have"
"associated weights. Weighted mean is only supported for face or edge centered "
"variables. Performing an unweighted mean."
)
else:
# user-defined weights
assert weights.shape[-1] == self.shape[-1]
# compute the total weight
total_weight = weights.sum()
# compute the weighted mean, with an assumption on the index of dimension (last one is geometry)
weighted_mean = (self * weights).sum(axis=-1) / total_weight
# create a UxDataArray and return it
return UxDataArray(weighted_mean, uxgrid=self.uxgrid)
def topological_mean(
self,
destination: Literal["node", "edge", "face"],
**kwargs,
):
"""Performs a topological mean aggregation.
See Also
--------
numpy.mean
dask.array.mean
xarray.DataArray.mean
Parameters
----------
destination: str,
Destination grid dimension for aggregation.
Node-Centered Variable:
- ``destination='edge'``: Aggregation is applied on the nodes that saddle each edge, with the result stored
on each edge
- ``destination='face'``: Aggregation is applied on the nodes that surround each face, with the result stored
on each face.
Edge-Centered Variable:
- ``destination='node'``: Aggregation is applied on the edges that intersect each node, with the result stored
on each node.
- ``Destination='face'``: Aggregation is applied on the edges that surround each face, with the result stored
on each face.