@@ -1402,8 +1402,11 @@ def track_dist(p): return np.linalg.norm(p - line[0])
14021402 return track
14031403
14041404 def layer_plot (self , lay = - 1 , ** kwargs ):
1405- """Creates a 2-D Matplotlib plot of the mesh at a specified layer. The
1406- *lay* parameter can be either a layer object or a layer index.
1405+ """Creates a 2-D Matplotlib plot of the mesh at a specified layer.
1406+
1407+ The *lay* parameter can be a layer object, a layer index or a
1408+ string with value 'surface' or 'bottom', which gives a plot of
1409+ cells at the mesh surface or bottom.
14071410
14081411 Other optional parameters:
14091412
@@ -1436,6 +1439,8 @@ def layer_plot(self, lay = -1, **kwargs):
14361439 if 'axes' in kwargs : ax = kwargs ['axes' ]
14371440 else : fig , ax = plt .subplots ()
14381441
1442+ surface_type = None
1443+
14391444 if 'elevation' in kwargs :
14401445 z = kwargs ['elevation' ]
14411446 lay = self .find (z )
@@ -1450,12 +1455,26 @@ def layer_plot(self, lay = -1, **kwargs):
14501455 lay = self .layer [lay ]
14511456 except :
14521457 raise Exception ('Unknown layer in layer_plot()' )
1458+ elif isinstance (lay , str ):
1459+ if lay in ['surface' , 'bottom' ]:
1460+ surface_type = lay
1461+ lay = self ._complete_layer
1462+ else :
1463+ raise Exception ('Unknown layer in layer_plot()' )
14531464
14541465 labels = kwargs .get ('label' , None )
14551466 label_fmt = kwargs .get ('label_format' , '%g' )
14561467 label_colour = kwargs .get ('label_colour' , 'black' )
14571468 verts = []
1458- for c in lay .cell :
1469+
1470+ if surface_type is None :
1471+ cells = lay .cell
1472+ elif surface_type == 'surface' :
1473+ cells = self .surface_cells
1474+ elif surface_type == 'bottom' :
1475+ cells = self .bottom_cells
1476+
1477+ for c in cells :
14591478 col = c .column
14601479 poslist = [tuple ([p for p in n .pos ])
14611480 for n in col .node ]
@@ -1484,12 +1503,12 @@ def layer_plot(self, lay = -1, **kwargs):
14841503 vals = kwargs ['value' ]
14851504 if len (vals ) >= self .num_cells :
14861505 vals = np .array (kwargs ['value' ])
1487- indices = [c .index for c in lay . cell ]
1506+ indices = [c .index for c in cells ]
14881507 layer_vals = vals [indices ]
14891508 polys .set_array (layer_vals )
14901509 self ._plot_colourbar (ax , polys , kwargs )
14911510 if labels == 'value' :
1492- for c in lay . cell :
1511+ for c in cells :
14931512 col = c .column
14941513 col_label = label_fmt % vals [c .index ]
14951514 ax .text (col .centre [0 ], col .centre [1 ], col_label ,
0 commit comments