|
21 | 21 | __all__ = [ # noqa: RUF022 |
22 | 22 | 'platform_registry', 'get_cpu_info', 'get_gpu_info', 'get_visible_devices', |
23 | 23 | 'get_nvidia_cc', 'get_cuda_path', 'get_cuda_version', 'get_hip_path', |
24 | | - 'check_cuda_runtime', 'get_m1_llvm_path', 'get_advisor_path', 'Platform', |
25 | | - 'Cpu64', 'Intel64', 'IntelSkylake', 'Amd', 'Arm', 'Power', 'Device', |
| 24 | + 'check_cuda_runtime', 'load_cudart', 'get_m1_llvm_path', 'get_advisor_path', |
| 25 | + 'Platform', 'Cpu64', 'Intel64', 'IntelSkylake', 'Amd', 'Arm', 'Power', |
| 26 | + 'Device', |
26 | 27 | 'NvidiaDevice', 'AmdDevice', 'IntelDevice', |
27 | 28 | # Brand-agnostic |
28 | 29 | 'ANYCPU', 'ANYGPU', |
@@ -646,13 +647,25 @@ def get_m1_llvm_path(language): |
646 | 647 |
|
647 | 648 |
|
648 | 649 | @memoized_func |
649 | | -def check_cuda_runtime(): |
| 650 | +def load_cudart(): |
| 651 | + """ |
| 652 | + Load the CUDA runtime library. |
| 653 | + """ |
650 | 654 | libname = ctypes.util.find_library("cudart") |
651 | 655 | if not libname: |
| 656 | + raise RuntimeError("Unable to find CUDA runtime library `libcudart`") |
| 657 | + |
| 658 | + return ctypes.CDLL(libname) |
| 659 | + |
| 660 | + |
| 661 | +@memoized_func |
| 662 | +def check_cuda_runtime(): |
| 663 | + try: |
| 664 | + cuda = load_cudart() |
| 665 | + except RuntimeError: |
652 | 666 | warning("Unable to check compatibility of NVidia driver and runtime") |
653 | 667 | return |
654 | 668 |
|
655 | | - cuda = ctypes.CDLL(libname) |
656 | 669 | driver_version = ctypes.c_int() |
657 | 670 | runtime_version = ctypes.c_int() |
658 | 671 |
|
@@ -1115,11 +1128,10 @@ def max_shm_per_block(self): |
1115 | 1128 | """ |
1116 | 1129 | Get the maximum amount of shared memory per thread block |
1117 | 1130 | """ |
1118 | | - # Load libcudart |
1119 | | - libname = ctypes.util.find_library("cudart") |
1120 | | - if not libname: |
| 1131 | + try: |
| 1132 | + lib = load_cudart() |
| 1133 | + except RuntimeError: |
1121 | 1134 | return 64 * 1024 # 64 KB default |
1122 | | - lib = ctypes.CDLL(libname) |
1123 | 1135 |
|
1124 | 1136 | cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 |
1125 | 1137 | # get current device |
|
0 commit comments