Skip to content

Commit 8eae43c

Browse files
committed
optimize further
1 parent 2c05d11 commit 8eae43c

File tree

6 files changed

+28
-18
lines changed

6 files changed

+28
-18
lines changed

ajet/context_tracker/multiagent_tracking.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def extract_text_content_from_content_dict(self, msg):
107107
should_skip_message = False
108108
return str_content, should_skip_message
109109

110+
110111
def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_toolcalls: bool = False) -> List[ExtendedMessage]:
111112
"""Spawn a timeline from messages.
112113

ajet/task_rollout/async_llm_bridge.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -535,14 +535,11 @@ async def run_infer(
535535

536536
# run llm inference ✨
537537
if self.config.ajet.llm_infer_submit_method == "sync":
538-
llm_output = await asyncio.wait_for(
539-
asyncio.to_thread(
540-
self.llm_inference_fn, converted_message, custom_sampling_params, tools
541-
),
542-
timeout=1800,
538+
llm_output = await asyncio.to_thread(
539+
self.llm_inference_fn, converted_message, custom_sampling_params, tools
543540
)
544541
else:
545-
llm_output = await asyncio.wait_for(self.llm_inference_fn(converted_message, custom_sampling_params, tools), timeout=1800)
542+
llm_output = await self.llm_inference_fn(converted_message, custom_sampling_params, tools)
546543

547544

548545
# begin context tracking

ajet/task_runner/base_runner.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,16 @@ def generated_token_callback_fn(token_array):
6868
async def wrapper_type_asyncio(self, workflow_cls: Type[Workflow], workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput:
6969
user_workflow: Workflow = workflow_cls(name="ajet-workflow")
7070
result = await user_workflow.execute(workflow_task, tuner)
71+
72+
# malloc garbage collection
7173
del user_workflow
72-
with gc_lock:
73-
gc.collect() # force garbage collection
74+
75+
# run gc in a thread-safe way
76+
if gc_lock.acquire(blocking=False):
77+
try:
78+
gc.collect()
79+
finally:
80+
gc_lock.release()
7481
return result
7582

7683

ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import InterchangeCompletionRequest
1414
from redis.exceptions import TimeoutError
1515
from ajet.utils.free_port import find_free_port
16-
from ajet.utils.sington import ThreadExecutorLlmInferSingleton, ThreadExecutorSingleton
16+
from ajet.utils.sington import ThreadExecutorContextTrackerSingleton, ThreadExecutorSingleton
1717
from functools import cache
1818

1919
import pickle
@@ -141,14 +141,19 @@ def begin_service(self):
141141
if DEBUG: logger.info(f"[client] {self.episode_uuid} | Starting InterchangeClient service loop...")
142142
self.socket = context.socket(zmq.REP)
143143
self.socket.bind(f"{self.episode_contect_address}")
144-
self.socket.setsockopt(zmq.RCVTIMEO, 2*1000) # 60 秒超时
144+
self.socket.setsockopt(zmq.RCVTIMEO, 3*1000) # 60 秒超时
145145

146146
self.executor = ThreadExecutorSingleton().get_executor()
147147
if DEBUG: logger.info(f"[client] {self.episode_uuid} | Submitting _begin_service_threading to executor...")
148148
future = self.executor.submit(self._begin_service_threading)
149-
time.sleep(1)
149+
150+
# wait till service begin running
151+
time.sleep(0.5)
152+
w_time = 1
150153
while future._state == 'PENDING':
151-
time.sleep(1)
154+
time.sleep(min(w_time * 2, 10))
155+
w_time += 1
156+
152157
if DEBUG: logger.info(f"[client] {self.episode_uuid} | Future ready...")
153158

154159
# t = threading.Thread(target=self._begin_service_threading, daemon=True)
@@ -189,9 +194,9 @@ def _begin_service_threading(self):
189194
loop = asyncio.get_running_loop()
190195
except:
191196
loop = asyncio.new_event_loop()
192-
executor = ThreadExecutorLlmInferSingleton().get_executor()
197+
context_tracker_executor = ThreadExecutorContextTrackerSingleton().get_executor()
193198
future = loop.run_in_executor(
194-
executor, # executor
199+
context_tracker_executor,
195200
asyncio.run,
196201
self.llm_infer(
197202
req=parsed_msg.completion_request,

ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ async def serve_with_monitor():
186186
host="0.0.0.0",
187187
port=self.port,
188188
log_level="error",
189-
# workers=4
189+
workers=2
190190
)
191191
server = uvicorn.Server(config)
192192
await server.serve()

ajet/utils/sington.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@
44
@singleton
55
class ThreadExecutorSingleton:
66
def __init__(self):
7-
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=16)
7+
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=64)
88

99
def get_executor(self) -> concurrent.futures.ThreadPoolExecutor:
1010
return self.executor
1111

1212

1313
@singleton
14-
class ThreadExecutorLlmInferSingleton:
14+
class ThreadExecutorContextTrackerSingleton:
1515
def __init__(self):
16-
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=16)
16+
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=64)
1717

1818
def get_executor(self) -> concurrent.futures.ThreadPoolExecutor:
1919
return self.executor

0 commit comments

Comments
 (0)