Skip to content

Commit 82f816c

Browse files
committed
Add CUDA process checkpointing helpers
1 parent 11347ff commit 82f816c

6 files changed

Lines changed: 412 additions & 3 deletions

File tree

cuda_core/cuda/core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def _import_versioned_module():
2828
del _import_versioned_module
2929

3030

31-
from cuda.core import system, utils
31+
from cuda.core import checkpoint, system, utils
3232
from cuda.core._device import Device
3333
from cuda.core._event import Event, EventOptions
3434
from cuda.core._graphics import GraphicsResource

cuda_core/cuda/core/checkpoint.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from __future__ import annotations
6+
7+
from dataclasses import dataclass
8+
from enum import IntEnum
9+
from typing import Any, Iterable
10+
11+
12+
class ProcessState(IntEnum):
13+
"""
14+
CUDA checkpoint state for a process.
15+
"""
16+
17+
RUNNING = 0
18+
LOCKED = 1
19+
CHECKPOINTED = 2
20+
FAILED = 3
21+
22+
23+
@dataclass(frozen=True)
24+
class GpuPair:
25+
"""
26+
GPU UUID remapping pair used while restoring a checkpointed CUDA process.
27+
28+
Attributes
29+
----------
30+
old_uuid
31+
UUID of the GPU that was checkpointed.
32+
new_uuid
33+
UUID of the GPU to restore onto.
34+
"""
35+
36+
old_uuid: Any
37+
new_uuid: Any
38+
39+
40+
@dataclass(frozen=True)
41+
class Process:
42+
"""
43+
CUDA process that can be locked, checkpointed, restored, and unlocked.
44+
45+
Parameters
46+
----------
47+
pid : int
48+
Process ID of the CUDA process.
49+
"""
50+
51+
pid: int
52+
53+
def __post_init__(self):
54+
_check_pid(self.pid)
55+
56+
@property
57+
def state(self) -> ProcessState:
58+
"""
59+
CUDA checkpoint state for this process.
60+
"""
61+
driver = _get_driver()
62+
state = _handle_return(driver.cuCheckpointProcessGetState(self.pid))
63+
return ProcessState(int(state))
64+
65+
@property
66+
def restore_thread_id(self) -> int:
67+
"""
68+
CUDA restore thread ID for this process.
69+
"""
70+
driver = _get_driver()
71+
return _handle_return(driver.cuCheckpointProcessGetRestoreThreadId(self.pid))
72+
73+
def lock(self, timeout_ms: int = 0) -> None:
74+
"""
75+
Lock this process, blocking further CUDA API calls.
76+
77+
Parameters
78+
----------
79+
timeout_ms : int, optional
80+
Timeout in milliseconds. A value of 0 indicates no timeout.
81+
"""
82+
driver = _get_driver()
83+
args = driver.CUcheckpointLockArgs()
84+
args.timeoutMs = _check_timeout_ms(timeout_ms)
85+
_handle_return(driver.cuCheckpointProcessLock(self.pid, args))
86+
87+
def checkpoint(self) -> None:
88+
"""
89+
Checkpoint the GPU memory contents of this locked process.
90+
"""
91+
driver = _get_driver()
92+
_handle_return(driver.cuCheckpointProcessCheckpoint(self.pid, None))
93+
94+
def restore(self, gpu_pairs: Iterable[GpuPair] | None = None) -> None:
95+
"""
96+
Restore this checkpointed process.
97+
98+
Parameters
99+
----------
100+
gpu_pairs : iterable of GpuPair, optional
101+
GPU UUID remapping pairs. If provided, the array must contain every
102+
checkpointed GPU.
103+
"""
104+
driver = _get_driver()
105+
args = _make_restore_args(driver, gpu_pairs)
106+
_handle_return(driver.cuCheckpointProcessRestore(self.pid, args))
107+
108+
def unlock(self) -> None:
109+
"""
110+
Unlock this locked process so it can resume CUDA API calls.
111+
"""
112+
driver = _get_driver()
113+
_handle_return(driver.cuCheckpointProcessUnlock(self.pid, None))
114+
115+
116+
def _get_driver():
117+
try:
118+
from cuda.bindings import driver
119+
except ImportError:
120+
from cuda import cuda as driver
121+
122+
required = (
123+
"cuCheckpointProcessCheckpoint",
124+
"cuCheckpointProcessGetRestoreThreadId",
125+
"cuCheckpointProcessGetState",
126+
"cuCheckpointProcessLock",
127+
"cuCheckpointProcessRestore",
128+
"cuCheckpointProcessUnlock",
129+
"CUcheckpointGpuPair",
130+
"CUcheckpointLockArgs",
131+
"CUcheckpointRestoreArgs",
132+
)
133+
missing = [name for name in required if not hasattr(driver, name)]
134+
if missing:
135+
raise RuntimeError(
136+
f"CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. Missing: {', '.join(missing)}"
137+
)
138+
return driver
139+
140+
141+
def _handle_return(result):
142+
from cuda.core._utils.cuda_utils import handle_return
143+
144+
return handle_return(result)
145+
146+
147+
def _check_pid(pid: int) -> int:
148+
if isinstance(pid, bool) or not isinstance(pid, int):
149+
raise TypeError("pid must be an int")
150+
if pid <= 0:
151+
raise ValueError("pid must be a positive int")
152+
return pid
153+
154+
155+
def _check_timeout_ms(timeout_ms: int) -> int:
156+
if isinstance(timeout_ms, bool) or not isinstance(timeout_ms, int):
157+
raise TypeError("timeout_ms must be an int")
158+
if timeout_ms < 0:
159+
raise ValueError("timeout_ms must be >= 0")
160+
return timeout_ms
161+
162+
163+
def _make_restore_args(driver, gpu_pairs: Iterable[GpuPair] | None):
164+
if gpu_pairs is None:
165+
return None
166+
167+
pairs = []
168+
for gpu_pair in gpu_pairs:
169+
if isinstance(gpu_pair, driver.CUcheckpointGpuPair):
170+
pair = gpu_pair
171+
elif isinstance(gpu_pair, GpuPair):
172+
pair = driver.CUcheckpointGpuPair()
173+
pair.oldUuid = gpu_pair.old_uuid
174+
pair.newUuid = gpu_pair.new_uuid
175+
else:
176+
raise TypeError("gpu_pairs must contain GpuPair or cuda.bindings.driver.CUcheckpointGpuPair objects")
177+
pairs.append(pair)
178+
179+
if not pairs:
180+
return None
181+
182+
args = driver.CUcheckpointRestoreArgs()
183+
args.gpuPairs = pairs
184+
args.gpuPairsCount = len(pairs)
185+
return args
186+
187+
188+
__all__ = [
189+
"GpuPair",
190+
"Process",
191+
"ProcessState",
192+
]

cuda_core/cuda/core/system/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

cuda_core/docs/source/api.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,29 @@ CUDA compilation toolchain
174174
LinkerOptions
175175

176176

177+
CUDA process checkpointing
178+
--------------------------
179+
180+
.. autosummary::
181+
:toctree: generated/
182+
183+
:template: class.rst
184+
185+
checkpoint.Process
186+
187+
.. autosummary::
188+
:toctree: generated/
189+
190+
:template: dataclass.rst
191+
192+
checkpoint.GpuPair
193+
194+
.. autosummary::
195+
:toctree: generated/
196+
197+
checkpoint.ProcessState
198+
199+
177200
CUDA system information and NVIDIA Management Library (NVML)
178201
------------------------------------------------------------
179202

cuda_core/docs/source/release/1.0.0-notes.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ Highlights
1616
New features
1717
------------
1818

19-
- TBD
19+
- Added the :mod:`cuda.core.checkpoint` module for CUDA process checkpointing,
20+
including process state queries, lock/checkpoint/restore/unlock operations,
21+
and GPU UUID remapping support for restore.
22+
(`#1343 <https://github.com/NVIDIA/cuda-python/issues/1343>`__)
2023

2124

2225
Fixes and enhancements

0 commit comments

Comments
 (0)