66
77
88def get_args () -> argparse .Namespace :
9- ''' 获取命令行参数
9+ """ 获取命令行参数
1010
1111 :return `argparse.Namespace`: 命令行参数命名空间
12- '''
12+ """
1313 parser = argparse .ArgumentParser ()
1414
15- parser .add_argument ('--ignore-ort-install' , action = 'store_true' , help = '忽略 onnxruntime-gpu 未安装的状态, 强制进行检查' )
15+ parser .add_argument (
16+ "--ignore-ort-install" ,
17+ action = "store_true" ,
18+ help = "忽略 onnxruntime-gpu 未安装的状态, 强制进行检查" ,
19+ )
1620
1721 return parser .parse_args ()
1822
1923
2024def get_onnxruntime_version_file () -> Path | None :
21- ''' 获取记录 onnxruntime 版本的文件路径
25+ """ 获取记录 onnxruntime 版本的文件路径
2226
2327 :return Path | None: 记录 onnxruntime 版本的文件路径
24- '''
25- package = ' onnxruntime-gpu'
26- version_file = ' onnxruntime/capi/version_info.py'
28+ """
29+ package = " onnxruntime-gpu"
30+ version_file = " onnxruntime/capi/version_info.py"
2731 try :
28- util = [
29- p for p in importlib .metadata .files (package )
30- if version_file in str (p )
31- ][0 ]
32+ util = [p for p in importlib .metadata .files (package ) if version_file in str (p )][
33+ 0
34+ ]
3235 info_path = Path (util .locate ())
3336 except Exception as _ :
3437 info_path = None
@@ -37,60 +40,60 @@ def get_onnxruntime_version_file() -> Path | None:
3740
3841
3942def get_onnxruntime_support_cuda_version () -> tuple [str | None , str | None ]:
40- ''' 获取 onnxruntime 支持的 CUDA, cuDNN 版本
43+ """ 获取 onnxruntime 支持的 CUDA, cuDNN 版本
4144
4245 :return tuple[str | None, str | None]: onnxruntime 支持的 CUDA, cuDNN 版本
43- '''
46+ """
4447 ver_path = get_onnxruntime_version_file ()
4548 cuda_ver = None
4649 cudnn_ver = None
4750 try :
48- with open (ver_path , 'r' , encoding = ' utf8' ) as f :
51+ with open (ver_path , "r" , encoding = " utf8" ) as f :
4952 for line in f :
50- if ' cuda_version' in line :
51- cuda_ver = get_value_from_variable (line , ' cuda_version' )
52- if ' cudnn_version' in line :
53- cudnn_ver = get_value_from_variable (line , ' cudnn_version' )
53+ if " cuda_version" in line :
54+ cuda_ver = get_value_from_variable (line , " cuda_version" )
55+ if " cudnn_version" in line :
56+ cudnn_ver = get_value_from_variable (line , " cudnn_version" )
5457 except Exception as _ :
5558 pass
5659
5760 return cuda_ver , cudnn_ver
5861
5962
6063def get_value_from_variable (content : str , var_name : str ) -> str | None :
61- ''' 从字符串 (Python 代码片段) 中找出指定字符串变量的值
64+ """ 从字符串 (Python 代码片段) 中找出指定字符串变量的值
6265
6366 :param content(str): 待查找的内容
6467 :param var_name(str): 待查找的字符串变量
6568 :return str | None: 返回字符串变量的值
66- '''
67- pattern = fr '{ var_name } \s*=\s*"([^"]+)"'
69+ """
70+ pattern = rf '{ var_name } \s*=\s*"([^"]+)"'
6871 match = re .search (pattern , content )
6972 return match .group (1 ) if match else None
7073
7174
7275def compare_versions (version1 : str , version2 : str ) -> int :
73- ''' 对比两个版本号大小
76+ """ 对比两个版本号大小
7477
7578 :param version1(str): 第一个版本号
7679 :param version2(str): 第二个版本号
7780 :return int: 版本对比结果, 1 为第一个版本号大, -1 为第二个版本号大, 0 为两个版本号一样
78- '''
81+ """
7982 # 将版本号拆分成数字列表
8083 try :
8184 nums1 = (
82- re .sub (r' [a-zA-Z]+' , '' , version1 )
83- .replace ('-' , '.' )
84- .replace ('_' , '.' )
85- .replace ('+' , '.' )
86- .split ('.' )
85+ re .sub (r" [a-zA-Z]+" , "" , version1 )
86+ .replace ("-" , "." )
87+ .replace ("_" , "." )
88+ .replace ("+" , "." )
89+ .split ("." )
8790 )
8891 nums2 = (
89- re .sub (r' [a-zA-Z]+' , '' , version2 )
90- .replace ('-' , '.' )
91- .replace ('_' , '.' )
92- .replace ('+' , '.' )
93- .split ('.' )
92+ re .sub (r" [a-zA-Z]+" , "" , version2 )
93+ .replace ("-" , "." )
94+ .replace ("_" , "." )
95+ .replace ("+" , "." )
96+ .split ("." )
9497 )
9598 except Exception as _ :
9699 return 0
@@ -110,12 +113,13 @@ def compare_versions(version1: str, version2: str) -> int:
110113
111114
112115def get_torch_cuda_ver () -> tuple [str | None , str | None , str | None ]:
113- ''' 获取 Torch 的本体, CUDA, cuDNN 版本
116+ """ 获取 Torch 的本体, CUDA, cuDNN 版本
114117
115118 :return tuple[str | None, str | None, str | None]: Torch, CUDA, cuDNN 版本
116- '''
119+ """
117120 try :
118121 import torch
122+
119123 torch_ver = torch .__version__
120124 cuda_ver = torch .version .cuda
121125 cudnn_ver = torch .backends .cudnn .version ()
@@ -129,38 +133,35 @@ def get_torch_cuda_ver() -> tuple[str | None, str | None, str | None]:
129133
130134
131135class OrtType (str , Enum ):
132- ''' onnxruntime-gpu 的类型
136+ """ onnxruntime-gpu 的类型
133137
134- 版本说明:
138+ 版本说明:
135139 - CU121CUDNN8: CUDA 12.1 + cuDNN8
136140 - CU121CUDNN9: CUDA 12.1 + cuDNN9
137141 - CU118: CUDA 11.8
138- '''
139- CU121CUDNN8 = 'cu121cudnn8'
140- CU121CUDNN9 = 'cu121cudnn9'
141- CU118 = 'cu118'
142+ """
143+
144+ CU121CUDNN8 = "cu121cudnn8"
145+ CU121CUDNN9 = "cu121cudnn9"
146+ CU118 = "cu118"
142147
143148 def __str__ (self ):
144149 return self .value
145150
146151
147152def need_install_ort_ver (ignore_ort_install : bool = True ) -> OrtType | None :
148- ''' 判断需要安装的 onnxruntime 版本
153+ """ 判断需要安装的 onnxruntime 版本
149154
150155 :param ignore_ort_install(bool): 当 onnxruntime 未安装时跳过检查
151156 :return OrtType: 需要安装的 onnxruntime-gpu 类型
152- '''
157+ """
153158 # 检测是否安装了 Torch
154159 torch_ver , cuda_ver , cuddn_ver = get_torch_cuda_ver ()
155160 # 缺少 Torch / CUDA / cuDNN 版本时取消判断
156- if (
157- torch_ver is None
158- or cuda_ver is None
159- or cuddn_ver is None
160- ):
161+ if torch_ver is None or cuda_ver is None or cuddn_ver is None :
161162 if not ignore_ort_install :
162163 try :
163- _ = importlib .metadata .version (' onnxruntime-gpu' )
164+ _ = importlib .metadata .version (" onnxruntime-gpu" )
164165 except Exception as _ :
165166 # onnxruntime-gpu 没有安装时
166167 return OrtType .CU121CUDNN9
@@ -176,11 +177,11 @@ def need_install_ort_ver(ignore_ort_install: bool = True) -> OrtType | None:
176177 # 当 onnxruntime 已安装
177178
178179 # 判断 Torch 中的 CUDA 版本
179- if compare_versions (cuda_ver , ' 12.0' ) >= 0 :
180+ if compare_versions (cuda_ver , " 12.0" ) >= 0 :
180181 # CUDA >= 12.0
181182
182183 # 比较 onnxtuntime 支持的 CUDA 版本是否和 Torch 中所带的 CUDA 版本匹配
183- if compare_versions (ort_support_cuda_ver , ' 12.0' ) >= 0 :
184+ if compare_versions (ort_support_cuda_ver , " 12.0" ) >= 0 :
184185 # CUDA 版本为 12.x, torch 和 ort 的 CUDA 版本匹配
185186
186187 # 判断 Torch 和 onnxruntime 的 cuDNN 是否匹配
@@ -195,29 +196,29 @@ def need_install_ort_ver(ignore_ort_install: bool = True) -> OrtType | None:
195196 return None
196197 else :
197198 # CUDA 版本非 12.x, 不匹配
198- if compare_versions (cuddn_ver , '8' ) > 0 :
199+ if compare_versions (cuddn_ver , "8" ) > 0 :
199200 return OrtType .CU121CUDNN9
200201 else :
201202 return OrtType .CU121CUDNN8
202203 else :
203204 # CUDA <= 11.8
204- if compare_versions (ort_support_cuda_ver , ' 12.0' ) < 0 :
205+ if compare_versions (ort_support_cuda_ver , " 12.0" ) < 0 :
205206 return None
206207 else :
207208 return OrtType .CU118
208209 else :
209210 if ignore_ort_install :
210211 return None
211212
212- if compare_versions (cuda_ver , ' 12.0' ) >= 0 :
213- if compare_versions (cuddn_ver , '8' ) > 0 :
213+ if compare_versions (cuda_ver , " 12.0" ) >= 0 :
214+ if compare_versions (cuddn_ver , "8" ) > 0 :
214215 return OrtType .CU121CUDNN9
215216 else :
216217 return OrtType .CU121CUDNN8
217218 else :
218219 return OrtType .CU118
219220
220221
221- if __name__ == ' __main__' :
222+ if __name__ == " __main__" :
222223 arg = get_args ()
223224 print (need_install_ort_ver (not arg .ignore_ort_install ))
0 commit comments