-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathremote_server.py
More file actions
118 lines (90 loc) · 3.75 KB
/
Copy pathremote_server.py
File metadata and controls
118 lines (90 loc) · 3.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import random
import threading
import argparse
import uvicorn
from fastapi import FastAPI
from openai import OpenAI
import logging
from eval_protocol import Status, InitRequest, ElasticsearchDirectHttpHandler, RolloutIdFilter
app = FastAPI()
# attach handler to root logger
handler = ElasticsearchDirectHttpHandler()
logging.getLogger().addHandler(handler)
force_early_error_message = None
@app.post("/init")
def init(req: InitRequest):
if req.elastic_search_config:
handler.configure(req.elastic_search_config)
# attach rollout_id filter to logger
logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}")
logger.addFilter(RolloutIdFilter(req.metadata.rollout_id))
# Kick off worker thread that does a single-turn chat via Langfuse OpenAI integration
def _worker():
try:
if not req.messages:
raise ValueError("messages is required")
model = req.completion_params.get("model")
if not model:
raise ValueError("model is required in completion_params")
completion_kwargs = {
"model": model,
"messages": req.messages,
}
# Apply model_kwargs if present
if req.completion_params.get("model_kwargs"):
model_kwargs = req.completion_params["model_kwargs"]
if isinstance(model_kwargs, dict):
completion_kwargs.update(model_kwargs)
if req.tools:
completion_kwargs["tools"] = req.tools
logger.info(f"Final completion_kwargs: {completion_kwargs}")
client = OpenAI(base_url=req.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))
logger.info(f"Sending completion request to model {model}")
completion = client.chat.completions.create(**completion_kwargs)
logger.info(f"Completed response: {completion}")
# If force_early_error is set via command-line arg, log the error and return early
if force_early_error_message:
logger.error(
force_early_error_message,
extra={"status": Status.rollout_error(force_early_error_message)},
)
raise RuntimeError(force_early_error_message)
except Exception as e:
# Best-effort; mark as done even on error to unblock polling
logger.error(f"❌ Error in rollout {req.metadata.rollout_id}: {e}")
pass
finally:
if not force_early_error_message:
logger.info(
f"Rollout {req.metadata.rollout_id} completed",
extra={"status": Status.rollout_finished()},
)
t = threading.Thread(target=_worker, daemon=True)
t.start()
def main():
global force_early_error_message
parser = argparse.ArgumentParser(description="Run the remote server for evaluation protocol")
parser.add_argument(
"--host",
type=str,
default=os.getenv("REMOTE_SERVER_HOST", "127.0.0.1"),
help="Host to bind the server to (default: 127.0.0.1 or REMOTE_SERVER_HOST env var)",
)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("REMOTE_SERVER_PORT", "3000")),
help="Port to bind the server to (default: 3000 or REMOTE_SERVER_PORT env var)",
)
parser.add_argument(
"--force-early-error",
type=str,
default=None,
help="If set, /init will immediately return after logging a rollout_error with this message",
)
args = parser.parse_args()
force_early_error_message = args.force_early_error
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
main()