Skip to content

Commit 0165e0b

Browse files
committed
Allow layer_plot() to plot on surface or bottom
1 parent 3e49f94 commit 0165e0b

1 file changed

Lines changed: 24 additions & 5 deletions

File tree

layermesh/mesh.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,8 +1402,11 @@ def track_dist(p): return np.linalg.norm(p - line[0])
14021402
return track
14031403

14041404
def layer_plot(self, lay = -1, **kwargs):
1405-
"""Creates a 2-D Matplotlib plot of the mesh at a specified layer. The
1406-
*lay* parameter can be either a layer object or a layer index.
1405+
"""Creates a 2-D Matplotlib plot of the mesh at a specified layer.
1406+
1407+
The *lay* parameter can be a layer object, a layer index or a
1408+
string with value 'surface' or 'bottom', which gives a plot of
1409+
cells at the mesh surface or bottom.
14071410
14081411
Other optional parameters:
14091412
@@ -1436,6 +1439,8 @@ def layer_plot(self, lay = -1, **kwargs):
14361439
if 'axes' in kwargs: ax = kwargs['axes']
14371440
else: fig, ax = plt.subplots()
14381441

1442+
surface_type = None
1443+
14391444
if 'elevation' in kwargs:
14401445
z = kwargs['elevation']
14411446
lay = self.find(z)
@@ -1450,12 +1455,26 @@ def layer_plot(self, lay = -1, **kwargs):
14501455
lay = self.layer[lay]
14511456
except:
14521457
raise Exception('Unknown layer in layer_plot()')
1458+
elif isinstance(lay, str):
1459+
if lay in ['surface', 'bottom']:
1460+
surface_type = lay
1461+
lay = self._complete_layer
1462+
else:
1463+
raise Exception('Unknown layer in layer_plot()')
14531464

14541465
labels = kwargs.get('label', None)
14551466
label_fmt = kwargs.get('label_format', '%g')
14561467
label_colour = kwargs.get('label_colour', 'black')
14571468
verts = []
1458-
for c in lay.cell:
1469+
1470+
if surface_type is None:
1471+
cells = lay.cell
1472+
elif surface_type == 'surface':
1473+
cells = self.surface_cells
1474+
elif surface_type == 'bottom':
1475+
cells = self.bottom_cells
1476+
1477+
for c in cells:
14591478
col = c.column
14601479
poslist = [tuple([p for p in n.pos])
14611480
for n in col.node]
@@ -1484,12 +1503,12 @@ def layer_plot(self, lay = -1, **kwargs):
14841503
vals = kwargs['value']
14851504
if len(vals) >= self.num_cells:
14861505
vals = np.array(kwargs['value'])
1487-
indices = [c.index for c in lay.cell]
1506+
indices = [c.index for c in cells]
14881507
layer_vals = vals[indices]
14891508
polys.set_array(layer_vals)
14901509
self._plot_colourbar(ax, polys, kwargs)
14911510
if labels == 'value':
1492-
for c in lay.cell:
1511+
for c in cells:
14931512
col = c.column
14941513
col_label = label_fmt % vals[c.index]
14951514
ax.text(col.centre[0], col.centre[1], col_label,

0 commit comments

Comments
 (0)