Skip to content

Commit 689bed5

Browse files
cpcloudclaude
andcommitted
test(core): add GPU-free tests for Linker.backend classmethod
Tests monkeypatch module-level flags in cuda.core._linker to verify the classmethod without requiring a GPU. Covers nvJitLink path, driver path, probe invocation when not memoised, classmethod descriptor check, and idempotency. Refs #714 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent fe9db0e commit 689bed5

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
4+
5+
"""GPU-free tests for Linker.backend classmethod.
6+
7+
These live in a separate file from test_linker.py because that module calls
8+
Device() at import time, which requires a GPU. These tests use monkeypatch
9+
to set module-level flags and never touch CUDA devices.
10+
"""
11+
12+
import inspect
13+
14+
import pytest
15+
16+
import cuda.core._linker as _linker
17+
from cuda.core._linker import Linker
18+
19+
20+
class TestBackendClassmethod:
21+
def test_backend_returns_nvjitlink(self, monkeypatch):
22+
monkeypatch.setattr(_linker, "_use_nvjitlink_backend", True)
23+
assert Linker.backend() == "nvJitLink"
24+
25+
def test_backend_returns_driver(self, monkeypatch):
26+
monkeypatch.setattr(_linker, "_use_nvjitlink_backend", False)
27+
assert Linker.backend() == "driver"
28+
29+
def test_backend_invokes_probe_when_not_memoised(self, monkeypatch):
30+
monkeypatch.setattr(_linker, "_use_nvjitlink_backend", None)
31+
called = []
32+
33+
def fake_decide():
34+
called.append(True)
35+
return False # False = not falling back to driver = nvJitLink
36+
37+
monkeypatch.setattr(_linker, "_decide_nvjitlink_or_driver", fake_decide)
38+
result = Linker.backend()
39+
assert result == "nvJitLink"
40+
assert called, "_decide_nvjitlink_or_driver was not called"
41+
42+
def test_backend_is_classmethod(self):
43+
attr = inspect.getattr_static(Linker, "backend")
44+
assert isinstance(attr, classmethod)
45+
46+
def test_backend_idempotent(self, monkeypatch):
47+
monkeypatch.setattr(_linker, "_use_nvjitlink_backend", True)
48+
results = [Linker.backend() for _ in range(3)]
49+
assert results == ["nvJitLink", "nvJitLink", "nvJitLink"]

0 commit comments

Comments
 (0)