@@ -225,31 +225,25 @@ def _compute_band_overlap_area(
225225 return area
226226
227227
228- def _compute_conservative_zonal_mean_bands (uxda , bands ):
229- """
230- Compute conservative zonal mean over latitude bands.
228+ def _compute_face_band_weights (uxgrid , bands ):
229+ """Compute overlap area between every face and every latitude band.
231230
232- Uses get_faces_between_latitudes to optimize computation by avoiding
233- overlap area calculations for fully contained faces .
231+ Shared geometry kernel used by both zonal_mean and zonal_anomaly so the
232+ expensive intersection calculations are never duplicated .
234233
235234 Parameters
236235 ----------
237- uxda : UxDataArray
238- The data array to compute zonal means for
236+ uxgrid : Grid
239237 bands : array-like
240- Latitude band edges in degrees
238+ Latitude band edges in degrees, shape (n_bands + 1,)
241239
242240 Returns
243241 -------
244- result : array
245- Zonal means for each band
242+ W : ndarray, shape (n_face, n_bands)
243+ W[f, b] is the overlap area between face f and band b.
244+ Fully-contained faces carry their full face area; partially-overlapping
245+ faces carry the exact intersection area.
246246 """
247- import dask .array as da
248-
249- uxgrid = uxda .uxgrid
250- face_axis = uxda .get_axis_num ("n_face" )
251-
252- # Pre-compute face properties
253247 faces_edge_nodes_xyz = _get_cartesian_face_edge_nodes_array (
254248 uxgrid .face_node_connectivity .values ,
255249 uxgrid .n_face ,
@@ -263,80 +257,166 @@ def _compute_conservative_zonal_mean_bands(uxda, bands):
263257 face_areas = uxgrid .face_areas .values
264258
265259 bands = np .asarray (bands , dtype = float )
266- if bands .ndim != 1 or bands .size < 2 :
267- raise ValueError ("bands must be 1D with at least two edges" )
268-
269260 nb = bands .size - 1
270-
271- # Initialize result array
272- shape = list (uxda .shape )
273- shape [face_axis ] = nb
274- if isinstance (uxda .data , da .Array ):
275- result = da .zeros (shape , dtype = uxda .dtype )
276- else :
277- result = np .zeros (shape , dtype = uxda .dtype )
261+ W = np .zeros ((uxgrid .n_face , nb ), dtype = float )
278262
279263 for bi in range (nb ):
280264 lat0 = float (np .clip (bands [bi ], - 90.0 , 90.0 ))
281265 lat1 = float (np .clip (bands [bi + 1 ], - 90.0 , 90.0 ))
282-
283- # Ensure lat0 <= lat1
284266 if lat0 > lat1 :
285267 lat0 , lat1 = lat1 , lat0
286268
287269 z0 = np .sin (np .deg2rad (lat0 ))
288270 z1 = np .sin (np .deg2rad (lat1 ))
289271 zmin , zmax = (z0 , z1 ) if z0 <= z1 else (z1 , z0 )
290272
291- # Step 1: Get fully contained faces
292- fully_contained_faces = uxgrid .get_faces_between_latitudes ((lat0 , lat1 ))
293-
294- # Step 2: Get all overlapping faces (including partial)
273+ fully_contained = uxgrid .get_faces_between_latitudes ((lat0 , lat1 ))
295274 mask = ~ ((face_bounds_lat [:, 1 ] < lat0 ) | (face_bounds_lat [:, 0 ] > lat1 ))
296- all_overlapping_faces = np .nonzero (mask )[0 ]
275+ all_overlapping = np .nonzero (mask )[0 ]
297276
298- if all_overlapping_faces .size == 0 :
299- # No faces in this band
300- idx = [slice (None )] * result .ndim
301- idx [face_axis ] = bi
302- result [tuple (idx )] = np .nan
277+ if all_overlapping .size == 0 :
303278 continue
304279
305- # Step 3: Partition faces into fully contained vs partially overlapping
306- is_fully_contained = np .isin (all_overlapping_faces , fully_contained_faces )
307- partially_overlapping_faces = all_overlapping_faces [~ is_fully_contained ]
308-
309- # Step 4: Compute weights
310- all_weights = np .zeros (all_overlapping_faces .size , dtype = float )
311-
312- # For fully contained faces, use their full area
313- if fully_contained_faces .size > 0 :
314- fully_contained_indices = np .where (is_fully_contained )[0 ]
315- all_weights [fully_contained_indices ] = face_areas [fully_contained_faces ]
316-
317- # For partially overlapping faces, compute fractional area
318- if partially_overlapping_faces .size > 0 :
319- partial_indices = np .where (~ is_fully_contained )[0 ]
320- for i , face_idx in enumerate (partially_overlapping_faces ):
321- nedge = n_nodes_per_face [face_idx ]
322- face_edges = faces_edge_nodes_xyz [face_idx , :nedge ]
323- overlap_area = _compute_band_overlap_area (face_edges , zmin , zmax )
324- all_weights [partial_indices [i ]] = overlap_area
325-
326- # Step 5: Compute weighted average
327- data_slice = uxda .isel (n_face = all_overlapping_faces , ignore_grid = True ).data
328- total_weight = all_weights .sum ()
329-
330- if total_weight == 0.0 :
331- weighted = np .nan * data_slice [..., 0 ]
332- else :
333- w_shape = [1 ] * data_slice .ndim
334- w_shape [face_axis ] = all_weights .size
335- w_reshaped = all_weights .reshape (w_shape )
336- weighted = (data_slice * w_reshaped ).sum (axis = face_axis ) / total_weight
280+ is_fully_contained = np .isin (all_overlapping , fully_contained )
281+
282+ fc = all_overlapping [is_fully_contained ]
283+ W [fc , bi ] = face_areas [fc ]
284+
285+ for f in all_overlapping [~ is_fully_contained ]:
286+ nedge = n_nodes_per_face [f ]
287+ W [f , bi ] = _compute_band_overlap_area (
288+ faces_edge_nodes_xyz [f , :nedge ], zmin , zmax
289+ )
290+
291+ return W
292+
293+
294+ def _compute_conservative_zonal_mean_bands (uxda , bands ):
295+ """Compute conservative zonal mean over latitude bands.
296+
297+ Parameters
298+ ----------
299+ uxda : UxDataArray
300+ bands : array-like
301+ Latitude band edges in degrees
302+
303+ Returns
304+ -------
305+ result : array
306+ Zonal means for each band, with n_face axis replaced by n_bands
307+ """
308+ import dask .array as da
309+
310+ bands = np .asarray (bands , dtype = float )
311+ if bands .ndim != 1 or bands .size < 2 :
312+ raise ValueError ("bands must be 1D with at least two edges" )
313+
314+ W = _compute_face_band_weights (uxda .uxgrid , bands ) # (n_face, n_bands)
315+ nb = W .shape [1 ]
316+ face_axis = uxda .get_axis_num ("n_face" )
317+
318+ shape = list (uxda .shape )
319+ shape [face_axis ] = nb
320+ if isinstance (uxda .data , da .Array ):
321+ result = da .full (shape , np .nan , dtype = float )
322+ else :
323+ result = np .full (shape , np .nan , dtype = float )
324+
325+ for bi in range (nb ):
326+ overlapping = np .nonzero (W [:, bi ] > 0 )[0 ]
327+ if overlapping .size == 0 :
328+ continue
329+
330+ w = W [overlapping , bi ]
331+ total = w .sum ()
332+ if total == 0.0 :
333+ continue
334+
335+ data_slice = uxda .isel (n_face = overlapping , ignore_grid = True ).data
336+ w_shape = [1 ] * data_slice .ndim
337+ w_shape [face_axis ] = w .size
338+ weighted = (data_slice * w .reshape (w_shape )).sum (axis = face_axis ) / total
337339
338340 idx = [slice (None )] * result .ndim
339341 idx [face_axis ] = bi
340342 result [tuple (idx )] = weighted
341343
342344 return result
345+
346+
347+ def _compute_zonal_anomaly (uxda , bands , conservative = False ):
348+ """Compute zonal anomaly: each face value minus the mean of its latitude band.
349+
350+ Parameters
351+ ----------
352+ uxda : UxDataArray
353+ bands : array-like
354+ Latitude band edges in degrees
355+ conservative : bool
356+ If True, uses area-weighted band means and blends across bands for
357+ faces that straddle a boundary, reusing the same weight matrix as
358+ zonal_mean so geometry is computed only once.
359+ If False, assigns each face to a band by centroid latitude.
360+
361+ Returns
362+ -------
363+ ndarray
364+ Same shape as uxda, with the per-face band mean subtracted.
365+ """
366+ bands = np .asarray (bands , dtype = float )
367+ face_axis = uxda .get_axis_num ("n_face" )
368+ n_face = uxda .uxgrid .n_face
369+ nb = bands .size - 1
370+
371+ if conservative :
372+ # Single geometry pass shared with zonal_mean
373+ W = _compute_face_band_weights (uxda .uxgrid , bands ) # (n_face, n_bands)
374+
375+ # Band means
376+ band_means = np .full (nb , np .nan )
377+ for bi in range (nb ):
378+ overlapping = np .nonzero (W [:, bi ] > 0 )[0 ]
379+ if overlapping .size == 0 :
380+ continue
381+ w = W [overlapping , bi ]
382+ total = w .sum ()
383+ if total > 0 :
384+ vals = uxda .isel (n_face = overlapping , ignore_grid = True ).values
385+ band_means [bi ] = (w * vals ).sum () / total
386+
387+ # Map band means back to faces; straddling faces get area-weighted blend
388+ face_totals = W .sum (axis = 1 )
389+ valid = face_totals > 0
390+ face_means = np .where (
391+ valid ,
392+ np .where (
393+ valid ,
394+ (
395+ W * np .where (np .isnan (band_means ), 0.0 , band_means )[np .newaxis , :]
396+ ).sum (axis = 1 )
397+ / np .where (valid , face_totals , 1.0 ),
398+ np .nan ,
399+ ),
400+ np .nan ,
401+ )
402+ else :
403+ # Centroid-based: fast, no intersection geometry needed
404+ face_lats = uxda .uxgrid .face_lat .values
405+ band_indices = np .clip (np .digitize (face_lats , bands ) - 1 , 0 , nb - 1 )
406+
407+ band_means = np .full (nb , np .nan )
408+ for bi in range (nb ):
409+ mask = band_indices == bi
410+ if mask .any ():
411+ band_means [bi ] = float (
412+ uxda .isel (
413+ n_face = np .nonzero (mask )[0 ], ignore_grid = True
414+ ).values .mean ()
415+ )
416+
417+ face_means = band_means [band_indices ]
418+
419+ # Broadcast face_means to match uxda shape (face axis may not be last)
420+ shape = [1 ] * uxda .ndim
421+ shape [face_axis ] = n_face
422+ return uxda .values - face_means .reshape (shape )
0 commit comments