@@ -47,6 +47,79 @@ async def health_check():
4747 return {"status" : "ok" }
4848
4949
50+
51+ async def coro_task_1_lookup_dict_received__send_loop (key , websocket : WebSocket , stop_event : asyncio .Event ):
52+ # Monitor for new requests
53+ try :
54+ while not stop_event .is_set ():
55+ # Check for new requests in ajet_remote_handler_received
56+ if (key in ajet_remote_handler_received ) and len (ajet_remote_handler_received [key ]) > 0 :
57+
58+ timeline_uuid = list (ajet_remote_handler_received [key ].keys ())[0 ]
59+
60+ # Get the next request
61+ new_req : TypeCompletionRequest = ajet_remote_handler_received [key ].pop (timeline_uuid )
62+
63+ assert timeline_uuid == new_req .timeline_uuid
64+
65+ # Move to in_progress
66+ ajet_remote_handler_in_progress [key ][timeline_uuid ] = new_req
67+
68+ # will be received by:
69+ # ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py
70+ # await asyncio.wait_for(websocket.recv(decode=False), timeout=0.25)
71+ await websocket .send_bytes (pickle .dumps (new_req ))
72+
73+ except WebSocketDisconnect :
74+ stop_event .set ()
75+ return
76+
77+ except Exception as e :
78+ stop_event .set ()
79+ print (f"Error in websocket handler: { e } " )
80+ return
81+
82+
83+ async def coro_task_2_lookup_dict_received__receive_loop (key , websocket : WebSocket , stop_event : asyncio .Event ):
84+ try :
85+ while not stop_event .is_set ():
86+ # Wait for client response:
87+ # ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py
88+ # await websocket.send(pickle.dumps(response))
89+ response_data = pickle .loads (await websocket .receive_bytes ())
90+
91+ if not isinstance (response_data , ChatCompletion ):
92+ stop_event .set ()
93+ assert response_data == "terminate" , "Invalid terminate signal from client"
94+ await websocket .close ()
95+ return
96+
97+ # Process the response
98+ openai_response : ChatCompletion = response_data
99+
100+ # see `ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py::response.id = timeline_uuid`
101+ timeline_uuid = openai_response .id
102+
103+ # Remove from in_progress
104+ if timeline_uuid in ajet_remote_handler_in_progress [key ]:
105+ ajet_remote_handler_in_progress [key ].pop (timeline_uuid )
106+
107+ # Add to completed if not discarded
108+ if (key not in ajet_remote_handler_discarded ) or (timeline_uuid not in ajet_remote_handler_discarded [key ]):
109+ # openai_response should already be a ChatCompletion object if client sent pickle
110+ ajet_remote_handler_completed [key ][timeline_uuid ] = openai_response
111+
112+ except WebSocketDisconnect :
113+ stop_event .set ()
114+ return
115+
116+ except Exception as e :
117+ stop_event .set ()
118+ print (f"Error in websocket handler: { e } " )
119+ return
120+
121+
122+
50123@app .websocket ("/hook/context_tracker_client_listen" )
51124async def context_tracker_client_listen (websocket : WebSocket ):
52125 """
@@ -72,76 +145,9 @@ async def context_tracker_client_listen(websocket: WebSocket):
72145 key = f"episode_uuid:{ episode_uuid } "
73146 active_websockets [key ] = websocket
74147
75- # Send acknowledgment (still JSON for compatibility or Pickle?)
76- # Let's use pickle for consistency on this socket
77- await websocket .send_bytes (pickle .dumps ({"status" : "connected" , "key" : key }))
78-
79- # Monitor for new requests
80- while True :
81- try :
82- # Check for new requests in ajet_remote_handler_received
83- if (key in ajet_remote_handler_received ) and len (ajet_remote_handler_received [key ]) > 0 :
84-
85- timeline_uuid = list (ajet_remote_handler_received [key ].keys ())[0 ]
86-
87- # Get the next request
88- new_req : TypeCompletionRequest = ajet_remote_handler_received [key ].pop (timeline_uuid )
89-
90- assert timeline_uuid == new_req .timeline_uuid
91-
92- # Move to in_progress
93- ajet_remote_handler_in_progress [key ][timeline_uuid ] = new_req
94-
95- # Send request to client
96- episode_uuid = new_req .episode_uuid
97-
98- # will be received by:
99- # ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py
100- # await asyncio.wait_for(websocket.recv(decode=False), timeout=0.25)
101- await websocket .send_bytes (pickle .dumps (new_req ))
102-
103- # Wait for client response:
104- # ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py
105- # await websocket.send(pickle.dumps(response))
106- response_data = pickle .loads (await websocket .receive_bytes ())
107-
108- if not isinstance (response_data , ChatCompletion ):
109- assert response_data == "terminate" , "Invalid terminate signal from client"
110- break
111-
112- # Process the response
113- openai_response : ChatCompletion = response_data
114-
115- # Remove from in_progress
116- if timeline_uuid in ajet_remote_handler_in_progress [key ]:
117- ajet_remote_handler_in_progress [key ].pop (timeline_uuid )
118-
119- # Add to completed if not discarded
120- if (key not in ajet_remote_handler_discarded ) or (timeline_uuid not in ajet_remote_handler_discarded [key ]):
121- # openai_response should already be a ChatCompletion object if client sent pickle
122- ajet_remote_handler_completed [key ][timeline_uuid ] = openai_response
123-
124- else :
125- # nothing to do yet, sleep a bit
126- await asyncio .sleep (0.25 )
127-
128- # try:
129- # # let's see if the client is still there
130- # response_data = pickle.loads(await websocket.receive_bytes())
131-
132- # # Check if it's a terminate signal
133- # if not isinstance(response_data, ChatCompletion):
134- # assert response_data == "terminate", "Invalid terminate signal from client"
135- # break
136-
137- # except asyncio.TimeoutError:
138- # pass # No message, continue monitoring
139-
140- except WebSocketDisconnect :
141- break
142- except Exception as e :
143- print (f"Error in websocket handler: { e } " )
144- break
148+ stop_event = asyncio .Event ()
149+ asyncio .create_task (coro_task_1_lookup_dict_received__send_loop (key , websocket , stop_event ))
150+ asyncio .create_task (coro_task_2_lookup_dict_received__receive_loop (key , websocket , stop_event ))
145151
146152 finally :
147153
@@ -289,11 +295,6 @@ def run_server():
289295 # Start server in a new thread
290296 self .server_thread = threading .Thread (target = run_server , daemon = True )
291297 self .server_thread .start ()
292-
293- # Give the server a moment to start
294- import time
295- time .sleep (1 )
296-
297298 return self .port
298299
299300 def stop (self ):
0 commit comments