|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE |
| 3 | + |
| 4 | +import builtins |
| 5 | +import sys |
| 6 | +import types |
| 7 | +import warnings |
| 8 | + |
| 9 | +import pytest |
| 10 | +from cuda.core.experimental import _linker as linker |
| 11 | + |
| 12 | + |
| 13 | +@pytest.fixture(autouse=True) |
| 14 | +def fresh_env(monkeypatch): |
| 15 | + """ |
| 16 | + Put the module under test into a predictable state: |
| 17 | + - neutralize global caches, |
| 18 | + - provide a minimal 'driver' object, |
| 19 | + - make 'handle_return' a passthrough, |
| 20 | + - stabilize platform for deterministic messages. |
| 21 | + """ |
| 22 | + |
| 23 | + class FakeDriver: |
| 24 | + # Something realistic but not used by the logic under test |
| 25 | + def cuDriverGetVersion(self): |
| 26 | + return 12090 |
| 27 | + |
| 28 | + monkeypatch.setattr(linker, "_driver", None, raising=False) |
| 29 | + monkeypatch.setattr(linker, "_nvjitlink", None, raising=False) |
| 30 | + monkeypatch.setattr(linker, "_driver_ver", None, raising=False) |
| 31 | + monkeypatch.setattr(linker, "driver", FakeDriver(), raising=False) |
| 32 | + monkeypatch.setattr(linker, "handle_return", lambda x: x, raising=False) |
| 33 | + |
| 34 | + # Normalize platform-dependent wording (if any) |
| 35 | + monkeypatch.setattr(sys, "platform", "linux", raising=False) |
| 36 | + |
| 37 | + # Ensure a clean sys.modules slate for our synthetic packages |
| 38 | + for modname in list(sys.modules): |
| 39 | + if modname.startswith("cuda.bindings.nvjitlink") or modname == "cuda.bindings" or modname == "cuda": |
| 40 | + sys.modules.pop(modname, None) |
| 41 | + |
| 42 | + yield |
| 43 | + |
| 44 | + # Cleanup any stubs we added |
| 45 | + for modname in list(sys.modules): |
| 46 | + if modname.startswith("cuda.bindings.nvjitlink") or modname == "cuda.bindings" or modname == "cuda": |
| 47 | + sys.modules.pop(modname, None) |
| 48 | + |
| 49 | + |
| 50 | +def _install_public_nvjitlink_stub(): |
| 51 | + """ |
| 52 | + Provide enough structure so that: |
| 53 | + - `from cuda.bindings import nvjitlink` succeeds |
| 54 | + - `from cuda.bindings._internal import nvjitlink as inner_nvjitlink` succeeds |
| 55 | + We don't care about the contents of inner_nvjitlink because tests stub |
| 56 | + `_nvjitlink_has_version_symbol()` directly. |
| 57 | + """ |
| 58 | + # Make 'cuda' a package |
| 59 | + cuda_pkg = sys.modules.get("cuda") or types.ModuleType("cuda") |
| 60 | + cuda_pkg.__path__ = [] # mark as package |
| 61 | + sys.modules["cuda"] = cuda_pkg |
| 62 | + |
| 63 | + # Make 'cuda.bindings' a package |
| 64 | + bindings_pkg = sys.modules.get("cuda.bindings") or types.ModuleType("cuda.bindings") |
| 65 | + bindings_pkg.__path__ = [] # mark as package |
| 66 | + sys.modules["cuda.bindings"] = bindings_pkg |
| 67 | + |
| 68 | + # Public-facing nvjitlink module |
| 69 | + sys.modules["cuda.bindings.nvjitlink"] = types.ModuleType("cuda.bindings.nvjitlink") |
| 70 | + |
| 71 | + # Make 'cuda.bindings._internal' a package |
| 72 | + internal_pkg = sys.modules.get("cuda.bindings._internal") or types.ModuleType("cuda.bindings._internal") |
| 73 | + internal_pkg.__path__ = [] # mark as package |
| 74 | + sys.modules["cuda.bindings._internal"] = internal_pkg |
| 75 | + |
| 76 | + # Dummy inner nvjitlink module (imported but not actually used by tests) |
| 77 | + inner_nvjitlink_mod = types.ModuleType("cuda.bindings._internal.nvjitlink") |
| 78 | + # (optional) a no-op placeholder so attributes exist if accessed accidentally |
| 79 | + inner_nvjitlink_mod._inspect_function_pointer = lambda *_args, **_kw: True |
| 80 | + sys.modules["cuda.bindings._internal.nvjitlink"] = inner_nvjitlink_mod |
| 81 | + |
| 82 | + |
| 83 | +def _collect_runtime_warnings(record): |
| 84 | + return [w for w in record if issubclass(w.category, RuntimeWarning)] |
| 85 | + |
| 86 | + |
| 87 | +def _block_nvjitlink_import(monkeypatch): |
| 88 | + """Force 'from cuda.bindings import nvjitlink' to fail, regardless of sys.path.""" |
| 89 | + real_import = builtins.__import__ |
| 90 | + |
| 91 | + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): |
| 92 | + # Handle both 'from cuda.bindings import nvjitlink' and direct submodule imports |
| 93 | + target = "cuda.bindings.nvjitlink" |
| 94 | + if name == target or (name == "cuda.bindings" and fromlist and "nvjitlink" in fromlist): |
| 95 | + raise ModuleNotFoundError(target) |
| 96 | + return real_import(name, globals, locals, fromlist, level) |
| 97 | + |
| 98 | + monkeypatch.setattr(builtins, "__import__", fake_import) |
| 99 | + |
| 100 | + |
| 101 | +def test_warns_when_python_nvjitlink_missing(monkeypatch): |
| 102 | + """ |
| 103 | + Case 1: 'from cuda.bindings import nvjitlink' fails -> bindings missing. |
| 104 | + Expect a RuntimeWarning stating that cuda.bindings.nvjitlink is not available, |
| 105 | + and that we fall back to cuLink* (function returns True). |
| 106 | + """ |
| 107 | + # Ensure nothing is preloaded and actively block future imports. |
| 108 | + sys.modules.pop("cuda.bindings.nvjitlink", None) |
| 109 | + sys.modules.pop("cuda.bindings", None) |
| 110 | + sys.modules.pop("cuda", None) |
| 111 | + _block_nvjitlink_import(monkeypatch) |
| 112 | + |
| 113 | + with warnings.catch_warnings(record=True) as rec: |
| 114 | + warnings.simplefilter("always") |
| 115 | + ret = linker._decide_nvjitlink_or_driver() |
| 116 | + |
| 117 | + assert ret is True |
| 118 | + warns = _collect_runtime_warnings(rec) |
| 119 | + assert len(warns) == 1 |
| 120 | + msg = str(warns[0].message) |
| 121 | + assert "cuda.bindings.nvjitlink is not available" in msg |
| 122 | + assert "the culink APIs will be used instead" in msg |
| 123 | + assert "recent version of cuda-bindings." in msg |
| 124 | + |
| 125 | + |
| 126 | +def test_warns_when_nvjitlink_symbol_probe_raises(monkeypatch): |
| 127 | + """ |
| 128 | + Case 2: Bindings present, but symbol probe raises RuntimeError -> 'not available'. |
| 129 | + Expect a RuntimeWarning mentioning 'libnvJitLink.so* is not available' and fallback. |
| 130 | + """ |
| 131 | + _install_public_nvjitlink_stub() |
| 132 | + |
| 133 | + def raising_probe(_inner): |
| 134 | + raise RuntimeError("simulated: nvJitLink symbol unavailable") |
| 135 | + |
| 136 | + monkeypatch.setattr(linker, "_nvjitlink_has_version_symbol", raising_probe, raising=True) |
| 137 | + |
| 138 | + with warnings.catch_warnings(record=True) as rec: |
| 139 | + warnings.simplefilter("always") |
| 140 | + ret = linker._decide_nvjitlink_or_driver() |
| 141 | + |
| 142 | + assert ret is True |
| 143 | + warns = _collect_runtime_warnings(rec) |
| 144 | + assert len(warns) == 1 |
| 145 | + msg = str(warns[0].message) |
| 146 | + assert "libnvJitLink.so* is not available" in msg |
| 147 | + assert "cuda.bindings.nvjitlink is not usable" in msg |
| 148 | + assert "the culink APIs will be used instead" in msg |
| 149 | + assert "recent version of nvJitLink." in msg |
| 150 | + |
| 151 | + |
| 152 | +def test_warns_when_nvjitlink_too_old(monkeypatch): |
| 153 | + """ |
| 154 | + Case 3: Bindings present, probe returns False -> 'too old (<12.3)'. |
| 155 | + Expect a RuntimeWarning mentioning 'too old (<12.3)' and fallback. |
| 156 | + """ |
| 157 | + _install_public_nvjitlink_stub() |
| 158 | + monkeypatch.setattr(linker, "_nvjitlink_has_version_symbol", lambda _inner: False, raising=True) |
| 159 | + |
| 160 | + with warnings.catch_warnings(record=True) as rec: |
| 161 | + warnings.simplefilter("always") |
| 162 | + ret = linker._decide_nvjitlink_or_driver() |
| 163 | + |
| 164 | + assert ret is True |
| 165 | + warns = _collect_runtime_warnings(rec) |
| 166 | + assert len(warns) == 1 |
| 167 | + msg = str(warns[0].message) |
| 168 | + assert "libnvJitLink.so* is too old (<12.3)" in msg |
| 169 | + assert "cuda.bindings.nvjitlink is not usable" in msg |
| 170 | + assert "the culink APIs will be used instead" in msg |
| 171 | + assert "recent version of nvJitLink." in msg |
| 172 | + |
| 173 | + |
| 174 | +def test_uses_nvjitlink_when_available_and_ok(monkeypatch): |
| 175 | + """ |
| 176 | + Sanity: Bindings present and probe returns True → no warning, use nvJitLink. |
| 177 | + """ |
| 178 | + _install_public_nvjitlink_stub() |
| 179 | + monkeypatch.setattr(linker, "_nvjitlink_has_version_symbol", lambda _inner: True, raising=True) |
| 180 | + |
| 181 | + with warnings.catch_warnings(record=True) as rec: |
| 182 | + warnings.simplefilter("always") |
| 183 | + ret = linker._decide_nvjitlink_or_driver() |
| 184 | + |
| 185 | + assert ret is False # do NOT fall back |
| 186 | + warns = _collect_runtime_warnings(rec) |
| 187 | + assert not warns |
0 commit comments