Skip to content

Commit 4c9aefa

Browse files
committed
Rework device_matrix_type
1 parent 68b22b2 commit 4c9aefa

1 file changed

Lines changed: 52 additions & 0 deletions

File tree

firedrake/utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from pyop2.datatypes import IntType # noqa: F401
88
from pyop2.datatypes import as_ctypes # noqa: F401
99
from pyop2.mpi import MPI
10+
from petsc4py import PETSc
11+
from functools import cache
1012
import petsctools
1113

1214

@@ -23,6 +25,56 @@
2325
SLATE_SUPPORTS_COMPLEX = False
2426

2527

28+
@cache
29+
def device_matrix_type(warn: bool = False) -> str | None:
30+
"""Get device matrix type
31+
32+
Attempt to initialise a GPU device and return the PETSc mat_type
33+
compatible with that device, or None if no device is detected
34+
35+
Args:
36+
warn: Emit a warning containing the reason a device mat_type
37+
has not been returned. Defaults to False.
38+
39+
Raises:
40+
RuntimeError: Raised when PETSc initialises a GPU device that
41+
Firedrake does not understand
42+
43+
Returns:
44+
The PETSc mat_type compatible with the GPU device detected on
45+
this system or None
46+
47+
Typical Usage Example:
48+
mat_type = device_matrix_type(pc.comm.rank == 0)
49+
50+
"""
51+
_device_mat_type_map = {"CUDA": "aijcusparse"}
52+
try:
53+
dev = PETSc.Device.create()
54+
except PETSc.Error:
55+
# Could not initialise device - not a failure condition as this could
56+
# be a GPU-enabled PETSc installation running on a CPU-only host.
57+
if warn:
58+
warnings.warn(
59+
"This installation of Firedrake is GPU-enabled, but no GPU device has been detected"
60+
)
61+
return None
62+
dev_type = dev.getDeviceType()
63+
dev.destroy()
64+
if dev_type not in _device_mat_type_map:
65+
raise RuntimeError(f"Unknown device type: {dev_type} initialised by PETSc")
66+
67+
if dev_type == "HOST":
68+
if warn:
69+
warnings.warn(
70+
"This installation of Firedrake is not GPU-enabled, to enable GPU functionality "
71+
"PETSc will need to be rebuilt with some GPU capability appropriate for this system "
72+
"(e.g. '--with-cuda=1')."
73+
)
74+
return None
75+
return _device_mat_type_map[dev_type]
76+
77+
2678
def _new_uid(comm):
2779
uid = comm.Get_attr(FIREDRAKE_UID)
2880
if uid is None:

0 commit comments

Comments
 (0)