Skip to content

Commit f74d1dd

Browse files
committed
to fully async
1 parent ca1cf82 commit f74d1dd

File tree

2 files changed

+78
-75
lines changed

2 files changed

+78
-75
lines changed

ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ async def llm_infer(
8383
tool_choice="auto",
8484
)
8585

86+
# this is an important id assignment
87+
response.id = timeline_uuid
8688
assert isinstance(response, ChatCompletion)
8789
return response
8890

ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py

Lines changed: 76 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,79 @@ async def health_check():
4747
return {"status": "ok"}
4848

4949

50+
51+
async def coro_task_1_lookup_dict_received__send_loop(key, websocket: WebSocket, stop_event: asyncio.Event):
52+
# Monitor for new requests
53+
try:
54+
while not stop_event.is_set():
55+
# Check for new requests in ajet_remote_handler_received
56+
if (key in ajet_remote_handler_received) and len(ajet_remote_handler_received[key]) > 0:
57+
58+
timeline_uuid = list(ajet_remote_handler_received[key].keys())[0]
59+
60+
# Get the next request
61+
new_req: TypeCompletionRequest = ajet_remote_handler_received[key].pop(timeline_uuid)
62+
63+
assert timeline_uuid == new_req.timeline_uuid
64+
65+
# Move to in_progress
66+
ajet_remote_handler_in_progress[key][timeline_uuid] = new_req
67+
68+
# will be received by:
69+
# ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py
70+
# await asyncio.wait_for(websocket.recv(decode=False), timeout=0.25)
71+
await websocket.send_bytes(pickle.dumps(new_req))
72+
73+
except WebSocketDisconnect:
74+
stop_event.set()
75+
return
76+
77+
except Exception as e:
78+
stop_event.set()
79+
print(f"Error in websocket handler: {e}")
80+
return
81+
82+
83+
async def coro_task_2_lookup_dict_received__receive_loop(key, websocket: WebSocket, stop_event: asyncio.Event):
84+
try:
85+
while not stop_event.is_set():
86+
# Wait for client response:
87+
# ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py
88+
# await websocket.send(pickle.dumps(response))
89+
response_data = pickle.loads(await websocket.receive_bytes())
90+
91+
if not isinstance(response_data, ChatCompletion):
92+
stop_event.set()
93+
assert response_data == "terminate", "Invalid terminate signal from client"
94+
await websocket.close()
95+
return
96+
97+
# Process the response
98+
openai_response: ChatCompletion = response_data
99+
100+
# see `ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py::response.id = timeline_uuid`
101+
timeline_uuid = openai_response.id
102+
103+
# Remove from in_progress
104+
if timeline_uuid in ajet_remote_handler_in_progress[key]:
105+
ajet_remote_handler_in_progress[key].pop(timeline_uuid)
106+
107+
# Add to completed if not discarded
108+
if (key not in ajet_remote_handler_discarded) or (timeline_uuid not in ajet_remote_handler_discarded[key]):
109+
# openai_response should already be a ChatCompletion object if client sent pickle
110+
ajet_remote_handler_completed[key][timeline_uuid] = openai_response
111+
112+
except WebSocketDisconnect:
113+
stop_event.set()
114+
return
115+
116+
except Exception as e:
117+
stop_event.set()
118+
print(f"Error in websocket handler: {e}")
119+
return
120+
121+
122+
50123
@app.websocket("/hook/context_tracker_client_listen")
51124
async def context_tracker_client_listen(websocket: WebSocket):
52125
"""
@@ -72,76 +145,9 @@ async def context_tracker_client_listen(websocket: WebSocket):
72145
key = f"episode_uuid:{episode_uuid}"
73146
active_websockets[key] = websocket
74147

75-
# Send acknowledgment (still JSON for compatibility or Pickle?)
76-
# Let's use pickle for consistency on this socket
77-
await websocket.send_bytes(pickle.dumps({"status": "connected", "key": key}))
78-
79-
# Monitor for new requests
80-
while True:
81-
try:
82-
# Check for new requests in ajet_remote_handler_received
83-
if (key in ajet_remote_handler_received) and len(ajet_remote_handler_received[key]) > 0:
84-
85-
timeline_uuid = list(ajet_remote_handler_received[key].keys())[0]
86-
87-
# Get the next request
88-
new_req: TypeCompletionRequest = ajet_remote_handler_received[key].pop(timeline_uuid)
89-
90-
assert timeline_uuid == new_req.timeline_uuid
91-
92-
# Move to in_progress
93-
ajet_remote_handler_in_progress[key][timeline_uuid] = new_req
94-
95-
# Send request to client
96-
episode_uuid = new_req.episode_uuid
97-
98-
# will be received by:
99-
# ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py
100-
# await asyncio.wait_for(websocket.recv(decode=False), timeout=0.25)
101-
await websocket.send_bytes(pickle.dumps(new_req))
102-
103-
# Wait for client response:
104-
# ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py
105-
# await websocket.send(pickle.dumps(response))
106-
response_data = pickle.loads(await websocket.receive_bytes())
107-
108-
if not isinstance(response_data, ChatCompletion):
109-
assert response_data == "terminate", "Invalid terminate signal from client"
110-
break
111-
112-
# Process the response
113-
openai_response: ChatCompletion = response_data
114-
115-
# Remove from in_progress
116-
if timeline_uuid in ajet_remote_handler_in_progress[key]:
117-
ajet_remote_handler_in_progress[key].pop(timeline_uuid)
118-
119-
# Add to completed if not discarded
120-
if (key not in ajet_remote_handler_discarded) or (timeline_uuid not in ajet_remote_handler_discarded[key]):
121-
# openai_response should already be a ChatCompletion object if client sent pickle
122-
ajet_remote_handler_completed[key][timeline_uuid] = openai_response
123-
124-
else:
125-
# nothing to do yet, sleep a bit
126-
await asyncio.sleep(0.25)
127-
128-
# try:
129-
# # let's see if the client is still there
130-
# response_data = pickle.loads(await websocket.receive_bytes())
131-
132-
# # Check if it's a terminate signal
133-
# if not isinstance(response_data, ChatCompletion):
134-
# assert response_data == "terminate", "Invalid terminate signal from client"
135-
# break
136-
137-
# except asyncio.TimeoutError:
138-
# pass # No message, continue monitoring
139-
140-
except WebSocketDisconnect:
141-
break
142-
except Exception as e:
143-
print(f"Error in websocket handler: {e}")
144-
break
148+
stop_event = asyncio.Event()
149+
asyncio.create_task(coro_task_1_lookup_dict_received__send_loop(key, websocket, stop_event))
150+
asyncio.create_task(coro_task_2_lookup_dict_received__receive_loop(key, websocket, stop_event))
145151

146152
finally:
147153

@@ -289,11 +295,6 @@ def run_server():
289295
# Start server in a new thread
290296
self.server_thread = threading.Thread(target=run_server, daemon=True)
291297
self.server_thread.start()
292-
293-
# Give the server a moment to start
294-
import time
295-
time.sleep(1)
296-
297298
return self.port
298299

299300
def stop(self):

0 commit comments

Comments
 (0)