-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
137 lines (101 loc) · 4.38 KB
/
app.py
File metadata and controls
137 lines (101 loc) · 4.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import asyncio
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Query
from pydantic import BaseModel
app = FastAPI(title="RunPod Load Balancing Worker")
# ---------------------------------------------------------------------------
# Request / response models
# ---------------------------------------------------------------------------
class GenerationRequest(BaseModel):
prompt: str
max_tokens: int = 100
temperature: float = 0.7
class GenerationResponse(BaseModel):
generated_text: str
# ---------------------------------------------------------------------------
# State
# ---------------------------------------------------------------------------
request_count = 0
active_ws_connections: list[WebSocket] = []
# ---------------------------------------------------------------------------
# HTTP endpoints
# ---------------------------------------------------------------------------
@app.get("/ping")
async def health_check():
"""Required by RunPod to monitor worker health.
Return 204 while initialising, 200 when ready.
RunPod measures cold-start time as the period between /ping first
returning 204 and it first returning 200.
"""
return {"status": "healthy"}
@app.post("/generate", response_model=GenerationResponse)
async def generate(request: GenerationRequest):
"""Text generation endpoint (swap the mock body for a real model call)."""
global request_count
request_count += 1
# TODO: replace with actual model inference
generated_text = (
f"Response to: '{request.prompt}' "
f"(tokens={request.max_tokens}, temp={request.temperature}, "
f"request #{request_count})"
)
return {"generated_text": generated_text}
@app.get("/stats")
async def stats():
"""Return simple request stats."""
return {
"total_requests": request_count,
"active_websocket_connections": len(active_ws_connections),
}
# ---------------------------------------------------------------------------
# WebSocket endpoint
# ---------------------------------------------------------------------------
@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
token: str = Query(default=None, description="Ephemeral session token passed in the URL"),
):
"""Streaming WebSocket endpoint.
Clients can send JSON messages of the form:
{"prompt": "...", "max_tokens": 100, "temperature": 0.7}
The server streams token-by-token responses back (simulated here) and
sends a final {"done": true} frame when generation is complete.
RunPod load-balancing endpoints support WebSocket connections on the
same PORT as the HTTP server, so no extra configuration is needed.
"""
# Validate ephemeral token when present.
# Your reverse proxy (Traefik) strips/validates the token before it reaches
# this worker, so this check is a defence-in-depth fallback.
expected_token = os.getenv("WS_TOKEN")
if expected_token and token != expected_token:
await websocket.close(code=1008, reason="Unauthorized")
return
await websocket.accept()
active_ws_connections.append(websocket)
try:
while True:
data = await websocket.receive_json()
prompt: str = data.get("prompt", "")
max_tokens: int = int(data.get("max_tokens", 50))
temperature: float = float(data.get("temperature", 0.7))
if not prompt:
await websocket.send_json({"error": "prompt is required"})
continue
# Simulate streaming generation token by token
words = f"Streaming response for: '{prompt}'".split()
for i, word in enumerate(words[:max_tokens]):
await websocket.send_json({"token": word, "index": i})
await asyncio.sleep(0.05) # simulate inference latency
await websocket.send_json({"done": True, "total_tokens": min(len(words), max_tokens)})
except WebSocketDisconnect:
pass
finally:
active_ws_connections.remove(websocket)
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", 80))
print(f"Starting server on port {port}")
uvicorn.run(app, host="0.0.0.0", port=port)