Skip to content

Commit 8d8f29b

Browse files
authored
Add CUDA version compatibility check (#1412)
* Add CUDA major version compatibility check Add warn_if_cuda_major_version_mismatch() to cuda-bindings that warns when cuda-bindings was compiled for a newer CUDA major version than the installed driver supports. Called by cuda.core on first Device access. * Move version check import to local scope Import warn_if_cuda_major_version_mismatch locally in Device.__new__ after cuInit, using try/except/else pattern instead of module-level import with lambda fallback. * Refactor Device.__new__ into helper functions Extract Device.__new__ logic into cdef helper functions: - Device_ensure_cuda_initialized(): cuInit + version check - Device_resolve_device_id(): resolve None to current device or 0 - Device_ensure_tls_devices(): create thread-local singletons Reduces Device.__new__ from ~60 lines to ~12 lines. Helpers placed after Device class following memory module pattern. * test: use monkeypatch to properly save/restore version check flag Replace setup_method/teardown_method with a pytest fixture that uses monkeypatch to properly save and restore the original value of _major_version_compatibility_checked after each test. Minor change to Cython cdef inline helper function signature. * Add environment variables documentation for cuda.core Document runtime environment variables that affect cuda.core behavior: - CUDA_PYTHON_CUDA_PER_THREAD_DEFAULT_STREAM - CUDA_PYTHON_DISABLE_MAJOR_VERSION_WARNING Include note linking to cuda-bindings environment variables documentation.
1 parent 13de2c2 commit 8d8f29b

7 files changed

Lines changed: 248 additions & 44 deletions

File tree

cuda_bindings/cuda/bindings/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Callable
44

55
from ._ptx_utils import get_minimal_required_cuda_ver_from_ptx_ver, get_ptx_ver
6+
from ._version_check import warn_if_cuda_major_version_mismatch
67

78
_handle_getters: dict[type, Callable[[Any], int]] = {}
89

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
3+
4+
import os
5+
import threading
6+
import warnings
7+
8+
# Track whether we've already checked major version compatibility
9+
_major_version_compatibility_checked = False
10+
_lock = threading.Lock()
11+
12+
13+
def warn_if_cuda_major_version_mismatch():
14+
"""Warn if the CUDA driver major version is older than cuda-bindings compile-time version.
15+
16+
This function compares the CUDA major version that cuda-bindings was compiled
17+
against with the CUDA major version supported by the installed driver. If the
18+
compile-time major version is greater than the driver's major version, a warning
19+
is issued.
20+
21+
The check runs only once per process. Subsequent calls are no-ops.
22+
23+
The warning can be suppressed by setting the environment variable
24+
``CUDA_PYTHON_DISABLE_MAJOR_VERSION_WARNING=1``.
25+
"""
26+
global _major_version_compatibility_checked
27+
if _major_version_compatibility_checked:
28+
return
29+
with _lock:
30+
if _major_version_compatibility_checked:
31+
return
32+
_major_version_compatibility_checked = True
33+
34+
# Allow users to suppress the warning
35+
if os.environ.get("CUDA_PYTHON_DISABLE_MAJOR_VERSION_WARNING"):
36+
return
37+
38+
# Import here to avoid circular imports and allow lazy loading
39+
from cuda.bindings import driver
40+
41+
# Get compile-time CUDA version from cuda-bindings
42+
compile_version = driver.CUDA_VERSION # e.g., 13010
43+
compile_major = compile_version // 1000
44+
45+
# Get runtime driver version
46+
err, runtime_version = driver.cuDriverGetVersion()
47+
if err != driver.CUresult.CUDA_SUCCESS:
48+
raise RuntimeError(f"Failed to query CUDA driver version: {err}")
49+
50+
runtime_major = runtime_version // 1000
51+
52+
if compile_major > runtime_major:
53+
warnings.warn(
54+
f"cuda-bindings was built for CUDA major version {compile_major}, but the "
55+
f"NVIDIA driver only supports up to CUDA {runtime_major}. Some cuda-bindings "
56+
f"features may not work correctly. Consider updating your NVIDIA driver, "
57+
f"or using a cuda-bindings version built for CUDA {runtime_major}. "
58+
f"(Set CUDA_PYTHON_DISABLE_MAJOR_VERSION_WARNING=1 to suppress this warning.)",
59+
UserWarning,
60+
stacklevel=3,
61+
)

cuda_bindings/docs/source/environment_variables.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ Runtime Environment Variables
99

1010
- ``CUDA_PYTHON_CUDA_PER_THREAD_DEFAULT_STREAM`` : When set to 1, the default stream is the per-thread default stream. When set to 0, the default stream is the legacy default stream. This defaults to 0, for the legacy default stream. See `Stream Synchronization Behavior <https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html>`_ for an explanation of the legacy and per-thread default streams.
1111

12+
- ``CUDA_PYTHON_DISABLE_MAJOR_VERSION_WARNING`` : When set to 1, suppresses warnings about CUDA major version mismatches between ``cuda-bindings`` and the installed driver.
13+
1214

1315
Build-Time Environment Variables
1416
--------------------------------
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
3+
4+
import os
5+
import warnings
6+
from unittest import mock
7+
8+
import pytest
9+
from cuda.bindings import driver
10+
from cuda.bindings.utils import _version_check, warn_if_cuda_major_version_mismatch
11+
12+
13+
class TestVersionCompatibilityCheck:
14+
"""Tests for CUDA major version mismatch warning function."""
15+
16+
@pytest.fixture(autouse=True)
17+
def reset_version_check(self, monkeypatch):
18+
"""Reset the version compatibility check flag for each test, restoring after."""
19+
monkeypatch.setattr(_version_check, "_major_version_compatibility_checked", False)
20+
21+
def test_no_warning_when_driver_newer(self):
22+
"""No warning should be issued when driver version >= compile version."""
23+
# Mock compile version 12.9 and driver version 13.0
24+
with (
25+
mock.patch.object(driver, "CUDA_VERSION", 12090),
26+
mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 13000)),
27+
warnings.catch_warnings(record=True) as w,
28+
):
29+
warnings.simplefilter("always")
30+
warn_if_cuda_major_version_mismatch()
31+
assert len(w) == 0
32+
33+
def test_no_warning_when_same_major_version(self):
34+
"""No warning should be issued when major versions match."""
35+
# Mock compile version 12.9 and driver version 12.8
36+
with (
37+
mock.patch.object(driver, "CUDA_VERSION", 12090),
38+
mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 12080)),
39+
warnings.catch_warnings(record=True) as w,
40+
):
41+
warnings.simplefilter("always")
42+
warn_if_cuda_major_version_mismatch()
43+
assert len(w) == 0
44+
45+
def test_warning_when_compile_major_newer(self):
46+
"""Warning should be issued when compile major version > driver major version."""
47+
# Mock compile version 13.0 and driver version 12.8
48+
with (
49+
mock.patch.object(driver, "CUDA_VERSION", 13000),
50+
mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 12080)),
51+
warnings.catch_warnings(record=True) as w,
52+
):
53+
warnings.simplefilter("always")
54+
warn_if_cuda_major_version_mismatch()
55+
assert len(w) == 1
56+
assert issubclass(w[0].category, UserWarning)
57+
assert "cuda-bindings was built for CUDA major version 13" in str(w[0].message)
58+
assert "only supports up to CUDA 12" in str(w[0].message)
59+
60+
def test_warning_only_issued_once(self):
61+
"""Warning should only be issued once per process."""
62+
with (
63+
mock.patch.object(driver, "CUDA_VERSION", 13000),
64+
mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 12080)),
65+
warnings.catch_warnings(record=True) as w,
66+
):
67+
warnings.simplefilter("always")
68+
warn_if_cuda_major_version_mismatch()
69+
warn_if_cuda_major_version_mismatch()
70+
warn_if_cuda_major_version_mismatch()
71+
# Only one warning despite multiple calls
72+
assert len(w) == 1
73+
74+
def test_warning_suppressed_by_env_var(self):
75+
"""Warning should be suppressed when CUDA_PYTHON_DISABLE_MAJOR_VERSION_WARNING is set."""
76+
with (
77+
mock.patch.object(driver, "CUDA_VERSION", 13000),
78+
mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 12080)),
79+
mock.patch.dict(os.environ, {"CUDA_PYTHON_DISABLE_MAJOR_VERSION_WARNING": "1"}),
80+
warnings.catch_warnings(record=True) as w,
81+
):
82+
warnings.simplefilter("always")
83+
warn_if_cuda_major_version_mismatch()
84+
assert len(w) == 0
85+
86+
def test_error_when_driver_version_fails(self):
87+
"""Should raise RuntimeError if cuDriverGetVersion fails."""
88+
with (
89+
mock.patch.object(driver, "CUDA_VERSION", 13000),
90+
mock.patch.object(
91+
driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_ERROR_NOT_INITIALIZED, 0)
92+
),
93+
pytest.raises(RuntimeError, match="Failed to query CUDA driver version"),
94+
):
95+
warn_if_cuda_major_version_mismatch()

cuda_core/cuda/core/_device.pyx

Lines changed: 62 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -958,53 +958,12 @@ class Device:
958958
__slots__ = ("_device_id", "_memory_resource", "_has_inited", "_properties", "_uuid", "_context")
959959

960960
def __new__(cls, device_id: Device | int | None = None):
961-
# Handle device_id argument.
962961
if isinstance(device_id, Device):
963962
return device_id
964-
else:
965-
device_id = getattr(device_id, 'device_id', device_id)
966-
967-
# Initialize CUDA.
968-
global _is_cuInit
969-
if _is_cuInit is False:
970-
with _lock, nogil:
971-
HANDLE_RETURN(cydriver.cuInit(0))
972-
_is_cuInit = True
973963

974-
# important: creating a Device instance does not initialize the GPU!
975-
cdef cydriver.CUdevice dev
976-
cdef cydriver.CUcontext ctx
977-
if device_id is None:
978-
with nogil:
979-
err = cydriver.cuCtxGetDevice(&dev)
980-
if err == cydriver.CUresult.CUDA_SUCCESS:
981-
device_id = int(dev)
982-
elif err == cydriver.CUresult.CUDA_ERROR_INVALID_CONTEXT:
983-
# No context is current - verify and default to device 0 (cudart behavior)
984-
assert cydriver.cuCtxGetCurrent(&ctx) == cydriver.CUresult.CUDA_SUCCESS and ctx == NULL
985-
device_id = 0
986-
else:
987-
HANDLE_RETURN(err)
988-
elif device_id < 0:
989-
raise ValueError(f"device_id must be >= 0, got {device_id}")
990-
991-
# ensure Device is singleton
992-
cdef int total
993-
try:
994-
devices = _tls.devices
995-
except AttributeError:
996-
with nogil:
997-
HANDLE_RETURN(cydriver.cuDeviceGetCount(&total))
998-
devices = _tls.devices = []
999-
for i in range(total):
1000-
device = super().__new__(cls)
1001-
device._device_id = i
1002-
device._memory_resource = None
1003-
device._has_inited = False
1004-
device._properties = None
1005-
device._uuid = None
1006-
device._context = None
1007-
devices.append(device)
964+
Device_ensure_cuda_initialized()
965+
device_id = Device_resolve_device_id(device_id)
966+
devices = Device_ensure_tls_devices(cls)
1008967

1009968
try:
1010969
return devices[device_id]
@@ -1414,3 +1373,62 @@ class Device:
14141373
"""
14151374
self._check_context_initialized()
14161375
return GraphBuilder._init(stream=self.create_stream(), is_stream_owner=True)
1376+
1377+
1378+
cdef inline int Device_ensure_cuda_initialized() except? -1:
1379+
"""Initialize CUDA driver and check version compatibility (once per process)."""
1380+
global _is_cuInit
1381+
if _is_cuInit is False:
1382+
with _lock, nogil:
1383+
HANDLE_RETURN(cydriver.cuInit(0))
1384+
_is_cuInit = True
1385+
try:
1386+
from cuda.bindings.utils import warn_if_cuda_major_version_mismatch
1387+
except ImportError:
1388+
pass
1389+
else:
1390+
warn_if_cuda_major_version_mismatch()
1391+
return 0
1392+
1393+
1394+
cdef inline int Device_resolve_device_id(device_id) except? -1:
1395+
"""Resolve device_id, defaulting to current device or 0."""
1396+
cdef cydriver.CUdevice dev
1397+
cdef cydriver.CUcontext ctx
1398+
cdef cydriver.CUresult err
1399+
if device_id is None:
1400+
with nogil:
1401+
err = cydriver.cuCtxGetDevice(&dev)
1402+
if err == cydriver.CUresult.CUDA_SUCCESS:
1403+
return int(dev)
1404+
elif err == cydriver.CUresult.CUDA_ERROR_INVALID_CONTEXT:
1405+
with nogil:
1406+
HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx))
1407+
assert <void*>(ctx) == NULL
1408+
return 0 # cudart behavior
1409+
else:
1410+
HANDLE_RETURN(err)
1411+
elif device_id < 0:
1412+
raise ValueError(f"device_id must be >= 0, got {device_id}")
1413+
return device_id
1414+
1415+
1416+
cdef inline list Device_ensure_tls_devices(cls):
1417+
"""Ensure thread-local Device singletons exist, creating if needed."""
1418+
cdef int total
1419+
try:
1420+
return _tls.devices
1421+
except AttributeError:
1422+
with nogil:
1423+
HANDLE_RETURN(cydriver.cuDeviceGetCount(&total))
1424+
devices = _tls.devices = []
1425+
for dev_id in range(total):
1426+
device = super(Device, cls).__new__(cls)
1427+
device._device_id = dev_id
1428+
device._memory_resource = None
1429+
device._has_inited = False
1430+
device._properties = None
1431+
device._uuid = None
1432+
device._context = None
1433+
devices.append(device)
1434+
return devices
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
.. SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
.. SPDX-License-Identifier: Apache-2.0
3+
4+
Environment Variables
5+
=====================
6+
7+
.. note::
8+
9+
The ``cuda-bindings`` runtime environment variables also affect ``cuda.core``.
10+
See the `cuda-bindings environment variables documentation
11+
<https://nvidia.github.io/cuda-python/cuda-bindings/latest/environment_variables.html>`_.
12+
13+
Runtime Environment Variables
14+
-----------------------------
15+
16+
- ``CUDA_PYTHON_CUDA_PER_THREAD_DEFAULT_STREAM`` : When set to 1, the default
17+
stream is the per-thread default stream. When set to 0, the default stream
18+
is the legacy default stream. This defaults to 0, for the legacy default
19+
stream. See `Stream Synchronization Behavior
20+
<https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html>`_
21+
for an explanation of the legacy and per-thread default streams.
22+
23+
- ``CUDA_PYTHON_DISABLE_MAJOR_VERSION_WARNING`` : When set to 1, suppresses
24+
warnings about CUDA major version mismatches between ``cuda-bindings`` and
25+
the installed driver. This warning occurs when ``cuda-bindings`` was built
26+
for a newer CUDA major version than the installed driver supports.

cuda_core/docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Welcome to the documentation for ``cuda.core``.
1414
install
1515
interoperability
1616
api
17+
environment_variables
1718
contribute
1819

1920
.. toctree::

0 commit comments

Comments
 (0)