Skip to content

Commit 77a67ca

Browse files
committed
OffloadPC subclasses PCBase. Add Offloading test to firedrake-check
1 parent 6c09fe9 commit 77a67ca

7 files changed

Lines changed: 104 additions & 67 deletions

File tree

firedrake/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def init_petsc():
8989
MassInvPC, PCDPC, PatchPC, PlaneSmoother, PatchSNES, P1PC, P1SNES,
9090
LORPC, GTMGPC, PMGPC, PMGSNES, HypreAMS, HypreADS, FDMPC,
9191
PoissonFDMPC, TwoLevelPC, HiptmairPC, FacetSplitPC, BDDCPC,
92-
CovariancePC
92+
CovariancePC, OffloadPC
9393
)
9494
from firedrake.mesh import ( # noqa: F401
9595
Mesh, ExtrudedMesh, VertexOnlyMesh, RelabeledMesh,

firedrake/preconditioners/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AssembledPC, AuxiliaryOperatorPC
1010
)
1111
from firedrake.preconditioners.massinv import MassInvPC # noqa: F401
12+
from firedrake.preconditioners.offload import OffloadPC # noqa: F401
1213
from firedrake.preconditioners.pcd import PCDPC # noqa: F401
1314
from firedrake.preconditioners.patch import ( # noqa: F401
1415
PatchPC, PlaneSmoother, PatchSNES

firedrake/preconditioners/offload.py

Lines changed: 72 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,20 @@
1-
from firedrake.preconditioners.assembled import AssembledPC
1+
from firedrake.preconditioners.assembled import PCBase
22
from firedrake.petsc import PETSc
3-
from firedrake.utils import device_matrix_type
4-
from firedrake.logging import logger
5-
from functools import cache
6-
import warnings
3+
from firedrake.utils import device_matrix_type, get_device_type
74

85
import firedrake.dmhooks as dmhooks
96

107
__all__ = ("OffloadPC",)
118

129

13-
@cache
14-
def offload_mat_type(pc_comm_rank) -> str | None:
15-
mat_type = device_matrix_type()
16-
if mat_type is None:
17-
if pc_comm_rank == 0:
18-
warnings.warn(
19-
"This installation of Firedrake is not GPU-enabled, therefore OffloadPC"
20-
"will do nothing. For this preconditioner to function correctly PETSc"
21-
"will need to be rebuilt with some GPU capability (e.g. '--with-cuda=1')."
22-
)
23-
return None
24-
try:
25-
dev = PETSc.Device.create()
26-
except PETSc.Error:
27-
if pc_comm_rank == 0:
28-
logger.warning(
29-
"This installation of Firedrake is GPU-enabled, but no GPU device has"
30-
"been detected. OffloadPC will do nothing on this host"
31-
)
32-
return None
33-
if dev.getDeviceType() == "HOST":
34-
raise RuntimeError(
35-
"A GPU-enabled Firedrake build has been detected, and GPU hardware has been"
36-
"detected but a GPU device was unable to be initialised."
37-
)
38-
dev.destroy()
39-
return mat_type
10+
_device_vector_impls = {
11+
"CUDA": {
12+
"createWithArrays": "createCUDAWithArrays",
13+
}
14+
}
4015

4116

42-
class OffloadPC(AssembledPC):
17+
class OffloadPC(PCBase):
4318
"""Offload PC from CPU to GPU and back.
4419
4520
Internally this makes a PETSc PC object that can be controlled by
@@ -50,44 +25,84 @@ class OffloadPC(AssembledPC):
5025

5126
def initialize(self, pc):
5227
# Check if our PETSc installation is GPU enabled
53-
super().initialize(pc)
54-
self.offload_mat_type = offload_mat_type(pc.comm.rank)
55-
if self.offload_mat_type is not None:
28+
29+
A, P = pc.getOperators()
30+
31+
if pc.type != "python":
32+
raise ValueError("Expecting PC type python")
33+
opc = pc
34+
if P.type == "python":
35+
context = P.getPythonContext()
36+
# It only makes sense to preconditioner/invert a diagonal
37+
# block in general. That's all we're going to allow.
38+
if not context.on_diag:
39+
raise ValueError("Only makes sense to invert diagonal block")
40+
41+
prefix = pc.getOptionsPrefix() or ""
42+
options_prefix = prefix + self._prefix
43+
44+
self.device_mat = device_matrix_type(pc.comm.rank == 0)
45+
dm = opc.getDM()
46+
47+
pc = PETSc.PC().create(comm=opc.comm)
48+
pc.setDM(dm)
49+
pc.setOptionsPrefix(options_prefix)
50+
if self.device_mat is not None:
5651
with PETSc.Log.Event("Event: initialize offload"):
57-
A, P = pc.getOperators()
52+
P_dev = P.convert(mat_type=self.device_mat)
53+
A_dev = A.convert(mat_type=self.device_mat)
54+
P_dev.setNullSpace(P.getNullSpace())
55+
P_dev.setTransposeNullSpace(P.getTransposeNullSpace())
56+
P_dev.setNearNullSpace(P.getNearNullSpace())
57+
self.vector_impls = _device_vector_impls[get_device_type()]
58+
pc.setOperators(A_dev, P_dev)
59+
else:
60+
pc.setOperators(A, P)
5861

59-
# Convert matrix to ajicusparse
60-
with PETSc.Log.Event("Event: matrix offload"):
61-
P_cu = P.convert(self.offload_mat_type) # todo
62+
# Simplest reconstruction we can manage
63+
octx = dmhooks.get_appctx(dm)
64+
self._ctx_ref = octx.reconstruct(
65+
problem=None, mat_type=self.device_mat, pmat_type=self.device_mat
66+
)
67+
self.pc = pc
6268

63-
# Transfer nullspace
64-
P_cu.setNullSpace(P.getNullSpace())
65-
P_cu.setTransposeNullSpace(P.getTransposeNullSpace())
66-
P_cu.setNearNullSpace(P.getNearNullSpace())
69+
with dmhooks.add_hooks(dm, self, appctx=self._ctx_ref, save=False):
70+
self.pc.setFromOptions()
6771

68-
# Update preconditioner with GPU matrix
69-
self.pc.setOperators(A, P_cu)
72+
def update(self, pc):
73+
A, P = pc.getOperators()
74+
A_dev, P_dev = self.pc.getOperators()
75+
P.copy(P_dev)
76+
A.copy(A_dev)
7077

7178
# Convert vectors to CUDA, solve and get solution on CPU back
7279
def apply(self, pc, x, y):
73-
if self.offload_mat_type is None:
80+
if self.device_mat is None:
7481
self.pc.apply(x, y)
7582
else:
7683
with PETSc.Log.Event("Event: apply offload"): #
7784
dm = pc.getDM()
7885
with dmhooks.add_hooks(dm, self, appctx=self._ctx_ref):
7986
with PETSc.Log.Event("Event: vectors offload"):
80-
y_cu = PETSc.Vec() # begin
81-
y_cu.createCUDAWithArrays(y)
82-
x_cu = PETSc.Vec()
83-
# Passing a vec into another vec doesnt work because original is locked
84-
x_cu.createCUDAWithArrays(x.array_r)
87+
# Create the to-be-offloaded vector
88+
y_dev = PETSc.Vec()
89+
# Use device implementation of 'createWithArrays' function
90+
getattr(y_dev, self.vector_impls["createWithArrays"])(
91+
y.array_r, None
92+
)
93+
# Create the to-be-offloaded vector
94+
x_dev = PETSc.Vec()
95+
# Use device implementation of 'createWithArrays' function
96+
getattr(x_dev, self.vector_impls["createWithArrays"])(
97+
x.array_r, None
98+
)
8599
with PETSc.Log.Event("Event: solve"):
86-
self.pc.apply(x_cu, y_cu)
87-
# Calling data to synchronize vector
88-
tmp = y_cu.array_r # noqa: F841
89-
with PETSc.Log.Event("Event: vectors copy back"):
90-
y.copy(y_cu) #
100+
self.pc.apply(x_dev, y_dev)
101+
with PETSc.Log.Event("Event: vectors copy back"):
102+
# y is already designated as host storage for y_dev, so calling
103+
# getArray is sufficient to synchronise the vector on the device
104+
# with y on the host
105+
y_dev.getArray(True)
91106

92107
def applyTranspose(self, pc, X, Y):
93108
raise NotImplementedError

firedrake/solving_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def __init__(self, problem,
279279

280280
def reconstruct(self, problem=None, mat_type=None, pmat_type=None, **kwargs):
281281
"""Reconstruct this _SNESContext instance with new arguments."""
282-
problem = problem or self.problem
282+
problem = problem or self._problem
283283
mat_type = mat_type or self.mat_type
284284
pmat_type = pmat_type or self.pmat_type
285285
kwargs.setdefault("sub_mat_type", self.sub_mat_type)

firedrake/utils.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,29 @@
2626
SLATE_SUPPORTS_COMPLEX = False
2727

2828

29+
@cache
30+
def get_device_type() -> str | None:
31+
r"""Get PETSc device type
32+
33+
Attempt to initialise a GPU and return the type of GPU
34+
identified by PETSc
35+
36+
Returns
37+
-------
38+
str | None
39+
The PETSc device type
40+
"""
41+
try:
42+
dev = PETSc.Device.create()
43+
except PETSc.Error:
44+
# Could not initialise device - not a failure condition as this could
45+
# be a GPU-enabled PETSc installation running on a CPU-only host.
46+
return None
47+
dev_type = dev.getDeviceType()
48+
dev.destroy()
49+
return dev_type
50+
51+
2952
@cache
3053
def device_matrix_type(warn: bool = True) -> str | None:
3154
r"""Get device matrix type
@@ -39,7 +62,7 @@ def device_matrix_type(warn: bool = True) -> str | None:
3962
----------
4063
warn
4164
Emit a warning containing the reason a device mat_type
42-
has not been returned. Defaults to False.
65+
has not been returned. Defaults to True.
4366
4467
Raises
4568
------
@@ -55,18 +78,13 @@ def device_matrix_type(warn: bool = True) -> str | None:
5578
5679
"""
5780
_device_mat_type_map = {"HOST": None, "CUDA": "aijcusparse"}
58-
try:
59-
dev = PETSc.Device.create()
60-
except PETSc.Error:
61-
# Could not initialise device - not a failure condition as this could
62-
# be a GPU-enabled PETSc installation running on a CPU-only host.
81+
dev_type = get_device_type()
82+
if dev_type is None:
6383
if warn:
6484
warnings.warn(
6585
"This installation of Firedrake is GPU-enabled, but no GPU device has been detected"
6686
)
6787
return None
68-
dev_type = dev.getDeviceType()
69-
dev.destroy()
7088
if dev_type not in _device_mat_type_map:
7189
raise UnrecognisedDeviceError(
7290
f"Unknown device type: {dev_type} initialised by PETSc. Firedrake "

scripts/firedrake-check

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ TESTS = {
2323
"tests/firedrake/regression/test_matrix_free.py::test_fieldsplitting[parameters3-cofunc_rhs-variational]",
2424
# near nullspace
2525
"tests/firedrake/regression/test_nullspace.py::test_near_nullspace",
26+
# GPU offload
27+
"tests/firedrake/offload/test_poisson_offloading_pc.py::test_poisson_offload",
2628
),
2729
2: (
2830
# HDF5/checkpointing

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ def extensions():
240240
"tests/firedrake/regression/test_dg_advection.py",
241241
"tests/firedrake/regression/test_interpolate_cross_mesh.py",
242242
"tests/firedrake/output/test_io_function.py",
243+
"tests/firedrake/offload/test_poisson_offloading_pc.py",
243244
)
244245

245246

0 commit comments

Comments
 (0)