Skip to content

Commit 6c09fe9

Browse files
committed
Restart offload_pc from main_20260417
1 parent 74ecb96 commit 6c09fe9

2 files changed

Lines changed: 163 additions & 0 deletions

File tree

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from firedrake.preconditioners.assembled import AssembledPC
2+
from firedrake.petsc import PETSc
3+
from firedrake.utils import device_matrix_type
4+
from firedrake.logging import logger
5+
from functools import cache
6+
import warnings
7+
8+
import firedrake.dmhooks as dmhooks
9+
10+
__all__ = ("OffloadPC",)
11+
12+
13+
@cache
14+
def offload_mat_type(pc_comm_rank) -> str | None:
15+
mat_type = device_matrix_type()
16+
if mat_type is None:
17+
if pc_comm_rank == 0:
18+
warnings.warn(
19+
"This installation of Firedrake is not GPU-enabled, therefore OffloadPC"
20+
"will do nothing. For this preconditioner to function correctly PETSc"
21+
"will need to be rebuilt with some GPU capability (e.g. '--with-cuda=1')."
22+
)
23+
return None
24+
try:
25+
dev = PETSc.Device.create()
26+
except PETSc.Error:
27+
if pc_comm_rank == 0:
28+
logger.warning(
29+
"This installation of Firedrake is GPU-enabled, but no GPU device has"
30+
"been detected. OffloadPC will do nothing on this host"
31+
)
32+
return None
33+
if dev.getDeviceType() == "HOST":
34+
raise RuntimeError(
35+
"A GPU-enabled Firedrake build has been detected, and GPU hardware has been"
36+
"detected but a GPU device was unable to be initialised."
37+
)
38+
dev.destroy()
39+
return mat_type
40+
41+
42+
class OffloadPC(AssembledPC):
43+
"""Offload PC from CPU to GPU and back.
44+
45+
Internally this makes a PETSc PC object that can be controlled by
46+
options using the extra options prefix ``offload_``.
47+
"""
48+
49+
_prefix = "offload_"
50+
51+
def initialize(self, pc):
52+
# Check if our PETSc installation is GPU enabled
53+
super().initialize(pc)
54+
self.offload_mat_type = offload_mat_type(pc.comm.rank)
55+
if self.offload_mat_type is not None:
56+
with PETSc.Log.Event("Event: initialize offload"):
57+
A, P = pc.getOperators()
58+
59+
# Convert matrix to ajicusparse
60+
with PETSc.Log.Event("Event: matrix offload"):
61+
P_cu = P.convert(self.offload_mat_type) # todo
62+
63+
# Transfer nullspace
64+
P_cu.setNullSpace(P.getNullSpace())
65+
P_cu.setTransposeNullSpace(P.getTransposeNullSpace())
66+
P_cu.setNearNullSpace(P.getNearNullSpace())
67+
68+
# Update preconditioner with GPU matrix
69+
self.pc.setOperators(A, P_cu)
70+
71+
# Convert vectors to CUDA, solve and get solution on CPU back
72+
def apply(self, pc, x, y):
73+
if self.offload_mat_type is None:
74+
self.pc.apply(x, y)
75+
else:
76+
with PETSc.Log.Event("Event: apply offload"): #
77+
dm = pc.getDM()
78+
with dmhooks.add_hooks(dm, self, appctx=self._ctx_ref):
79+
with PETSc.Log.Event("Event: vectors offload"):
80+
y_cu = PETSc.Vec() # begin
81+
y_cu.createCUDAWithArrays(y)
82+
x_cu = PETSc.Vec()
83+
# Passing a vec into another vec doesnt work because original is locked
84+
x_cu.createCUDAWithArrays(x.array_r)
85+
with PETSc.Log.Event("Event: solve"):
86+
self.pc.apply(x_cu, y_cu)
87+
# Calling data to synchronize vector
88+
tmp = y_cu.array_r # noqa: F841
89+
with PETSc.Log.Event("Event: vectors copy back"):
90+
y.copy(y_cu) #
91+
92+
def applyTranspose(self, pc, X, Y):
93+
raise NotImplementedError
94+
95+
def view(self, pc, viewer=None):
96+
super().view(pc, viewer)
97+
if hasattr(self, "pc"):
98+
viewer.printfASCII("PC to solve on GPU\n")
99+
self.pc.view(viewer)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from firedrake import *
2+
import pytest
3+
4+
5+
# TODO: add marker for cuda pytests and something to check if cuda memory was really used
6+
@pytest.mark.skipnogpu
7+
@pytest.mark.parametrize(
8+
"ksp_type, pc_type", [("cg", "sor"), ("cg", "gamg"), ("preonly", "lu")]
9+
)
10+
def test_poisson_offload(ksp_type, pc_type):
11+
12+
# Different tests for poisson: cg and pctype sor, --ksp_type=cg --pc_type=gamg
13+
print(f"Using ksp_type = {ksp_type}, and pc_type = {pc_type}.", flush=True)
14+
15+
nested_parameters = {
16+
"pc_type": "ksp",
17+
"ksp": {
18+
"ksp_type": ksp_type,
19+
"ksp_max_it": 50,
20+
"ksp_view": None,
21+
"ksp_rtol": "1e-10",
22+
"ksp_monitor": None,
23+
"pc_type": pc_type,
24+
},
25+
}
26+
parameters = {
27+
"ksp_type": "preonly",
28+
"pc_type": "python",
29+
"pc_python_type": "firedrake.OffloadPC",
30+
"offload": nested_parameters,
31+
}
32+
33+
mesh = UnitSquareMesh(10, 10)
34+
V = FunctionSpace(mesh, "CG", 1)
35+
u = TrialFunction(V)
36+
v = TestFunction(V)
37+
38+
f = Function(V)
39+
x, y = SpatialCoordinate(mesh)
40+
f.interpolate(2 * pi**2 * sin(pi * x) * sin(pi * y))
41+
42+
# Equations
43+
L = inner(grad(u), grad(v)) * dx
44+
45+
# Dirichlet boundary on all sides to 0
46+
bcs = DirichletBC(V, 0, "on_boundary")
47+
48+
# Exact solution
49+
sol = Function(V)
50+
R = action(L, sol)
51+
52+
# Solution function
53+
u_f = Function(V)
54+
55+
problem = LinearVariationalProblem(L, R, u_f, bcs=bcs)
56+
solver = LinearVariationalSolver(problem, solver_parameters=parameters)
57+
solver.solve()
58+
error = errornorm(u_f, sol)
59+
print(f"Error norm = {error}", flush=True)
60+
assert error < 1.0e-9
61+
62+
63+
if __name__ == "__main__":
64+
test_poisson_offload("cg", "gamg")

0 commit comments

Comments
 (0)