Skip to content

Commit 396a2ca

Browse files
committed
Add CUDA process checkpointing helpers
1 parent 11347ff commit 396a2ca

5 files changed

Lines changed: 412 additions & 2 deletions

File tree

cuda_core/cuda/core/system/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

@@ -11,13 +11,22 @@
1111

1212
__all__ = [
1313
"CUDA_BINDINGS_NVML_IS_COMPATIBLE",
14+
"CudaCheckpointGpuPair",
15+
"CudaProcessState",
16+
"checkpoint_cuda_process",
17+
"get_cuda_process_restore_thread_id",
18+
"get_cuda_process_state",
1419
"get_driver_version",
1520
"get_driver_version_full",
1621
"get_num_devices",
1722
"get_process_name",
23+
"lock_cuda_process",
24+
"restore_cuda_process",
25+
"unlock_cuda_process",
1826
]
1927

2028

29+
from ._checkpoint import *
2130
from ._system import *
2231

2332
if CUDA_BINDINGS_NVML_IS_COMPATIBLE:
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from __future__ import annotations
6+
7+
from dataclasses import dataclass
8+
from enum import IntEnum
9+
from typing import Any, Iterable
10+
11+
12+
class CudaProcessState(IntEnum):
13+
"""
14+
CUDA checkpoint state for a process.
15+
"""
16+
17+
RUNNING = 0
18+
LOCKED = 1
19+
CHECKPOINTED = 2
20+
FAILED = 3
21+
22+
23+
@dataclass(frozen=True)
24+
class CudaCheckpointGpuPair:
25+
"""
26+
GPU UUID remapping pair used while restoring a checkpointed CUDA process.
27+
28+
Attributes
29+
----------
30+
old_uuid
31+
UUID of the GPU that was checkpointed.
32+
new_uuid
33+
UUID of the GPU to restore onto.
34+
"""
35+
36+
old_uuid: Any
37+
new_uuid: Any
38+
39+
40+
def _get_driver():
41+
try:
42+
from cuda.bindings import driver
43+
except ImportError:
44+
from cuda import cuda as driver
45+
46+
required = (
47+
"cuCheckpointProcessCheckpoint",
48+
"cuCheckpointProcessGetRestoreThreadId",
49+
"cuCheckpointProcessGetState",
50+
"cuCheckpointProcessLock",
51+
"cuCheckpointProcessRestore",
52+
"cuCheckpointProcessUnlock",
53+
"CUcheckpointGpuPair",
54+
"CUcheckpointLockArgs",
55+
"CUcheckpointRestoreArgs",
56+
)
57+
missing = [name for name in required if not hasattr(driver, name)]
58+
if missing:
59+
raise RuntimeError(
60+
f"CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. Missing: {', '.join(missing)}"
61+
)
62+
return driver
63+
64+
65+
def _handle_return(result):
66+
from cuda.core._utils.cuda_utils import handle_return
67+
68+
return handle_return(result)
69+
70+
71+
def _check_pid(pid: int) -> int:
72+
if not isinstance(pid, int):
73+
raise TypeError("pid must be an int")
74+
if pid <= 0:
75+
raise ValueError("pid must be a positive int")
76+
return pid
77+
78+
79+
def _check_timeout_ms(timeout_ms: int) -> int:
80+
if not isinstance(timeout_ms, int):
81+
raise TypeError("timeout_ms must be an int")
82+
if timeout_ms < 0:
83+
raise ValueError("timeout_ms must be >= 0")
84+
return timeout_ms
85+
86+
87+
def _make_restore_args(driver, gpu_pairs: Iterable[CudaCheckpointGpuPair] | None):
88+
if gpu_pairs is None:
89+
return None
90+
91+
pairs = []
92+
for gpu_pair in gpu_pairs:
93+
if isinstance(gpu_pair, driver.CUcheckpointGpuPair):
94+
pair = gpu_pair
95+
elif isinstance(gpu_pair, CudaCheckpointGpuPair):
96+
pair = driver.CUcheckpointGpuPair()
97+
pair.oldUuid = gpu_pair.old_uuid
98+
pair.newUuid = gpu_pair.new_uuid
99+
else:
100+
raise TypeError(
101+
"gpu_pairs must contain CudaCheckpointGpuPair or cuda.bindings.driver.CUcheckpointGpuPair objects"
102+
)
103+
pairs.append(pair)
104+
105+
if not pairs:
106+
return None
107+
108+
args = driver.CUcheckpointRestoreArgs()
109+
args.gpuPairs = pairs
110+
args.gpuPairsCount = len(pairs)
111+
return args
112+
113+
114+
def get_cuda_process_state(pid: int) -> CudaProcessState:
115+
"""
116+
Return the CUDA checkpoint state for a process.
117+
118+
Parameters
119+
----------
120+
pid : int
121+
Process ID of the CUDA process.
122+
"""
123+
driver = _get_driver()
124+
state = _handle_return(driver.cuCheckpointProcessGetState(_check_pid(pid)))
125+
return CudaProcessState(int(state))
126+
127+
128+
def get_cuda_process_restore_thread_id(pid: int) -> int:
129+
"""
130+
Return the CUDA restore thread ID for a process.
131+
132+
Parameters
133+
----------
134+
pid : int
135+
Process ID of the CUDA process.
136+
"""
137+
driver = _get_driver()
138+
return _handle_return(driver.cuCheckpointProcessGetRestoreThreadId(_check_pid(pid)))
139+
140+
141+
def lock_cuda_process(pid: int, timeout_ms: int = 0) -> None:
142+
"""
143+
Lock a running CUDA process, blocking further CUDA API calls.
144+
145+
Parameters
146+
----------
147+
pid : int
148+
Process ID of the CUDA process.
149+
timeout_ms : int, optional
150+
Timeout in milliseconds. A value of 0 indicates no timeout.
151+
"""
152+
driver = _get_driver()
153+
args = driver.CUcheckpointLockArgs()
154+
args.timeoutMs = _check_timeout_ms(timeout_ms)
155+
_handle_return(driver.cuCheckpointProcessLock(_check_pid(pid), args))
156+
157+
158+
def checkpoint_cuda_process(pid: int) -> None:
159+
"""
160+
Checkpoint the GPU memory contents of a locked CUDA process.
161+
162+
Parameters
163+
----------
164+
pid : int
165+
Process ID of the CUDA process.
166+
"""
167+
driver = _get_driver()
168+
_handle_return(driver.cuCheckpointProcessCheckpoint(_check_pid(pid), None))
169+
170+
171+
def restore_cuda_process(pid: int, gpu_pairs: Iterable[CudaCheckpointGpuPair] | None = None) -> None:
172+
"""
173+
Restore a checkpointed CUDA process.
174+
175+
Parameters
176+
----------
177+
pid : int
178+
Process ID of the CUDA process.
179+
gpu_pairs : iterable of CudaCheckpointGpuPair, optional
180+
GPU UUID remapping pairs. If provided, the array must contain every
181+
checkpointed GPU.
182+
"""
183+
driver = _get_driver()
184+
args = _make_restore_args(driver, gpu_pairs)
185+
_handle_return(driver.cuCheckpointProcessRestore(_check_pid(pid), args))
186+
187+
188+
def unlock_cuda_process(pid: int) -> None:
189+
"""
190+
Unlock a locked CUDA process so it can resume CUDA API calls.
191+
192+
Parameters
193+
----------
194+
pid : int
195+
Process ID of the CUDA process.
196+
"""
197+
driver = _get_driver()
198+
_handle_return(driver.cuCheckpointProcessUnlock(_check_pid(pid), None))
199+
200+
201+
__all__ = [
202+
"CudaCheckpointGpuPair",
203+
"CudaProcessState",
204+
"checkpoint_cuda_process",
205+
"get_cuda_process_restore_thread_id",
206+
"get_cuda_process_state",
207+
"lock_cuda_process",
208+
"restore_cuda_process",
209+
"unlock_cuda_process",
210+
]

cuda_core/docs/source/api.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,12 @@ Basic functions
191191
system.get_process_name
192192
system.get_topology_common_ancestor
193193
system.get_p2p_status
194+
system.get_cuda_process_state
195+
system.get_cuda_process_restore_thread_id
196+
system.lock_cuda_process
197+
system.checkpoint_cuda_process
198+
system.restore_cuda_process
199+
system.unlock_cuda_process
194200

195201
Events
196202
``````
@@ -227,13 +233,18 @@ Enums
227233
system.TemperatureThresholds
228234
system.ThermalController
229235
system.ThermalTarget
236+
system.CudaProcessState
230237

231238
Types
232239
`````
233240

234241
.. autosummary::
235242
:toctree: generated/
236243

244+
:template: dataclass.rst
245+
246+
system.CudaCheckpointGpuPair
247+
237248
:template: autosummary/cyclass.rst
238249

239250
system.Device

cuda_core/docs/source/release/1.0.0-notes.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ Highlights
1616
New features
1717
------------
1818

19-
- TBD
19+
- Added CUDA process checkpointing helpers to :mod:`cuda.core.system`, including
20+
process state queries, lock/checkpoint/restore/unlock operations, and GPU UUID
21+
remapping support for restore. (`#1343 <https://github.com/NVIDIA/cuda-python/issues/1343>`__)
2022

2123

2224
Fixes and enhancements

0 commit comments

Comments
 (0)