22import json
33import requests
44from pathlib import Path
5- import hashlib
6- import argparse
75from typing import Dict , Any
8- from fastapi import FastAPI , HTTPException
9- from fastapi .staticfiles import StaticFiles
10- from fastapi .responses import JSONResponse , FileResponse , StreamingResponse
11- from pydantic import BaseModel
12- import uvicorn
13- from fastapi .middleware .cors import CORSMiddleware
146import aiohttp
157import aiofiles
168from urllib .parse import urlparse
179import time
18- import asyncio
19- from concurrent .futures import ThreadPoolExecutor
20- import webbrowser
21- import threading
22- import socket
23- import tkinter as tk
24- from tkinter import filedialog
2510
26- class PathUpdate (BaseModel ):
27- path : str
28-
29- async def select_directory () -> str :
30- """使用文件对话框选择目录"""
31- root = tk .Tk ()
32- root .withdraw () # 隐藏主窗口
33- root .attributes ('-topmost' , True ) # 确保对话框在最前面
34- path = filedialog .askdirectory ()
35- root .destroy () # 完全清理 Tk 实例
36- return path if path else ""
11+ from src .utils .hash_utils import HashUtils
3712
3813class ModelManager :
3914 def __init__ (self , config_file = "config.json" ):
@@ -44,7 +19,7 @@ def __init__(self, config_file="config.json"):
4419 self .models_info : Dict [str , Any ] = {}
4520 self .images_path = Path ("static/images" ) # 添加图片保存路径
4621 self .images_path .mkdir (parents = True , exist_ok = True ) # 确保目录存在
47- self .thread_pool = ThreadPoolExecutor ()
22+ self .hash_utils = HashUtils ()
4823
4924 def load_config (self ) -> dict :
5025 """加载配置文件"""
@@ -64,23 +39,6 @@ def update_models_path(self, path: str):
6439 self .config ["models_path" ] = str (self .models_path )
6540 self .save_config ()
6641
67- def calculate_model_hash (self , file_path ):
68- """计算模型文件的SHA256哈希值"""
69- sha256_hash = hashlib .sha256 ()
70- with open (file_path , "rb" ) as f :
71- for byte_block in iter (lambda : f .read (4096 ), b"" ):
72- sha256_hash .update (byte_block )
73- return sha256_hash .hexdigest ()
74-
75- async def calculate_model_hash_async (self , file_path ):
76- """异步计算模型文件的哈希值"""
77- loop = asyncio .get_event_loop ()
78- return await loop .run_in_executor (self .thread_pool , self .calculate_model_hash , file_path )
79-
80- def _get_file_mtime (self , file_path : Path ) -> float :
81- """获取文件的修改时间戳"""
82- return os .path .getmtime (file_path )
83-
8442 async def scan_models (self ):
8543 """扫描指定目录下的所有.safetensors文件"""
8644 print (f"开始扫描目录: { self .models_path } " )
@@ -105,7 +63,7 @@ async def scan_models(self):
10563 for file_path in safetensors_files :
10664 print (f"\n 处理文件: { file_path .name } " )
10765 try :
108- current_mtime = self . _get_file_mtime (file_path )
66+ current_mtime = os . path . getmtime (file_path )
10967 existing_info = self .models_info .get (str (file_path ), {})
11068
11169 # 检查文件是否已经扫描过且未修改
@@ -115,7 +73,7 @@ async def scan_models(self):
11573 yield f"data: { json .dumps ({'progress' : processed / total , 'message' : f'跳过: { file_path .name } ' })} \n \n "
11674 continue
11775
118- model_hash = await self .calculate_model_hash_async (file_path )
76+ model_hash = await self .hash_utils . calculate_model_hash_async (file_path )
11977 print (f"计算得到哈希值: { model_hash } " )
12078 await self .fetch_model_info (model_hash , file_path , current_mtime )
12179 self .save_models_info () # 每次获取新信息后保存
@@ -259,107 +217,4 @@ def get_all_models_info(self) -> list:
259217 }
260218 for model_path in self .models_info .keys ()
261219 if str (model_path ).startswith (current_path )
262- ]
263-
264- # 创建 FastAPI 应用
265- app = FastAPI (title = "Stable Diffusion 模型管理器" )
266-
267- # 配置 CORS
268- app .add_middleware (
269- CORSMiddleware ,
270- allow_origins = ["*" ],
271- allow_credentials = True ,
272- allow_methods = ["*" ],
273- allow_headers = ["*" ],
274- )
275-
276- # 创建 ModelManager 实例
277- manager = ModelManager ()
278-
279- # 在创建 FastAPI 应用后添加
280- app .mount ("/static" , StaticFiles (directory = "static" ), name = "static" )
281-
282- @app .get ("/" )
283- async def read_root ():
284- return FileResponse ("templates/index.html" )
285-
286- @app .get ("/favicon.svg" )
287- async def get_favicon ():
288- return FileResponse ("templates/favicon.svg" )
289-
290- @app .get ("/api/models" )
291- async def get_models ():
292- """获取所有模型信息"""
293- return manager .get_all_models_info ()
294-
295- @app .post ("/api/path" )
296- async def update_path (path_update : PathUpdate ):
297- """更新模型路径"""
298- if not os .path .exists (path_update .path ):
299- raise HTTPException (status_code = 400 , detail = "路径不存在" )
300- manager .update_models_path (path_update .path )
301- return {"message" : "路径已更新" }
302-
303- @app .get ("/api/scan" )
304- async def scan_models_endpoint ():
305- """扫描模型"""
306- if not manager .models_path or not os .path .exists (manager .models_path ):
307- raise HTTPException (status_code = 400 , detail = "请先设置有效的模型目录路径" )
308- try :
309- return StreamingResponse (
310- manager .scan_models (),
311- media_type = "text/event-stream"
312- )
313- except Exception as e :
314- raise HTTPException (status_code = 500 , detail = str (e ))
315-
316- @app .get ("/api/config" )
317- async def get_config ():
318- """获取当前配置"""
319- return {
320- "models_path" : str (manager .models_path ) if manager .models_path else "" ,
321- "is_path_valid" : os .path .exists (manager .models_path ) if manager .models_path else False
322- }
323-
324- @app .post ("/api/select_directory" )
325- async def select_directory_endpoint ():
326- """选择目录"""
327- path = await select_directory ()
328- return {"path" : path }
329-
330- def find_free_port (start_port = 8080 , max_tries = 100 ):
331- """查找可用的端口号"""
332- for port in range (start_port , start_port + max_tries ):
333- try :
334- with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
335- s .bind (('127.0.0.1' , port ))
336- return port
337- except OSError :
338- continue
339- raise RuntimeError ('无法找到可用的端口' )
340-
341- def open_browser (port : int ):
342- """延迟一秒后打开浏览器"""
343- time .sleep (1 )
344- webbrowser .open (f'http://127.0.0.1:{ port } ' )
345-
346- if __name__ == "__main__" :
347- parser = argparse .ArgumentParser (description = '模型管理器' )
348- parser .add_argument ('--port' , type = int , default = None , help = 'Web界面端口' )
349- args = parser .parse_args ()
350-
351- # 获取可用端口
352- port = args .port or find_free_port ()
353-
354- # 加载已有的模型信息
355- manager .load_models_info ()
356-
357- # 在新线程中打开浏览器
358- threading .Thread (target = open_browser , args = (port ,), daemon = True ).start ()
359-
360- # 启动 FastAPI 服务器
361- uvicorn .run (
362- app ,
363- host = "127.0.0.1" ,
364- port = port
365- )
220+ ]
0 commit comments