1010from openai .types .chat .chat_completion import ChatCompletion
1111
1212import pickle
13- import websockets
13+ import httpx
14+ import logging
15+ logging .getLogger ("httpx" ).setLevel (logging .WARNING )
16+
1417import base64
1518import json
1619
@@ -96,7 +99,7 @@ def should_terminate(self) -> bool:
9699
97100 def begin_service (self ):
98101 """
99- Starts the websocket service loop.
102+ Starts the SSE service loop.
100103 """
101104 t = threading .Thread (target = lambda : asyncio .run (self ._ensure_service_loop ()), daemon = True )
102105 t .start ()
@@ -112,7 +115,7 @@ async def _ensure_service_loop(self):
112115 async def _service_loop (self ):
113116 """
114117 In fact this is not a service,
115- it is a client that pretends to be a service, by interacting with a local websocket interchange server.
118+ it is a client that pretends to be a service, by interacting with a local interchange server via SSE .
116119
117120 This design is for efficiency
118121 """
@@ -121,43 +124,66 @@ async def _service_loop(self):
121124
122125 port = os .getenv ("AJET_DAT_INTERCHANGE_PORT" )
123126 assert port is not None , "AJET_DAT_INTERCHANGE_PORT env var must be set"
124- uri = f"ws://127.0.0.1:{ port } /hook/context_tracker_client_listen"
125127
126- async with websockets .connect (uri , ping_timeout = 3600 , open_timeout = 16 ) as websocket :
128+ base_url = f"http://127.0.0.1:{ port } "
129+ listen_url = f"{ base_url } /hook/context_tracker_client_listen"
130+ response_url = f"{ base_url } /hook/context_tracker_client_response"
131+ key = f"episode_uuid:{ self .episode_uuid } "
132+
133+ async with httpx .AsyncClient (timeout = None ) as client :
134+ try :
135+ async with client .stream ("GET" , listen_url , params = {"episode_uuid" : self .episode_uuid }, timeout = None ) as response :
136+ async for line in response .aiter_lines ():
137+ if self .should_terminate :
138+ break
139+
140+ if not line .strip ():
141+ continue
142+
143+ if line .startswith (":" ): # keepalive
144+ continue
145+
146+ if line .startswith ("data: " ):
147+ data = line [6 :].strip ()
148+ if not data :
149+ continue
150+
151+ try :
152+ parsed_msg = InterchangeCompletionRequest .model_validate_json (data )
153+
154+ result = await self .llm_infer (
155+ req = parsed_msg .completion_request ,
156+ timeline_uuid = parsed_msg .timeline_uuid ,
157+ agent_name = parsed_msg .agent_name ,
158+ target_tag = parsed_msg .target_tag ,
159+ episode_uuid = parsed_msg .episode_uuid ,
160+ )
161+
162+ # Send response back
163+ await client .post (
164+ response_url ,
165+ params = {"key" : key },
166+ content = pickle .dumps (result ),
167+ headers = {"Content-Type" : "application/octet-stream" }
168+ )
169+
170+ except Exception as e :
171+ logger .error (f"Error processing SSE event: { e } " )
172+ continue
173+
174+ except httpx .RequestError as e :
175+ logger .warning (f"SSE connection error: { e } " )
176+ raise # Let ensure_service_loop handle restart
177+
178+ # Send terminate signal if we are exiting cleanly
127179 try :
128- # Send initialization parameters
129- # Sending as a list [agent_name, target_tag, episode_uuid] to match "input (a,b,c)" structure
130- await websocket .send (pickle .dumps (f"episode_uuid:{ self .episode_uuid } " ))
131-
132- while not self .should_terminate :
133-
134- try :
135- # wait message from ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py
136- parsed_msg_str : str = pickle .loads (
137- await asyncio .wait_for (websocket .recv (decode = False ), timeout = 0.25 )
138- )
139- parsed_msg :InterchangeCompletionRequest = InterchangeCompletionRequest (** json .loads (parsed_msg_str ))
140-
141- response = await self .llm_infer (
142- req = parsed_msg .completion_request ,
143- timeline_uuid = parsed_msg .timeline_uuid ,
144- agent_name = parsed_msg .agent_name ,
145- target_tag = parsed_msg .target_tag ,
146- episode_uuid = parsed_msg .episode_uuid ,
147- )
148- await websocket .send (pickle .dumps (response ))
149-
150- except asyncio .TimeoutError :
151- # 0.25s timeout, loop back to check should_terminate
152- continue
153- except websockets .exceptions .ConnectionClosed :
154- logger .warning ("Websocket connection closed by server" )
155- return # Exit inner loop to reconnect or finish
156-
157- await websocket .send (pickle .dumps ("terminate" ))
158-
159- except (OSError , IOError ) as e :
160- logger .warning (f"Websocket connection error: { e } " )
180+ await client .post (
181+ response_url ,
182+ params = {"key" : key },
183+ content = pickle .dumps ("terminate" ),
184+ headers = {"Content-Type" : "application/octet-stream" }
185+ )
186+ except :
161187 pass
162188
163189
0 commit comments