2929
3030# Code from https://github.com/pytorch/pytorch/blob/master/torch/utils/cpp_extension.py
3131
32- CC_COMPATIBILITY_TABLE = [
33- # gencode, code, support_begin, suport_end
34- (20 , 20 , 1 , 1.0 ),
35- (30 , 30 , 1 , 11.0 ),
36- (35 , 35 , 1 , 12.0 ),
37- (37 , 37 , 1 , 12.0 ),
38- (50 , 50 , 6.5 , 13.0 ),
39- (52 , 52 , 6.5 , 13.0 ),
40- (60 , 60 , 8.0 , 13.0 ),
41- (61 , 61 , 8.0 , 13.0 ),
42- (70 , 70 , 9.0 , 13.0 ), # GTX 10 series
43- (75 , 75 , 10.0 , 999 ), # RTX 20 series, T4, T1000, RTX8000
44- (80 , 80 , 11.0 , 999 ), # A100, A30
45- (86 , 86 , 11.1 , 999 ), # RTX 30 series, A40, A10, A16, A2, RTX A6000
46- (87 , 87 , 11.5 , 999 ), # Jetson AGX Orin, Orin Nano, Orin NX
47- (89 , 89 , 11.8 , 999 ), # RTX 40 series, RTX Ada
48- (90 , 90 , 11.8 , 999 ), # GH200, H200, H100
49- (100 , 100 , 12.8 , 999 ), # GB200, B200
50- (103 , 103 , 12.8 , 999 ), # GB300, B300
51- (110 , 110 , 12.8 , 999 ), # Jetson T5000, T4000
52- (120 , 120 , 12.8 , 999 ), # RTX 50 series, RTX Pro Balckwell
53- (121 , 121 , 12.8 , 999 ), # GB10
54- ]
55-
5632COMPUTE_CAPABILITY_ARGS = [
5733 "--ptxas-options=-v" ,
5834 "-c" ,
@@ -79,6 +55,14 @@ def get_cuda_version(cuda_home):
7955 raise RuntimeError ("Cannot read cuda version file" )
8056
8157
58+ def get_cuda_cc (cuda_home ):
59+ """List nvcc Compute Capabilities"""
60+ cc = subprocess .check_output (
61+ [os .path .join (cuda_home , "bin" , "nvcc" ), "--list-gpu-arch" ]
62+ )
63+ return re .findall (r"\d+" , str (cc ))
64+
65+
8266def locate_cuda ():
8367 """Locate the CUDA environment on the system
8468
@@ -141,15 +125,9 @@ def _is_cuda_file(path):
141125 cuda_list = re .findall (r'\d+' , CUDA_VERSION )
142126 cuda_version = float ( str (cuda_list [0 ] + '.' + cuda_list [1 ]))
143127
144- # Insert CUDA arguments depedning on the version
145- for item in CC_COMPATIBILITY_TABLE :
146- support_begin = item [2 ]
147- support_end = item [3 ]
148- if cuda_version < support_begin :
149- continue
150- if cuda_version >= support_end :
151- continue
152- str_arg = f"-gencode=arch=compute_{ item [0 ]} ,code=sm_{ item [1 ]} "
128+ # Insert CUDA CC arguments
129+ for item in get_cuda_cc (CUDA ["home" ]):
130+ str_arg = f"-gencode=arch=compute_{ item } ,code=sm_{ item } "
153131 COMPUTE_CAPABILITY_ARGS .insert (0 , str_arg )
154132
155133
0 commit comments