Skip to content

Commit 05a964b

Browse files
committed
重构:拆分
1 parent feb06a9 commit 05a964b

6 files changed

Lines changed: 173 additions & 151 deletions

File tree

build.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,16 @@ def build_executable():
7575
os.environ['PYTHONIOENCODING'] = 'utf-8'
7676

7777
PyInstaller.__main__.run([
78-
'model_manager.py',
78+
'main.py',
7979
f'--name={output_name}',
8080
'--onefile',
8181
'--hidden-import=uvicorn.logging',
8282
'--hidden-import=uvicorn.lifespan.on',
8383
'--hidden-import=uvicorn.lifespan',
84+
'--hidden-import=src.core.model_manager',
85+
'--hidden-import=src.api.model_api',
86+
'--hidden-import=src.utils.file_utils',
87+
'--hidden-import=src.utils.hash_utils',
8488
'--add-data=templates;templates',
8589
'--add-data=static/favicon.svg;static',
8690
])

main.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import argparse
2+
import time
3+
import webbrowser
4+
import threading
5+
import uvicorn
6+
7+
from src.core.model_manager import ModelManager
8+
from src.api.model_api import create_api
9+
from src.utils.file_utils import find_free_port
10+
11+
def open_browser(port: int):
12+
"""延迟一秒后打开浏览器"""
13+
time.sleep(1)
14+
webbrowser.open(f'http://127.0.0.1:{port}')
15+
16+
if __name__ == "__main__":
17+
parser = argparse.ArgumentParser(description='模型管理器')
18+
parser.add_argument('--port', type=int, default=None, help='Web界面端口')
19+
args = parser.parse_args()
20+
21+
# 获取可用端口
22+
port = args.port or find_free_port()
23+
24+
# 创建 ModelManager 实例
25+
manager = ModelManager()
26+
27+
# 加载已有的模型信息
28+
manager.load_models_info()
29+
30+
# 创建 FastAPI 应用
31+
app = create_api(manager)
32+
33+
# 在新线程中打开浏览器
34+
threading.Thread(target=open_browser, args=(port,), daemon=True).start()
35+
36+
# 启动 FastAPI 服务器
37+
uvicorn.run(app, host="127.0.0.1", port=port)

src/api/model_api.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from fastapi import FastAPI, HTTPException
2+
from fastapi.staticfiles import StaticFiles
3+
from fastapi.responses import FileResponse, StreamingResponse
4+
from fastapi.middleware.cors import CORSMiddleware
5+
from pydantic import BaseModel
6+
import os
7+
from src.utils.file_utils import select_directory
8+
9+
class PathUpdate(BaseModel):
10+
path: str
11+
12+
def create_api(manager):
13+
"""创建并配置FastAPI应用"""
14+
app = FastAPI(title="Stable Diffusion 模型管理器")
15+
16+
# 配置 CORS
17+
app.add_middleware(
18+
CORSMiddleware,
19+
allow_origins=["*"],
20+
allow_credentials=True,
21+
allow_methods=["*"],
22+
allow_headers=["*"],
23+
)
24+
25+
# 挂载静态文件
26+
app.mount("/static", StaticFiles(directory="static"), name="static")
27+
28+
@app.get("/")
29+
async def read_root():
30+
return FileResponse("templates/index.html")
31+
32+
@app.get("/favicon.svg")
33+
async def get_favicon():
34+
return FileResponse("templates/favicon.svg")
35+
36+
@app.get("/api/models")
37+
async def get_models():
38+
"""获取所有模型信息"""
39+
return manager.get_all_models_info()
40+
41+
@app.post("/api/path")
42+
async def update_path(path_update: PathUpdate):
43+
"""更新模型路径"""
44+
if not os.path.exists(path_update.path):
45+
raise HTTPException(status_code=400, detail="路径不存在")
46+
manager.update_models_path(path_update.path)
47+
return {"message": "路径已更新"}
48+
49+
@app.get("/api/scan")
50+
async def scan_models_endpoint():
51+
"""扫描模型"""
52+
if not manager.models_path or not os.path.exists(manager.models_path):
53+
raise HTTPException(status_code=400, detail="请先设置有效的模型目录路径")
54+
try:
55+
return StreamingResponse(
56+
manager.scan_models(),
57+
media_type="text/event-stream"
58+
)
59+
except Exception as e:
60+
raise HTTPException(status_code=500, detail=str(e))
61+
62+
@app.get("/api/config")
63+
async def get_config():
64+
"""获取当前配置"""
65+
return {
66+
"models_path": str(manager.models_path) if manager.models_path else "",
67+
"is_path_valid": os.path.exists(manager.models_path) if manager.models_path else False
68+
}
69+
70+
@app.post("/api/select_directory")
71+
async def select_directory_endpoint():
72+
"""选择目录"""
73+
path = await select_directory()
74+
return {"path": path}
75+
76+
return app
Lines changed: 5 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -2,38 +2,13 @@
22
import json
33
import requests
44
from pathlib import Path
5-
import hashlib
6-
import argparse
75
from typing import Dict, Any
8-
from fastapi import FastAPI, HTTPException
9-
from fastapi.staticfiles import StaticFiles
10-
from fastapi.responses import JSONResponse, FileResponse, StreamingResponse
11-
from pydantic import BaseModel
12-
import uvicorn
13-
from fastapi.middleware.cors import CORSMiddleware
146
import aiohttp
157
import aiofiles
168
from urllib.parse import urlparse
179
import time
18-
import asyncio
19-
from concurrent.futures import ThreadPoolExecutor
20-
import webbrowser
21-
import threading
22-
import socket
23-
import tkinter as tk
24-
from tkinter import filedialog
2510

26-
class PathUpdate(BaseModel):
27-
path: str
28-
29-
async def select_directory() -> str:
30-
"""使用文件对话框选择目录"""
31-
root = tk.Tk()
32-
root.withdraw() # 隐藏主窗口
33-
root.attributes('-topmost', True) # 确保对话框在最前面
34-
path = filedialog.askdirectory()
35-
root.destroy() # 完全清理 Tk 实例
36-
return path if path else ""
11+
from src.utils.hash_utils import HashUtils
3712

3813
class ModelManager:
3914
def __init__(self, config_file="config.json"):
@@ -44,7 +19,7 @@ def __init__(self, config_file="config.json"):
4419
self.models_info: Dict[str, Any] = {}
4520
self.images_path = Path("static/images") # 添加图片保存路径
4621
self.images_path.mkdir(parents=True, exist_ok=True) # 确保目录存在
47-
self.thread_pool = ThreadPoolExecutor()
22+
self.hash_utils = HashUtils()
4823

4924
def load_config(self) -> dict:
5025
"""加载配置文件"""
@@ -64,23 +39,6 @@ def update_models_path(self, path: str):
6439
self.config["models_path"] = str(self.models_path)
6540
self.save_config()
6641

67-
def calculate_model_hash(self, file_path):
68-
"""计算模型文件的SHA256哈希值"""
69-
sha256_hash = hashlib.sha256()
70-
with open(file_path, "rb") as f:
71-
for byte_block in iter(lambda: f.read(4096), b""):
72-
sha256_hash.update(byte_block)
73-
return sha256_hash.hexdigest()
74-
75-
async def calculate_model_hash_async(self, file_path):
76-
"""异步计算模型文件的哈希值"""
77-
loop = asyncio.get_event_loop()
78-
return await loop.run_in_executor(self.thread_pool, self.calculate_model_hash, file_path)
79-
80-
def _get_file_mtime(self, file_path: Path) -> float:
81-
"""获取文件的修改时间戳"""
82-
return os.path.getmtime(file_path)
83-
8442
async def scan_models(self):
8543
"""扫描指定目录下的所有.safetensors文件"""
8644
print(f"开始扫描目录: {self.models_path}")
@@ -105,7 +63,7 @@ async def scan_models(self):
10563
for file_path in safetensors_files:
10664
print(f"\n处理文件: {file_path.name}")
10765
try:
108-
current_mtime = self._get_file_mtime(file_path)
66+
current_mtime = os.path.getmtime(file_path)
10967
existing_info = self.models_info.get(str(file_path), {})
11068

11169
# 检查文件是否已经扫描过且未修改
@@ -115,7 +73,7 @@ async def scan_models(self):
11573
yield f"data: {json.dumps({'progress': processed / total, 'message': f'跳过: {file_path.name}'})}\n\n"
11674
continue
11775

118-
model_hash = await self.calculate_model_hash_async(file_path)
76+
model_hash = await self.hash_utils.calculate_model_hash_async(file_path)
11977
print(f"计算得到哈希值: {model_hash}")
12078
await self.fetch_model_info(model_hash, file_path, current_mtime)
12179
self.save_models_info() # 每次获取新信息后保存
@@ -259,107 +217,4 @@ def get_all_models_info(self) -> list:
259217
}
260218
for model_path in self.models_info.keys()
261219
if str(model_path).startswith(current_path)
262-
]
263-
264-
# 创建 FastAPI 应用
265-
app = FastAPI(title="Stable Diffusion 模型管理器")
266-
267-
# 配置 CORS
268-
app.add_middleware(
269-
CORSMiddleware,
270-
allow_origins=["*"],
271-
allow_credentials=True,
272-
allow_methods=["*"],
273-
allow_headers=["*"],
274-
)
275-
276-
# 创建 ModelManager 实例
277-
manager = ModelManager()
278-
279-
# 在创建 FastAPI 应用后添加
280-
app.mount("/static", StaticFiles(directory="static"), name="static")
281-
282-
@app.get("/")
283-
async def read_root():
284-
return FileResponse("templates/index.html")
285-
286-
@app.get("/favicon.svg")
287-
async def get_favicon():
288-
return FileResponse("templates/favicon.svg")
289-
290-
@app.get("/api/models")
291-
async def get_models():
292-
"""获取所有模型信息"""
293-
return manager.get_all_models_info()
294-
295-
@app.post("/api/path")
296-
async def update_path(path_update: PathUpdate):
297-
"""更新模型路径"""
298-
if not os.path.exists(path_update.path):
299-
raise HTTPException(status_code=400, detail="路径不存在")
300-
manager.update_models_path(path_update.path)
301-
return {"message": "路径已更新"}
302-
303-
@app.get("/api/scan")
304-
async def scan_models_endpoint():
305-
"""扫描模型"""
306-
if not manager.models_path or not os.path.exists(manager.models_path):
307-
raise HTTPException(status_code=400, detail="请先设置有效的模型目录路径")
308-
try:
309-
return StreamingResponse(
310-
manager.scan_models(),
311-
media_type="text/event-stream"
312-
)
313-
except Exception as e:
314-
raise HTTPException(status_code=500, detail=str(e))
315-
316-
@app.get("/api/config")
317-
async def get_config():
318-
"""获取当前配置"""
319-
return {
320-
"models_path": str(manager.models_path) if manager.models_path else "",
321-
"is_path_valid": os.path.exists(manager.models_path) if manager.models_path else False
322-
}
323-
324-
@app.post("/api/select_directory")
325-
async def select_directory_endpoint():
326-
"""选择目录"""
327-
path = await select_directory()
328-
return {"path": path}
329-
330-
def find_free_port(start_port=8080, max_tries=100):
331-
"""查找可用的端口号"""
332-
for port in range(start_port, start_port + max_tries):
333-
try:
334-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
335-
s.bind(('127.0.0.1', port))
336-
return port
337-
except OSError:
338-
continue
339-
raise RuntimeError('无法找到可用的端口')
340-
341-
def open_browser(port: int):
342-
"""延迟一秒后打开浏览器"""
343-
time.sleep(1)
344-
webbrowser.open(f'http://127.0.0.1:{port}')
345-
346-
if __name__ == "__main__":
347-
parser = argparse.ArgumentParser(description='模型管理器')
348-
parser.add_argument('--port', type=int, default=None, help='Web界面端口')
349-
args = parser.parse_args()
350-
351-
# 获取可用端口
352-
port = args.port or find_free_port()
353-
354-
# 加载已有的模型信息
355-
manager.load_models_info()
356-
357-
# 在新线程中打开浏览器
358-
threading.Thread(target=open_browser, args=(port,), daemon=True).start()
359-
360-
# 启动 FastAPI 服务器
361-
uvicorn.run(
362-
app,
363-
host="127.0.0.1",
364-
port=port
365-
)
220+
]

src/utils/file_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import os
2+
import socket
3+
import tkinter as tk
4+
from tkinter import filedialog
5+
from pathlib import Path
6+
7+
async def select_directory() -> str:
8+
"""使用文件对话框选择目录"""
9+
root = tk.Tk()
10+
root.withdraw() # 隐藏主窗口
11+
root.attributes('-topmost', True) # 确保对话框在最前面
12+
path = filedialog.askdirectory()
13+
root.destroy() # 完全清理 Tk 实例
14+
return path if path else ""
15+
16+
def find_free_port(start_port=8080, max_tries=100):
17+
"""查找可用的端口号"""
18+
for port in range(start_port, start_port + max_tries):
19+
try:
20+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
21+
s.bind(('127.0.0.1', port))
22+
return port
23+
except OSError:
24+
continue
25+
raise RuntimeError('无法找到可用的端口')
26+
27+
def get_file_mtime(file_path: Path) -> float:
28+
"""获取文件的修改时间戳"""
29+
return os.path.getmtime(file_path)

src/utils/hash_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import hashlib
2+
from pathlib import Path
3+
import asyncio
4+
from concurrent.futures import ThreadPoolExecutor
5+
6+
class HashUtils:
7+
def __init__(self):
8+
self.thread_pool = ThreadPoolExecutor()
9+
10+
def calculate_model_hash(self, file_path: Path) -> str:
11+
"""计算模型文件的SHA256哈希值"""
12+
sha256_hash = hashlib.sha256()
13+
with open(file_path, "rb") as f:
14+
for byte_block in iter(lambda: f.read(4096), b""):
15+
sha256_hash.update(byte_block)
16+
return sha256_hash.hexdigest()
17+
18+
async def calculate_model_hash_async(self, file_path: Path) -> str:
19+
"""异步计算模型文件的哈希值"""
20+
loop = asyncio.get_event_loop()
21+
return await loop.run_in_executor(self.thread_pool, self.calculate_model_hash, file_path)

0 commit comments

Comments
 (0)