Skip to content

Commit ba1a57e

Browse files
rwgkcursoragent
andcommitted
pathfinder: simplify NVML init/shutdown handling and document ref-count contract
Collapse the duplicated ``nvmlShutdown()`` calls in ``_query_driver_release_version_text`` into a single ``try/finally`` so the cleanup always runs in one place. The asymmetric error-precedence rule is preserved via ``sys.exc_info()[1]``: when both the NVML body and shutdown fail, the body's error wins (Python keeps the shutdown error on ``__context__`` for debugging); when only shutdown fails, the shutdown error surfaces. Add comments above the matched ``nvmlInit_v2()`` / ``nvmlShutdown()`` pair noting that NVML's init/shutdown is reference-counted, so this balanced pair is safe even when the caller has already initialized NVML elsewhere in the process. Pre-empts a question raised in review on PR #2000. Add two focused tests filling out the cleanup matrix: - ``test_query_driver_release_version_text_raises_when_only_shutdown_fails`` asserts a non-zero shutdown status surfaces when the body succeeded. - ``test_query_driver_release_version_text_body_error_wins_when_both_fail`` locks in the body-error-wins precedence when both calls fail. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent c30a373 commit ba1a57e

2 files changed

Lines changed: 42 additions & 8 deletions

File tree

cuda_pathfinder/cuda/pathfinder/_utils/driver_info.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import ctypes
77
import functools
88
import re
9+
import sys
910
from collections.abc import Callable
1011
from dataclasses import dataclass
1112
from typing import cast
@@ -140,6 +141,9 @@ def _query_driver_release_version_text() -> str:
140141
nvml_shutdown.argtypes = []
141142
nvml_shutdown.restype = ctypes.c_int
142143

144+
# NVML's init/shutdown pair is reference-counted (see "Initialization and
145+
# Cleanup" in the NVML API docs), so this balanced pair is safe even when
146+
# the caller has already initialized NVML elsewhere in the process.
143147
init_status = nvml_init_v2()
144148
if init_status != _NVML_SUCCESS:
145149
raise RuntimeError(f"Failed to initialize NVML via nvmlInit_v2() (status={init_status}).")
@@ -152,13 +156,13 @@ def _query_driver_release_version_text() -> str:
152156
f"Failed to query driver release version via nvmlSystemGetDriverVersion() (status={status})."
153157
)
154158
release_version = version_buffer.value.decode()
155-
except BaseException as exc:
159+
finally:
160+
# Balance the init_v2() above unconditionally. If the body already
161+
# raised, let that error win; a non-zero shutdown status here would
162+
# only mask the more useful root cause (Python keeps it on
163+
# ``__context__`` for debugging). ``sys.exc_info()[1]`` is the
164+
# currently-propagating exception inside the finally, or None.
156165
shutdown_status = nvml_shutdown()
157-
if shutdown_status != _NVML_SUCCESS:
158-
raise RuntimeError(f"Failed to shut down NVML via nvmlShutdown() (status={shutdown_status}).") from exc
159-
raise
160-
161-
shutdown_status = nvml_shutdown()
162-
if shutdown_status != _NVML_SUCCESS:
163-
raise RuntimeError(f"Failed to shut down NVML via nvmlShutdown() (status={shutdown_status}).")
166+
if shutdown_status != _NVML_SUCCESS and sys.exc_info()[1] is None:
167+
raise RuntimeError(f"Failed to shut down NVML via nvmlShutdown() (status={shutdown_status}).")
164168
return release_version

cuda_pathfinder/tests/test_utils_driver_info.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,36 @@ def test_query_driver_release_version_text_raises_when_nvml_call_fails(monkeypat
155155
assert fake_nvml_lib.shutdown_calls == 1
156156

157157

158+
def test_query_driver_release_version_text_raises_when_only_shutdown_fails(monkeypatch):
159+
fake_nvml_lib = _FakeNvmlLib(shutdown_statuses=(2,))
160+
161+
monkeypatch.setattr(
162+
driver_info,
163+
"_load_nvidia_dynamic_lib",
164+
lambda _libname: _loaded_nvml("/usr/lib/libnvidia-ml.so.1"),
165+
)
166+
monkeypatch.setattr(driver_info.ctypes, "CDLL", lambda _abs_path: fake_nvml_lib)
167+
168+
with pytest.raises(RuntimeError, match=r"nvmlShutdown\(\) \(status=2\)"):
169+
driver_info._query_driver_release_version_text()
170+
assert fake_nvml_lib.shutdown_calls == 1
171+
172+
173+
def test_query_driver_release_version_text_body_error_wins_when_both_fail(monkeypatch):
174+
fake_nvml_lib = _FakeNvmlLib(query_status=1, shutdown_statuses=(2,))
175+
176+
monkeypatch.setattr(
177+
driver_info,
178+
"_load_nvidia_dynamic_lib",
179+
lambda _libname: _loaded_nvml("/usr/lib/libnvidia-ml.so.1"),
180+
)
181+
monkeypatch.setattr(driver_info.ctypes, "CDLL", lambda _abs_path: fake_nvml_lib)
182+
183+
with pytest.raises(RuntimeError, match=r"nvmlSystemGetDriverVersion\(\) \(status=1\)"):
184+
driver_info._query_driver_release_version_text()
185+
assert fake_nvml_lib.shutdown_calls == 1
186+
187+
158188
def test_query_driver_cuda_version_uses_windll_on_windows(monkeypatch):
159189
fake_driver_lib = _FakeDriverLib(status=0, version=12080)
160190
loaded_paths: list[str] = []

0 commit comments

Comments
 (0)