Skip to content

Commit 6630d8e

Browse files
committed
Add test_linker_warnings.py
As generated by ChatGPT 5, with minor manual tweaks.
1 parent 4724769 commit 6630d8e

File tree

1 file changed

+187
-0
lines changed

1 file changed

+187
-0
lines changed
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
3+
4+
import builtins
5+
import sys
6+
import types
7+
import warnings
8+
9+
import pytest
10+
from cuda.core.experimental import _linker as linker
11+
12+
13+
@pytest.fixture(autouse=True)
14+
def fresh_env(monkeypatch):
15+
"""
16+
Put the module under test into a predictable state:
17+
- neutralize global caches,
18+
- provide a minimal 'driver' object,
19+
- make 'handle_return' a passthrough,
20+
- stabilize platform for deterministic messages.
21+
"""
22+
23+
class FakeDriver:
24+
# Something realistic but not used by the logic under test
25+
def cuDriverGetVersion(self):
26+
return 12090
27+
28+
monkeypatch.setattr(linker, "_driver", None, raising=False)
29+
monkeypatch.setattr(linker, "_nvjitlink", None, raising=False)
30+
monkeypatch.setattr(linker, "_driver_ver", None, raising=False)
31+
monkeypatch.setattr(linker, "driver", FakeDriver(), raising=False)
32+
monkeypatch.setattr(linker, "handle_return", lambda x: x, raising=False)
33+
34+
# Normalize platform-dependent wording (if any)
35+
monkeypatch.setattr(sys, "platform", "linux", raising=False)
36+
37+
# Ensure a clean sys.modules slate for our synthetic packages
38+
for modname in list(sys.modules):
39+
if modname.startswith("cuda.bindings.nvjitlink") or modname == "cuda.bindings" or modname == "cuda":
40+
sys.modules.pop(modname, None)
41+
42+
yield
43+
44+
# Cleanup any stubs we added
45+
for modname in list(sys.modules):
46+
if modname.startswith("cuda.bindings.nvjitlink") or modname == "cuda.bindings" or modname == "cuda":
47+
sys.modules.pop(modname, None)
48+
49+
50+
def _install_public_nvjitlink_stub():
51+
"""
52+
Provide enough structure so that:
53+
- `from cuda.bindings import nvjitlink` succeeds
54+
- `from cuda.bindings._internal import nvjitlink as inner_nvjitlink` succeeds
55+
We don't care about the contents of inner_nvjitlink because tests stub
56+
`_nvjitlink_has_version_symbol()` directly.
57+
"""
58+
# Make 'cuda' a package
59+
cuda_pkg = sys.modules.get("cuda") or types.ModuleType("cuda")
60+
cuda_pkg.__path__ = [] # mark as package
61+
sys.modules["cuda"] = cuda_pkg
62+
63+
# Make 'cuda.bindings' a package
64+
bindings_pkg = sys.modules.get("cuda.bindings") or types.ModuleType("cuda.bindings")
65+
bindings_pkg.__path__ = [] # mark as package
66+
sys.modules["cuda.bindings"] = bindings_pkg
67+
68+
# Public-facing nvjitlink module
69+
sys.modules["cuda.bindings.nvjitlink"] = types.ModuleType("cuda.bindings.nvjitlink")
70+
71+
# Make 'cuda.bindings._internal' a package
72+
internal_pkg = sys.modules.get("cuda.bindings._internal") or types.ModuleType("cuda.bindings._internal")
73+
internal_pkg.__path__ = [] # mark as package
74+
sys.modules["cuda.bindings._internal"] = internal_pkg
75+
76+
# Dummy inner nvjitlink module (imported but not actually used by tests)
77+
inner_nvjitlink_mod = types.ModuleType("cuda.bindings._internal.nvjitlink")
78+
# (optional) a no-op placeholder so attributes exist if accessed accidentally
79+
inner_nvjitlink_mod._inspect_function_pointer = lambda *_args, **_kw: True
80+
sys.modules["cuda.bindings._internal.nvjitlink"] = inner_nvjitlink_mod
81+
82+
83+
def _collect_runtime_warnings(record):
84+
return [w for w in record if issubclass(w.category, RuntimeWarning)]
85+
86+
87+
def _block_nvjitlink_import(monkeypatch):
88+
"""Force 'from cuda.bindings import nvjitlink' to fail, regardless of sys.path."""
89+
real_import = builtins.__import__
90+
91+
def fake_import(name, globals=None, locals=None, fromlist=(), level=0):
92+
# Handle both 'from cuda.bindings import nvjitlink' and direct submodule imports
93+
target = "cuda.bindings.nvjitlink"
94+
if name == target or (name == "cuda.bindings" and fromlist and "nvjitlink" in fromlist):
95+
raise ModuleNotFoundError(target)
96+
return real_import(name, globals, locals, fromlist, level)
97+
98+
monkeypatch.setattr(builtins, "__import__", fake_import)
99+
100+
101+
def test_warns_when_python_nvjitlink_missing(monkeypatch):
102+
"""
103+
Case 1: 'from cuda.bindings import nvjitlink' fails -> bindings missing.
104+
Expect a RuntimeWarning stating that cuda.bindings.nvjitlink is not available,
105+
and that we fall back to cuLink* (function returns True).
106+
"""
107+
# Ensure nothing is preloaded and actively block future imports.
108+
sys.modules.pop("cuda.bindings.nvjitlink", None)
109+
sys.modules.pop("cuda.bindings", None)
110+
sys.modules.pop("cuda", None)
111+
_block_nvjitlink_import(monkeypatch)
112+
113+
with warnings.catch_warnings(record=True) as rec:
114+
warnings.simplefilter("always")
115+
ret = linker._decide_nvjitlink_or_driver()
116+
117+
assert ret is True
118+
warns = _collect_runtime_warnings(rec)
119+
assert len(warns) == 1
120+
msg = str(warns[0].message)
121+
assert "cuda.bindings.nvjitlink is not available" in msg
122+
assert "the culink APIs will be used instead" in msg
123+
assert "recent version of cuda-bindings." in msg
124+
125+
126+
def test_warns_when_nvjitlink_symbol_probe_raises(monkeypatch):
127+
"""
128+
Case 2: Bindings present, but symbol probe raises RuntimeError -> 'not available'.
129+
Expect a RuntimeWarning mentioning 'libnvJitLink.so* is not available' and fallback.
130+
"""
131+
_install_public_nvjitlink_stub()
132+
133+
def raising_probe(_inner):
134+
raise RuntimeError("simulated: nvJitLink symbol unavailable")
135+
136+
monkeypatch.setattr(linker, "_nvjitlink_has_version_symbol", raising_probe, raising=True)
137+
138+
with warnings.catch_warnings(record=True) as rec:
139+
warnings.simplefilter("always")
140+
ret = linker._decide_nvjitlink_or_driver()
141+
142+
assert ret is True
143+
warns = _collect_runtime_warnings(rec)
144+
assert len(warns) == 1
145+
msg = str(warns[0].message)
146+
assert "libnvJitLink.so* is not available" in msg
147+
assert "cuda.bindings.nvjitlink is not usable" in msg
148+
assert "the culink APIs will be used instead" in msg
149+
assert "recent version of nvJitLink." in msg
150+
151+
152+
def test_warns_when_nvjitlink_too_old(monkeypatch):
153+
"""
154+
Case 3: Bindings present, probe returns False -> 'too old (<12.3)'.
155+
Expect a RuntimeWarning mentioning 'too old (<12.3)' and fallback.
156+
"""
157+
_install_public_nvjitlink_stub()
158+
monkeypatch.setattr(linker, "_nvjitlink_has_version_symbol", lambda _inner: False, raising=True)
159+
160+
with warnings.catch_warnings(record=True) as rec:
161+
warnings.simplefilter("always")
162+
ret = linker._decide_nvjitlink_or_driver()
163+
164+
assert ret is True
165+
warns = _collect_runtime_warnings(rec)
166+
assert len(warns) == 1
167+
msg = str(warns[0].message)
168+
assert "libnvJitLink.so* is too old (<12.3)" in msg
169+
assert "cuda.bindings.nvjitlink is not usable" in msg
170+
assert "the culink APIs will be used instead" in msg
171+
assert "recent version of nvJitLink." in msg
172+
173+
174+
def test_uses_nvjitlink_when_available_and_ok(monkeypatch):
175+
"""
176+
Sanity: Bindings present and probe returns True → no warning, use nvJitLink.
177+
"""
178+
_install_public_nvjitlink_stub()
179+
monkeypatch.setattr(linker, "_nvjitlink_has_version_symbol", lambda _inner: True, raising=True)
180+
181+
with warnings.catch_warnings(record=True) as rec:
182+
warnings.simplefilter("always")
183+
ret = linker._decide_nvjitlink_or_driver()
184+
185+
assert ret is False # do NOT fall back
186+
warns = _collect_runtime_warnings(rec)
187+
assert not warns

0 commit comments

Comments
 (0)