Skip to content

Commit f1127b0

Browse files
committed
add zonal_anomaly with shared face-band weight kernel
1 parent 6fb17d8 commit f1127b0

2 files changed

Lines changed: 216 additions & 71 deletions

File tree

uxarray/core/dataarray.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from uxarray.core.zonal import (
2525
_compute_conservative_zonal_mean_bands,
2626
_compute_non_conservative_zonal_mean,
27+
_compute_zonal_anomaly,
2728
)
2829
from uxarray.cross_sections import UxDataArrayCrossSectionAccessor
2930
from uxarray.formatting_html import array_repr
@@ -767,6 +768,70 @@ def zonal_average(self, lat=(-90, 90, 10), conservative: bool = False, **kwargs)
767768
"""Alias of zonal_mean; prefer `zonal_mean` for primary API."""
768769
return self.zonal_mean(lat=lat, conservative=conservative, **kwargs)
769770

771+
def zonal_anomaly(self, lat=(-90, 90, 10), conservative: bool = False):
772+
"""Compute the zonal anomaly: each face value minus the mean of its latitude band.
773+
774+
Returns a new ``UxDataArray`` with the same dimensions as the input,
775+
where each face holds its original value minus the zonal mean of the
776+
latitude band it belongs to.
777+
778+
Parameters
779+
----------
780+
lat : tuple or array-like, default=(-90, 90, 10)
781+
Latitude band specification:
782+
- tuple (start, end, step): band edges via np.linspace(start, end, n)
783+
- array-like: explicit band edges in degrees
784+
conservative : bool, default=False
785+
If True, uses area-weighted band means and blends across bands for
786+
faces that straddle a band boundary, reusing the face-band weight
787+
matrix computed for zonal_mean so no geometry is duplicated.
788+
If False, assigns each face to a band by its centroid latitude.
789+
790+
Returns
791+
-------
792+
UxDataArray
793+
Same dimensions as input with per-face band mean subtracted.
794+
795+
Examples
796+
--------
797+
>>> uxds["var"].zonal_anomaly()
798+
>>> uxds["var"].zonal_anomaly(lat=(-60, 60, 5), conservative=True)
799+
"""
800+
if not self._face_centered():
801+
raise ValueError(
802+
"Zonal anomaly is only supported for face-centered data variables."
803+
)
804+
805+
if isinstance(lat, tuple):
806+
start, end, step = lat
807+
if step <= 0:
808+
raise ValueError("Step size must be positive.")
809+
num_points = int(round((end - start) / step)) + 1
810+
edges = np.linspace(start, end, num_points)
811+
edges = np.clip(edges, -90, 90)
812+
elif isinstance(lat, (list, np.ndarray)):
813+
edges = np.asarray(lat, dtype=float)
814+
else:
815+
raise ValueError(
816+
"Invalid value for 'lat'. Must be a tuple (start, end, step) or array-like band edges."
817+
)
818+
819+
if edges.ndim != 1 or edges.size < 2:
820+
raise ValueError("Band edges must be 1D with at least two values.")
821+
822+
res = _compute_zonal_anomaly(self, edges, conservative=conservative)
823+
824+
return UxDataArray(
825+
res,
826+
dims=self.dims,
827+
coords=self.coords,
828+
name=self.name + "_zonal_anomaly"
829+
if self.name is not None
830+
else "zonal_anomaly",
831+
attrs={"zonal_anomaly": True, "conservative": conservative},
832+
uxgrid=self.uxgrid,
833+
)
834+
770835
def azimuthal_mean(
771836
self,
772837
center_coord,

uxarray/core/zonal.py

Lines changed: 151 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -225,31 +225,25 @@ def _compute_band_overlap_area(
225225
return area
226226

227227

228-
def _compute_conservative_zonal_mean_bands(uxda, bands):
229-
"""
230-
Compute conservative zonal mean over latitude bands.
228+
def _compute_face_band_weights(uxgrid, bands):
229+
"""Compute overlap area between every face and every latitude band.
231230
232-
Uses get_faces_between_latitudes to optimize computation by avoiding
233-
overlap area calculations for fully contained faces.
231+
Shared geometry kernel used by both zonal_mean and zonal_anomaly so the
232+
expensive intersection calculations are never duplicated.
234233
235234
Parameters
236235
----------
237-
uxda : UxDataArray
238-
The data array to compute zonal means for
236+
uxgrid : Grid
239237
bands : array-like
240-
Latitude band edges in degrees
238+
Latitude band edges in degrees, shape (n_bands + 1,)
241239
242240
Returns
243241
-------
244-
result : array
245-
Zonal means for each band
242+
W : ndarray, shape (n_face, n_bands)
243+
W[f, b] is the overlap area between face f and band b.
244+
Fully-contained faces carry their full face area; partially-overlapping
245+
faces carry the exact intersection area.
246246
"""
247-
import dask.array as da
248-
249-
uxgrid = uxda.uxgrid
250-
face_axis = uxda.get_axis_num("n_face")
251-
252-
# Pre-compute face properties
253247
faces_edge_nodes_xyz = _get_cartesian_face_edge_nodes_array(
254248
uxgrid.face_node_connectivity.values,
255249
uxgrid.n_face,
@@ -263,80 +257,166 @@ def _compute_conservative_zonal_mean_bands(uxda, bands):
263257
face_areas = uxgrid.face_areas.values
264258

265259
bands = np.asarray(bands, dtype=float)
266-
if bands.ndim != 1 or bands.size < 2:
267-
raise ValueError("bands must be 1D with at least two edges")
268-
269260
nb = bands.size - 1
270-
271-
# Initialize result array
272-
shape = list(uxda.shape)
273-
shape[face_axis] = nb
274-
if isinstance(uxda.data, da.Array):
275-
result = da.zeros(shape, dtype=uxda.dtype)
276-
else:
277-
result = np.zeros(shape, dtype=uxda.dtype)
261+
W = np.zeros((uxgrid.n_face, nb), dtype=float)
278262

279263
for bi in range(nb):
280264
lat0 = float(np.clip(bands[bi], -90.0, 90.0))
281265
lat1 = float(np.clip(bands[bi + 1], -90.0, 90.0))
282-
283-
# Ensure lat0 <= lat1
284266
if lat0 > lat1:
285267
lat0, lat1 = lat1, lat0
286268

287269
z0 = np.sin(np.deg2rad(lat0))
288270
z1 = np.sin(np.deg2rad(lat1))
289271
zmin, zmax = (z0, z1) if z0 <= z1 else (z1, z0)
290272

291-
# Step 1: Get fully contained faces
292-
fully_contained_faces = uxgrid.get_faces_between_latitudes((lat0, lat1))
293-
294-
# Step 2: Get all overlapping faces (including partial)
273+
fully_contained = uxgrid.get_faces_between_latitudes((lat0, lat1))
295274
mask = ~((face_bounds_lat[:, 1] < lat0) | (face_bounds_lat[:, 0] > lat1))
296-
all_overlapping_faces = np.nonzero(mask)[0]
275+
all_overlapping = np.nonzero(mask)[0]
297276

298-
if all_overlapping_faces.size == 0:
299-
# No faces in this band
300-
idx = [slice(None)] * result.ndim
301-
idx[face_axis] = bi
302-
result[tuple(idx)] = np.nan
277+
if all_overlapping.size == 0:
303278
continue
304279

305-
# Step 3: Partition faces into fully contained vs partially overlapping
306-
is_fully_contained = np.isin(all_overlapping_faces, fully_contained_faces)
307-
partially_overlapping_faces = all_overlapping_faces[~is_fully_contained]
308-
309-
# Step 4: Compute weights
310-
all_weights = np.zeros(all_overlapping_faces.size, dtype=float)
311-
312-
# For fully contained faces, use their full area
313-
if fully_contained_faces.size > 0:
314-
fully_contained_indices = np.where(is_fully_contained)[0]
315-
all_weights[fully_contained_indices] = face_areas[fully_contained_faces]
316-
317-
# For partially overlapping faces, compute fractional area
318-
if partially_overlapping_faces.size > 0:
319-
partial_indices = np.where(~is_fully_contained)[0]
320-
for i, face_idx in enumerate(partially_overlapping_faces):
321-
nedge = n_nodes_per_face[face_idx]
322-
face_edges = faces_edge_nodes_xyz[face_idx, :nedge]
323-
overlap_area = _compute_band_overlap_area(face_edges, zmin, zmax)
324-
all_weights[partial_indices[i]] = overlap_area
325-
326-
# Step 5: Compute weighted average
327-
data_slice = uxda.isel(n_face=all_overlapping_faces, ignore_grid=True).data
328-
total_weight = all_weights.sum()
329-
330-
if total_weight == 0.0:
331-
weighted = np.nan * data_slice[..., 0]
332-
else:
333-
w_shape = [1] * data_slice.ndim
334-
w_shape[face_axis] = all_weights.size
335-
w_reshaped = all_weights.reshape(w_shape)
336-
weighted = (data_slice * w_reshaped).sum(axis=face_axis) / total_weight
280+
is_fully_contained = np.isin(all_overlapping, fully_contained)
281+
282+
fc = all_overlapping[is_fully_contained]
283+
W[fc, bi] = face_areas[fc]
284+
285+
for f in all_overlapping[~is_fully_contained]:
286+
nedge = n_nodes_per_face[f]
287+
W[f, bi] = _compute_band_overlap_area(
288+
faces_edge_nodes_xyz[f, :nedge], zmin, zmax
289+
)
290+
291+
return W
292+
293+
294+
def _compute_conservative_zonal_mean_bands(uxda, bands):
295+
"""Compute conservative zonal mean over latitude bands.
296+
297+
Parameters
298+
----------
299+
uxda : UxDataArray
300+
bands : array-like
301+
Latitude band edges in degrees
302+
303+
Returns
304+
-------
305+
result : array
306+
Zonal means for each band, with n_face axis replaced by n_bands
307+
"""
308+
import dask.array as da
309+
310+
bands = np.asarray(bands, dtype=float)
311+
if bands.ndim != 1 or bands.size < 2:
312+
raise ValueError("bands must be 1D with at least two edges")
313+
314+
W = _compute_face_band_weights(uxda.uxgrid, bands) # (n_face, n_bands)
315+
nb = W.shape[1]
316+
face_axis = uxda.get_axis_num("n_face")
317+
318+
shape = list(uxda.shape)
319+
shape[face_axis] = nb
320+
if isinstance(uxda.data, da.Array):
321+
result = da.full(shape, np.nan, dtype=float)
322+
else:
323+
result = np.full(shape, np.nan, dtype=float)
324+
325+
for bi in range(nb):
326+
overlapping = np.nonzero(W[:, bi] > 0)[0]
327+
if overlapping.size == 0:
328+
continue
329+
330+
w = W[overlapping, bi]
331+
total = w.sum()
332+
if total == 0.0:
333+
continue
334+
335+
data_slice = uxda.isel(n_face=overlapping, ignore_grid=True).data
336+
w_shape = [1] * data_slice.ndim
337+
w_shape[face_axis] = w.size
338+
weighted = (data_slice * w.reshape(w_shape)).sum(axis=face_axis) / total
337339

338340
idx = [slice(None)] * result.ndim
339341
idx[face_axis] = bi
340342
result[tuple(idx)] = weighted
341343

342344
return result
345+
346+
347+
def _compute_zonal_anomaly(uxda, bands, conservative=False):
348+
"""Compute zonal anomaly: each face value minus the mean of its latitude band.
349+
350+
Parameters
351+
----------
352+
uxda : UxDataArray
353+
bands : array-like
354+
Latitude band edges in degrees
355+
conservative : bool
356+
If True, uses area-weighted band means and blends across bands for
357+
faces that straddle a boundary, reusing the same weight matrix as
358+
zonal_mean so geometry is computed only once.
359+
If False, assigns each face to a band by centroid latitude.
360+
361+
Returns
362+
-------
363+
ndarray
364+
Same shape as uxda, with the per-face band mean subtracted.
365+
"""
366+
bands = np.asarray(bands, dtype=float)
367+
face_axis = uxda.get_axis_num("n_face")
368+
n_face = uxda.uxgrid.n_face
369+
nb = bands.size - 1
370+
371+
if conservative:
372+
# Single geometry pass shared with zonal_mean
373+
W = _compute_face_band_weights(uxda.uxgrid, bands) # (n_face, n_bands)
374+
375+
# Band means
376+
band_means = np.full(nb, np.nan)
377+
for bi in range(nb):
378+
overlapping = np.nonzero(W[:, bi] > 0)[0]
379+
if overlapping.size == 0:
380+
continue
381+
w = W[overlapping, bi]
382+
total = w.sum()
383+
if total > 0:
384+
vals = uxda.isel(n_face=overlapping, ignore_grid=True).values
385+
band_means[bi] = (w * vals).sum() / total
386+
387+
# Map band means back to faces; straddling faces get area-weighted blend
388+
face_totals = W.sum(axis=1)
389+
valid = face_totals > 0
390+
face_means = np.where(
391+
valid,
392+
np.where(
393+
valid,
394+
(
395+
W * np.where(np.isnan(band_means), 0.0, band_means)[np.newaxis, :]
396+
).sum(axis=1)
397+
/ np.where(valid, face_totals, 1.0),
398+
np.nan,
399+
),
400+
np.nan,
401+
)
402+
else:
403+
# Centroid-based: fast, no intersection geometry needed
404+
face_lats = uxda.uxgrid.face_lat.values
405+
band_indices = np.clip(np.digitize(face_lats, bands) - 1, 0, nb - 1)
406+
407+
band_means = np.full(nb, np.nan)
408+
for bi in range(nb):
409+
mask = band_indices == bi
410+
if mask.any():
411+
band_means[bi] = float(
412+
uxda.isel(
413+
n_face=np.nonzero(mask)[0], ignore_grid=True
414+
).values.mean()
415+
)
416+
417+
face_means = band_means[band_indices]
418+
419+
# Broadcast face_means to match uxda shape (face axis may not be last)
420+
shape = [1] * uxda.ndim
421+
shape[face_axis] = n_face
422+
return uxda.values - face_means.reshape(shape)

0 commit comments

Comments
 (0)