1111import os
1212
1313import xarray as xr
14- import numpy as np
15- import scipy .sparse
16- import scipy .sparse .linalg
14+
15+ from mpas_tools .ocean .barotropic_streamfunction import (
16+ compute_barotropic_streamfunction ,
17+ shift_barotropic_streamfunction
18+ )
1719
1820from mpas_analysis .shared import AnalysisTask
1921from mpas_analysis .shared .climatology import RemapMpasClimatologySubtask
2022from mpas_analysis .shared .plot import PlotClimatologyMapSubtask
21- from mpas_analysis .ocean .utility import compute_zmid
2223from mpas_analysis .shared .projection import comparison_grid_option_suffixes
2324
2425
@@ -315,8 +316,30 @@ def customize_masked_climatology(self, climatology, season):
315316 'edgesOnVertex' , 'dcEdge' , 'dvEdge' , 'bottomDepth' ,
316317 'maxLevelCell' , 'latVertex' , 'areaTriangle' ,]]
317318 ds_mesh .load ()
318- bsf_vertex = self ._compute_barotropic_streamfunction_vertex (
319- ds_mesh , climatology )
319+
320+ cells_on_vertex = ds_mesh .cellsOnVertex - 1
321+ lat_vertex = ds_mesh .latVertex
322+ bsf_vertex = compute_barotropic_streamfunction (
323+ ds_mesh = ds_mesh ,
324+ ds = climatology ,
325+ min_depth = self .min_depth ,
326+ max_depth = self .max_depth ,
327+ include_bolus = self .include_bolus ,
328+ include_submesoscale = self .include_submesoscale ,
329+ logger = logger ,
330+ )
331+
332+ lat_range = config .getexpression (
333+ self .taskName , 'latitudeRangeForZeroBSF' )
334+
335+ bsf_vertex = shift_barotropic_streamfunction (
336+ bsf_vertex = bsf_vertex ,
337+ lat_range = lat_range ,
338+ cells_on_vertex = cells_on_vertex ,
339+ lat_vertex = lat_vertex ,
340+ logger = logger ,
341+ )
342+
320343 logger .info ('bsf on vertices computed.' )
321344
322345 climatology ['barotropicStreamfunction' ] = bsf_vertex
@@ -340,234 +363,16 @@ def customize_masked_climatology(self, climatology, season):
340363
341364 lat_range = config .getexpression (
342365 config_section_name , 'latitudeRangeForZeroBSF' )
343- climatology [mpas_field_name ] = _shift_bsf (
344- bsf_vertex , lat_range , ds_mesh .cellsOnVertex - 1 ,
345- ds_mesh .latVertex )
366+ climatology [mpas_field_name ] = shift_barotropic_streamfunction (
367+ bsf_vertex = bsf_vertex ,
368+ lat_range = lat_range ,
369+ cells_on_vertex = cells_on_vertex ,
370+ lat_vertex = lat_vertex ,
371+ logger = logger ,
372+ )
346373 climatology [mpas_field_name ].attrs ['units' ] = 'Sv'
347374 climatology [mpas_field_name ].attrs ['description' ] = \
348375 f'barotropic streamfunction at vertices, offset for ' \
349376 f'{ grid_suffix } plots'
350377
351378 return climatology
352-
353- def _compute_vert_integ_velocity (self , ds_mesh , ds ):
354-
355- cells_on_edge = ds_mesh .cellsOnEdge - 1
356- inner_edges = np .logical_and (cells_on_edge .isel (TWO = 0 ) >= 0 ,
357- cells_on_edge .isel (TWO = 1 ) >= 0 )
358-
359- # convert from boolean mask to indices
360- inner_edges = np .flatnonzero (inner_edges .values )
361-
362- cell0 = cells_on_edge .isel (nEdges = inner_edges , TWO = 0 )
363- cell1 = cells_on_edge .isel (nEdges = inner_edges , TWO = 1 )
364- n_vert_levels = ds .sizes ['nVertLevels' ]
365-
366- layer_thickness = ds .timeMonthly_avg_layerThickness
367- max_level_cell = ds_mesh .maxLevelCell - 1
368-
369- vert_index = xr .DataArray .from_dict (
370- {'dims' : ('nVertLevels' ,), 'data' : np .arange (n_vert_levels )})
371- z_mid = compute_zmid (ds_mesh .bottomDepth , max_level_cell ,
372- layer_thickness )
373- z_mid_edge = 0.5 * (z_mid .isel (nCells = cell0 ) +
374- z_mid .isel (nCells = cell1 ))
375-
376- normal_velocity = ds .timeMonthly_avg_normalVelocity
377- if self .include_bolus :
378- normal_velocity += ds .timeMonthly_avg_normalGMBolusVelocity
379- if self .include_submesoscale :
380- normal_velocity += ds .timeMonthly_avg_normalMLEvelocity
381- normal_velocity = normal_velocity .isel (nEdges = inner_edges )
382-
383- layer_thickness_edge = 0.5 * (layer_thickness .isel (nCells = cell0 ) +
384- layer_thickness .isel (nCells = cell1 ))
385- mask_bottom = (vert_index <= max_level_cell ).T
386- mask_bottom_edge = np .logical_and (mask_bottom .isel (nCells = cell0 ),
387- mask_bottom .isel (nCells = cell1 ))
388- masks = [mask_bottom_edge ,
389- z_mid_edge <= self .min_depth ,
390- z_mid_edge >= self .max_depth ]
391- for mask in masks :
392- normal_velocity = normal_velocity .where (mask )
393- layer_thickness_edge = layer_thickness_edge .where (mask )
394-
395- vert_integ_velocity = np .zeros (ds_mesh .dims ['nEdges' ], dtype = float )
396- inner_vert_integ_vel = (
397- (layer_thickness_edge * normal_velocity ).sum (dim = 'nVertLevels' ))
398- vert_integ_velocity [inner_edges ] = inner_vert_integ_vel .values
399-
400- vert_integ_velocity = xr .DataArray (vert_integ_velocity ,
401- dims = ('nEdges' ,))
402-
403- return vert_integ_velocity
404-
405- def _compute_edge_sign_on_vertex (self , ds_mesh ):
406- edges_on_vertex = ds_mesh .edgesOnVertex - 1
407- vertices_on_edge = ds_mesh .verticesOnEdge - 1
408-
409- nvertices = ds_mesh .sizes ['nVertices' ]
410- vertex_degree = ds_mesh .sizes ['vertexDegree' ]
411-
412- edge_sign_on_vertex = np .zeros ((nvertices , vertex_degree ), dtype = int )
413- vertices = np .arange (nvertices )
414- for iedge in range (vertex_degree ):
415- eov = edges_on_vertex .isel (vertexDegree = iedge )
416- valid_edge = eov >= 0
417-
418- v0_on_edge = vertices_on_edge .isel (nEdges = eov , TWO = 0 )
419- v1_on_edge = vertices_on_edge .isel (nEdges = eov , TWO = 1 )
420- valid_edge = np .logical_and (valid_edge , v0_on_edge >= 0 )
421- valid_edge = np .logical_and (valid_edge , v1_on_edge >= 0 )
422-
423- mask = np .logical_and (valid_edge , v0_on_edge == vertices )
424- edge_sign_on_vertex [mask , iedge ] = - 1
425-
426- mask = np .logical_and (valid_edge , v1_on_edge == vertices )
427- edge_sign_on_vertex [mask , iedge ] = 1
428-
429- return edge_sign_on_vertex
430-
431- def _compute_vert_integ_vorticity (self , ds_mesh , vert_integ_velocity ,
432- edge_sign_on_vertex ):
433-
434- area_vertex = ds_mesh .areaTriangle
435- dc_edge = ds_mesh .dcEdge
436- edges_on_vertex = ds_mesh .edgesOnVertex - 1
437-
438- vertex_degree = ds_mesh .sizes ['vertexDegree' ]
439-
440- vert_integ_vorticity = xr .zeros_like (ds_mesh .latVertex )
441- for iedge in range (vertex_degree ):
442- eov = edges_on_vertex .isel (vertexDegree = iedge )
443- edge_sign = edge_sign_on_vertex [:, iedge ]
444- dc = dc_edge .isel (nEdges = eov )
445- vert_integ_vel = vert_integ_velocity .isel (nEdges = eov )
446- vert_integ_vorticity += (
447- dc / area_vertex * edge_sign * vert_integ_vel )
448-
449- return vert_integ_vorticity
450-
451- def _compute_barotropic_streamfunction_vertex (self , ds_mesh , ds ):
452- edge_sign_on_vertex = self ._compute_edge_sign_on_vertex (ds_mesh )
453- vert_integ_velocity = self ._compute_vert_integ_velocity (ds_mesh , ds )
454- vert_integ_vorticity = self ._compute_vert_integ_vorticity (
455- ds_mesh , vert_integ_velocity , edge_sign_on_vertex )
456- self .logger .info ('vertically integrated vorticity computed.' )
457-
458- config = self .config
459- lat_range = config .getexpression (
460- 'climatologyMapBSF' , 'latitudeRangeForZeroBSF' )
461-
462- nvertices = ds_mesh .sizes ['nVertices' ]
463- vertex_degree = ds_mesh .sizes ['vertexDegree' ]
464-
465- cells_on_vertex = ds_mesh .cellsOnVertex - 1
466- edges_on_vertex = ds_mesh .edgesOnVertex - 1
467- vertices_on_edge = ds_mesh .verticesOnEdge - 1
468- area_vertex = ds_mesh .areaTriangle
469- dc_edge = ds_mesh .dcEdge
470- dv_edge = ds_mesh .dvEdge
471-
472- # one equation involving vertex degree + 1 vertices for each vertex
473- # plus 2 entries for the boundary condition and Lagrange multiplier
474- ndata = (vertex_degree + 1 ) * nvertices + 2
475- indices = np .zeros ((2 , ndata ), dtype = int )
476- data = np .zeros (ndata , dtype = float )
477-
478- # the laplacian on the dual mesh of the streamfunction is the
479- # vertically integrated vorticity
480- vertices = np .arange (nvertices , dtype = int )
481- idata = (vertex_degree + 1 ) * vertices + 1
482- indices [0 , idata ] = vertices
483- indices [1 , idata ] = vertices
484- for iedge in range (vertex_degree ):
485- eov = edges_on_vertex .isel (vertexDegree = iedge )
486- dc = dc_edge .isel (nEdges = eov )
487- dv = dv_edge .isel (nEdges = eov )
488-
489- v0 = vertices_on_edge .isel (nEdges = eov , TWO = 0 )
490- v1 = vertices_on_edge .isel (nEdges = eov , TWO = 1 )
491-
492- edge_sign = edge_sign_on_vertex [:, iedge ]
493-
494- mask = v0 == vertices
495- # the difference is v1 - v0, so we want to subtract this vertex
496- # when it is v0 and add it when it is v1
497- this_vert_sign = np .where (mask , - 1. , 1. )
498- # the other vertex is obviously whichever one this is not
499- other_vert_index = np .where (mask , v1 , v0 )
500- # if there are invalid vertices, we need to make sure we don't
501- # index out of bounds. The edge_sign will mask these out
502- other_vert_index = np .where (other_vert_index >= 0 ,
503- other_vert_index , 0 )
504-
505- idata_other = idata + iedge + 1
506-
507- indices [0 , idata ] = vertices
508- indices [1 , idata ] = vertices
509- indices [0 , idata_other ] = vertices
510- indices [1 , idata_other ] = other_vert_index
511-
512- this_data = this_vert_sign * edge_sign * dc / (dv * area_vertex )
513- data [idata ] += this_data
514- data [idata_other ] = - this_data
515-
516- # Now, the boundary condition: To begin with, we set the BSF at the
517- # frist vertext to zero
518- indices [0 , - 2 ] = nvertices
519- indices [1 , - 2 ] = 0
520- data [- 2 ] = 1.
521-
522- # The same in the final column
523- indices [0 , - 1 ] = 0
524- indices [1 , - 1 ] = nvertices
525- data [- 1 ] = 1.
526-
527- # one extra spot for the Lagrange multiplier
528- rhs = np .zeros (nvertices + 1 , dtype = float )
529-
530- rhs [0 :- 1 ] = vert_integ_vorticity .values
531-
532- matrix = scipy .sparse .csr_matrix (
533- (data , indices ),
534- shape = (nvertices + 1 , nvertices + 1 ))
535-
536- solution = scipy .sparse .linalg .spsolve (matrix , rhs )
537-
538- # drop the Lagrange multiplier and convert to Sv with the desired sign
539- # convention
540- bsf_vertex = xr .DataArray (- 1e-6 * solution [0 :- 1 ],
541- dims = ('nVertices' ,))
542-
543- bsf_vertex = _shift_bsf (bsf_vertex , lat_range , cells_on_vertex ,
544- ds_mesh .latVertex )
545-
546- return bsf_vertex
547-
548-
549- def _shift_bsf (bsf_vertex , lat_range , cells_on_vertex , lat_vertex ):
550- """
551- Shift the barotropic streamfunction to be zero at the boundary over
552- the given latitude range
553- """
554- is_boundary_cov = cells_on_vertex == - 1
555- boundary_vertices = is_boundary_cov .sum (dim = 'vertexDegree' ) > 0
556-
557- boundary_vertices = np .logical_and (
558- boundary_vertices ,
559- lat_vertex >= np .deg2rad (lat_range [0 ])
560- )
561- boundary_vertices = np .logical_and (
562- boundary_vertices ,
563- lat_vertex <= np .deg2rad (lat_range [1 ])
564- )
565-
566- # convert from boolean mask to indices
567- boundary_vertices = np .flatnonzero (boundary_vertices .values )
568-
569- mean_boundary_bsf = bsf_vertex .isel (nVertices = boundary_vertices ).mean ()
570-
571- bsf_shifted = bsf_vertex - mean_boundary_bsf
572-
573- return bsf_shifted
0 commit comments