Skip to content

Commit 714b88c

Browse files
committed
Add back _get_numba_CUDA_INCLUDE_PATH from 2279bda (i.e. cuda_paths.py as it was right before re-forking)
1 parent 7dcaa50 commit 714b88c

1 file changed

Lines changed: 40 additions & 4 deletions

File tree

cuda_bindings/cuda/bindings/_path_finder/cuda_paths.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,51 @@
44
import re
55
import site
66
import sys
7+
import traceback
8+
import warnings
79
from collections import namedtuple
810
from pathlib import Path
911

10-
from numba import config
11-
from numba.core.config import IS_WIN32
12-
from numba.misc.findlib import find_lib
12+
from .findlib import find_lib
13+
14+
IS_WIN32 = sys.platform.startswith("win32")
1315

1416
_env_path_tuple = namedtuple("_env_path_tuple", ["by", "info"])
1517

18+
19+
def _get_numba_CUDA_INCLUDE_PATH():
20+
# From numba/numba/core/config.py
21+
22+
def _readenv(name, ctor, default):
23+
value = os.environ.get(name)
24+
if value is None:
25+
return default() if callable(default) else default
26+
try:
27+
return ctor(value)
28+
except Exception:
29+
warnings.warn( # noqa: B028
30+
f"Environment variable '{name}' is defined but "
31+
f"its associated value '{value}' could not be "
32+
"parsed.\nThe parse failed with exception:\n"
33+
f"{traceback.format_exc()}",
34+
RuntimeWarning,
35+
)
36+
return default
37+
38+
if IS_WIN32:
39+
cuda_path = os.environ.get("CUDA_PATH")
40+
if cuda_path: # noqa: SIM108
41+
default_cuda_include_path = os.path.join(cuda_path, "include")
42+
else:
43+
default_cuda_include_path = "cuda_include_not_found"
44+
else:
45+
default_cuda_include_path = os.path.join(os.sep, "usr", "local", "cuda", "include")
46+
CUDA_INCLUDE_PATH = _readenv("NUMBA_CUDA_INCLUDE_PATH", str, default_cuda_include_path)
47+
return CUDA_INCLUDE_PATH
48+
49+
50+
config_CUDA_INCLUDE_PATH = _get_numba_CUDA_INCLUDE_PATH()
51+
1652
SEARCH_PRIORITY = [
1753
"Conda environment",
1854
"Conda environment (NVIDIA package)",
@@ -502,7 +538,7 @@ def _get_include_dir():
502538
"""Find the root include directory."""
503539
options = [
504540
("Conda environment (NVIDIA package)", get_conda_include_dir()),
505-
("CUDA_INCLUDE_PATH Config Entry", config.CUDA_INCLUDE_PATH),
541+
("CUDA_INCLUDE_PATH Config Entry", config_CUDA_INCLUDE_PATH),
506542
# TODO: add others
507543
]
508544
by, include_dir = _find_valid_path(options)

0 commit comments

Comments
 (0)