|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +import ctypes as _ctypes |
| 6 | +from collections.abc import Mapping as _Mapping |
| 7 | +from typing import Any as _Any |
| 8 | + |
| 9 | +from cuda.core._utils.cuda_utils import handle_return as _handle_cuda_return |
| 10 | +from cuda.core._utils.version import binding_version as _binding_version |
| 11 | +from cuda.core._utils.version import driver_version as _driver_version |
| 12 | +from cuda.core.typing import ProcessStateT as _ProcessStateT |
| 13 | + |
| 14 | +try: |
| 15 | + from cuda.bindings import driver as _driver |
| 16 | +except ImportError: |
| 17 | + from cuda import cuda as _driver |
| 18 | + |
| 19 | + |
| 20 | +_PROCESS_STATE_NAME_ATTRS: tuple[tuple[str, _ProcessStateT], ...] = ( |
| 21 | + ("CU_PROCESS_STATE_RUNNING", "running"), |
| 22 | + ("CU_PROCESS_STATE_LOCKED", "locked"), |
| 23 | + ("CU_PROCESS_STATE_CHECKPOINTED", "checkpointed"), |
| 24 | + ("CU_PROCESS_STATE_FAILED", "failed"), |
| 25 | +) |
| 26 | + |
| 27 | +_REQUIRED_BINDING_ATTRS = ( |
| 28 | + "cuCheckpointProcessCheckpoint", |
| 29 | + "cuCheckpointProcessGetRestoreThreadId", |
| 30 | + "cuCheckpointProcessGetState", |
| 31 | + "cuCheckpointProcessLock", |
| 32 | + "cuCheckpointProcessRestore", |
| 33 | + "cuCheckpointProcessUnlock", |
| 34 | + "CUcheckpointGpuPair", |
| 35 | + "CUcheckpointLockArgs", |
| 36 | + "CUprocessState", |
| 37 | + "CUcheckpointRestoreArgs", |
| 38 | +) |
| 39 | +_REQUIRED_DRIVER_VERSION = (12, 8, 0) |
| 40 | +_driver_capability_checked = False |
| 41 | + |
| 42 | + |
| 43 | +class Process: |
| 44 | + """ |
| 45 | + CUDA process that can be locked, checkpointed, restored, and unlocked. |
| 46 | +
|
| 47 | + Parameters |
| 48 | + ---------- |
| 49 | + pid : int |
| 50 | + Process ID of the CUDA process. |
| 51 | + """ |
| 52 | + |
| 53 | + __slots__ = ("_pid",) |
| 54 | + |
| 55 | + def __init__(self, pid: int): |
| 56 | + self._pid = _check_pid(pid) |
| 57 | + |
| 58 | + @property |
| 59 | + def pid(self) -> int: |
| 60 | + """ |
| 61 | + Process ID of the CUDA process. |
| 62 | + """ |
| 63 | + return self._pid |
| 64 | + |
| 65 | + @property |
| 66 | + def state(self) -> _ProcessStateT: |
| 67 | + """ |
| 68 | + CUDA checkpoint state for this process. |
| 69 | + """ |
| 70 | + driver = _get_driver() |
| 71 | + state = _call_driver(driver, driver.cuCheckpointProcessGetState, self._pid) |
| 72 | + state_names = _get_process_state_names(driver) |
| 73 | + try: |
| 74 | + return state_names[state] |
| 75 | + except KeyError as e: |
| 76 | + state_value = int(state) |
| 77 | + raise RuntimeError(f"Unknown CUDA checkpoint process state: {state_value}") from e |
| 78 | + |
| 79 | + @property |
| 80 | + def restore_thread_id(self) -> int: |
| 81 | + """ |
| 82 | + CUDA restore thread ID for this process. |
| 83 | + """ |
| 84 | + driver = _get_driver() |
| 85 | + return _call_driver(driver, driver.cuCheckpointProcessGetRestoreThreadId, self._pid) |
| 86 | + |
| 87 | + def lock(self, timeout_ms: int = 0) -> None: |
| 88 | + """ |
| 89 | + Lock this process, blocking further CUDA API calls. |
| 90 | +
|
| 91 | + Parameters |
| 92 | + ---------- |
| 93 | + timeout_ms : int, optional |
| 94 | + Timeout in milliseconds. A value of 0 indicates no timeout. |
| 95 | + """ |
| 96 | + driver = _get_driver() |
| 97 | + args = driver.CUcheckpointLockArgs() |
| 98 | + args.timeoutMs = _check_timeout_ms(timeout_ms) |
| 99 | + _call_driver(driver, driver.cuCheckpointProcessLock, self._pid, args) |
| 100 | + |
| 101 | + def checkpoint(self) -> None: |
| 102 | + """ |
| 103 | + Checkpoint the GPU memory contents of this locked process. |
| 104 | + """ |
| 105 | + driver = _get_driver() |
| 106 | + _call_driver(driver, driver.cuCheckpointProcessCheckpoint, self._pid, None) |
| 107 | + |
| 108 | + def restore(self, gpu_mapping: _Mapping[_Any, _Any] | None = None) -> None: |
| 109 | + """ |
| 110 | + Restore this checkpointed process. |
| 111 | +
|
| 112 | + Parameters |
| 113 | + ---------- |
| 114 | + gpu_mapping : mapping, optional |
| 115 | + GPU UUID remapping from each checkpointed GPU UUID to the GPU UUID |
| 116 | + to restore onto. For migration workflows, provide mappings for |
| 117 | + every GPU visible to the kernel-mode driver. User-space masking |
| 118 | + such as ``CUDA_VISIBLE_DEVICES`` does not reduce this mapping |
| 119 | + requirement. |
| 120 | + """ |
| 121 | + driver = _get_driver() |
| 122 | + args = _make_restore_args(driver, gpu_mapping) |
| 123 | + _call_driver(driver, driver.cuCheckpointProcessRestore, self._pid, args) |
| 124 | + |
| 125 | + def unlock(self) -> None: |
| 126 | + """ |
| 127 | + Unlock this locked process so it can resume CUDA API calls. |
| 128 | + """ |
| 129 | + driver = _get_driver() |
| 130 | + _call_driver(driver, driver.cuCheckpointProcessUnlock, self._pid, None) |
| 131 | + |
| 132 | + |
| 133 | +def _get_driver(): |
| 134 | + global _driver_capability_checked |
| 135 | + if _driver_capability_checked: |
| 136 | + return _driver |
| 137 | + |
| 138 | + binding_ver = _binding_version() |
| 139 | + if not _binding_version_supports_checkpoint(binding_ver): |
| 140 | + raise RuntimeError( |
| 141 | + "CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. " |
| 142 | + f"Found cuda.bindings {'.'.join(str(part) for part in binding_ver[:3])}." |
| 143 | + ) |
| 144 | + |
| 145 | + missing = [name for name in _REQUIRED_BINDING_ATTRS if not hasattr(_driver, name)] |
| 146 | + if missing: |
| 147 | + raise RuntimeError( |
| 148 | + f"CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. Missing: {', '.join(missing)}" |
| 149 | + ) |
| 150 | + |
| 151 | + driver_ver = _driver_version() |
| 152 | + if driver_ver < _REQUIRED_DRIVER_VERSION: |
| 153 | + raise RuntimeError( |
| 154 | + "CUDA checkpointing is not supported by the installed NVIDIA driver. " |
| 155 | + "Upgrade to a driver version with CUDA checkpoint API support." |
| 156 | + ) |
| 157 | + |
| 158 | + _driver_capability_checked = True |
| 159 | + return _driver |
| 160 | + |
| 161 | + |
| 162 | +def _binding_version_supports_checkpoint(version) -> bool: |
| 163 | + major, minor, patch = version[:3] |
| 164 | + return (major == 12 and (minor, patch) >= (8, 0)) or (major == 13 and (minor, patch) >= (0, 2)) or major > 13 |
| 165 | + |
| 166 | + |
| 167 | +def _get_process_state_names(driver) -> dict[_Any, _ProcessStateT]: |
| 168 | + return {getattr(driver.CUprocessState, attr): state_name for attr, state_name in _PROCESS_STATE_NAME_ATTRS} |
| 169 | + |
| 170 | + |
| 171 | +def _call_driver(driver, func, *args): |
| 172 | + try: |
| 173 | + result = func(*args) |
| 174 | + except RuntimeError as e: |
| 175 | + if "cuCheckpointProcess" in str(e) and "not found" in str(e): |
| 176 | + raise RuntimeError( |
| 177 | + "CUDA checkpointing is not supported by the installed NVIDIA driver. " |
| 178 | + "Upgrade to a driver version with CUDA checkpoint API support." |
| 179 | + ) from e |
| 180 | + raise |
| 181 | + |
| 182 | + err = result[0] |
| 183 | + not_supported_errors = ( |
| 184 | + getattr(driver.CUresult, "CUDA_ERROR_NOT_FOUND", None), |
| 185 | + getattr(driver.CUresult, "CUDA_ERROR_NOT_SUPPORTED", None), |
| 186 | + ) |
| 187 | + if err in not_supported_errors: |
| 188 | + raise RuntimeError( |
| 189 | + "CUDA checkpointing is not supported by the installed NVIDIA driver. " |
| 190 | + "Upgrade to a driver version with CUDA checkpoint API support." |
| 191 | + ) |
| 192 | + |
| 193 | + return _handle_cuda_return(result) |
| 194 | + |
| 195 | + |
| 196 | +def _check_pid(pid: int) -> int: |
| 197 | + if isinstance(pid, bool) or not isinstance(pid, int): |
| 198 | + raise TypeError("pid must be an int") |
| 199 | + if pid <= 0: |
| 200 | + raise ValueError("pid must be a positive int") |
| 201 | + return pid |
| 202 | + |
| 203 | + |
| 204 | +def _check_timeout_ms(timeout_ms: int) -> int: |
| 205 | + if isinstance(timeout_ms, bool) or not isinstance(timeout_ms, int): |
| 206 | + raise TypeError("timeout_ms must be an int") |
| 207 | + if timeout_ms < 0: |
| 208 | + raise ValueError("timeout_ms must be >= 0") |
| 209 | + return timeout_ms |
| 210 | + |
| 211 | + |
| 212 | +def _make_restore_args(driver, gpu_mapping: _Mapping[_Any, _Any] | None): |
| 213 | + if gpu_mapping is None: |
| 214 | + return None |
| 215 | + if not isinstance(gpu_mapping, _Mapping): |
| 216 | + raise TypeError("gpu_mapping must be a mapping from checkpointed GPU UUID to restore GPU UUID") |
| 217 | + |
| 218 | + pairs = [] |
| 219 | + for old_uuid, new_uuid in gpu_mapping.items(): |
| 220 | + pair = driver.CUcheckpointGpuPair() |
| 221 | + buffers = [] |
| 222 | + pair.oldUuid = _as_cuuuid(driver, old_uuid, buffers) |
| 223 | + pair.newUuid = _as_cuuuid(driver, new_uuid, buffers) |
| 224 | + pairs.append(pair) |
| 225 | + |
| 226 | + if not pairs: |
| 227 | + return None |
| 228 | + |
| 229 | + args = driver.CUcheckpointRestoreArgs() |
| 230 | + args.gpuPairs = pairs |
| 231 | + args.gpuPairsCount = len(pairs) |
| 232 | + return args |
| 233 | + |
| 234 | + |
| 235 | +def _as_cuuuid(driver, value, buffers): |
| 236 | + """Convert *value* to a ``CUuuid``. |
| 237 | +
|
| 238 | + Accepts a ``CUuuid`` instance (returned as-is) or a UUID string in |
| 239 | + the ``"xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"`` format returned by |
| 240 | + :attr:`Device.uuid`. |
| 241 | + """ |
| 242 | + if isinstance(value, str): |
| 243 | + raw = bytes.fromhex(value.replace("-", "")) |
| 244 | + if len(raw) != 16: |
| 245 | + raise ValueError(f"GPU UUID string must be 32 hex characters (with optional hyphens), got {value!r}") |
| 246 | + buf = _ctypes.create_string_buffer(raw, 16) |
| 247 | + buffers.append(buf) |
| 248 | + return driver.CUuuid(_ctypes.addressof(buf)) |
| 249 | + return value |
| 250 | + |
| 251 | + |
| 252 | +__all__ = [ |
| 253 | + "Process", |
| 254 | +] |
0 commit comments