Skip to content

Commit e9d0bc6

Browse files
committed
Switch to using barotropic streamfunction from mpas_tools
1 parent fbddf79 commit e9d0bc6

1 file changed

Lines changed: 36 additions & 231 deletions

File tree

mpas_analysis/ocean/climatology_map_bsf.py

Lines changed: 36 additions & 231 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
import os
1212

1313
import xarray as xr
14-
import numpy as np
15-
import scipy.sparse
16-
import scipy.sparse.linalg
14+
15+
from mpas_tools.ocean.barotropic_streamfunction import (
16+
compute_barotropic_streamfunction,
17+
shift_barotropic_streamfunction
18+
)
1719

1820
from mpas_analysis.shared import AnalysisTask
1921
from mpas_analysis.shared.climatology import RemapMpasClimatologySubtask
2022
from mpas_analysis.shared.plot import PlotClimatologyMapSubtask
21-
from mpas_analysis.ocean.utility import compute_zmid
2223
from mpas_analysis.shared.projection import comparison_grid_option_suffixes
2324

2425

@@ -315,8 +316,30 @@ def customize_masked_climatology(self, climatology, season):
315316
'edgesOnVertex', 'dcEdge', 'dvEdge', 'bottomDepth',
316317
'maxLevelCell', 'latVertex', 'areaTriangle',]]
317318
ds_mesh.load()
318-
bsf_vertex = self._compute_barotropic_streamfunction_vertex(
319-
ds_mesh, climatology)
319+
320+
cells_on_vertex = ds_mesh.cellsOnVertex - 1
321+
lat_vertex = ds_mesh.latVertex
322+
bsf_vertex = compute_barotropic_streamfunction(
323+
ds_mesh=ds_mesh,
324+
ds=climatology,
325+
min_depth=self.min_depth,
326+
max_depth=self.max_depth,
327+
include_bolus=self.include_bolus,
328+
include_submesoscale=self.include_submesoscale,
329+
logger=logger,
330+
)
331+
332+
lat_range = config.getexpression(
333+
self.taskName, 'latitudeRangeForZeroBSF')
334+
335+
bsf_vertex = shift_barotropic_streamfunction(
336+
bsf_vertex=bsf_vertex,
337+
lat_range=lat_range,
338+
cells_on_vertex=cells_on_vertex,
339+
lat_vertex=lat_vertex,
340+
logger=logger,
341+
)
342+
320343
logger.info('bsf on vertices computed.')
321344

322345
climatology['barotropicStreamfunction'] = bsf_vertex
@@ -340,234 +363,16 @@ def customize_masked_climatology(self, climatology, season):
340363

341364
lat_range = config.getexpression(
342365
config_section_name, 'latitudeRangeForZeroBSF')
343-
climatology[mpas_field_name] = _shift_bsf(
344-
bsf_vertex, lat_range, ds_mesh.cellsOnVertex - 1,
345-
ds_mesh.latVertex)
366+
climatology[mpas_field_name] = shift_barotropic_streamfunction(
367+
bsf_vertex=bsf_vertex,
368+
lat_range=lat_range,
369+
cells_on_vertex=cells_on_vertex,
370+
lat_vertex=lat_vertex,
371+
logger=logger,
372+
)
346373
climatology[mpas_field_name].attrs['units'] = 'Sv'
347374
climatology[mpas_field_name].attrs['description'] = \
348375
f'barotropic streamfunction at vertices, offset for ' \
349376
f'{grid_suffix} plots'
350377

351378
return climatology
352-
353-
def _compute_vert_integ_velocity(self, ds_mesh, ds):
354-
355-
cells_on_edge = ds_mesh.cellsOnEdge - 1
356-
inner_edges = np.logical_and(cells_on_edge.isel(TWO=0) >= 0,
357-
cells_on_edge.isel(TWO=1) >= 0)
358-
359-
# convert from boolean mask to indices
360-
inner_edges = np.flatnonzero(inner_edges.values)
361-
362-
cell0 = cells_on_edge.isel(nEdges=inner_edges, TWO=0)
363-
cell1 = cells_on_edge.isel(nEdges=inner_edges, TWO=1)
364-
n_vert_levels = ds.sizes['nVertLevels']
365-
366-
layer_thickness = ds.timeMonthly_avg_layerThickness
367-
max_level_cell = ds_mesh.maxLevelCell - 1
368-
369-
vert_index = xr.DataArray.from_dict(
370-
{'dims': ('nVertLevels',), 'data': np.arange(n_vert_levels)})
371-
z_mid = compute_zmid(ds_mesh.bottomDepth, max_level_cell,
372-
layer_thickness)
373-
z_mid_edge = 0.5*(z_mid.isel(nCells=cell0) +
374-
z_mid.isel(nCells=cell1))
375-
376-
normal_velocity = ds.timeMonthly_avg_normalVelocity
377-
if self.include_bolus:
378-
normal_velocity += ds.timeMonthly_avg_normalGMBolusVelocity
379-
if self.include_submesoscale:
380-
normal_velocity += ds.timeMonthly_avg_normalMLEvelocity
381-
normal_velocity = normal_velocity.isel(nEdges=inner_edges)
382-
383-
layer_thickness_edge = 0.5*(layer_thickness.isel(nCells=cell0) +
384-
layer_thickness.isel(nCells=cell1))
385-
mask_bottom = (vert_index <= max_level_cell).T
386-
mask_bottom_edge = np.logical_and(mask_bottom.isel(nCells=cell0),
387-
mask_bottom.isel(nCells=cell1))
388-
masks = [mask_bottom_edge,
389-
z_mid_edge <= self.min_depth,
390-
z_mid_edge >= self.max_depth]
391-
for mask in masks:
392-
normal_velocity = normal_velocity.where(mask)
393-
layer_thickness_edge = layer_thickness_edge.where(mask)
394-
395-
vert_integ_velocity = np.zeros(ds_mesh.dims['nEdges'], dtype=float)
396-
inner_vert_integ_vel = (
397-
(layer_thickness_edge * normal_velocity).sum(dim='nVertLevels'))
398-
vert_integ_velocity[inner_edges] = inner_vert_integ_vel.values
399-
400-
vert_integ_velocity = xr.DataArray(vert_integ_velocity,
401-
dims=('nEdges',))
402-
403-
return vert_integ_velocity
404-
405-
def _compute_edge_sign_on_vertex(self, ds_mesh):
406-
edges_on_vertex = ds_mesh.edgesOnVertex - 1
407-
vertices_on_edge = ds_mesh.verticesOnEdge - 1
408-
409-
nvertices = ds_mesh.sizes['nVertices']
410-
vertex_degree = ds_mesh.sizes['vertexDegree']
411-
412-
edge_sign_on_vertex = np.zeros((nvertices, vertex_degree), dtype=int)
413-
vertices = np.arange(nvertices)
414-
for iedge in range(vertex_degree):
415-
eov = edges_on_vertex.isel(vertexDegree=iedge)
416-
valid_edge = eov >= 0
417-
418-
v0_on_edge = vertices_on_edge.isel(nEdges=eov, TWO=0)
419-
v1_on_edge = vertices_on_edge.isel(nEdges=eov, TWO=1)
420-
valid_edge = np.logical_and(valid_edge, v0_on_edge >= 0)
421-
valid_edge = np.logical_and(valid_edge, v1_on_edge >= 0)
422-
423-
mask = np.logical_and(valid_edge, v0_on_edge == vertices)
424-
edge_sign_on_vertex[mask, iedge] = -1
425-
426-
mask = np.logical_and(valid_edge, v1_on_edge == vertices)
427-
edge_sign_on_vertex[mask, iedge] = 1
428-
429-
return edge_sign_on_vertex
430-
431-
def _compute_vert_integ_vorticity(self, ds_mesh, vert_integ_velocity,
432-
edge_sign_on_vertex):
433-
434-
area_vertex = ds_mesh.areaTriangle
435-
dc_edge = ds_mesh.dcEdge
436-
edges_on_vertex = ds_mesh.edgesOnVertex - 1
437-
438-
vertex_degree = ds_mesh.sizes['vertexDegree']
439-
440-
vert_integ_vorticity = xr.zeros_like(ds_mesh.latVertex)
441-
for iedge in range(vertex_degree):
442-
eov = edges_on_vertex.isel(vertexDegree=iedge)
443-
edge_sign = edge_sign_on_vertex[:, iedge]
444-
dc = dc_edge.isel(nEdges=eov)
445-
vert_integ_vel = vert_integ_velocity.isel(nEdges=eov)
446-
vert_integ_vorticity += (
447-
dc / area_vertex * edge_sign * vert_integ_vel)
448-
449-
return vert_integ_vorticity
450-
451-
def _compute_barotropic_streamfunction_vertex(self, ds_mesh, ds):
452-
edge_sign_on_vertex = self._compute_edge_sign_on_vertex(ds_mesh)
453-
vert_integ_velocity = self._compute_vert_integ_velocity(ds_mesh, ds)
454-
vert_integ_vorticity = self._compute_vert_integ_vorticity(
455-
ds_mesh, vert_integ_velocity, edge_sign_on_vertex)
456-
self.logger.info('vertically integrated vorticity computed.')
457-
458-
config = self.config
459-
lat_range = config.getexpression(
460-
'climatologyMapBSF', 'latitudeRangeForZeroBSF')
461-
462-
nvertices = ds_mesh.sizes['nVertices']
463-
vertex_degree = ds_mesh.sizes['vertexDegree']
464-
465-
cells_on_vertex = ds_mesh.cellsOnVertex - 1
466-
edges_on_vertex = ds_mesh.edgesOnVertex - 1
467-
vertices_on_edge = ds_mesh.verticesOnEdge - 1
468-
area_vertex = ds_mesh.areaTriangle
469-
dc_edge = ds_mesh.dcEdge
470-
dv_edge = ds_mesh.dvEdge
471-
472-
# one equation involving vertex degree + 1 vertices for each vertex
473-
# plus 2 entries for the boundary condition and Lagrange multiplier
474-
ndata = (vertex_degree + 1) * nvertices + 2
475-
indices = np.zeros((2, ndata), dtype=int)
476-
data = np.zeros(ndata, dtype=float)
477-
478-
# the laplacian on the dual mesh of the streamfunction is the
479-
# vertically integrated vorticity
480-
vertices = np.arange(nvertices, dtype=int)
481-
idata = (vertex_degree + 1) * vertices + 1
482-
indices[0, idata] = vertices
483-
indices[1, idata] = vertices
484-
for iedge in range(vertex_degree):
485-
eov = edges_on_vertex.isel(vertexDegree=iedge)
486-
dc = dc_edge.isel(nEdges=eov)
487-
dv = dv_edge.isel(nEdges=eov)
488-
489-
v0 = vertices_on_edge.isel(nEdges=eov, TWO=0)
490-
v1 = vertices_on_edge.isel(nEdges=eov, TWO=1)
491-
492-
edge_sign = edge_sign_on_vertex[:, iedge]
493-
494-
mask = v0 == vertices
495-
# the difference is v1 - v0, so we want to subtract this vertex
496-
# when it is v0 and add it when it is v1
497-
this_vert_sign = np.where(mask, -1., 1.)
498-
# the other vertex is obviously whichever one this is not
499-
other_vert_index = np.where(mask, v1, v0)
500-
# if there are invalid vertices, we need to make sure we don't
501-
# index out of bounds. The edge_sign will mask these out
502-
other_vert_index = np.where(other_vert_index >= 0,
503-
other_vert_index, 0)
504-
505-
idata_other = idata + iedge + 1
506-
507-
indices[0, idata] = vertices
508-
indices[1, idata] = vertices
509-
indices[0, idata_other] = vertices
510-
indices[1, idata_other] = other_vert_index
511-
512-
this_data = this_vert_sign * edge_sign * dc / (dv * area_vertex)
513-
data[idata] += this_data
514-
data[idata_other] = -this_data
515-
516-
# Now, the boundary condition: To begin with, we set the BSF at the
517-
# frist vertext to zero
518-
indices[0, -2] = nvertices
519-
indices[1, -2] = 0
520-
data[-2] = 1.
521-
522-
# The same in the final column
523-
indices[0, -1] = 0
524-
indices[1, -1] = nvertices
525-
data[-1] = 1.
526-
527-
# one extra spot for the Lagrange multiplier
528-
rhs = np.zeros(nvertices + 1, dtype=float)
529-
530-
rhs[0:-1] = vert_integ_vorticity.values
531-
532-
matrix = scipy.sparse.csr_matrix(
533-
(data, indices),
534-
shape=(nvertices + 1, nvertices + 1))
535-
536-
solution = scipy.sparse.linalg.spsolve(matrix, rhs)
537-
538-
# drop the Lagrange multiplier and convert to Sv with the desired sign
539-
# convention
540-
bsf_vertex = xr.DataArray(-1e-6 * solution[0:-1],
541-
dims=('nVertices',))
542-
543-
bsf_vertex = _shift_bsf(bsf_vertex, lat_range, cells_on_vertex,
544-
ds_mesh.latVertex)
545-
546-
return bsf_vertex
547-
548-
549-
def _shift_bsf(bsf_vertex, lat_range, cells_on_vertex, lat_vertex):
550-
"""
551-
Shift the barotropic streamfunction to be zero at the boundary over
552-
the given latitude range
553-
"""
554-
is_boundary_cov = cells_on_vertex == -1
555-
boundary_vertices = is_boundary_cov.sum(dim='vertexDegree') > 0
556-
557-
boundary_vertices = np.logical_and(
558-
boundary_vertices,
559-
lat_vertex >= np.deg2rad(lat_range[0])
560-
)
561-
boundary_vertices = np.logical_and(
562-
boundary_vertices,
563-
lat_vertex <= np.deg2rad(lat_range[1])
564-
)
565-
566-
# convert from boolean mask to indices
567-
boundary_vertices = np.flatnonzero(boundary_vertices.values)
568-
569-
mean_boundary_bsf = bsf_vertex.isel(nVertices=boundary_vertices).mean()
570-
571-
bsf_shifted = bsf_vertex - mean_boundary_bsf
572-
573-
return bsf_shifted

0 commit comments

Comments
 (0)