Skip to content

Commit 855cb2f

Browse files
committed
redis
1 parent c8efc5e commit 855cb2f

File tree

5 files changed

+264
-431
lines changed

5 files changed

+264
-431
lines changed

ajet/task_rollout/native_parallel_worker.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -450,11 +450,11 @@ def rollout_dynamic( # noqa: C901
450450

451451
logger.info(print_buffer)
452452

453-
for tracker in tracker_array:
454-
# average of gourp success rate
455-
tracker.current_batch_success_rate = np.mean(task_success_rate)
456-
# average of gourp average reward
457-
tracker.current_batch_reward = np.mean(task_group_reward)
453+
# for tracker in tracker_array:
454+
# # average of gourp success rate
455+
# tracker.current_batch_success_rate = np.mean(task_success_rate)
456+
# # average of gourp average reward
457+
# tracker.current_batch_reward = np.mean(task_group_reward)
458458

459459
return tracker_array
460460

ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py

Lines changed: 121 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,19 @@
33
import json
44
import threading
55
import os
6+
import redis
67
import time
78
from loguru import logger
89
from typing import Optional, List, Dict, Any, Union, TYPE_CHECKING
910
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ChatCompletionResponse
1011
from openai.types.chat.chat_completion import ChatCompletion
12+
from redis.exceptions import TimeoutError
13+
14+
from functools import cache
1115

1216
import pickle
1317
import httpx
18+
import zmq
1419
import logging
1520
logging.getLogger("httpx").setLevel(logging.WARNING)
1621

@@ -51,6 +56,24 @@ def generate_auth_token(agent_name, target_tag, episode_uuid):
5156
return auth_token
5257

5358

59+
@cache
60+
def get_redis_connection_pool():
61+
pool = redis.BlockingConnectionPool(
62+
host='localhost',
63+
port=6379,
64+
max_connections=256,
65+
socket_timeout=30,
66+
socket_connect_timeout=30,
67+
retry_on_timeout=True
68+
)
69+
return pool
70+
71+
72+
def get_redis_client():
73+
pool = get_redis_connection_pool()
74+
return redis.Redis(connection_pool=pool, decode_responses=False, encoding='utf-8')
75+
76+
5477
class InterchangeClient:
5578

5679
def __init__(self, episode_uuid: str, context_tracker: "MultiAgentContextTracker", llm_inference_fn, config):
@@ -101,94 +124,107 @@ def begin_service(self):
101124
"""
102125
Starts the SSE service loop.
103126
"""
104-
t = threading.Thread(target=lambda: asyncio.run(self._ensure_service_loop()), daemon=True)
127+
t = threading.Thread(target=self._begin_service_threading, daemon=True)
105128
t.start()
106129

107-
async def _ensure_service_loop(self):
108-
while not self.should_terminate:
109-
try:
110-
await self._service_loop()
111-
except Exception as e:
112-
logger.warning(f"InterchangeClient service loop error: {e}. Restarting...")
113-
await asyncio.sleep(4) # brief pause before reconnecting
114-
115-
async def _service_loop(self):
116-
"""
117-
In fact this is not a service,
118-
it is a client that pretends to be a service, by interacting with a local interchange server via SSE.
119130

120-
This design is for efficiency
131+
def _handle_service_request(self, msg: bytes, sem: threading.Semaphore):
132+
"""handle a single service request in its own thread
121133
"""
122-
123134
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import InterchangeCompletionRequest
124-
125-
port = os.getenv("AJET_DAT_INTERCHANGE_PORT")
126-
assert port is not None, "AJET_DAT_INTERCHANGE_PORT env var must be set"
127-
128-
base_url = f"http://127.0.0.1:{port}"
129-
listen_url = f"{base_url}/hook/context_tracker_client_listen"
130-
response_url = f"{base_url}/hook/context_tracker_client_response"
131-
key = f"episode_uuid:{self.episode_uuid}"
132-
133-
async with httpx.AsyncClient(timeout=None) as client:
134-
try:
135-
async with client.stream("GET", listen_url, params={"episode_uuid": self.episode_uuid}, timeout=None) as response:
136-
async for line in response.aiter_lines():
137-
if self.should_terminate:
138-
break
139-
140-
if not line.strip():
141-
continue
142-
143-
if line.startswith(":"): # keepalive
144-
continue
145-
146-
if line.startswith("data: "):
147-
data = line[6:].strip()
148-
if not data:
149-
continue
150-
151-
try:
152-
try:
153-
parsed_msg = InterchangeCompletionRequest(**json.loads(data))
154-
except Exception as e:
155-
logger.error(f"Failed to parse SSE event data: {e}" + data)
156-
continue
157-
158-
result = await self.llm_infer(
159-
req=parsed_msg.completion_request,
160-
timeline_uuid=parsed_msg.timeline_uuid,
161-
agent_name=parsed_msg.agent_name,
162-
target_tag=parsed_msg.target_tag,
163-
episode_uuid=parsed_msg.episode_uuid,
164-
)
165-
166-
# Send response back
167-
await client.post(
168-
response_url,
169-
params={"key": key},
170-
content=pickle.dumps(result),
171-
headers={"Content-Type": "application/octet-stream"}
172-
)
173-
174-
except Exception as e:
175-
logger.error(f"Error processing SSE event: {e}")
176-
continue
177-
178-
except httpx.RequestError as e:
179-
logger.warning(f"SSE connection error: {e}")
180-
raise # Let ensure_service_loop handle restart
181-
182-
# Send terminate signal if we are exiting cleanly
183-
try:
184-
await client.post(
185-
response_url,
186-
params={"key": key},
187-
content=pickle.dumps("terminate"),
188-
headers={"Content-Type": "application/octet-stream"}
189-
)
190-
except:
191-
pass
192-
193-
135+
logger.info(f"[client] {self.episode_uuid} | inside _handle_service_request")
136+
redis_client = get_redis_client()
137+
logger.info(f"[client] {self.episode_uuid} | get_redis_client")
138+
data_as_json = ""
139+
topic = ""
140+
try:
141+
data_as_json = json.loads(pickle.loads(msg))
142+
timeline_uuid = data_as_json["timeline_uuid"]
143+
topic = f"timeline_uuid:{timeline_uuid}/episode_uuid:{self.episode_uuid}"
144+
logger.info(f"[client] {self.episode_uuid} | json.loads(pickle.loads(msg))")
145+
146+
147+
if "health_check" in data_as_json and data_as_json["health_check"]:
148+
# logger.info(f"Received health check for timeline_uuid: {timeline_uuid}")
149+
result = '{"health_check_ok": "True"}'
150+
# logger.success(f"Health check OK for timeline_uuid: {timeline_uuid}")
151+
else:
152+
parsed_msg = InterchangeCompletionRequest(**data_as_json)
153+
# start llm request
154+
result = asyncio.run(self.llm_infer(
155+
req=parsed_msg.completion_request,
156+
timeline_uuid=parsed_msg.timeline_uuid,
157+
agent_name=parsed_msg.agent_name,
158+
target_tag=parsed_msg.target_tag,
159+
episode_uuid=parsed_msg.episode_uuid,
160+
)).model_dump_json()
161+
# logger.success(f"LLM inference completed for timeline_uuid: {timeline_uuid}")
162+
logger.info(f"[client] {self.episode_uuid} | result = asyncio.run(self.llm_infer")
163+
# send result back
164+
bytes_arr = pickle.dumps(result)
165+
logger.info(f"[client] {self.episode_uuid} | bytes_arr = pickle.dumps(result)")
166+
redis_client.publish(topic, bytes_arr)
167+
logger.info(f"[client] {self.episode_uuid} | redis_client.publish(topic, pickle.dumps(result))")
168+
169+
except Exception as e:
170+
err = f"[ERR]: Error when processing data: {data_as_json} Error: {e}"
171+
result = err
172+
logger.error(err)
173+
if topic:
174+
redis_client.publish(topic, pickle.dumps(result))
175+
176+
finally:
177+
# release semaphore when done
178+
sem.release()
179+
redis_client.close()
180+
181+
182+
183+
184+
def _begin_service_threading(self):
185+
"""begin listening for service requests in a threading model
186+
"""
187+
# logger.success(f"InterchangeClient starting for episode_uuid:{self.episode_uuid}")
188+
debug_logs = []
189+
begin_time = time.time()
190+
logger.info(f"[client] {self.episode_uuid} | Starting InterchangeClient service loop...")
191+
redis_client = get_redis_client()
192+
redis_sub = redis_client.pubsub()
193+
episode_topic = f"episode_uuid:{self.episode_uuid}"
194+
redis_sub.subscribe(episode_topic)
195+
sem = threading.Semaphore(8) # 4 concurrent requests max
196+
logger.info(f"[client] {self.episode_uuid} | Subscribed to topic {episode_topic}, waiting for messages...")
197+
is_init = True
198+
try:
199+
while not self.should_terminate:
200+
# wait for a new message
201+
logger.info(f"[client] {self.episode_uuid} | Waiting for new message on topic {episode_topic}...")
202+
response = redis_sub.get_message(timeout=10) # type: ignore
203+
timepassed = time.time() - begin_time
204+
205+
if response is None:
206+
if is_init and timepassed > 30:
207+
logger.warning(f"[client] Still waiting for first message... (time passed {timepassed}) for episode_uuid:{self.episode_uuid}...")
208+
continue
209+
210+
if response['type'] not in ['message', 'pmessage']:
211+
continue
212+
213+
is_init = False
214+
logger.info(f"[client] {self.episode_uuid} | get message...")
215+
# got a message
216+
msg: bytes = response['data'] # type: ignore
217+
# are we free to spawn a new thread?
218+
sem.acquire()
219+
logger.info(f"[client] {self.episode_uuid} | sem acquire...")
220+
# begin a new thread to handle this request
221+
threading.Thread(target=self._handle_service_request, args=(msg, sem), daemon=True).start()
222+
223+
224+
except KeyboardInterrupt:
225+
return
226+
227+
finally:
228+
redis_sub.close()
229+
redis_client.close()
194230

0 commit comments

Comments
 (0)