forked from kning/modal-comfy-worker
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathworkflow.py
More file actions
162 lines (136 loc) · 4.46 KB
/
workflow.py
File metadata and controls
162 lines (136 loc) · 4.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from modal import (
Secret,
enter,
App,
Volume,
method,
exception,
functions,
asgi_app,
web_server,
)
from comfy.server import ComfyServer, ComfyConfig
from comfy.models import ExecutionCallbacks, ExecutionData
from lib.image import get_comfy_image
from lib.logger import logger
from lib.utils import get_time_ms
from prompt_constructor import WorkflowInput, construct_workflow_prompt
import os
from fastapi import FastAPI, HTTPException
local_snapshot_path = os.path.join(os.path.dirname(__file__), "snapshot.json")
local_prompt_path = os.path.join(os.path.dirname(__file__), "prompt.json")
github_secret = Secret.from_name("github-secret")
image = get_comfy_image(
local_snapshot_path=local_snapshot_path,
local_prompt_path=local_prompt_path,
github_secret=github_secret,
)
APP_NAME = "comfy-worker"
VOLUME_NAME = f"{APP_NAME}-volume"
app = App(APP_NAME)
volume = Volume.from_name(VOLUME_NAME, create_if_missing=True)
@app.cls(
image=image,
# Add in your secrets
secrets=[],
# Add in your volumes
volumes={"/root/ComfyUI/models": volume},
gpu="l4",
# allow_concurrent_inputs=3,
# concurrency_limit=10,
# timeout=38,
container_idle_timeout=60,
# keep_warm=1,
retries=1,
)
class ComfyWorkflow:
@enter()
def run_this_on_container_startup(self):
self.web_app = FastAPI()
self.server = ComfyServer()
self.server.start()
self.server.wait_until_ready()
@method()
async def infer(self, payload: WorkflowInput):
server_ws_connection = None
job_start_time = get_time_ms()
prompt = construct_workflow_prompt(payload)
try:
# Define callbacks for execution monitoring
callbacks = ExecutionCallbacks(
on_error=lambda error_data: (logger.error(error_data),),
on_done=lambda msg: (
logger.info("Job Completed. Sending Completion Event."),
),
on_ws_message=lambda type, msg: (
logger.info(f"Received message: {type} - {msg}"),
print(f"Received message: {type} - {msg}"),
),
on_start=lambda msg: (
logger.info(
f"Execution start took: {get_time_ms() - job_start_time} ms"
),
),
)
# Execute the prompt
execution_result = await self.server.execute(
data=ExecutionData(prompt=prompt, process_id="123"), callbacks=callbacks
)
print("execution result", execution_result)
return execution_result
except Exception as e:
logger.error(f"Error in execution: {str(e)}")
raise e
finally:
if server_ws_connection:
server_ws_connection.close()
web_app = FastAPI()
@web_app.post("/infer_sync")
async def infer(payload: WorkflowInput):
try:
execution_result = ComfyWorkflow().infer.remote(payload)
return execution_result
except Exception as e:
print("Error in infer", e)
raise HTTPException(status_code=500, detail=str(e))
@web_app.post("/infer_async")
async def infer_async(payload: WorkflowInput):
try:
call = ComfyWorkflow().infer.spawn(payload)
return {"call_id": call.object_id}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@web_app.get("/status/{call_id}")
async def status(call_id: str):
function_call = functions.FunctionCall.from_id(call_id)
try:
result = function_call.get(timeout=5)
except exception.OutputExpiredError:
result = {"result": None, "status": "expired"}
except TimeoutError:
result = {"result": None, "status": "pending"}
return {"result": result}
@web_app.post("/cancel/{call_id}")
async def cancel(call_id: str):
function_call = functions.FunctionCall.from_id(call_id)
function_call.cancel()
return {"call_id": call_id}
@app.function(image=image)
@asgi_app()
def asgi_app():
return web_app
@app.function(
allow_concurrent_inputs=10,
concurrency_limit=1,
image=image,
volumes={"/root/ComfyUI/models": volume},
container_idle_timeout=30,
timeout=1800,
gpu="l4",
)
@web_server(8188, startup_timeout=120)
def ui():
logger.info("Starting UI")
config = ComfyConfig(SERVER_HOST="0.0.0.0", SERVER_PORT=8188)
server = ComfyServer(config=config)
server.start()