Skip to content

Commit 9fba5d5

Browse files
committed
ws -> sse for stability
1 parent 01bec90 commit 9fba5d5

File tree

2 files changed

+176
-153
lines changed

2 files changed

+176
-153
lines changed

ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py

Lines changed: 64 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from openai.types.chat.chat_completion import ChatCompletion
1111

1212
import pickle
13-
import websockets
13+
import httpx
14+
import logging
15+
logging.getLogger("httpx").setLevel(logging.WARNING)
16+
1417
import base64
1518
import json
1619

@@ -96,7 +99,7 @@ def should_terminate(self) -> bool:
9699

97100
def begin_service(self):
98101
"""
99-
Starts the websocket service loop.
102+
Starts the SSE service loop.
100103
"""
101104
t = threading.Thread(target=lambda: asyncio.run(self._ensure_service_loop()), daemon=True)
102105
t.start()
@@ -112,7 +115,7 @@ async def _ensure_service_loop(self):
112115
async def _service_loop(self):
113116
"""
114117
In fact this is not a service,
115-
it is a client that pretends to be a service, by interacting with a local websocket interchange server.
118+
it is a client that pretends to be a service, by interacting with a local interchange server via SSE.
116119
117120
This design is for efficiency
118121
"""
@@ -121,43 +124,66 @@ async def _service_loop(self):
121124

122125
port = os.getenv("AJET_DAT_INTERCHANGE_PORT")
123126
assert port is not None, "AJET_DAT_INTERCHANGE_PORT env var must be set"
124-
uri = f"ws://127.0.0.1:{port}/hook/context_tracker_client_listen"
125127

126-
async with websockets.connect(uri, ping_timeout=3600, open_timeout=16) as websocket:
128+
base_url = f"http://127.0.0.1:{port}"
129+
listen_url = f"{base_url}/hook/context_tracker_client_listen"
130+
response_url = f"{base_url}/hook/context_tracker_client_response"
131+
key = f"episode_uuid:{self.episode_uuid}"
132+
133+
async with httpx.AsyncClient(timeout=None) as client:
134+
try:
135+
async with client.stream("GET", listen_url, params={"episode_uuid": self.episode_uuid}, timeout=None) as response:
136+
async for line in response.aiter_lines():
137+
if self.should_terminate:
138+
break
139+
140+
if not line.strip():
141+
continue
142+
143+
if line.startswith(":"): # keepalive
144+
continue
145+
146+
if line.startswith("data: "):
147+
data = line[6:].strip()
148+
if not data:
149+
continue
150+
151+
try:
152+
parsed_msg = InterchangeCompletionRequest.model_validate_json(data)
153+
154+
result = await self.llm_infer(
155+
req=parsed_msg.completion_request,
156+
timeline_uuid=parsed_msg.timeline_uuid,
157+
agent_name=parsed_msg.agent_name,
158+
target_tag=parsed_msg.target_tag,
159+
episode_uuid=parsed_msg.episode_uuid,
160+
)
161+
162+
# Send response back
163+
await client.post(
164+
response_url,
165+
params={"key": key},
166+
content=pickle.dumps(result),
167+
headers={"Content-Type": "application/octet-stream"}
168+
)
169+
170+
except Exception as e:
171+
logger.error(f"Error processing SSE event: {e}")
172+
continue
173+
174+
except httpx.RequestError as e:
175+
logger.warning(f"SSE connection error: {e}")
176+
raise # Let ensure_service_loop handle restart
177+
178+
# Send terminate signal if we are exiting cleanly
127179
try:
128-
# Send initialization parameters
129-
# Sending as a list [agent_name, target_tag, episode_uuid] to match "input (a,b,c)" structure
130-
await websocket.send(pickle.dumps(f"episode_uuid:{self.episode_uuid}"))
131-
132-
while not self.should_terminate:
133-
134-
try:
135-
# wait message from ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py
136-
parsed_msg_str: str = pickle.loads(
137-
await asyncio.wait_for(websocket.recv(decode=False), timeout=0.25)
138-
)
139-
parsed_msg:InterchangeCompletionRequest = InterchangeCompletionRequest(**json.loads(parsed_msg_str))
140-
141-
response = await self.llm_infer(
142-
req=parsed_msg.completion_request,
143-
timeline_uuid=parsed_msg.timeline_uuid,
144-
agent_name=parsed_msg.agent_name,
145-
target_tag=parsed_msg.target_tag,
146-
episode_uuid=parsed_msg.episode_uuid,
147-
)
148-
await websocket.send(pickle.dumps(response))
149-
150-
except asyncio.TimeoutError:
151-
# 0.25s timeout, loop back to check should_terminate
152-
continue
153-
except websockets.exceptions.ConnectionClosed:
154-
logger.warning("Websocket connection closed by server")
155-
return # Exit inner loop to reconnect or finish
156-
157-
await websocket.send(pickle.dumps("terminate"))
158-
159-
except (OSError, IOError) as e:
160-
logger.warning(f"Websocket connection error: {e}")
180+
await client.post(
181+
response_url,
182+
params={"key": key},
183+
content=pickle.dumps("terminate"),
184+
headers={"Content-Type": "application/octet-stream"}
185+
)
186+
except:
161187
pass
162188

163189

0 commit comments

Comments
 (0)