Skip to content

Commit 6909d36

Browse files
authored
refactor(pathfinder): privatize optional_cuda_import (NVIDIA#1742)
* refactor(pathfinder): privatize `optional_cuda_import` Rename `optional_cuda_import` to `_optional_cuda_import` to remove it from the public API surface. It is an internal helper used only by `cuda.core` and should not be part of the stable public interface. Made-with: Cursor * docs: trim release notes for private helper Made-with: Cursor * fix: stop re-exporting private `_optional_cuda_import` from `__init__` The re-export shadowed the `_optional_cuda_import` submodule with the function of the same name, breaking `import ... as` in tests. Import directly from the submodule in consumers instead. Made-with: Cursor
1 parent ec99964 commit 6909d36

File tree

8 files changed

+29
-32
lines changed

8 files changed

+29
-32
lines changed

cuda_core/cuda/core/_linker.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ from dataclasses import dataclass
2929
from typing import Union
3030
from warnings import warn
3131

32-
from cuda.pathfinder import optional_cuda_import
32+
from cuda.pathfinder._optional_cuda_import import _optional_cuda_import
3333
from cuda.core._device import Device
3434
from cuda.core._module import ObjectCode
3535
from cuda.core._utils.clear_error_support import assert_type
@@ -650,7 +650,7 @@ def _decide_nvjitlink_or_driver() -> bool:
650650
" For best results, consider upgrading to a recent version of"
651651
)
652652

653-
nvjitlink_module = optional_cuda_import(
653+
nvjitlink_module = _optional_cuda_import(
654654
"cuda.bindings.nvjitlink",
655655
probe_function=lambda module: module.version(), # probe triggers nvJitLink runtime load
656656
)

cuda_core/cuda/core/_program.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import threading
1414
from warnings import warn
1515

1616
from cuda.bindings import driver, nvrtc
17-
from cuda.pathfinder import optional_cuda_import
17+
from cuda.pathfinder._optional_cuda_import import _optional_cuda_import
1818

1919
from libcpp.vector cimport vector
2020

@@ -485,7 +485,7 @@ def _get_nvvm_module():
485485
"Please update cuda-bindings to use NVVM features."
486486
)
487487

488-
nvvm = optional_cuda_import(
488+
nvvm = _optional_cuda_import(
489489
"cuda.bindings.nvvm",
490490
probe_function=lambda module: module.version(), # probe triggers libnvvm load
491491
)

cuda_core/tests/test_optional_dependency_imports.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,14 @@ def _patch_driver_version(monkeypatch, version=13000):
4747
def test_get_nvvm_module_reraises_nested_module_not_found(monkeypatch):
4848
monkeypatch.setattr(_program, "get_binding_version", lambda: (12, 9))
4949

50-
def fake_optional_cuda_import(modname, probe_function=None):
50+
def fake__optional_cuda_import(modname, probe_function=None):
5151
assert modname == "cuda.bindings.nvvm"
5252
assert probe_function is not None
5353
err = ModuleNotFoundError("No module named 'not_a_real_dependency'")
5454
err.name = "not_a_real_dependency"
5555
raise err
5656

57-
monkeypatch.setattr(_program, "optional_cuda_import", fake_optional_cuda_import)
57+
monkeypatch.setattr(_program, "_optional_cuda_import", fake__optional_cuda_import)
5858

5959
with pytest.raises(ModuleNotFoundError, match="not_a_real_dependency") as excinfo:
6060
_program._get_nvvm_module()
@@ -64,12 +64,12 @@ def fake_optional_cuda_import(modname, probe_function=None):
6464
def test_get_nvvm_module_reports_missing_nvvm_module(monkeypatch):
6565
monkeypatch.setattr(_program, "get_binding_version", lambda: (12, 9))
6666

67-
def fake_optional_cuda_import(modname, probe_function=None):
67+
def fake__optional_cuda_import(modname, probe_function=None):
6868
assert modname == "cuda.bindings.nvvm"
6969
assert probe_function is not None
7070
return None
7171

72-
monkeypatch.setattr(_program, "optional_cuda_import", fake_optional_cuda_import)
72+
monkeypatch.setattr(_program, "_optional_cuda_import", fake__optional_cuda_import)
7373

7474
with pytest.raises(RuntimeError, match="cuda.bindings.nvvm"):
7575
_program._get_nvvm_module()
@@ -78,12 +78,12 @@ def fake_optional_cuda_import(modname, probe_function=None):
7878
def test_get_nvvm_module_handles_missing_libnvvm(monkeypatch):
7979
monkeypatch.setattr(_program, "get_binding_version", lambda: (12, 9))
8080

81-
def fake_optional_cuda_import(modname, probe_function=None):
81+
def fake__optional_cuda_import(modname, probe_function=None):
8282
assert modname == "cuda.bindings.nvvm"
8383
assert probe_function is not None
8484
return None
8585

86-
monkeypatch.setattr(_program, "optional_cuda_import", fake_optional_cuda_import)
86+
monkeypatch.setattr(_program, "_optional_cuda_import", fake__optional_cuda_import)
8787

8888
with pytest.raises(RuntimeError, match="libnvvm"):
8989
_program._get_nvvm_module()
@@ -92,14 +92,14 @@ def fake_optional_cuda_import(modname, probe_function=None):
9292
def test_decide_nvjitlink_or_driver_reraises_nested_module_not_found(monkeypatch):
9393
_patch_driver_version(monkeypatch)
9494

95-
def fake_optional_cuda_import(modname, probe_function=None):
95+
def fake__optional_cuda_import(modname, probe_function=None):
9696
assert modname == "cuda.bindings.nvjitlink"
9797
assert probe_function is not None
9898
err = ModuleNotFoundError("No module named 'not_a_real_dependency'")
9999
err.name = "not_a_real_dependency"
100100
raise err
101101

102-
monkeypatch.setattr(_linker, "optional_cuda_import", fake_optional_cuda_import)
102+
monkeypatch.setattr(_linker, "_optional_cuda_import", fake__optional_cuda_import)
103103

104104
with pytest.raises(ModuleNotFoundError, match="not_a_real_dependency") as excinfo:
105105
_linker._decide_nvjitlink_or_driver()
@@ -109,12 +109,12 @@ def fake_optional_cuda_import(modname, probe_function=None):
109109
def test_decide_nvjitlink_or_driver_falls_back_when_module_missing(monkeypatch):
110110
_patch_driver_version(monkeypatch)
111111

112-
def fake_optional_cuda_import(modname, probe_function=None):
112+
def fake__optional_cuda_import(modname, probe_function=None):
113113
assert modname == "cuda.bindings.nvjitlink"
114114
assert probe_function is not None
115115
return None
116116

117-
monkeypatch.setattr(_linker, "optional_cuda_import", fake_optional_cuda_import)
117+
monkeypatch.setattr(_linker, "_optional_cuda_import", fake__optional_cuda_import)
118118

119119
with pytest.warns(RuntimeWarning, match="cuda.bindings.nvjitlink is not available"):
120120
use_driver_backend = _linker._decide_nvjitlink_or_driver()

cuda_pathfinder/cuda/pathfinder/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
locate_nvidia_header_directory as locate_nvidia_header_directory,
2626
)
2727
from cuda.pathfinder._headers.supported_nvidia_headers import SUPPORTED_HEADERS_CTK as _SUPPORTED_HEADERS_CTK
28-
from cuda.pathfinder._optional_cuda_import import optional_cuda_import as optional_cuda_import
2928
from cuda.pathfinder._static_libs.find_bitcode_lib import (
3029
SUPPORTED_BITCODE_LIBS as _SUPPORTED_BITCODE_LIBS,
3130
)

cuda_pathfinder/cuda/pathfinder/_optional_cuda_import.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from cuda.pathfinder._dynamic_libs.load_dl_common import DynamicLibNotFoundError
1111

1212

13-
def optional_cuda_import(
13+
def _optional_cuda_import(
1414
fully_qualified_modname: str,
1515
*,
1616
probe_function: Callable[[ModuleType], object] | None = None,

cuda_pathfinder/docs/source/api.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ locating NVIDIA C/C++ header directories, and finding CUDA binary utilities.
1414

1515
SUPPORTED_NVIDIA_LIBNAMES
1616
load_nvidia_dynamic_lib
17-
optional_cuda_import
1817
LoadedDL
1918
DynamicLibNotFoundError
2019
DynamicLibUnknownError

cuda_pathfinder/docs/source/release/1.4.2-notes.rst

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,5 @@
99
Highlights
1010
----------
1111

12-
* Add ``optional_cuda_import()`` to support robust optional imports of CUDA
13-
Python modules. It returns ``None`` when the requested module is absent or a
14-
probe hits ``DynamicLibNotFoundError``, while still re-raising unrelated
15-
``ModuleNotFoundError`` exceptions (for missing transitive dependencies).
12+
* Privatize ``optional_cuda_import()`` (renamed to ``_optional_cuda_import()``)
13+
to remove it from the public API surface.

cuda_pathfinder/tests/test_optional_cuda_import.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,33 @@
66
import pytest
77

88
import cuda.pathfinder._optional_cuda_import as optional_import_mod
9-
from cuda.pathfinder import DynamicLibNotFoundError, optional_cuda_import
9+
from cuda.pathfinder import DynamicLibNotFoundError
10+
from cuda.pathfinder._optional_cuda_import import _optional_cuda_import
1011

1112

12-
def test_optional_cuda_import_returns_module_when_available(monkeypatch):
13+
def test__optional_cuda_import_returns_module_when_available(monkeypatch):
1314
fake_module = types.SimpleNamespace(__name__="cuda.bindings.nvvm")
1415
monkeypatch.setattr(optional_import_mod.importlib, "import_module", lambda _name: fake_module)
1516

16-
result = optional_cuda_import("cuda.bindings.nvvm")
17+
result = _optional_cuda_import("cuda.bindings.nvvm")
1718

1819
assert result is fake_module
1920

2021

21-
def test_optional_cuda_import_returns_none_when_module_missing(monkeypatch):
22+
def test__optional_cuda_import_returns_none_when_module_missing(monkeypatch):
2223
def fake_import_module(name):
2324
err = ModuleNotFoundError("No module named 'cuda.bindings.nvvm'")
2425
err.name = name
2526
raise err
2627

2728
monkeypatch.setattr(optional_import_mod.importlib, "import_module", fake_import_module)
2829

29-
result = optional_cuda_import("cuda.bindings.nvvm")
30+
result = _optional_cuda_import("cuda.bindings.nvvm")
3031

3132
assert result is None
3233

3334

34-
def test_optional_cuda_import_reraises_nested_module_not_found(monkeypatch):
35+
def test__optional_cuda_import_reraises_nested_module_not_found(monkeypatch):
3536
def fake_import_module(_name):
3637
err = ModuleNotFoundError("No module named 'not_a_real_dependency'")
3738
err.name = "not_a_real_dependency"
@@ -40,28 +41,28 @@ def fake_import_module(_name):
4041
monkeypatch.setattr(optional_import_mod.importlib, "import_module", fake_import_module)
4142

4243
with pytest.raises(ModuleNotFoundError, match="not_a_real_dependency") as excinfo:
43-
optional_cuda_import("cuda.bindings.nvvm")
44+
_optional_cuda_import("cuda.bindings.nvvm")
4445
assert excinfo.value.name == "not_a_real_dependency"
4546

4647

47-
def test_optional_cuda_import_returns_none_when_probe_finds_missing_dynamic_lib(monkeypatch):
48+
def test__optional_cuda_import_returns_none_when_probe_finds_missing_dynamic_lib(monkeypatch):
4849
fake_module = types.SimpleNamespace(__name__="cuda.bindings.nvvm")
4950
monkeypatch.setattr(optional_import_mod.importlib, "import_module", lambda _name: fake_module)
5051

5152
def probe(_module):
5253
raise DynamicLibNotFoundError("libnvvm missing")
5354

54-
result = optional_cuda_import("cuda.bindings.nvvm", probe_function=probe)
55+
result = _optional_cuda_import("cuda.bindings.nvvm", probe_function=probe)
5556

5657
assert result is None
5758

5859

59-
def test_optional_cuda_import_reraises_non_pathfinder_probe_error(monkeypatch):
60+
def test__optional_cuda_import_reraises_non_pathfinder_probe_error(monkeypatch):
6061
fake_module = types.SimpleNamespace(__name__="cuda.bindings.nvvm")
6162
monkeypatch.setattr(optional_import_mod.importlib, "import_module", lambda _name: fake_module)
6263

6364
def probe(_module):
6465
raise RuntimeError("unexpected probe failure")
6566

6667
with pytest.raises(RuntimeError, match="unexpected probe failure"):
67-
optional_cuda_import("cuda.bindings.nvvm", probe_function=probe)
68+
_optional_cuda_import("cuda.bindings.nvvm", probe_function=probe)

0 commit comments

Comments
 (0)