|
4 | 4 |
|
5 | 5 | from cuda.core._version import __version__ |
6 | 6 |
|
7 | | -try: |
| 7 | + |
| 8 | +def _import_versioned_module(): |
| 9 | + import importlib |
| 10 | + |
8 | 11 | from cuda import bindings |
9 | | -except ImportError: |
10 | | - raise ImportError("cuda.bindings 12.x or 13.x must be installed") from None |
11 | | -else: |
12 | | - cuda_major, cuda_minor = bindings.__version__.split(".")[:2] |
| 12 | + |
| 13 | + cuda_major, _ = bindings.__version__.split(".")[:2] |
13 | 14 | if cuda_major not in ("12", "13"): |
14 | 15 | raise ImportError("cuda.bindings 12.x or 13.x must be installed") |
15 | 16 |
|
16 | | -import importlib |
| 17 | + subdir = f"cu{cuda_major}" |
| 18 | + try: |
| 19 | + versioned_mod = importlib.import_module(f".{subdir}", __package__) |
| 20 | + # Import all symbols from the module |
| 21 | + globals().update(versioned_mod.__dict__) |
| 22 | + except ImportError: |
| 23 | + # This is not a wheel build, but a conda or local build, do nothing |
| 24 | + pass |
| 25 | + |
| 26 | + |
| 27 | +_import_versioned_module() |
| 28 | +del _import_versioned_module |
17 | 29 |
|
18 | | -subdir = f"cu{cuda_major}" |
19 | | -try: |
20 | | - versioned_mod = importlib.import_module(f".{subdir}", __package__) |
21 | | - # Import all symbols from the module |
22 | | - globals().update(versioned_mod.__dict__) |
23 | | -except ImportError: |
24 | | - # This is not a wheel build, but a conda or local build, do nothing |
25 | | - pass |
26 | | -else: |
27 | | - del versioned_mod |
28 | | -finally: |
29 | | - del bindings, importlib, subdir, cuda_major, cuda_minor |
30 | 30 |
|
31 | 31 | from cuda.core import system, utils |
32 | 32 | from cuda.core._device import Device |
|
0 commit comments