|
3 | 3 | import json |
4 | 4 | import threading |
5 | 5 | import os |
| 6 | +import redis |
6 | 7 | import time |
7 | 8 | from loguru import logger |
8 | 9 | from typing import Optional, List, Dict, Any, Union, TYPE_CHECKING |
9 | 10 | from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ChatCompletionResponse |
10 | 11 | from openai.types.chat.chat_completion import ChatCompletion |
| 12 | +from redis.exceptions import TimeoutError |
| 13 | + |
| 14 | +from functools import cache |
11 | 15 |
|
12 | 16 | import pickle |
13 | 17 | import httpx |
| 18 | +import zmq |
14 | 19 | import logging |
15 | 20 | logging.getLogger("httpx").setLevel(logging.WARNING) |
16 | 21 |
|
@@ -51,6 +56,24 @@ def generate_auth_token(agent_name, target_tag, episode_uuid): |
51 | 56 | return auth_token |
52 | 57 |
|
53 | 58 |
|
| 59 | +@cache |
| 60 | +def get_redis_connection_pool(): |
| 61 | + pool = redis.BlockingConnectionPool( |
| 62 | + host='localhost', |
| 63 | + port=6379, |
| 64 | + max_connections=256, |
| 65 | + socket_timeout=30, |
| 66 | + socket_connect_timeout=30, |
| 67 | + retry_on_timeout=True |
| 68 | + ) |
| 69 | + return pool |
| 70 | + |
| 71 | + |
| 72 | +def get_redis_client(): |
| 73 | + pool = get_redis_connection_pool() |
| 74 | + return redis.Redis(connection_pool=pool, decode_responses=False, encoding='utf-8') |
| 75 | + |
| 76 | + |
54 | 77 | class InterchangeClient: |
55 | 78 |
|
56 | 79 | def __init__(self, episode_uuid: str, context_tracker: "MultiAgentContextTracker", llm_inference_fn, config): |
@@ -101,94 +124,107 @@ def begin_service(self): |
101 | 124 | """ |
102 | 125 | Starts the SSE service loop. |
103 | 126 | """ |
104 | | - t = threading.Thread(target=lambda: asyncio.run(self._ensure_service_loop()), daemon=True) |
| 127 | + t = threading.Thread(target=self._begin_service_threading, daemon=True) |
105 | 128 | t.start() |
106 | 129 |
|
107 | | - async def _ensure_service_loop(self): |
108 | | - while not self.should_terminate: |
109 | | - try: |
110 | | - await self._service_loop() |
111 | | - except Exception as e: |
112 | | - logger.warning(f"InterchangeClient service loop error: {e}. Restarting...") |
113 | | - await asyncio.sleep(4) # brief pause before reconnecting |
114 | | - |
115 | | - async def _service_loop(self): |
116 | | - """ |
117 | | - In fact this is not a service, |
118 | | - it is a client that pretends to be a service, by interacting with a local interchange server via SSE. |
119 | 130 |
|
120 | | - This design is for efficiency |
| 131 | + def _handle_service_request(self, msg: bytes, sem: threading.Semaphore): |
| 132 | + """handle a single service request in its own thread |
121 | 133 | """ |
122 | | - |
123 | 134 | from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import InterchangeCompletionRequest |
124 | | - |
125 | | - port = os.getenv("AJET_DAT_INTERCHANGE_PORT") |
126 | | - assert port is not None, "AJET_DAT_INTERCHANGE_PORT env var must be set" |
127 | | - |
128 | | - base_url = f"http://127.0.0.1:{port}" |
129 | | - listen_url = f"{base_url}/hook/context_tracker_client_listen" |
130 | | - response_url = f"{base_url}/hook/context_tracker_client_response" |
131 | | - key = f"episode_uuid:{self.episode_uuid}" |
132 | | - |
133 | | - async with httpx.AsyncClient(timeout=None) as client: |
134 | | - try: |
135 | | - async with client.stream("GET", listen_url, params={"episode_uuid": self.episode_uuid}, timeout=None) as response: |
136 | | - async for line in response.aiter_lines(): |
137 | | - if self.should_terminate: |
138 | | - break |
139 | | - |
140 | | - if not line.strip(): |
141 | | - continue |
142 | | - |
143 | | - if line.startswith(":"): # keepalive |
144 | | - continue |
145 | | - |
146 | | - if line.startswith("data: "): |
147 | | - data = line[6:].strip() |
148 | | - if not data: |
149 | | - continue |
150 | | - |
151 | | - try: |
152 | | - try: |
153 | | - parsed_msg = InterchangeCompletionRequest(**json.loads(data)) |
154 | | - except Exception as e: |
155 | | - logger.error(f"Failed to parse SSE event data: {e}" + data) |
156 | | - continue |
157 | | - |
158 | | - result = await self.llm_infer( |
159 | | - req=parsed_msg.completion_request, |
160 | | - timeline_uuid=parsed_msg.timeline_uuid, |
161 | | - agent_name=parsed_msg.agent_name, |
162 | | - target_tag=parsed_msg.target_tag, |
163 | | - episode_uuid=parsed_msg.episode_uuid, |
164 | | - ) |
165 | | - |
166 | | - # Send response back |
167 | | - await client.post( |
168 | | - response_url, |
169 | | - params={"key": key}, |
170 | | - content=pickle.dumps(result), |
171 | | - headers={"Content-Type": "application/octet-stream"} |
172 | | - ) |
173 | | - |
174 | | - except Exception as e: |
175 | | - logger.error(f"Error processing SSE event: {e}") |
176 | | - continue |
177 | | - |
178 | | - except httpx.RequestError as e: |
179 | | - logger.warning(f"SSE connection error: {e}") |
180 | | - raise # Let ensure_service_loop handle restart |
181 | | - |
182 | | - # Send terminate signal if we are exiting cleanly |
183 | | - try: |
184 | | - await client.post( |
185 | | - response_url, |
186 | | - params={"key": key}, |
187 | | - content=pickle.dumps("terminate"), |
188 | | - headers={"Content-Type": "application/octet-stream"} |
189 | | - ) |
190 | | - except: |
191 | | - pass |
192 | | - |
193 | | - |
| 135 | + logger.info(f"[client] {self.episode_uuid} | inside _handle_service_request") |
| 136 | + redis_client = get_redis_client() |
| 137 | + logger.info(f"[client] {self.episode_uuid} | get_redis_client") |
| 138 | + data_as_json = "" |
| 139 | + topic = "" |
| 140 | + try: |
| 141 | + data_as_json = json.loads(pickle.loads(msg)) |
| 142 | + timeline_uuid = data_as_json["timeline_uuid"] |
| 143 | + topic = f"timeline_uuid:{timeline_uuid}/episode_uuid:{self.episode_uuid}" |
| 144 | + logger.info(f"[client] {self.episode_uuid} | json.loads(pickle.loads(msg))") |
| 145 | + |
| 146 | + |
| 147 | + if "health_check" in data_as_json and data_as_json["health_check"]: |
| 148 | + # logger.info(f"Received health check for timeline_uuid: {timeline_uuid}") |
| 149 | + result = '{"health_check_ok": "True"}' |
| 150 | + # logger.success(f"Health check OK for timeline_uuid: {timeline_uuid}") |
| 151 | + else: |
| 152 | + parsed_msg = InterchangeCompletionRequest(**data_as_json) |
| 153 | + # start llm request |
| 154 | + result = asyncio.run(self.llm_infer( |
| 155 | + req=parsed_msg.completion_request, |
| 156 | + timeline_uuid=parsed_msg.timeline_uuid, |
| 157 | + agent_name=parsed_msg.agent_name, |
| 158 | + target_tag=parsed_msg.target_tag, |
| 159 | + episode_uuid=parsed_msg.episode_uuid, |
| 160 | + )).model_dump_json() |
| 161 | + # logger.success(f"LLM inference completed for timeline_uuid: {timeline_uuid}") |
| 162 | + logger.info(f"[client] {self.episode_uuid} | result = asyncio.run(self.llm_infer") |
| 163 | + # send result back |
| 164 | + bytes_arr = pickle.dumps(result) |
| 165 | + logger.info(f"[client] {self.episode_uuid} | bytes_arr = pickle.dumps(result)") |
| 166 | + redis_client.publish(topic, bytes_arr) |
| 167 | + logger.info(f"[client] {self.episode_uuid} | redis_client.publish(topic, pickle.dumps(result))") |
| 168 | + |
| 169 | + except Exception as e: |
| 170 | + err = f"[ERR]: Error when processing data: {data_as_json} Error: {e}" |
| 171 | + result = err |
| 172 | + logger.error(err) |
| 173 | + if topic: |
| 174 | + redis_client.publish(topic, pickle.dumps(result)) |
| 175 | + |
| 176 | + finally: |
| 177 | + # release semaphore when done |
| 178 | + sem.release() |
| 179 | + redis_client.close() |
| 180 | + |
| 181 | + |
| 182 | + |
| 183 | + |
| 184 | + def _begin_service_threading(self): |
| 185 | + """begin listening for service requests in a threading model |
| 186 | + """ |
| 187 | + # logger.success(f"InterchangeClient starting for episode_uuid:{self.episode_uuid}") |
| 188 | + debug_logs = [] |
| 189 | + begin_time = time.time() |
| 190 | + logger.info(f"[client] {self.episode_uuid} | Starting InterchangeClient service loop...") |
| 191 | + redis_client = get_redis_client() |
| 192 | + redis_sub = redis_client.pubsub() |
| 193 | + episode_topic = f"episode_uuid:{self.episode_uuid}" |
| 194 | + redis_sub.subscribe(episode_topic) |
| 195 | + sem = threading.Semaphore(8) # 4 concurrent requests max |
| 196 | + logger.info(f"[client] {self.episode_uuid} | Subscribed to topic {episode_topic}, waiting for messages...") |
| 197 | + is_init = True |
| 198 | + try: |
| 199 | + while not self.should_terminate: |
| 200 | + # wait for a new message |
| 201 | + logger.info(f"[client] {self.episode_uuid} | Waiting for new message on topic {episode_topic}...") |
| 202 | + response = redis_sub.get_message(timeout=10) # type: ignore |
| 203 | + timepassed = time.time() - begin_time |
| 204 | + |
| 205 | + if response is None: |
| 206 | + if is_init and timepassed > 30: |
| 207 | + logger.warning(f"[client] Still waiting for first message... (time passed {timepassed}) for episode_uuid:{self.episode_uuid}...") |
| 208 | + continue |
| 209 | + |
| 210 | + if response['type'] not in ['message', 'pmessage']: |
| 211 | + continue |
| 212 | + |
| 213 | + is_init = False |
| 214 | + logger.info(f"[client] {self.episode_uuid} | get message...") |
| 215 | + # got a message |
| 216 | + msg: bytes = response['data'] # type: ignore |
| 217 | + # are we free to spawn a new thread? |
| 218 | + sem.acquire() |
| 219 | + logger.info(f"[client] {self.episode_uuid} | sem acquire...") |
| 220 | + # begin a new thread to handle this request |
| 221 | + threading.Thread(target=self._handle_service_request, args=(msg, sem), daemon=True).start() |
| 222 | + |
| 223 | + |
| 224 | + except KeyboardInterrupt: |
| 225 | + return |
| 226 | + |
| 227 | + finally: |
| 228 | + redis_sub.close() |
| 229 | + redis_client.close() |
194 | 230 |
|
0 commit comments