-
Notifications
You must be signed in to change notification settings - Fork 279
Add CUDA process checkpointing helpers #1983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
kkraus14
merged 13 commits into
NVIDIA:main
from
kkraus14:kk/issue-1343-cuda-checkpointing
May 5, 2026
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
d8a2031
Add CUDA process checkpointing helpers
kkraus14 4992921
Address checkpoint review feedback
kkraus14 5afd43c
Rewrite checkpoint tests: replace mocks with real GPU tests
leofang f67a5e6
Accept Device.uuid strings in gpu_mapping; use cuda.core APIs in tests
leofang 245e7a4
Apply pre-commit formatting fixes
leofang 7c7f0e5
Restore original device in self_process fixture teardown
leofang 8192df6
Address checkpoint review follow-ups
kkraus14 fbb8037
Skip checkpoint lifecycle/migration tests in CI
leofang 8f798f4
Isolate checkpoint lifecycle tests
kkraus14 376acc7
Address checkpoint review follow-ups
kkraus14 7a2e683
Fix checkpoint subprocess imports in CI
kkraus14 8aeb8e8
Handle checkpoint migration no-op in CI
kkraus14 b9fa2a1
Simplify checkpoint subprocess tests
kkraus14 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,254 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import ctypes as _ctypes | ||
| from collections.abc import Mapping as _Mapping | ||
| from typing import Any as _Any | ||
|
|
||
| from cuda.core._utils.cuda_utils import handle_return as _handle_cuda_return | ||
| from cuda.core._utils.version import binding_version as _binding_version | ||
| from cuda.core._utils.version import driver_version as _driver_version | ||
| from cuda.core.typing import ProcessStateT as _ProcessStateT | ||
|
|
||
| try: | ||
| from cuda.bindings import driver as _driver | ||
| except ImportError: | ||
| from cuda import cuda as _driver | ||
|
leofang marked this conversation as resolved.
|
||
|
|
||
|
|
||
| _PROCESS_STATE_NAME_ATTRS: tuple[tuple[str, _ProcessStateT], ...] = ( | ||
| ("CU_PROCESS_STATE_RUNNING", "running"), | ||
| ("CU_PROCESS_STATE_LOCKED", "locked"), | ||
| ("CU_PROCESS_STATE_CHECKPOINTED", "checkpointed"), | ||
| ("CU_PROCESS_STATE_FAILED", "failed"), | ||
| ) | ||
|
|
||
| _REQUIRED_BINDING_ATTRS = ( | ||
| "cuCheckpointProcessCheckpoint", | ||
| "cuCheckpointProcessGetRestoreThreadId", | ||
| "cuCheckpointProcessGetState", | ||
| "cuCheckpointProcessLock", | ||
| "cuCheckpointProcessRestore", | ||
| "cuCheckpointProcessUnlock", | ||
| "CUcheckpointGpuPair", | ||
| "CUcheckpointLockArgs", | ||
| "CUprocessState", | ||
| "CUcheckpointRestoreArgs", | ||
| ) | ||
| _REQUIRED_DRIVER_VERSION = (12, 8, 0) | ||
| _driver_capability_checked = False | ||
|
|
||
|
|
||
| class Process: | ||
| """ | ||
| CUDA process that can be locked, checkpointed, restored, and unlocked. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| pid : int | ||
| Process ID of the CUDA process. | ||
| """ | ||
|
|
||
| __slots__ = ("_pid",) | ||
|
|
||
| def __init__(self, pid: int): | ||
| self._pid = _check_pid(pid) | ||
|
|
||
| @property | ||
| def pid(self) -> int: | ||
| """ | ||
| Process ID of the CUDA process. | ||
| """ | ||
| return self._pid | ||
|
|
||
| @property | ||
| def state(self) -> _ProcessStateT: | ||
| """ | ||
| CUDA checkpoint state for this process. | ||
| """ | ||
| driver = _get_driver() | ||
| state = _call_driver(driver, driver.cuCheckpointProcessGetState, self._pid) | ||
| state_names = _get_process_state_names(driver) | ||
| try: | ||
| return state_names[state] | ||
| except KeyError as e: | ||
| state_value = int(state) | ||
| raise RuntimeError(f"Unknown CUDA checkpoint process state: {state_value}") from e | ||
|
|
||
| @property | ||
| def restore_thread_id(self) -> int: | ||
| """ | ||
| CUDA restore thread ID for this process. | ||
| """ | ||
| driver = _get_driver() | ||
| return _call_driver(driver, driver.cuCheckpointProcessGetRestoreThreadId, self._pid) | ||
|
|
||
| def lock(self, timeout_ms: int = 0) -> None: | ||
|
leofang marked this conversation as resolved.
|
||
| """ | ||
| Lock this process, blocking further CUDA API calls. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| timeout_ms : int, optional | ||
| Timeout in milliseconds. A value of 0 indicates no timeout. | ||
| """ | ||
| driver = _get_driver() | ||
| args = driver.CUcheckpointLockArgs() | ||
| args.timeoutMs = _check_timeout_ms(timeout_ms) | ||
| _call_driver(driver, driver.cuCheckpointProcessLock, self._pid, args) | ||
|
|
||
| def checkpoint(self) -> None: | ||
| """ | ||
| Checkpoint the GPU memory contents of this locked process. | ||
| """ | ||
| driver = _get_driver() | ||
| _call_driver(driver, driver.cuCheckpointProcessCheckpoint, self._pid, None) | ||
|
|
||
| def restore(self, gpu_mapping: _Mapping[_Any, _Any] | None = None) -> None: | ||
| """ | ||
| Restore this checkpointed process. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| gpu_mapping : mapping, optional | ||
| GPU UUID remapping from each checkpointed GPU UUID to the GPU UUID | ||
| to restore onto. For migration workflows, provide mappings for | ||
| every GPU visible to the kernel-mode driver. User-space masking | ||
| such as ``CUDA_VISIBLE_DEVICES`` does not reduce this mapping | ||
| requirement. | ||
| """ | ||
| driver = _get_driver() | ||
| args = _make_restore_args(driver, gpu_mapping) | ||
| _call_driver(driver, driver.cuCheckpointProcessRestore, self._pid, args) | ||
|
|
||
| def unlock(self) -> None: | ||
| """ | ||
| Unlock this locked process so it can resume CUDA API calls. | ||
| """ | ||
| driver = _get_driver() | ||
| _call_driver(driver, driver.cuCheckpointProcessUnlock, self._pid, None) | ||
|
|
||
|
|
||
| def _get_driver(): | ||
| global _driver_capability_checked | ||
| if _driver_capability_checked: | ||
| return _driver | ||
|
|
||
| binding_ver = _binding_version() | ||
| if not _binding_version_supports_checkpoint(binding_ver): | ||
| raise RuntimeError( | ||
| "CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. " | ||
| f"Found cuda.bindings {'.'.join(str(part) for part in binding_ver[:3])}." | ||
| ) | ||
|
|
||
| missing = [name for name in _REQUIRED_BINDING_ATTRS if not hasattr(_driver, name)] | ||
| if missing: | ||
| raise RuntimeError( | ||
| f"CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. Missing: {', '.join(missing)}" | ||
| ) | ||
|
|
||
| driver_ver = _driver_version() | ||
| if driver_ver < _REQUIRED_DRIVER_VERSION: | ||
| raise RuntimeError( | ||
| "CUDA checkpointing is not supported by the installed NVIDIA driver. " | ||
| "Upgrade to a driver version with CUDA checkpoint API support." | ||
| ) | ||
|
|
||
| _driver_capability_checked = True | ||
| return _driver | ||
|
leofang marked this conversation as resolved.
|
||
|
|
||
|
|
||
| def _binding_version_supports_checkpoint(version) -> bool: | ||
| major, minor, patch = version[:3] | ||
| return (major == 12 and (minor, patch) >= (8, 0)) or (major == 13 and (minor, patch) >= (0, 2)) or major > 13 | ||
|
|
||
|
|
||
| def _get_process_state_names(driver) -> dict[_Any, _ProcessStateT]: | ||
| return {getattr(driver.CUprocessState, attr): state_name for attr, state_name in _PROCESS_STATE_NAME_ATTRS} | ||
|
|
||
|
|
||
| def _call_driver(driver, func, *args): | ||
| try: | ||
| result = func(*args) | ||
| except RuntimeError as e: | ||
| if "cuCheckpointProcess" in str(e) and "not found" in str(e): | ||
| raise RuntimeError( | ||
| "CUDA checkpointing is not supported by the installed NVIDIA driver. " | ||
| "Upgrade to a driver version with CUDA checkpoint API support." | ||
| ) from e | ||
| raise | ||
|
|
||
| err = result[0] | ||
| not_supported_errors = ( | ||
| getattr(driver.CUresult, "CUDA_ERROR_NOT_FOUND", None), | ||
| getattr(driver.CUresult, "CUDA_ERROR_NOT_SUPPORTED", None), | ||
| ) | ||
| if err in not_supported_errors: | ||
| raise RuntimeError( | ||
| "CUDA checkpointing is not supported by the installed NVIDIA driver. " | ||
| "Upgrade to a driver version with CUDA checkpoint API support." | ||
| ) | ||
|
|
||
| return _handle_cuda_return(result) | ||
|
|
||
|
|
||
| def _check_pid(pid: int) -> int: | ||
| if isinstance(pid, bool) or not isinstance(pid, int): | ||
| raise TypeError("pid must be an int") | ||
| if pid <= 0: | ||
| raise ValueError("pid must be a positive int") | ||
| return pid | ||
|
|
||
|
|
||
| def _check_timeout_ms(timeout_ms: int) -> int: | ||
| if isinstance(timeout_ms, bool) or not isinstance(timeout_ms, int): | ||
| raise TypeError("timeout_ms must be an int") | ||
| if timeout_ms < 0: | ||
| raise ValueError("timeout_ms must be >= 0") | ||
| return timeout_ms | ||
|
|
||
|
|
||
| def _make_restore_args(driver, gpu_mapping: _Mapping[_Any, _Any] | None): | ||
| if gpu_mapping is None: | ||
| return None | ||
| if not isinstance(gpu_mapping, _Mapping): | ||
| raise TypeError("gpu_mapping must be a mapping from checkpointed GPU UUID to restore GPU UUID") | ||
|
|
||
| pairs = [] | ||
| for old_uuid, new_uuid in gpu_mapping.items(): | ||
| pair = driver.CUcheckpointGpuPair() | ||
| buffers = [] | ||
| pair.oldUuid = _as_cuuuid(driver, old_uuid, buffers) | ||
| pair.newUuid = _as_cuuuid(driver, new_uuid, buffers) | ||
| pairs.append(pair) | ||
|
|
||
| if not pairs: | ||
| return None | ||
|
|
||
| args = driver.CUcheckpointRestoreArgs() | ||
| args.gpuPairs = pairs | ||
| args.gpuPairsCount = len(pairs) | ||
|
leofang marked this conversation as resolved.
|
||
| return args | ||
|
|
||
|
|
||
| def _as_cuuuid(driver, value, buffers): | ||
| """Convert *value* to a ``CUuuid``. | ||
|
|
||
| Accepts a ``CUuuid`` instance (returned as-is) or a UUID string in | ||
| the ``"xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"`` format returned by | ||
| :attr:`Device.uuid`. | ||
| """ | ||
| if isinstance(value, str): | ||
| raw = bytes.fromhex(value.replace("-", "")) | ||
| if len(raw) != 16: | ||
| raise ValueError(f"GPU UUID string must be 32 hex characters (with optional hyphens), got {value!r}") | ||
| buf = _ctypes.create_string_buffer(raw, 16) | ||
| buffers.append(buf) | ||
| return driver.CUuuid(_ctypes.addressof(buf)) | ||
| return value | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "Process", | ||
| ] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.