Skip to content

Commit fc7215b

Browse files
authored
Merge branch 'firedrakeproject:main' into dsroberts/offload-pc
2 parents 221193a + c3b1325 commit fc7215b

5 files changed

Lines changed: 146 additions & 1 deletion

File tree

firedrake/ensemble/ensemble.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from functools import wraps
22
import weakref
3+
from contextlib import contextmanager
34
from itertools import zip_longest
5+
from types import SimpleNamespace
46

57
from firedrake.petsc import PETSc
68
from firedrake.function import Function
@@ -584,3 +586,93 @@ def isendrecv(self, fsend: Function | Cofunction, dest: int, sendtag: int = 0,
584586
requests.extend([self._ensemble_comm.Irecv(dat.data, source=source, tag=recvtag)
585587
for dat in frecv.dat])
586588
return requests
589+
590+
@contextmanager
591+
def sequential(self, *, synchronise: bool = False, reverse: bool = False, **kwargs):
592+
"""
593+
Context manager for executing code on each ensemble
594+
member consecutively (ordered by increasing
595+
:attr:`~.Ensemble.ensemble_rank`).
596+
597+
Any data in ``kwargs`` will be made available in the returned
598+
context and will be communicated forward after each ensemble
599+
member exits. :class:`.Function` or :class:`.Cofunction`
600+
``kwargs`` will be sent with the corresponding Ensemble methods.
601+
602+
For example:
603+
604+
.. code-block:: python3
605+
606+
with ensemble.sequential(index=0) as ctx:
607+
print(ensemble.ensemble_rank, ctx.index)
608+
ctx.index += 2
609+
610+
Would print:
611+
612+
.. code-block::
613+
614+
0 0
615+
1 2
616+
2 4
617+
3 6
618+
...
619+
620+
If ``reverse is True`` then the ensemble ranks will be looped through
621+
in decreasing order i.e. ``ensemble_rank == (ensemble_size - 1)`` will
622+
run first, then ``ensemble_rank == (ensemble_size - 2)`` etc.
623+
624+
Parameters
625+
----------
626+
synchronise :
627+
If True then MPI_Barrier will be called on the ``global_comm``
628+
at the beginning and end of this method.
629+
630+
reverse :
631+
If True then will iterate through spatial comms in order of
632+
decreasing ``ensemble_rank``.
633+
634+
kwargs :
635+
Data to be passed forward by each rank and made available
636+
in the returned ``ctx``.
637+
"""
638+
rank = self.ensemble_rank
639+
if reverse: # send backwards
640+
src = rank + 1
641+
dst = rank - 1
642+
first_rank = (rank == self.ensemble_size - 1)
643+
last_rank = (rank == 0)
644+
else: # send forwards
645+
src = rank - 1
646+
dst = rank + 1
647+
first_rank = (rank == 0)
648+
last_rank = (rank == self.ensemble_size - 1)
649+
650+
if synchronise:
651+
self.global_comm.Barrier()
652+
653+
if not first_rank:
654+
for i, (k, v) in enumerate(kwargs.items()):
655+
if isinstance(v, (Function, Cofunction)):
656+
# Functions are sent in-place, everything else is pickled
657+
recv_args = [kwargs[k]]
658+
else:
659+
recv_args = []
660+
kwargs[k] = self.recv(*recv_args, source=src, tag=rank+i*100)
661+
662+
ctx = SimpleNamespace(**kwargs)
663+
yield ctx
664+
665+
if not last_rank:
666+
for i, v in enumerate((getattr(ctx, k)
667+
for k in kwargs.keys())):
668+
try:
669+
self.send(v, dest=dst, tag=dst+i*100)
670+
except Exception as error:
671+
raise TypeError(
672+
"Failed to send object of type {type(v)__name__}. kwargs for"
673+
" Ensemble.sequential must be Functions, Cofunctions,"
674+
" or acceptable arguments to mpi4py.MPI.Comm.send."
675+
) from error
676+
677+
if synchronise:
678+
self.global_comm.Barrier()

tests/firedrake/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
os.environ["FIREDRAKE_DISABLE_OPTIONS_LEFT"] = "1"
1010

1111
import pytest
12+
from mpi4py import MPI
1213
from petsctools import get_external_packages
1314
from pyadjoint.tape import (
1415
annotate_tape, get_working_tape, set_working_tape,
@@ -264,3 +265,10 @@ def __exit__(self, exc_type, exc_val, traceback):
264265
def petsc_raises():
265266
# This function is needed because pytest does not support classes as fixtures.
266267
return _petsc_raises
268+
269+
270+
@pytest.fixture
271+
def garbage_cleanup():
272+
"""Fixture that runs the parallel garbage collector."""
273+
yield
274+
PETSc.garbage_cleanup(MPI.COMM_WORLD)

tests/firedrake/ensemble/test_ensemble.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,3 +380,43 @@ def test_ensemble_solvers(ensemble, W, urank, urank_sum):
380380
ensemble.allreduce(u_separate, usum)
381381

382382
parallel_assert(errornorm(u_combined, usum) < 1e-8)
383+
384+
385+
@pytest.mark.parallel(nprocs=6)
386+
@pytest.mark.parametrize("direction", ["forward", "reverse"])
387+
def test_ensemble_sequential(ensemble, direction):
388+
"""
389+
Test that the sequential context manager sends forward
390+
the correct values after each rank has executed, for both
391+
intrinsic types (float) and Firedrake types (Function).
392+
"""
393+
394+
rank = ensemble.ensemble_rank
395+
mesh = UnitIntervalMesh(1, comm=ensemble.comm)
396+
R = FunctionSpace(mesh, "R", 0)
397+
398+
reverse = direction == "reverse"
399+
400+
idx_i = 0
401+
idx_f = Function(R).zero()
402+
two = Function(R).assign(2)
403+
404+
with ensemble.sequential(reverse=reverse, idx_i=idx_i, idx_f=idx_f) as ctx:
405+
recv_i = float(ctx.idx_i)
406+
recv_f = float(ctx.idx_f)
407+
408+
ctx.idx_i += 2
409+
ctx.idx_f += two
410+
411+
if reverse:
412+
expected = 2*(ensemble.ensemble_size - 1 - rank)
413+
else:
414+
expected = 2*rank
415+
416+
parallel_assert(
417+
recv_i == expected,
418+
msg=f"Failed to send int properly. Expecting {expected} but received {recv_i}")
419+
420+
parallel_assert(
421+
abs(float(recv_f)-expected) < 1e-12,
422+
msg=f"Failed to send Function properly. Expecting {expected} but received {float(recv_f)}")

tests/firedrake/output/test_io_function.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
func_name = "f"
1717

1818

19+
@pytest.fixture(autouse=True)
20+
def autouse_garbage_cleanup(garbage_cleanup):
21+
pass
22+
23+
1924
def _initialise_function(f, _f, method):
2025
if method == "project":
2126
getattr(f, method)(_f, solver_parameters={"ksp_type": "cg", "pc_type": "sor", "ksp_rtol": 1.e-16})

tests/firedrake/regression/test_covariance_operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def rng():
3939
@pytest.mark.parametrize("family", ("CG", "DG"))
4040
@pytest.mark.parametrize("mesh_type", ("interval", "square"))
4141
@pytest.mark.parametrize("backend_type", (PyOP2NoiseBackend, PetscNoiseBackend), ids=("pyop2", "petsc"))
42-
def test_white_noise(family, degree, mesh_type, dim, backend_type, rng):
42+
def test_white_noise(family, degree, mesh_type, dim, backend_type, rng, garbage_cleanup):
4343
"""Test that white noise generator converges to a mass matrix covariance.
4444
"""
4545

0 commit comments

Comments
 (0)