Skip to content

Commit 19020b2

Browse files
committed
support loading from the versioned module if any exists
1 parent 021e0f3 commit 19020b2

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

cuda_core/cuda/core/experimental/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,29 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
try:
6+
import cuda.bindings
7+
except ImportError as e:
8+
raise ImportError("cuda.bindings 12.x or 13.x must be installed")
9+
else:
10+
cuda_major, cuda_minor = cuda.bindings.__version__.split(".")[:2]
11+
if cuda_major not in ("12", "13"):
12+
raise ImportError("cuda.bindings 12.x or 13.x must be installed")
13+
14+
import importlib
15+
subdir = f"cu{cuda_major}"
16+
try:
17+
verioned_mod = importlib.import_module(f".{subdir}", __package__)
18+
# Import all symbols from the module
19+
globals().update(verioned_mod.__dict__)
20+
except ImportError:
21+
# This is not a wheel build, but a conda or local build, do nothing
22+
pass
23+
else:
24+
del verioned_mod
25+
finally:
26+
del cuda.bindings, importlib, subdir, cuda_major, cuda_minor
27+
528
from cuda.core.experimental import utils
629
from cuda.core.experimental._device import Device
730
from cuda.core.experimental._event import Event, EventOptions

0 commit comments

Comments
 (0)