99# distributed with this code, or at
1010# https://raw.githubusercontent.com/MPAS-Dev/MPAS-Analysis/main/LICENSE
1111import xarray as xr
12- import numpy as np
13- import scipy .sparse
14- import scipy .sparse .linalg
12+
13+ from mpas_tools .ocean .barotropic_streamfunction import (
14+ compute_barotropic_streamfunction ,
15+ shift_barotropic_streamfunction
16+ )
1517
1618from mpas_analysis .shared import AnalysisTask
1719from mpas_analysis .shared .climatology import RemapMpasClimatologySubtask
1820from mpas_analysis .shared .plot import PlotClimatologyMapSubtask
19- from mpas_analysis .ocean .utility import compute_zmid
2021from mpas_analysis .shared .projection import comparison_grid_option_suffixes
2122
2223
@@ -280,8 +281,21 @@ def customize_masked_climatology(self, climatology, season):
280281 'edgesOnVertex' , 'dcEdge' , 'dvEdge' , 'bottomDepth' ,
281282 'maxLevelCell' , 'latVertex' , 'areaTriangle' ,]]
282283 ds_mesh .load ()
283- bsf_vertex = self ._compute_barotropic_streamfunction_vertex (
284- ds_mesh , climatology )
284+
285+ cells_on_vertex = ds_mesh .cellsOnVertex - 1
286+ lat_vertex = ds_mesh .latVertex
287+ bsf_vertex = compute_barotropic_streamfunction (
288+ ds_mesh = ds_mesh , ds = climatology , min_depth = self .min_depth ,
289+ max_depth = self .max_depth , include_bolus = self .include_bolus ,
290+ include_submesoscale = self .include_submesoscale )
291+
292+ lat_range = config .getexpression (
293+ self .taskName , 'latitudeRangeForZeroBSF' )
294+
295+ bsf_vertex = shift_barotropic_streamfunction (
296+ bsf_vertex = bsf_vertex , lat_range = lat_range ,
297+ cells_on_vertex = cells_on_vertex , lat_vertex = lat_vertex )
298+
285299 logger .info ('bsf on vertices computed.' )
286300
287301 climatology ['barotropicStreamfunction' ] = bsf_vertex
@@ -305,234 +319,12 @@ def customize_masked_climatology(self, climatology, season):
305319
306320 lat_range = config .getexpression (
307321 config_section_name , 'latitudeRangeForZeroBSF' )
308- climatology [mpas_field_name ] = _shift_bsf (
309- bsf_vertex , lat_range , ds_mesh . cellsOnVertex - 1 ,
310- ds_mesh . latVertex )
322+ climatology [mpas_field_name ] = shift_barotropic_streamfunction (
323+ bsf_vertex = bsf_vertex , lat_range = lat_range ,
324+ cells_on_vertex = cells_on_vertex , lat_vertex = lat_vertex )
311325 climatology [mpas_field_name ].attrs ['units' ] = 'Sv'
312326 climatology [mpas_field_name ].attrs ['description' ] = \
313327 f'barotropic streamfunction at vertices, offset for ' \
314328 f'{ grid_suffix } plots'
315329
316330 return climatology
317-
318- def _compute_vert_integ_velocity (self , ds_mesh , ds ):
319-
320- cells_on_edge = ds_mesh .cellsOnEdge - 1
321- inner_edges = np .logical_and (cells_on_edge .isel (TWO = 0 ) >= 0 ,
322- cells_on_edge .isel (TWO = 1 ) >= 0 )
323-
324- # convert from boolean mask to indices
325- inner_edges = np .flatnonzero (inner_edges .values )
326-
327- cell0 = cells_on_edge .isel (nEdges = inner_edges , TWO = 0 )
328- cell1 = cells_on_edge .isel (nEdges = inner_edges , TWO = 1 )
329- n_vert_levels = ds .sizes ['nVertLevels' ]
330-
331- layer_thickness = ds .timeMonthly_avg_layerThickness
332- max_level_cell = ds_mesh .maxLevelCell - 1
333-
334- vert_index = xr .DataArray .from_dict (
335- {'dims' : ('nVertLevels' ,), 'data' : np .arange (n_vert_levels )})
336- z_mid = compute_zmid (ds_mesh .bottomDepth , max_level_cell ,
337- layer_thickness )
338- z_mid_edge = 0.5 * (z_mid .isel (nCells = cell0 ) +
339- z_mid .isel (nCells = cell1 ))
340-
341- normal_velocity = ds .timeMonthly_avg_normalVelocity
342- if self .include_bolus :
343- normal_velocity += ds .timeMonthly_avg_normalGMBolusVelocity
344- if self .include_submesoscale :
345- normal_velocity += ds .timeMonthly_avg_normalMLEvelocity
346- normal_velocity = normal_velocity .isel (nEdges = inner_edges )
347-
348- layer_thickness_edge = 0.5 * (layer_thickness .isel (nCells = cell0 ) +
349- layer_thickness .isel (nCells = cell1 ))
350- mask_bottom = (vert_index <= max_level_cell ).T
351- mask_bottom_edge = np .logical_and (mask_bottom .isel (nCells = cell0 ),
352- mask_bottom .isel (nCells = cell1 ))
353- masks = [mask_bottom_edge ,
354- z_mid_edge <= self .min_depth ,
355- z_mid_edge >= self .max_depth ]
356- for mask in masks :
357- normal_velocity = normal_velocity .where (mask )
358- layer_thickness_edge = layer_thickness_edge .where (mask )
359-
360- vert_integ_velocity = np .zeros (ds_mesh .dims ['nEdges' ], dtype = float )
361- inner_vert_integ_vel = (
362- (layer_thickness_edge * normal_velocity ).sum (dim = 'nVertLevels' ))
363- vert_integ_velocity [inner_edges ] = inner_vert_integ_vel .values
364-
365- vert_integ_velocity = xr .DataArray (vert_integ_velocity ,
366- dims = ('nEdges' ,))
367-
368- return vert_integ_velocity
369-
370- def _compute_edge_sign_on_vertex (self , ds_mesh ):
371- edges_on_vertex = ds_mesh .edgesOnVertex - 1
372- vertices_on_edge = ds_mesh .verticesOnEdge - 1
373-
374- nvertices = ds_mesh .sizes ['nVertices' ]
375- vertex_degree = ds_mesh .sizes ['vertexDegree' ]
376-
377- edge_sign_on_vertex = np .zeros ((nvertices , vertex_degree ), dtype = int )
378- vertices = np .arange (nvertices )
379- for iedge in range (vertex_degree ):
380- eov = edges_on_vertex .isel (vertexDegree = iedge )
381- valid_edge = eov >= 0
382-
383- v0_on_edge = vertices_on_edge .isel (nEdges = eov , TWO = 0 )
384- v1_on_edge = vertices_on_edge .isel (nEdges = eov , TWO = 1 )
385- valid_edge = np .logical_and (valid_edge , v0_on_edge >= 0 )
386- valid_edge = np .logical_and (valid_edge , v1_on_edge >= 0 )
387-
388- mask = np .logical_and (valid_edge , v0_on_edge == vertices )
389- edge_sign_on_vertex [mask , iedge ] = - 1
390-
391- mask = np .logical_and (valid_edge , v1_on_edge == vertices )
392- edge_sign_on_vertex [mask , iedge ] = 1
393-
394- return edge_sign_on_vertex
395-
396- def _compute_vert_integ_vorticity (self , ds_mesh , vert_integ_velocity ,
397- edge_sign_on_vertex ):
398-
399- area_vertex = ds_mesh .areaTriangle
400- dc_edge = ds_mesh .dcEdge
401- edges_on_vertex = ds_mesh .edgesOnVertex - 1
402-
403- vertex_degree = ds_mesh .sizes ['vertexDegree' ]
404-
405- vert_integ_vorticity = xr .zeros_like (ds_mesh .latVertex )
406- for iedge in range (vertex_degree ):
407- eov = edges_on_vertex .isel (vertexDegree = iedge )
408- edge_sign = edge_sign_on_vertex [:, iedge ]
409- dc = dc_edge .isel (nEdges = eov )
410- vert_integ_vel = vert_integ_velocity .isel (nEdges = eov )
411- vert_integ_vorticity += (
412- dc / area_vertex * edge_sign * vert_integ_vel )
413-
414- return vert_integ_vorticity
415-
416- def _compute_barotropic_streamfunction_vertex (self , ds_mesh , ds ):
417- edge_sign_on_vertex = self ._compute_edge_sign_on_vertex (ds_mesh )
418- vert_integ_velocity = self ._compute_vert_integ_velocity (ds_mesh , ds )
419- vert_integ_vorticity = self ._compute_vert_integ_vorticity (
420- ds_mesh , vert_integ_velocity , edge_sign_on_vertex )
421- self .logger .info ('vertically integrated vorticity computed.' )
422-
423- config = self .config
424- lat_range = config .getexpression (
425- 'climatologyMapBSF' , 'latitudeRangeForZeroBSF' )
426-
427- nvertices = ds_mesh .sizes ['nVertices' ]
428- vertex_degree = ds_mesh .sizes ['vertexDegree' ]
429-
430- cells_on_vertex = ds_mesh .cellsOnVertex - 1
431- edges_on_vertex = ds_mesh .edgesOnVertex - 1
432- vertices_on_edge = ds_mesh .verticesOnEdge - 1
433- area_vertex = ds_mesh .areaTriangle
434- dc_edge = ds_mesh .dcEdge
435- dv_edge = ds_mesh .dvEdge
436-
437- # one equation involving vertex degree + 1 vertices for each vertex
438- # plus 2 entries for the boundary condition and Lagrange multiplier
439- ndata = (vertex_degree + 1 ) * nvertices + 2
440- indices = np .zeros ((2 , ndata ), dtype = int )
441- data = np .zeros (ndata , dtype = float )
442-
443- # the laplacian on the dual mesh of the streamfunction is the
444- # vertically integrated vorticity
445- vertices = np .arange (nvertices , dtype = int )
446- idata = (vertex_degree + 1 ) * vertices + 1
447- indices [0 , idata ] = vertices
448- indices [1 , idata ] = vertices
449- for iedge in range (vertex_degree ):
450- eov = edges_on_vertex .isel (vertexDegree = iedge )
451- dc = dc_edge .isel (nEdges = eov )
452- dv = dv_edge .isel (nEdges = eov )
453-
454- v0 = vertices_on_edge .isel (nEdges = eov , TWO = 0 )
455- v1 = vertices_on_edge .isel (nEdges = eov , TWO = 1 )
456-
457- edge_sign = edge_sign_on_vertex [:, iedge ]
458-
459- mask = v0 == vertices
460- # the difference is v1 - v0, so we want to subtract this vertex
461- # when it is v0 and add it when it is v1
462- this_vert_sign = np .where (mask , - 1. , 1. )
463- # the other vertex is obviously whichever one this is not
464- other_vert_index = np .where (mask , v1 , v0 )
465- # if there are invalid vertices, we need to make sure we don't
466- # index out of bounds. The edge_sign will mask these out
467- other_vert_index = np .where (other_vert_index >= 0 ,
468- other_vert_index , 0 )
469-
470- idata_other = idata + iedge + 1
471-
472- indices [0 , idata ] = vertices
473- indices [1 , idata ] = vertices
474- indices [0 , idata_other ] = vertices
475- indices [1 , idata_other ] = other_vert_index
476-
477- this_data = this_vert_sign * edge_sign * dc / (dv * area_vertex )
478- data [idata ] += this_data
479- data [idata_other ] = - this_data
480-
481- # Now, the boundary condition: To begin with, we set the BSF at the
482- # frist vertext to zero
483- indices [0 , - 2 ] = nvertices
484- indices [1 , - 2 ] = 0
485- data [- 2 ] = 1.
486-
487- # The same in the final column
488- indices [0 , - 1 ] = 0
489- indices [1 , - 1 ] = nvertices
490- data [- 1 ] = 1.
491-
492- # one extra spot for the Lagrange multiplier
493- rhs = np .zeros (nvertices + 1 , dtype = float )
494-
495- rhs [0 :- 1 ] = vert_integ_vorticity .values
496-
497- matrix = scipy .sparse .csr_matrix (
498- (data , indices ),
499- shape = (nvertices + 1 , nvertices + 1 ))
500-
501- solution = scipy .sparse .linalg .spsolve (matrix , rhs )
502-
503- # drop the Lagrange multiplier and convert to Sv with the desired sign
504- # convention
505- bsf_vertex = xr .DataArray (- 1e-6 * solution [0 :- 1 ],
506- dims = ('nVertices' ,))
507-
508- bsf_vertex = _shift_bsf (bsf_vertex , lat_range , cells_on_vertex ,
509- ds_mesh .latVertex )
510-
511- return bsf_vertex
512-
513-
514- def _shift_bsf (bsf_vertex , lat_range , cells_on_vertex , lat_vertex ):
515- """
516- Shift the barotropic streamfunction to be zero at the boundary over
517- the given latitude range
518- """
519- is_boundary_cov = cells_on_vertex == - 1
520- boundary_vertices = is_boundary_cov .sum (dim = 'vertexDegree' ) > 0
521-
522- boundary_vertices = np .logical_and (
523- boundary_vertices ,
524- lat_vertex >= np .deg2rad (lat_range [0 ])
525- )
526- boundary_vertices = np .logical_and (
527- boundary_vertices ,
528- lat_vertex <= np .deg2rad (lat_range [1 ])
529- )
530-
531- # convert from boolean mask to indices
532- boundary_vertices = np .flatnonzero (boundary_vertices .values )
533-
534- mean_boundary_bsf = bsf_vertex .isel (nVertices = boundary_vertices ).mean ()
535-
536- bsf_shifted = bsf_vertex - mean_boundary_bsf
537-
538- return bsf_shifted
0 commit comments