Skip to content

Commit d8a2031

Browse files
committed
Add CUDA process checkpointing helpers
1 parent 11347ff commit d8a2031

5 files changed

Lines changed: 406 additions & 2 deletions

File tree

cuda_core/cuda/core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def _import_versioned_module():
2828
del _import_versioned_module
2929

3030

31-
from cuda.core import system, utils
31+
from cuda.core import checkpoint, system, utils
3232
from cuda.core._device import Device
3333
from cuda.core._event import Event, EventOptions
3434
from cuda.core._graphics import GraphicsResource

cuda_core/cuda/core/checkpoint.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from collections.abc import Mapping as _Mapping
6+
from dataclasses import dataclass as _dataclass
7+
from enum import IntEnum as _IntEnum
8+
from typing import Any as _Any
9+
10+
from cuda.core._utils.cuda_utils import handle_return as _handle_cuda_return
11+
12+
try:
13+
from cuda.bindings import driver as _driver
14+
except ImportError:
15+
from cuda import cuda as _driver
16+
17+
18+
class ProcessState(_IntEnum):
19+
"""
20+
CUDA checkpoint state for a process.
21+
"""
22+
23+
RUNNING = 0
24+
LOCKED = 1
25+
CHECKPOINTED = 2
26+
FAILED = 3
27+
28+
29+
@_dataclass(frozen=True)
30+
class Process:
31+
"""
32+
CUDA process that can be locked, checkpointed, restored, and unlocked.
33+
34+
Parameters
35+
----------
36+
pid : int
37+
Process ID of the CUDA process.
38+
"""
39+
40+
pid: int
41+
42+
def __post_init__(self):
43+
_check_pid(self.pid)
44+
45+
@property
46+
def state(self) -> ProcessState:
47+
"""
48+
CUDA checkpoint state for this process.
49+
"""
50+
driver = _get_driver()
51+
state = _handle_return(driver, driver.cuCheckpointProcessGetState(self.pid))
52+
return ProcessState(int(state))
53+
54+
@property
55+
def restore_thread_id(self) -> int:
56+
"""
57+
CUDA restore thread ID for this process.
58+
"""
59+
driver = _get_driver()
60+
return _handle_return(driver, driver.cuCheckpointProcessGetRestoreThreadId(self.pid))
61+
62+
def lock(self, timeout_ms: int = 0) -> None:
63+
"""
64+
Lock this process, blocking further CUDA API calls.
65+
66+
Parameters
67+
----------
68+
timeout_ms : int, optional
69+
Timeout in milliseconds. A value of 0 indicates no timeout.
70+
"""
71+
driver = _get_driver()
72+
args = driver.CUcheckpointLockArgs()
73+
args.timeoutMs = _check_timeout_ms(timeout_ms)
74+
_handle_return(driver, driver.cuCheckpointProcessLock(self.pid, args))
75+
76+
def checkpoint(self) -> None:
77+
"""
78+
Checkpoint the GPU memory contents of this locked process.
79+
"""
80+
driver = _get_driver()
81+
_handle_return(driver, driver.cuCheckpointProcessCheckpoint(self.pid, None))
82+
83+
def restore(self, gpu_mapping: _Mapping[_Any, _Any] | None = None) -> None:
84+
"""
85+
Restore this checkpointed process.
86+
87+
Parameters
88+
----------
89+
gpu_mapping : mapping, optional
90+
GPU UUID remapping from each checkpointed GPU UUID to the GPU UUID
91+
to restore onto. If provided, the mapping must contain every
92+
checkpointed GPU UUID.
93+
"""
94+
driver = _get_driver()
95+
args = _make_restore_args(driver, gpu_mapping)
96+
_handle_return(driver, driver.cuCheckpointProcessRestore(self.pid, args))
97+
98+
def unlock(self) -> None:
99+
"""
100+
Unlock this locked process so it can resume CUDA API calls.
101+
"""
102+
driver = _get_driver()
103+
_handle_return(driver, driver.cuCheckpointProcessUnlock(self.pid, None))
104+
105+
106+
def _get_driver():
107+
required = (
108+
"cuCheckpointProcessCheckpoint",
109+
"cuCheckpointProcessGetRestoreThreadId",
110+
"cuCheckpointProcessGetState",
111+
"cuCheckpointProcessLock",
112+
"cuCheckpointProcessRestore",
113+
"cuCheckpointProcessUnlock",
114+
"CUcheckpointGpuPair",
115+
"CUcheckpointLockArgs",
116+
"CUcheckpointRestoreArgs",
117+
)
118+
missing = [name for name in required if not hasattr(_driver, name)]
119+
if missing:
120+
raise RuntimeError(
121+
f"CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. Missing: {', '.join(missing)}"
122+
)
123+
return _driver
124+
125+
126+
def _handle_return(driver, result):
127+
err = result[0]
128+
not_supported_errors = (
129+
getattr(driver.CUresult, "CUDA_ERROR_NOT_FOUND", None),
130+
getattr(driver.CUresult, "CUDA_ERROR_NOT_SUPPORTED", None),
131+
)
132+
if err in not_supported_errors:
133+
raise RuntimeError(
134+
"CUDA checkpointing is not supported by the installed NVIDIA driver. "
135+
"Upgrade to a driver version with CUDA checkpoint API support."
136+
)
137+
138+
return _handle_cuda_return(result)
139+
140+
141+
def _check_pid(pid: int) -> int:
142+
if isinstance(pid, bool) or not isinstance(pid, int):
143+
raise TypeError("pid must be an int")
144+
if pid <= 0:
145+
raise ValueError("pid must be a positive int")
146+
return pid
147+
148+
149+
def _check_timeout_ms(timeout_ms: int) -> int:
150+
if isinstance(timeout_ms, bool) or not isinstance(timeout_ms, int):
151+
raise TypeError("timeout_ms must be an int")
152+
if timeout_ms < 0:
153+
raise ValueError("timeout_ms must be >= 0")
154+
return timeout_ms
155+
156+
157+
def _make_restore_args(driver, gpu_mapping: _Mapping[_Any, _Any] | None):
158+
if gpu_mapping is None:
159+
return None
160+
if not isinstance(gpu_mapping, _Mapping):
161+
raise TypeError("gpu_mapping must be a mapping from checkpointed GPU UUID to restore GPU UUID")
162+
163+
pairs = []
164+
for old_uuid, new_uuid in gpu_mapping.items():
165+
pair = driver.CUcheckpointGpuPair()
166+
pair.oldUuid = old_uuid
167+
pair.newUuid = new_uuid
168+
pairs.append(pair)
169+
170+
if not pairs:
171+
return None
172+
173+
args = driver.CUcheckpointRestoreArgs()
174+
args.gpuPairs = pairs
175+
args.gpuPairsCount = len(pairs)
176+
return args
177+
178+
179+
__all__ = [
180+
"Process",
181+
"ProcessState",
182+
]

cuda_core/docs/source/api.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,22 @@ CUDA compilation toolchain
174174
LinkerOptions
175175

176176

177+
CUDA process checkpointing
178+
--------------------------
179+
180+
.. autosummary::
181+
:toctree: generated/
182+
183+
:template: class.rst
184+
185+
checkpoint.Process
186+
187+
.. autosummary::
188+
:toctree: generated/
189+
190+
checkpoint.ProcessState
191+
192+
177193
CUDA system information and NVIDIA Management Library (NVML)
178194
------------------------------------------------------------
179195

cuda_core/docs/source/release/1.0.0-notes.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ Highlights
1616
New features
1717
------------
1818

19-
- TBD
19+
- Added the :mod:`cuda.core.checkpoint` module for CUDA process checkpointing,
20+
including process state queries, lock/checkpoint/restore/unlock operations,
21+
and GPU UUID remapping support for restore.
22+
(`#1343 <https://github.com/NVIDIA/cuda-python/issues/1343>`__)
2023

2124

2225
Fixes and enhancements

0 commit comments

Comments
 (0)