Skip to content

Commit dc0b22f

Browse files
author
yinsu.zs
committed
refactor
Change-Id: I7a5940740cd62ea924248a701e18e81826ee99b1
1 parent a3f7962 commit dc0b22f

21 files changed

Lines changed: 2604 additions & 2657 deletions

src/code/agent/main.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,3 @@
1-
# 第一步:全局替换 print 函数(必须在所有其他导入之前)
2-
import builtins
3-
from datetime import datetime
4-
5-
_original_print = builtins.print
6-
7-
def timestamped_print(*args, **kwargs):
8-
"""带时间戳的 print 函数"""
9-
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
10-
message = ' '.join(str(arg) for arg in args)
11-
timestamped_message = f"{timestamp} {message}"
12-
_original_print(timestamped_message, **kwargs)
13-
14-
builtins.print = timestamped_print
15-
16-
# 第二步:初始化日志系统
17-
from utils.logger import init_logging
18-
init_logging()
19-
20-
# 第三步:导入其他模块
211
from routes.routes import Routes
222

233
r = Routes()
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
import json
2+
import os
3+
import time
4+
import traceback
5+
6+
from flask import Blueprint, Flask, jsonify, request
7+
from flask_sock import Sock
8+
9+
from services.management_service import ManagementService, BackendStatus
10+
from utils.logger import log
11+
from services.gateway import (
12+
CpuGatewayService,
13+
HistoryGatewayService,
14+
get_task_queue
15+
)
16+
from services.process.websocket.websocket_manager import ws_manager
17+
from services.serverlessapi.serverless_api_service import ServerlessApiService
18+
19+
20+
class CpuRoutes:
21+
"""CPU模式路由:处理任务队列和异步转发"""
22+
23+
def __init__(self):
24+
# HTTP 路由使用 /api 前缀
25+
self.bp = Blueprint("cpu_routes", __name__, url_prefix="/api")
26+
# WebSocket 路由使用根路径(保持 ComfyUI 兼容性)
27+
self.ws_bp = Blueprint("cpu_ws", __name__)
28+
self.service = ManagementService() # 单例模式,直接创建实例
29+
self.sock = Sock()
30+
self.sock.bp = self.ws_bp # 将 WebSocket 绑定到单独的 Blueprint
31+
self.setup_routes()
32+
33+
def register(self, app: Flask):
34+
app.register_blueprint(self.bp)
35+
app.register_blueprint(self.ws_bp)
36+
37+
def setup_routes(self):
38+
"""设置所有路由"""
39+
self._register_websocket()
40+
self._register_queue_handler()
41+
self._register_prompt_handler()
42+
self._register_serverless_run_handler()
43+
self._register_history_handler()
44+
# 通过环境变量控制是否禁用 userdata 保存
45+
# DISABLE_USERDATA_SAVE=true 时禁用
46+
disable_userdata = os.environ.get('DISABLE_USERDATA_SAVE', '').lower() in ('true', '1', 'yes')
47+
if disable_userdata:
48+
self._register_userdata_handler()
49+
50+
def _check_backend_status(self):
51+
"""
52+
检查后端服务状态
53+
54+
Returns:
55+
tuple: (is_valid, error_response)
56+
is_valid为True时error_response为None
57+
is_valid为False时error_response为错误响应
58+
"""
59+
backend_status = self.service.status
60+
if backend_status not in (BackendStatus.RUNNING, BackendStatus.SAVING):
61+
return False, (jsonify({
62+
"status": "failed",
63+
"message": "Please start your comfyui/sd service first"
64+
}), 500)
65+
return True, None
66+
67+
def _register_websocket(self):
68+
@self.sock.route("/ws")
69+
def comfyui_compatible_ws(ws):
70+
"""
71+
CPU函数接收ComfyUI原生的WebSocket连接
72+
保持与ComfyUI前端完全兼容,但推送的是基于任务队列和状态轮询的真实状态
73+
74+
支持重连机制:
75+
- 客户端可通过 ?clientId=xxx 参数传递已有的 client_id
76+
- 重连时会复用相同的 client_id,确保能接收到之前任务的状态更新
77+
"""
78+
try:
79+
# 从查询参数获取 clientId(ComfyUI 前端重连时会传递)
80+
from flask import request as flask_request
81+
client_id = flask_request.args.get('clientId', '')
82+
83+
if client_id:
84+
# 复用已有的 client_id(重连场景)
85+
log("INFO", f"WebSocket reconnecting with existing client_id: {client_id}")
86+
else:
87+
# 生成新的 client_id(首次连接)
88+
client_id = f"cpu_client_{int(time.time() * 1000)}"
89+
log("INFO", f"New ComfyUI WebSocket connection with client_id: {client_id}")
90+
91+
# 添加连接到管理器
92+
ws_manager.add_connection(ws)
93+
94+
# 发送初始状态消息(模拟ComfyUI原生行为)
95+
try:
96+
ws.send(json.dumps({
97+
"type": "status",
98+
"data": {
99+
"sid": client_id,
100+
"status": {
101+
"exec_info": {
102+
"queue_remaining": get_task_queue()._get_pending_task_count()
103+
}
104+
}
105+
}
106+
}))
107+
except Exception as e:
108+
log("ERROR", f"Failed to send initial status: {e}")
109+
return
110+
111+
# 设置客户端ID,用于后续关联任务
112+
setattr(ws, '_comfyui_client_id', client_id)
113+
114+
# 将客户端ID与连接关联在WebSocketManager中
115+
ws_manager.associate_client_id_with_connection(ws, client_id)
116+
117+
# 如果是重连,重新订阅该客户端的所有进行中的任务
118+
ws_manager.resubscribe_client_tasks(ws, client_id)
119+
120+
# TODO 可能是多余的
121+
while True:
122+
try:
123+
message = ws.receive()
124+
log("DEBUG", f"Received message from ComfyUI frontend: {message[:100]}...")
125+
126+
except Exception as e:
127+
error_str = str(e)
128+
if "Connection closed" in error_str or "closed" in error_str.lower():
129+
log("INFO", f"Connection closed by client")
130+
break
131+
log("ERROR", f"Error receiving message: {e}\n{traceback.format_exc()}")
132+
break
133+
134+
except Exception as e:
135+
log("ERROR", f"Connection error: {e}\n{traceback.format_exc()}")
136+
finally:
137+
try:
138+
ws_manager.remove_connection(ws)
139+
log("INFO", f"ComfyUI WebSocket connection closed")
140+
except Exception as e:
141+
log("ERROR", f"Error removing connection: {e}")
142+
143+
def _register_queue_handler(self):
144+
@self.bp.route("/queue", methods=["GET", "POST"])
145+
def handle_queue():
146+
is_valid, error_response = self._check_backend_status()
147+
if not is_valid:
148+
return error_response
149+
150+
try:
151+
gateway_service = CpuGatewayService()
152+
153+
if request.method == "GET":
154+
return gateway_service.handle_queue_get_request()
155+
elif request.method == "POST":
156+
return gateway_service.handle_queue_post_request()
157+
else:
158+
return jsonify({
159+
"error": {
160+
"type": "method_not_allowed",
161+
"message": f"Method {request.method} not allowed"
162+
}
163+
}), 405
164+
165+
except Exception as e:
166+
error_msg = f"Failed to handle queue request: {str(e)}"
167+
log("ERROR", f"{error_msg}\nStacktrace:\n{traceback.format_exc()}")
168+
169+
return jsonify({
170+
"error": {
171+
"type": "queue_operation_error",
172+
"message": error_msg
173+
}
174+
}), 500
175+
176+
def _register_prompt_handler(self):
177+
@self.bp.route("/prompt", methods=["POST"])
178+
def handle_prompt():
179+
is_valid, error_response = self._check_backend_status()
180+
if not is_valid:
181+
return error_response
182+
183+
try:
184+
gateway_service = CpuGatewayService()
185+
return gateway_service.handle_prompt_request_async()
186+
except Exception as e:
187+
error_msg = f"Failed to handle prompt request: {str(e)}"
188+
log("ERROR", f"{error_msg}\nStacktrace:\n{traceback.format_exc()}")
189+
190+
return jsonify({
191+
"error": {
192+
"type": "prompt_operation_error",
193+
"message": error_msg
194+
}
195+
}), 500
196+
197+
def _register_serverless_run_handler(self):
198+
@self.bp.route("/serverless/run", methods=["POST"])
199+
def handle_serverless_run():
200+
"""
201+
处理 /api/serverless/run 请求,支持同步和异步两种模式
202+
203+
调用方式:
204+
- 默认: 异步调用(与 /api/prompt 处理一致)
205+
- Header X-Art-Invocation-Type: Sync 时: 同步调用,等待GPU返回结果
206+
207+
异步模式:
208+
- 将请求转发到GPU函数(异步调用)
209+
- 返回任务ID,前端通过任务ID轮询获取结果
210+
- 使用任务队列跟踪任务状态
211+
212+
同步模式:
213+
- 将请求转发到GPU函数(同步调用)
214+
- 等待GPU处理完成并返回结果
215+
- 直接返回结果给客户端
216+
"""
217+
is_valid, error_response = self._check_backend_status()
218+
if not is_valid:
219+
return error_response
220+
221+
try:
222+
gateway_service = CpuGatewayService()
223+
224+
# 检查调用类型:Header X-Art-Invocation-Type: Sync 表示同步调用
225+
invocation_type = request.headers.get("X-Art-Invocation-Type", "").strip()
226+
is_sync = invocation_type.lower() == "sync"
227+
228+
if is_sync:
229+
log("DEBUG", f"Processing /serverless/run in SYNC mode (X-Art-Invocation-Type: Sync)")
230+
return gateway_service.handle_serverless_run_sync()
231+
else:
232+
log("DEBUG", f"Processing /serverless/run in ASYNC mode (default)")
233+
return gateway_service.handle_serverless_run_async()
234+
235+
except Exception as e:
236+
error_msg = f"Failed to handle serverless run request: {str(e)}"
237+
log("ERROR", f"{error_msg}\nStacktrace:\n{traceback.format_exc()}")
238+
239+
return jsonify({
240+
"error": {
241+
"type": "serverless_run_error",
242+
"message": error_msg
243+
}
244+
}), 500
245+
246+
def _register_history_handler(self):
247+
@self.bp.route("/history", methods=["GET", "POST", "DELETE"])
248+
@self.bp.route("/history/<path:subpath>", methods=["GET", "POST", "DELETE"])
249+
def handle_history(subpath=""):
250+
is_valid, error_response = self._check_backend_status()
251+
if not is_valid:
252+
return error_response
253+
254+
try:
255+
history_gateway = HistoryGatewayService()
256+
path = f"api/history/{subpath}" if subpath else "api/history"
257+
return history_gateway.handle_history_request(path)
258+
except Exception as e:
259+
error_msg = f"Failed to handle history request: {str(e)}"
260+
log("ERROR", f"{error_msg}\nStacktrace:\n{traceback.format_exc()}")
261+
262+
return jsonify({
263+
"error": {
264+
"type": "history_operation_error",
265+
"message": error_msg
266+
}
267+
}), 500
268+
269+
def _register_userdata_handler(self):
270+
"""在 prod 模式下,阻止保存 userdata 文件"""
271+
@self.bp.route("/userdata/<path:file>", methods=["POST"])
272+
def block_userdata_save(file):
273+
log("WARN", f"Attempt to save userdata blocked in prod mode: {file}")
274+
return jsonify({
275+
"error": {
276+
"type": "forbidden",
277+
"message": "Saving workflow is disabled in prod mode"
278+
}
279+
}), 403

0 commit comments

Comments
 (0)