Skip to content

Commit f8dac8e

Browse files
committed
lint
1 parent ce05731 commit f8dac8e

26 files changed

Lines changed: 1253 additions & 1127 deletions

python_modules/check_comfyui_env.py

Lines changed: 433 additions & 441 deletions
Large diffs are not rendered by default.

python_modules/check_controlnet_aux.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
try:
2-
import controlnet_aux
2+
import controlnet_aux # noqa: F401
3+
34
success = True
45
except:
56
success = False
67

78
if not success:
89
from importlib.metadata import requires
10+
911
try:
1012
invokeai_requires = requires("invokeai")
1113
except:

python_modules/check_cuda_malloc_avaliable.py

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,21 @@
22
import importlib.util
33
import subprocess
44

5-
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
5+
6+
# Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
67
def get_gpu_names():
7-
if os.name == 'nt':
8+
if os.name == "nt":
89
import ctypes
910

1011
# Define necessary C structures and types
1112
class DISPLAY_DEVICEA(ctypes.Structure):
1213
_fields_ = [
13-
('cb', ctypes.c_ulong),
14-
('DeviceName', ctypes.c_char * 32),
15-
('DeviceString', ctypes.c_char * 128),
16-
('StateFlags', ctypes.c_ulong),
17-
('DeviceID', ctypes.c_char * 128),
18-
('DeviceKey', ctypes.c_char * 128)
14+
("cb", ctypes.c_ulong),
15+
("DeviceName", ctypes.c_char * 32),
16+
("DeviceString", ctypes.c_char * 128),
17+
("StateFlags", ctypes.c_ulong),
18+
("DeviceID", ctypes.c_char * 128),
19+
("DeviceKey", ctypes.c_char * 128),
1920
]
2021

2122
# Load user32.dll
@@ -28,26 +29,66 @@ def enum_display_devices():
2829
device_index = 0
2930
gpu_names = set()
3031

31-
while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0):
32+
while user32.EnumDisplayDevicesA(
33+
None, device_index, ctypes.byref(device_info), 0
34+
):
3235
device_index += 1
33-
gpu_names.add(device_info.DeviceString.decode('utf-8'))
36+
gpu_names.add(device_info.DeviceString.decode("utf-8"))
3437
return gpu_names
38+
3539
return enum_display_devices()
3640
else:
3741
gpu_names = set()
38-
out = subprocess.check_output(['nvidia-smi', '-L'])
39-
for l in out.split(b'\n'):
42+
out = subprocess.check_output(["nvidia-smi", "-L"])
43+
for l in out.split(b"\n"):
4044
if len(l) > 0:
41-
gpu_names.add(l.decode('utf-8').split(' (UUID')[0])
45+
gpu_names.add(l.decode("utf-8").split(" (UUID")[0])
4246
return gpu_names
4347

44-
blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M",
45-
"GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", "Quadro K620",
46-
"Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
47-
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000",
48-
"GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M",
49-
"GeForce GTX 1650", "GeForce GTX 1630", "Tesla M4", "Tesla M6", "Tesla M10", "Tesla M40", "Tesla M60"
50-
}
48+
49+
blacklist = {
50+
"GeForce GTX TITAN X",
51+
"GeForce GTX 980",
52+
"GeForce GTX 970",
53+
"GeForce GTX 960",
54+
"GeForce GTX 950",
55+
"GeForce 945M",
56+
"GeForce 940M",
57+
"GeForce 930M",
58+
"GeForce 920M",
59+
"GeForce 910M",
60+
"GeForce GTX 750",
61+
"GeForce GTX 745",
62+
"Quadro K620",
63+
"Quadro K1200",
64+
"Quadro K2200",
65+
"Quadro M500",
66+
"Quadro M520",
67+
"Quadro M600",
68+
"Quadro M620",
69+
"Quadro M1000",
70+
"Quadro M1200",
71+
"Quadro M2000",
72+
"Quadro M2200",
73+
"Quadro M3000",
74+
"Quadro M4000",
75+
"Quadro M5000",
76+
"Quadro M5500",
77+
"Quadro M6000",
78+
"GeForce MX110",
79+
"GeForce MX130",
80+
"GeForce 830M",
81+
"GeForce 840M",
82+
"GeForce GTX 850M",
83+
"GeForce GTX 860M",
84+
"GeForce GTX 1650",
85+
"GeForce GTX 1630",
86+
"Tesla M4",
87+
"Tesla M6",
88+
"Tesla M10",
89+
"Tesla M40",
90+
"Tesla M60",
91+
}
5192

5293

5394
def cuda_malloc_supported():
@@ -74,7 +115,7 @@ def is_nvidia_device():
74115
return False
75116

76117

77-
def get_pytorch_cuda_alloc_conf(is_cuda = True):
118+
def get_pytorch_cuda_alloc_conf(is_cuda=True):
78119
if is_nvidia_device():
79120
if cuda_malloc_supported():
80121
if is_cuda:
@@ -94,12 +135,14 @@ def main():
94135
for folder in torch_spec.submodule_search_locations:
95136
ver_file = os.path.join(folder, "version.py")
96137
if os.path.isfile(ver_file):
97-
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
138+
spec = importlib.util.spec_from_file_location(
139+
"torch_version_import", ver_file
140+
)
98141
module = importlib.util.module_from_spec(spec)
99142
spec.loader.exec_module(module)
100143
version = module.__version__
101-
if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
102-
if "+cu" in version: #only on cuda torch
144+
if int(version[0]) >= 2: # enable by default for torch version 2.0 and up
145+
if "+cu" in version: # only on cuda torch
103146
print(get_pytorch_cuda_alloc_conf())
104147
else:
105148
print(get_pytorch_cuda_alloc_conf(False))

python_modules/check_invokeai_installed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
tmp = version("invokeai")
55
print(True)
66
except:
7-
print(False)
7+
print(False)

python_modules/check_onnxruntime_gpu.py

Lines changed: 57 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,32 @@
66

77

88
def get_args() -> argparse.Namespace:
9-
'''获取命令行参数
9+
"""获取命令行参数
1010
1111
:return `argparse.Namespace`: 命令行参数命名空间
12-
'''
12+
"""
1313
parser = argparse.ArgumentParser()
1414

15-
parser.add_argument('--ignore-ort-install', action='store_true', help='忽略 onnxruntime-gpu 未安装的状态, 强制进行检查')
15+
parser.add_argument(
16+
"--ignore-ort-install",
17+
action="store_true",
18+
help="忽略 onnxruntime-gpu 未安装的状态, 强制进行检查",
19+
)
1620

1721
return parser.parse_args()
1822

1923

2024
def get_onnxruntime_version_file() -> Path | None:
21-
'''获取记录 onnxruntime 版本的文件路径
25+
"""获取记录 onnxruntime 版本的文件路径
2226
2327
:return Path | None: 记录 onnxruntime 版本的文件路径
24-
'''
25-
package = 'onnxruntime-gpu'
26-
version_file = 'onnxruntime/capi/version_info.py'
28+
"""
29+
package = "onnxruntime-gpu"
30+
version_file = "onnxruntime/capi/version_info.py"
2731
try:
28-
util = [
29-
p for p in importlib.metadata.files(package)
30-
if version_file in str(p)
31-
][0]
32+
util = [p for p in importlib.metadata.files(package) if version_file in str(p)][
33+
0
34+
]
3235
info_path = Path(util.locate())
3336
except Exception as _:
3437
info_path = None
@@ -37,60 +40,60 @@ def get_onnxruntime_version_file() -> Path | None:
3740

3841

3942
def get_onnxruntime_support_cuda_version() -> tuple[str | None, str | None]:
40-
'''获取 onnxruntime 支持的 CUDA, cuDNN 版本
43+
"""获取 onnxruntime 支持的 CUDA, cuDNN 版本
4144
4245
:return tuple[str | None, str | None]: onnxruntime 支持的 CUDA, cuDNN 版本
43-
'''
46+
"""
4447
ver_path = get_onnxruntime_version_file()
4548
cuda_ver = None
4649
cudnn_ver = None
4750
try:
48-
with open(ver_path, 'r', encoding='utf8') as f:
51+
with open(ver_path, "r", encoding="utf8") as f:
4952
for line in f:
50-
if 'cuda_version' in line:
51-
cuda_ver = get_value_from_variable(line, 'cuda_version')
52-
if 'cudnn_version' in line:
53-
cudnn_ver = get_value_from_variable(line, 'cudnn_version')
53+
if "cuda_version" in line:
54+
cuda_ver = get_value_from_variable(line, "cuda_version")
55+
if "cudnn_version" in line:
56+
cudnn_ver = get_value_from_variable(line, "cudnn_version")
5457
except Exception as _:
5558
pass
5659

5760
return cuda_ver, cudnn_ver
5861

5962

6063
def get_value_from_variable(content: str, var_name: str) -> str | None:
61-
'''从字符串 (Python 代码片段) 中找出指定字符串变量的值
64+
"""从字符串 (Python 代码片段) 中找出指定字符串变量的值
6265
6366
:param content(str): 待查找的内容
6467
:param var_name(str): 待查找的字符串变量
6568
:return str | None: 返回字符串变量的值
66-
'''
67-
pattern = fr'{var_name}\s*=\s*"([^"]+)"'
69+
"""
70+
pattern = rf'{var_name}\s*=\s*"([^"]+)"'
6871
match = re.search(pattern, content)
6972
return match.group(1) if match else None
7073

7174

7275
def compare_versions(version1: str, version2: str) -> int:
73-
'''对比两个版本号大小
76+
"""对比两个版本号大小
7477
7578
:param version1(str): 第一个版本号
7679
:param version2(str): 第二个版本号
7780
:return int: 版本对比结果, 1 为第一个版本号大, -1 为第二个版本号大, 0 为两个版本号一样
78-
'''
81+
"""
7982
# 将版本号拆分成数字列表
8083
try:
8184
nums1 = (
82-
re.sub(r'[a-zA-Z]+', '', version1)
83-
.replace('-', '.')
84-
.replace('_', '.')
85-
.replace('+', '.')
86-
.split('.')
85+
re.sub(r"[a-zA-Z]+", "", version1)
86+
.replace("-", ".")
87+
.replace("_", ".")
88+
.replace("+", ".")
89+
.split(".")
8790
)
8891
nums2 = (
89-
re.sub(r'[a-zA-Z]+', '', version2)
90-
.replace('-', '.')
91-
.replace('_', '.')
92-
.replace('+', '.')
93-
.split('.')
92+
re.sub(r"[a-zA-Z]+", "", version2)
93+
.replace("-", ".")
94+
.replace("_", ".")
95+
.replace("+", ".")
96+
.split(".")
9497
)
9598
except Exception as _:
9699
return 0
@@ -110,12 +113,13 @@ def compare_versions(version1: str, version2: str) -> int:
110113

111114

112115
def get_torch_cuda_ver() -> tuple[str | None, str | None, str | None]:
113-
'''获取 Torch 的本体, CUDA, cuDNN 版本
116+
"""获取 Torch 的本体, CUDA, cuDNN 版本
114117
115118
:return tuple[str | None, str | None, str | None]: Torch, CUDA, cuDNN 版本
116-
'''
119+
"""
117120
try:
118121
import torch
122+
119123
torch_ver = torch.__version__
120124
cuda_ver = torch.version.cuda
121125
cudnn_ver = torch.backends.cudnn.version()
@@ -129,38 +133,35 @@ def get_torch_cuda_ver() -> tuple[str | None, str | None, str | None]:
129133

130134

131135
class OrtType(str, Enum):
132-
'''onnxruntime-gpu 的类型
136+
"""onnxruntime-gpu 的类型
133137
134-
版本说明:
138+
版本说明:
135139
- CU121CUDNN8: CUDA 12.1 + cuDNN8
136140
- CU121CUDNN9: CUDA 12.1 + cuDNN9
137141
- CU118: CUDA 11.8
138-
'''
139-
CU121CUDNN8 = 'cu121cudnn8'
140-
CU121CUDNN9 = 'cu121cudnn9'
141-
CU118 = 'cu118'
142+
"""
143+
144+
CU121CUDNN8 = "cu121cudnn8"
145+
CU121CUDNN9 = "cu121cudnn9"
146+
CU118 = "cu118"
142147

143148
def __str__(self):
144149
return self.value
145150

146151

147152
def need_install_ort_ver(ignore_ort_install: bool = True) -> OrtType | None:
148-
'''判断需要安装的 onnxruntime 版本
153+
"""判断需要安装的 onnxruntime 版本
149154
150155
:param ignore_ort_install(bool): 当 onnxruntime 未安装时跳过检查
151156
:return OrtType: 需要安装的 onnxruntime-gpu 类型
152-
'''
157+
"""
153158
# 检测是否安装了 Torch
154159
torch_ver, cuda_ver, cuddn_ver = get_torch_cuda_ver()
155160
# 缺少 Torch / CUDA / cuDNN 版本时取消判断
156-
if (
157-
torch_ver is None
158-
or cuda_ver is None
159-
or cuddn_ver is None
160-
):
161+
if torch_ver is None or cuda_ver is None or cuddn_ver is None:
161162
if not ignore_ort_install:
162163
try:
163-
_ = importlib.metadata.version('onnxruntime-gpu')
164+
_ = importlib.metadata.version("onnxruntime-gpu")
164165
except Exception as _:
165166
# onnxruntime-gpu 没有安装时
166167
return OrtType.CU121CUDNN9
@@ -176,11 +177,11 @@ def need_install_ort_ver(ignore_ort_install: bool = True) -> OrtType | None:
176177
# 当 onnxruntime 已安装
177178

178179
# 判断 Torch 中的 CUDA 版本
179-
if compare_versions(cuda_ver, '12.0') >= 0:
180+
if compare_versions(cuda_ver, "12.0") >= 0:
180181
# CUDA >= 12.0
181182

182183
# 比较 onnxtuntime 支持的 CUDA 版本是否和 Torch 中所带的 CUDA 版本匹配
183-
if compare_versions(ort_support_cuda_ver, '12.0') >= 0:
184+
if compare_versions(ort_support_cuda_ver, "12.0") >= 0:
184185
# CUDA 版本为 12.x, torch 和 ort 的 CUDA 版本匹配
185186

186187
# 判断 Torch 和 onnxruntime 的 cuDNN 是否匹配
@@ -195,29 +196,29 @@ def need_install_ort_ver(ignore_ort_install: bool = True) -> OrtType | None:
195196
return None
196197
else:
197198
# CUDA 版本非 12.x, 不匹配
198-
if compare_versions(cuddn_ver, '8') > 0:
199+
if compare_versions(cuddn_ver, "8") > 0:
199200
return OrtType.CU121CUDNN9
200201
else:
201202
return OrtType.CU121CUDNN8
202203
else:
203204
# CUDA <= 11.8
204-
if compare_versions(ort_support_cuda_ver, '12.0') < 0:
205+
if compare_versions(ort_support_cuda_ver, "12.0") < 0:
205206
return None
206207
else:
207208
return OrtType.CU118
208209
else:
209210
if ignore_ort_install:
210211
return None
211212

212-
if compare_versions(cuda_ver, '12.0') >= 0:
213-
if compare_versions(cuddn_ver, '8') > 0:
213+
if compare_versions(cuda_ver, "12.0") >= 0:
214+
if compare_versions(cuddn_ver, "8") > 0:
214215
return OrtType.CU121CUDNN9
215216
else:
216217
return OrtType.CU121CUDNN8
217218
else:
218219
return OrtType.CU118
219220

220221

221-
if __name__ == '__main__':
222+
if __name__ == "__main__":
222223
arg = get_args()
223224
print(need_install_ort_ver(not arg.ignore_ort_install))

0 commit comments

Comments
 (0)