Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1a2b249
added new plotting functions for VOM + unit tests
achanbour Apr 14, 2026
8a9a99f
refactored vom plotting functions
achanbour Apr 17, 2026
c4b8892
fixed flake8 errors
achanbour Apr 17, 2026
e9510ec
fixed flake8 errors in test file
achanbour Apr 17, 2026
4c8a72a
updated docstrings to match matlotlib obj refs requirements
achanbour Apr 17, 2026
fa3e1da
first set of fixes to vom plotting methods
achanbour Apr 17, 2026
7f62e01
fixed docstrings and added plotting doc to manual source along with c…
achanbour Apr 17, 2026
7014ffe
renamed pointplot to scatter
achanbour Apr 17, 2026
a7159dc
renamed plotting method in all scripts, updated docs and example script
achanbour Apr 17, 2026
5684217
fixed flake8 errors
achanbour Apr 17, 2026
a52757d
fixed flake8 errors in test file
achanbour Apr 17, 2026
fe861a4
Update firedrake/pyplot/mpl.py
achanbour Apr 17, 2026
d30fac7
added serial run checks and exceptions to plotting methods
achanbour Apr 20, 2026
14bb72d
fixed flake8 errors in mpl.py
achanbour Apr 20, 2026
134f9e1
fixed flake8 errors in exceptions.py
achanbour Apr 20, 2026
5eaf524
Update firedrake/pyplot/mpl.py scatter exception 1
achanbour Apr 20, 2026
a3ed223
Update firedrake/pyplot/mpl.py scatter exception
achanbour Apr 20, 2026
1977670
Update firedrake/pyplot/mpl.py
achanbour Apr 20, 2026
00b7a42
Update firedrake/pyplot/mpl.py scatter exception
achanbour Apr 20, 2026
d59eed8
Update firedrake/pyplot/mpl.py quiver docstring
achanbour Apr 20, 2026
3d41501
updated scatter docstring to link to VertexOnlyMesh
achanbour Apr 20, 2026
db54e99
modified vom argument name in scatter
achanbour Apr 22, 2026
93be5c2
changed scatter arg name to vom_or_function
achanbour Apr 22, 2026
132ca8e
merged with main and fixed conflicts
achanbour Apr 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions docs/source/visualisation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,49 @@ points per element could be specfied to when calling :func:`plot
To install matplotlib_, please look at the installation instructions of
matplotlib.

Visualising a vertex-only mesh
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Firedrake supports the visualisation of point data, represented as a :func:`~.VertexOnlyMesh`, in much the same way as its other plotting routines.
In particular, :func:`scatter <firedrake.pyplot.scatter>` wraps matplotlib's ``scatter`` method and uses the coordinates of the mesh's constituent points to produce a scatter plot.
Although separate from :func:`triplot <firedrake.pyplot.triplot>`, it makes most sense to use it in conjuction with :func:`triplot <firedrake.pyplot.triplot>`
which makes apparent the embedding of the vertex-only mesh inside its parent mesh. As the below code demonstrates, :func:`scatter <firedrake.pyplot.scatter>` gives the user the freedom to pass
Comment thread
achanbour marked this conversation as resolved.
either a :func:`~.VertexOnlyMesh` object or a scalar :class:`~.Function` defined on it, in which case, the values of the function will be used to colour the points.

.. literalinclude:: ../../tests/firedrake/output/test_vom_plotting_manual.py
:language: python3
:dedent:
:start-after: [test_vom_plotting_2d_manual_examples 1]
:end-before: [test_vom_plotting_2d_manual_examples 2]


The visualisation works equally well in 3D. It is advisable, however, to reduce the opacity of the mesh's interior facets
to ensure the points remain visible.

.. literalinclude:: ../../tests/firedrake/output/test_vom_plotting_manual.py
:language: python3
:dedent:
:start-after: [test_vom_plotting_3d_manual_examples 1]
:end-before: [test_vom_plotting_3d_manual_examples 2]

As :func:`scatter <firedrake.pyplot.scatter>` is exposed as a standalone plotting method, it is possible to combine it with any other plots in one single figure.
The example below demonstrates this by superimposing the scatter plot of point data onto a :func:`tripcolor <firedrake.pyplot.tripcolor>` plot of a field
defined on the parent mesh.

.. literalinclude:: ../../tests/firedrake/output/test_vom_plotting_manual.py
:language: python3
:dedent:
:start-after: [test_vom_plotting_2d_manual_examples 3]
:end-before: [test_vom_plotting_2d_manual_examples 4]

Last but not least, vector fields defined on a vertex-only mesh can be visualised using
:func:`quiver <firedrake.pyplot.quiver>` which will show both the magnitude and direction of the vectors at each point.

.. literalinclude:: ../../tests/firedrake/output/test_vom_plotting_manual.py
:language: python3
:dedent:
:start-after: [test_vom_plotting_2d_manual_examples 5]
:end-before: [test_vom_plotting_2d_manual_examples 6]

.. _Paraview: http://www.paraview.org
.. _VTK: http://www.vtk.org
Expand Down
2 changes: 1 addition & 1 deletion firedrake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def init_petsc():
from firedrake.exceptions import ( # noqa: F401
FiredrakeException, ConvergenceError, MismatchingDomainError,
VertexOnlyMeshMissingPointsError, DofNotDefinedError, DofTypeError,
PointNotInDomainError,
SerialExecutionOnlyError, PointNotInDomainError,
)
from firedrake.function import ( # noqa: F401
Function, CoordinatelessFunction, PointEvaluator
Expand Down
5 changes: 5 additions & 0 deletions firedrake/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ class UnrecognisedDeviceError(FiredrakeException):
"""


class SerialExecutionOnlyError(FiredrakeException):
"""Raised when calling any Firedrake method that only runs in serial.
"""


class PointNotInDomainError(FiredrakeException):
r"""Raised when attempting to evaluate a function outside its domain,
and no fill value was given.
Expand Down
4 changes: 2 additions & 2 deletions firedrake/pyplot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from firedrake.pyplot.mpl import (
plot, triplot, tricontourf, tricontour, trisurf, tripcolor, quiver,
plot, triplot, scatter, tricontourf, tricontour, trisurf, tripcolor, quiver,
streamplot, FunctionPlotter
)
from firedrake.pyplot.pgf import pgfplot

__all__ = [
"plot", "triplot", "tricontourf", "tricontour", "trisurf", "tripcolor",
"plot", "triplot", "scatter", "tricontourf", "tricontour", "trisurf", "tripcolor",
"quiver", "streamplot", "FunctionPlotter", "pgfplot"
]
137 changes: 121 additions & 16 deletions firedrake/pyplot/mpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
from math import factorial
from firedrake import (interpolate, sqrt, inner, Function, SpatialCoordinate,
FunctionSpace, VectorFunctionSpace, PointNotInDomainError,
Constant, assemble, dx)
from firedrake.mesh import MeshGeometry
SerialExecutionOnlyError, Constant, assemble, dx)
from firedrake.mesh import MeshGeometry, VertexOnlyMeshTopology
from firedrake.petsc import PETSc
from ufl.domain import extract_unique_domain


__all__ = [
"plot", "triplot", "tricontourf", "tricontour", "trisurf", "tripcolor",
"plot", "triplot", "scatter", "tricontourf", "tricontour", "trisurf", "tripcolor",
"quiver", "streamplot", "FunctionPlotter"
]

Expand Down Expand Up @@ -83,6 +83,72 @@ def _get_collection_types(gdim, tdim):
raise ValueError("Geometric dimension must be either 2 or 3!")


@PETSc.Log.EventDecorator()
def scatter(vom_or_function: MeshGeometry | Function, axes: matplotlib.axes.Axes | None = None, **kwargs) -> matplotlib.collections.PathCollection:
r"""Plot a 2D or 3D :func:`.VertexOnlyMesh` as a scatter plot.

Parameters
----------
vom_or_function
A :func:`.VertexOnlyMesh` or a scalar-valued :class:`~.Function` defined on one.
If a :class:`~.Function` is provided, its values are used to colour the points.
axes
The axes on which to plot. If not provided, the current active axes are used.
**kwargs
Additional keyword arguments passed to :meth:`matplotlib.axes.Axes.scatter`.

Returns
-------
matplotlib.collections.PathCollection
The scatter plot artist.
"""
Comment thread
achanbour marked this conversation as resolved.
is_vom = isinstance(vom_or_function, MeshGeometry) and isinstance(vom_or_function.topology, VertexOnlyMeshTopology)
is_function_on_vom = isinstance(vom_or_function, Function) and isinstance(vom_or_function.function_space().mesh().topology, VertexOnlyMeshTopology)

if not (is_vom or is_function_on_vom):
raise TypeError("Expected a VertexOnlyMesh or a Function defined on one.")

if isinstance(vom_or_function, Function):
if len(vom_or_function.ufl_shape) == 0:
# scalar field: colour points by value
kwargs["c"] = vom_or_function.dat.data_ro
elif len(vom_or_function.ufl_shape) == 1:
# vector field: use quiver instead
raise ValueError("Expected a scalar-valued Function. Use quiver to plot vector-valued Functions.")
else:
raise ValueError(
f"Cannot plot a rank-{len(vom_or_function.ufl_shape)} tensor field; "
"only scalar-valued Functions are supported by this method. "
"For vector-valued Functions, use quiver.")
vom = vom_or_function.function_space().mesh()
else:
vom = vom_or_function

if vom.comm.size > 1:
raise SerialExecutionOnlyError("Firedrake plotting functions can only be used in serial.")

gdim = vom.geometric_dimension
coords = toreal(vom.coordinates.dat.data_ro_with_halos, "real")

if axes is None:
fig = plt.figure()
if gdim == 3:
axes = fig.add_subplot(111, projection="3d")
elif gdim == 2:
axes = fig.add_subplot(111)
Comment thread
achanbour marked this conversation as resolved.
else:
raise ValueError("Scatter is only supported for 2D and 3D meshes.")

kwargs.setdefault("zorder", 5) # this makes sure that points are drawn on top of the parent mesh lines
kwargs.setdefault("s", 10) # controls scatter dot size
kwargs.setdefault("c", "red") # default colour if no function provided

collection = axes.scatter(*(coords.T), **kwargs)

_autoscale_view(axes, coords)
return collection


@PETSc.Log.EventDecorator()
def triplot(mesh, axes=None, interior_kw={}, boundary_kw={}):
r"""Plot a mesh colouring marked facet segments
Expand All @@ -100,6 +166,9 @@ def triplot(mesh, axes=None, interior_kw={}, boundary_kw={}):
:arg boundary_kw: keyword arguments to apply when plotting the mesh boundary
:return: list of matplotlib :class:`Collection <matplotlib.collections.Collection>` objects
"""
if mesh.comm.size > 1:
raise SerialExecutionOnlyError("Firedrake plotting functions can only be used in serial.")

gdim = mesh.geometric_dimension
tdim = mesh.topological_dimension
BoundaryCollection, InteriorCollection = _get_collection_types(gdim, tdim)
Expand Down Expand Up @@ -180,6 +249,7 @@ def facet_data(typ):
if tdim == 3:
boundary_kw["edgecolors"] = boundary_kw.get("edgecolors", "k")
boundary_kw["linewidths"] = boundary_kw.get("linewidths", 1.0)

for marker, color in zip(markers, colors):
vertices = []
for typ in ["interior", "exterior"]:
Expand Down Expand Up @@ -212,6 +282,10 @@ def _plot_2d_field(method_name, function, *args, complex_component="real", **kwa

Q = function.function_space()
mesh = Q.mesh()

if mesh.comm.size > 1:
raise SerialExecutionOnlyError("Firedrake plotting functions can only be used in serial.")

if len(function.ufl_shape) == 1:
element = function.ufl_element().sub_elements[0]
Q = FunctionSpace(mesh, element)
Expand Down Expand Up @@ -319,6 +393,10 @@ def trisurf(function, *args, complex_component="real", **kwargs):

Q = function.function_space()
mesh = Q.mesh()

if mesh.comm.size > 1:
raise SerialExecutionOnlyError("Firedrake plotting functions can only be used in serial.")

if mesh.geometric_dimension == 3:
return _trisurf_3d(axes, function, *args, complex_component=complex_component, **_kwargs)
_kwargs.update({"shade": False})
Expand All @@ -336,14 +414,23 @@ def trisurf(function, *args, complex_component="real", **kwargs):


@PETSc.Log.EventDecorator()
def quiver(function, *, complex_component="real", **kwargs):
r"""Make a quiver plot of a 2D vector Firedrake :class:`~.Function`

:arg function: the vector field to plot
:kwarg complex_component: If plotting complex data, which
component? (``'real'`` or ``'imag'``). Default is ``'real'``.
:arg kwargs: same as for matplotlib :func:`quiver <matplotlib.pyplot.quiver>`
:return: matplotlib :class:`Quiver <matplotlib.quiver.Quiver>` object
def quiver(function: Function, *, complex_component: str = "real", **kwargs) -> matplotlib.quiver.Quiver:
r"""Make a quiver plot of a 2D vector Firedrake :class:`~.Function`.

Parameters
----------
function
The 2D vector field to plot.
complex_component
Which component to plot if the data is complex. Either ``'real'``
or ``'imag'``. Defaults to ``'real'``.
**kwargs
Additional keyword arguments passed to :func:`matplotlib.pyplot.quiver`.

Returns
-------
matplotlib.quiver.Quiver
The quiver plot artist.
"""
if function.ufl_shape != (2,):
raise ValueError("Quiver plots only defined for 2D vector fields!")
Expand All @@ -353,10 +440,18 @@ def quiver(function, *, complex_component="real", **kwargs):
figure = plt.figure()
axes = figure.add_subplot(111)

coords = toreal(extract_unique_domain(function).coordinates.dat.data_ro, "real")
V = extract_unique_domain(function).coordinates.function_space()
function_interp = assemble(interpolate(function, V))
vals = toreal(function_interp.dat.data_ro, complex_component)
mesh = function.function_space().mesh()
if mesh.comm.size > 1:
raise SerialExecutionOnlyError("Firedrake plotting functions can only be used in serial.")

coords = toreal(mesh.coordinates.dat.data_ro, "real")
if isinstance(mesh.topology, VertexOnlyMeshTopology):
vals = toreal(function.dat.data_ro, complex_component)
else:
V = mesh.coordinates.function_space()
function_interp = assemble(interpolate(function, V))
vals = toreal(function_interp.dat.data_ro, complex_component)

C = np.linalg.norm(vals, axis=1)
return axes.quiver(*(coords.T), *(vals.T), C, **kwargs)

Expand Down Expand Up @@ -388,6 +483,9 @@ def streamline(function, point, direction=+1, tolerance=3e-3, loc_tolerance=1e-1
:returns: a generator of the position, velocity, and timestep ``(x, v, dt)``
"""
mesh = extract_unique_domain(function)
if mesh.comm.size > 1:
raise SerialExecutionOnlyError("Firedrake plotting functions can only be used in serial.")

cell_sizes = mesh.cell_sizes

x = np.array(point)
Expand Down Expand Up @@ -612,12 +710,15 @@ def streamplot(function, resolution=None, min_length=None, max_time=None,
if function.ufl_shape != (2,):
raise ValueError("Streamplot only defined for 2D vector fields!")

mesh = extract_unique_domain(function)
if mesh.comm.size > 1:
raise SerialExecutionOnlyError("Firedrake plotting functions can only be used in serial.")

axes = kwargs.pop("axes", None)
if axes is None:
figure = plt.figure()
axes = figure.add_subplot(111)

mesh = extract_unique_domain(function)
if resolution is None:
coords = toreal(mesh.coordinates.dat.data_ro, "real")
resolution = (coords.max(axis=0) - coords.min(axis=0)).max() / 20
Expand Down Expand Up @@ -738,6 +839,7 @@ def plot(function, *args, num_sample_points=10, complex_component="real", **kwar
:arg kwargs: same as for matplotlib :class:`PathPatch <matplotlib.patches.PathPatch>`
:return: list of matplotlib :class:`Line2D <matplotlib.lines.Line2D>`
"""

axes = kwargs.pop("axes", None)
if axes is None:
figure = plt.figure()
Expand All @@ -752,6 +854,9 @@ def plot(function, *args, num_sample_points=10, complex_component="real", **kwar
if isinstance(line, MeshGeometry):
raise TypeError("Expected Function, not Mesh; see firedrake.triplot")

if extract_unique_domain(line).comm.size > 1:
raise SerialExecutionOnlyError("Firedrake plotting functions can only be used in serial.")

if extract_unique_domain(line).geometric_dimension > 1:
raise ValueError("Expected 1D Function; for plotting higher-dimensional fields, "
"see tricontourf, tripcolor, quiver, trisurf")
Expand Down
53 changes: 53 additions & 0 deletions tests/firedrake/output/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,21 @@ def test_quiver_plot():
fig.colorbar(arrows)


@pytest.mark.skipplot
def test_quiver_plot_vom():
mesh = UnitSquareMesh(10, 10)
vom = VertexOnlyMesh(mesh, [[0.5, 0.5], [0.2, 0.8], [0.9, 0.1]])
V = VectorFunctionSpace(vom, "DG", 0)
f = Function(V)
x = SpatialCoordinate(mesh)
f.interpolate(as_vector((-x[1], x[0])))

fig, axes = plt.subplots()
arrows = quiver(f, axes=axes)
assert arrows is not None
fig.colorbar(arrows)


@pytest.mark.skipplot
def test_streamplot():
mesh = UnitSquareMesh(10, 10)
Expand Down Expand Up @@ -371,3 +386,41 @@ def animate(time):

# Use a method of the animation to prevent warning about it being unused
movie.to_jshtml()


@pytest.mark.skipplot
def test_scatter():
mesh = UnitSquareMesh(10, 10)
vom = VertexOnlyMesh(mesh, [[0.5, 0.5], [0.2, 0.8], [0.9, 0.1]])

fig, axes = plt.subplots()
sc = scatter(vom, axes=axes)

assert sc is not None
assert len(sc.get_offsets()) == vom.num_vertices()


@pytest.mark.skipplot
def test_scatter_3d():
mesh = UnitCubeMesh(5, 5, 5)
coords_3d = np.random.rand(20, 3)
vom = VertexOnlyMesh(mesh, coords_3d)

fig = plt.figure()
axes = fig.add_subplot(111, projection='3d')
sc = scatter(vom, axes=axes)
assert sc is not None
assert len(sc.get_offsets()) == vom.num_vertices()


@pytest.mark.skipplot
def test_scatter_scalar_field():
mesh = UnitSquareMesh(10, 10)
vom = VertexOnlyMesh(mesh, [[0.5, 0.5], [0.2, 0.8], [0.9, 0.1]])
V = FunctionSpace(vom, "DG", 0)
f = Function(V)
f.dat.data[:] = [1.0, 2.0, 3.0]

fig, axes = plt.subplots()
sc = scatter(f, axes=axes)
assert np.allclose(sc.get_array(), f.dat.data_ro)
Loading
Loading