Skip to content

Commit 15d083f

Browse files
committed
Switch to using barotropic streamfunction from mpas_tools
1 parent 3c079f6 commit 15d083f

1 file changed

Lines changed: 23 additions & 231 deletions

File tree

mpas_analysis/ocean/climatology_map_bsf.py

Lines changed: 23 additions & 231 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99
# distributed with this code, or at
1010
# https://raw.githubusercontent.com/MPAS-Dev/MPAS-Analysis/main/LICENSE
1111
import xarray as xr
12-
import numpy as np
13-
import scipy.sparse
14-
import scipy.sparse.linalg
12+
13+
from mpas_tools.ocean.barotropic_streamfunction import (
14+
compute_barotropic_streamfunction,
15+
shift_barotropic_streamfunction
16+
)
1517

1618
from mpas_analysis.shared import AnalysisTask
1719
from mpas_analysis.shared.climatology import RemapMpasClimatologySubtask
1820
from mpas_analysis.shared.plot import PlotClimatologyMapSubtask
19-
from mpas_analysis.ocean.utility import compute_zmid
2021
from mpas_analysis.shared.projection import comparison_grid_option_suffixes
2122

2223

@@ -280,8 +281,21 @@ def customize_masked_climatology(self, climatology, season):
280281
'edgesOnVertex', 'dcEdge', 'dvEdge', 'bottomDepth',
281282
'maxLevelCell', 'latVertex', 'areaTriangle',]]
282283
ds_mesh.load()
283-
bsf_vertex = self._compute_barotropic_streamfunction_vertex(
284-
ds_mesh, climatology)
284+
285+
cells_on_vertex = ds_mesh.cellsOnVertex - 1
286+
lat_vertex = ds_mesh.latVertex
287+
bsf_vertex = compute_barotropic_streamfunction(
288+
ds_mesh=ds_mesh, ds=climatology, min_depth=self.min_depth,
289+
max_depth=self.max_depth, include_bolus=self.include_bolus,
290+
include_submesoscale=self.include_submesoscale)
291+
292+
lat_range = config.getexpression(
293+
self.taskName, 'latitudeRangeForZeroBSF')
294+
295+
bsf_vertex = shift_barotropic_streamfunction(
296+
bsf_vertex=bsf_vertex, lat_range=lat_range,
297+
cells_on_vertex=cells_on_vertex, lat_vertex=lat_vertex)
298+
285299
logger.info('bsf on vertices computed.')
286300

287301
climatology['barotropicStreamfunction'] = bsf_vertex
@@ -305,234 +319,12 @@ def customize_masked_climatology(self, climatology, season):
305319

306320
lat_range = config.getexpression(
307321
config_section_name, 'latitudeRangeForZeroBSF')
308-
climatology[mpas_field_name] = _shift_bsf(
309-
bsf_vertex, lat_range, ds_mesh.cellsOnVertex - 1,
310-
ds_mesh.latVertex)
322+
climatology[mpas_field_name] = shift_barotropic_streamfunction(
323+
bsf_vertex=bsf_vertex, lat_range=lat_range,
324+
cells_on_vertex=cells_on_vertex, lat_vertex=lat_vertex)
311325
climatology[mpas_field_name].attrs['units'] = 'Sv'
312326
climatology[mpas_field_name].attrs['description'] = \
313327
f'barotropic streamfunction at vertices, offset for ' \
314328
f'{grid_suffix} plots'
315329

316330
return climatology
317-
318-
def _compute_vert_integ_velocity(self, ds_mesh, ds):
319-
320-
cells_on_edge = ds_mesh.cellsOnEdge - 1
321-
inner_edges = np.logical_and(cells_on_edge.isel(TWO=0) >= 0,
322-
cells_on_edge.isel(TWO=1) >= 0)
323-
324-
# convert from boolean mask to indices
325-
inner_edges = np.flatnonzero(inner_edges.values)
326-
327-
cell0 = cells_on_edge.isel(nEdges=inner_edges, TWO=0)
328-
cell1 = cells_on_edge.isel(nEdges=inner_edges, TWO=1)
329-
n_vert_levels = ds.sizes['nVertLevels']
330-
331-
layer_thickness = ds.timeMonthly_avg_layerThickness
332-
max_level_cell = ds_mesh.maxLevelCell - 1
333-
334-
vert_index = xr.DataArray.from_dict(
335-
{'dims': ('nVertLevels',), 'data': np.arange(n_vert_levels)})
336-
z_mid = compute_zmid(ds_mesh.bottomDepth, max_level_cell,
337-
layer_thickness)
338-
z_mid_edge = 0.5*(z_mid.isel(nCells=cell0) +
339-
z_mid.isel(nCells=cell1))
340-
341-
normal_velocity = ds.timeMonthly_avg_normalVelocity
342-
if self.include_bolus:
343-
normal_velocity += ds.timeMonthly_avg_normalGMBolusVelocity
344-
if self.include_submesoscale:
345-
normal_velocity += ds.timeMonthly_avg_normalMLEvelocity
346-
normal_velocity = normal_velocity.isel(nEdges=inner_edges)
347-
348-
layer_thickness_edge = 0.5*(layer_thickness.isel(nCells=cell0) +
349-
layer_thickness.isel(nCells=cell1))
350-
mask_bottom = (vert_index <= max_level_cell).T
351-
mask_bottom_edge = np.logical_and(mask_bottom.isel(nCells=cell0),
352-
mask_bottom.isel(nCells=cell1))
353-
masks = [mask_bottom_edge,
354-
z_mid_edge <= self.min_depth,
355-
z_mid_edge >= self.max_depth]
356-
for mask in masks:
357-
normal_velocity = normal_velocity.where(mask)
358-
layer_thickness_edge = layer_thickness_edge.where(mask)
359-
360-
vert_integ_velocity = np.zeros(ds_mesh.dims['nEdges'], dtype=float)
361-
inner_vert_integ_vel = (
362-
(layer_thickness_edge * normal_velocity).sum(dim='nVertLevels'))
363-
vert_integ_velocity[inner_edges] = inner_vert_integ_vel.values
364-
365-
vert_integ_velocity = xr.DataArray(vert_integ_velocity,
366-
dims=('nEdges',))
367-
368-
return vert_integ_velocity
369-
370-
def _compute_edge_sign_on_vertex(self, ds_mesh):
371-
edges_on_vertex = ds_mesh.edgesOnVertex - 1
372-
vertices_on_edge = ds_mesh.verticesOnEdge - 1
373-
374-
nvertices = ds_mesh.sizes['nVertices']
375-
vertex_degree = ds_mesh.sizes['vertexDegree']
376-
377-
edge_sign_on_vertex = np.zeros((nvertices, vertex_degree), dtype=int)
378-
vertices = np.arange(nvertices)
379-
for iedge in range(vertex_degree):
380-
eov = edges_on_vertex.isel(vertexDegree=iedge)
381-
valid_edge = eov >= 0
382-
383-
v0_on_edge = vertices_on_edge.isel(nEdges=eov, TWO=0)
384-
v1_on_edge = vertices_on_edge.isel(nEdges=eov, TWO=1)
385-
valid_edge = np.logical_and(valid_edge, v0_on_edge >= 0)
386-
valid_edge = np.logical_and(valid_edge, v1_on_edge >= 0)
387-
388-
mask = np.logical_and(valid_edge, v0_on_edge == vertices)
389-
edge_sign_on_vertex[mask, iedge] = -1
390-
391-
mask = np.logical_and(valid_edge, v1_on_edge == vertices)
392-
edge_sign_on_vertex[mask, iedge] = 1
393-
394-
return edge_sign_on_vertex
395-
396-
def _compute_vert_integ_vorticity(self, ds_mesh, vert_integ_velocity,
397-
edge_sign_on_vertex):
398-
399-
area_vertex = ds_mesh.areaTriangle
400-
dc_edge = ds_mesh.dcEdge
401-
edges_on_vertex = ds_mesh.edgesOnVertex - 1
402-
403-
vertex_degree = ds_mesh.sizes['vertexDegree']
404-
405-
vert_integ_vorticity = xr.zeros_like(ds_mesh.latVertex)
406-
for iedge in range(vertex_degree):
407-
eov = edges_on_vertex.isel(vertexDegree=iedge)
408-
edge_sign = edge_sign_on_vertex[:, iedge]
409-
dc = dc_edge.isel(nEdges=eov)
410-
vert_integ_vel = vert_integ_velocity.isel(nEdges=eov)
411-
vert_integ_vorticity += (
412-
dc / area_vertex * edge_sign * vert_integ_vel)
413-
414-
return vert_integ_vorticity
415-
416-
def _compute_barotropic_streamfunction_vertex(self, ds_mesh, ds):
417-
edge_sign_on_vertex = self._compute_edge_sign_on_vertex(ds_mesh)
418-
vert_integ_velocity = self._compute_vert_integ_velocity(ds_mesh, ds)
419-
vert_integ_vorticity = self._compute_vert_integ_vorticity(
420-
ds_mesh, vert_integ_velocity, edge_sign_on_vertex)
421-
self.logger.info('vertically integrated vorticity computed.')
422-
423-
config = self.config
424-
lat_range = config.getexpression(
425-
'climatologyMapBSF', 'latitudeRangeForZeroBSF')
426-
427-
nvertices = ds_mesh.sizes['nVertices']
428-
vertex_degree = ds_mesh.sizes['vertexDegree']
429-
430-
cells_on_vertex = ds_mesh.cellsOnVertex - 1
431-
edges_on_vertex = ds_mesh.edgesOnVertex - 1
432-
vertices_on_edge = ds_mesh.verticesOnEdge - 1
433-
area_vertex = ds_mesh.areaTriangle
434-
dc_edge = ds_mesh.dcEdge
435-
dv_edge = ds_mesh.dvEdge
436-
437-
# one equation involving vertex degree + 1 vertices for each vertex
438-
# plus 2 entries for the boundary condition and Lagrange multiplier
439-
ndata = (vertex_degree + 1) * nvertices + 2
440-
indices = np.zeros((2, ndata), dtype=int)
441-
data = np.zeros(ndata, dtype=float)
442-
443-
# the laplacian on the dual mesh of the streamfunction is the
444-
# vertically integrated vorticity
445-
vertices = np.arange(nvertices, dtype=int)
446-
idata = (vertex_degree + 1) * vertices + 1
447-
indices[0, idata] = vertices
448-
indices[1, idata] = vertices
449-
for iedge in range(vertex_degree):
450-
eov = edges_on_vertex.isel(vertexDegree=iedge)
451-
dc = dc_edge.isel(nEdges=eov)
452-
dv = dv_edge.isel(nEdges=eov)
453-
454-
v0 = vertices_on_edge.isel(nEdges=eov, TWO=0)
455-
v1 = vertices_on_edge.isel(nEdges=eov, TWO=1)
456-
457-
edge_sign = edge_sign_on_vertex[:, iedge]
458-
459-
mask = v0 == vertices
460-
# the difference is v1 - v0, so we want to subtract this vertex
461-
# when it is v0 and add it when it is v1
462-
this_vert_sign = np.where(mask, -1., 1.)
463-
# the other vertex is obviously whichever one this is not
464-
other_vert_index = np.where(mask, v1, v0)
465-
# if there are invalid vertices, we need to make sure we don't
466-
# index out of bounds. The edge_sign will mask these out
467-
other_vert_index = np.where(other_vert_index >= 0,
468-
other_vert_index, 0)
469-
470-
idata_other = idata + iedge + 1
471-
472-
indices[0, idata] = vertices
473-
indices[1, idata] = vertices
474-
indices[0, idata_other] = vertices
475-
indices[1, idata_other] = other_vert_index
476-
477-
this_data = this_vert_sign * edge_sign * dc / (dv * area_vertex)
478-
data[idata] += this_data
479-
data[idata_other] = -this_data
480-
481-
# Now, the boundary condition: To begin with, we set the BSF at the
482-
# frist vertext to zero
483-
indices[0, -2] = nvertices
484-
indices[1, -2] = 0
485-
data[-2] = 1.
486-
487-
# The same in the final column
488-
indices[0, -1] = 0
489-
indices[1, -1] = nvertices
490-
data[-1] = 1.
491-
492-
# one extra spot for the Lagrange multiplier
493-
rhs = np.zeros(nvertices + 1, dtype=float)
494-
495-
rhs[0:-1] = vert_integ_vorticity.values
496-
497-
matrix = scipy.sparse.csr_matrix(
498-
(data, indices),
499-
shape=(nvertices + 1, nvertices + 1))
500-
501-
solution = scipy.sparse.linalg.spsolve(matrix, rhs)
502-
503-
# drop the Lagrange multiplier and convert to Sv with the desired sign
504-
# convention
505-
bsf_vertex = xr.DataArray(-1e-6 * solution[0:-1],
506-
dims=('nVertices',))
507-
508-
bsf_vertex = _shift_bsf(bsf_vertex, lat_range, cells_on_vertex,
509-
ds_mesh.latVertex)
510-
511-
return bsf_vertex
512-
513-
514-
def _shift_bsf(bsf_vertex, lat_range, cells_on_vertex, lat_vertex):
515-
"""
516-
Shift the barotropic streamfunction to be zero at the boundary over
517-
the given latitude range
518-
"""
519-
is_boundary_cov = cells_on_vertex == -1
520-
boundary_vertices = is_boundary_cov.sum(dim='vertexDegree') > 0
521-
522-
boundary_vertices = np.logical_and(
523-
boundary_vertices,
524-
lat_vertex >= np.deg2rad(lat_range[0])
525-
)
526-
boundary_vertices = np.logical_and(
527-
boundary_vertices,
528-
lat_vertex <= np.deg2rad(lat_range[1])
529-
)
530-
531-
# convert from boolean mask to indices
532-
boundary_vertices = np.flatnonzero(boundary_vertices.values)
533-
534-
mean_boundary_bsf = bsf_vertex.isel(nVertices=boundary_vertices).mean()
535-
536-
bsf_shifted = bsf_vertex - mean_boundary_bsf
537-
538-
return bsf_shifted

0 commit comments

Comments
 (0)