Skip to content

Commit f532854

Browse files
committed
sharing httpx client
1 parent 642eddc commit f532854

8 files changed

Lines changed: 60 additions & 30 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,4 @@ prompts
171171
swarmexp
172172
swarmlog
173173
werewolves_swarm
174+
.claude

ajet/task_rollout/native_parallel_worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def enough_finished_task_stop_condition(completed_task_id_map_ct) -> bool:
233233
f"Too many cached episodes [{total_completed_episodes}] has exceeded the max cached episodes [{self.config.ajet.swarm_mode_sample_collection_max_cached_episodes}] "
234234
f"Deleting cached episodes to release memory..."
235235
)
236-
completed_task_id_map_ct = {}
236+
completed_task_id_map_ct.clear()
237237
return (total_completed_tasks >= n_batch_task)
238238

239239
def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool:
@@ -258,7 +258,7 @@ def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool:
258258
f"Too many cached episodes [{total_completed_episodes}] has exceeded the max cached episodes [{self.config.ajet.swarm_mode_sample_collection_max_cached_episodes}] "
259259
f"Deleting cached episodes to release memory..."
260260
)
261-
completed_task_id_map_ct = {}
261+
completed_task_id_map_ct.clear()
262262
return (total_completed_non_dummy_tasks >= n_batch_task)
263263

264264
# select stop condition function based on config
@@ -387,6 +387,7 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma
387387
logger.info('Collecting results...')
388388
for ct_list in completed_task_id_map_ct.values():
389389
tracker_array.extend(ct_list)
390+
completed_task_id_map_ct.clear()
390391

391392
# TODO: support multi-step reward
392393
task_success_rate = np.mean(
@@ -402,7 +403,6 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma
402403

403404
update_rollout_result_array_preview(observation_window, completed_task_id_map_ct)
404405
self._write_swarm_rollout_dynamic_log(observation_window)
405-
406406
return tracker_array
407407

408408

ajet/tuner_lib/experimental/as_oai_model_server.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ def start_interchange_server(config, blocking=False, env={}) -> int:
428428

429429
# polling for server ready
430430
start_time = time.time()
431+
_httpx_client = httpx.Client(timeout=0.5)
431432
while True:
432433
if interchange_server and interchange_server.exitcode is not None:
433434
logger.error(f"Interchange server subprocess failed to start. Return code: {interchange_server.exitcode}")
@@ -437,7 +438,7 @@ def start_interchange_server(config, blocking=False, env={}) -> int:
437438
logger.error(msg)
438439
raise RuntimeError(msg)
439440
try:
440-
if httpx.get(health_url, timeout=0.5).status_code == 200:
441+
if _httpx_client.get(health_url).status_code == 200:
441442
break
442443
except Exception:
443444
# keep waiting
@@ -462,7 +463,7 @@ def start_interchange_server(config, blocking=False, env={}) -> int:
462463
interchange_server.join()
463464
except KeyboardInterrupt:
464465
logger.info("Shutting down interchange server...")
465-
try: httpx.post(f"http://127.0.0.1:{port}/stop_engine", timeout=8).status_code
466+
try: _httpx_client.post(f"http://127.0.0.1:{port}/stop_engine", timeout=8).status_code
466467
except Exception: pass
467468

468469
if interchange_server:

ajet/tuner_lib/experimental/as_swarm_client.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def __init__(self, server_url: str):
7171
self._agent_jet_job = None
7272
# throttle
7373
self._recent_seen_tasks = []
74+
# reuse httpx client to avoid creating SSL context repeatedly
75+
self._http_client = httpx.Client(timeout=GENERAL_TIMEOUT)
7476

7577
def logger_info(self, message):
7678
# logger with de-duplication within 1 second to prevent log flooding
@@ -252,10 +254,9 @@ def _begin_episode_auto_retry(self, discard_episode_timeout=240, episode_type="t
252254
discard_episode_timeout=discard_episode_timeout,
253255
throttle_policy=throttle_policy
254256
)
255-
resp = httpx.post(
257+
resp = self._http_client.post(
256258
f"{self.server_url}/claim_episode",
257-
json=req_obj.model_dump(),
258-
timeout=GENERAL_TIMEOUT
259+
json=req_obj.model_dump()
259260
)
260261
raise_for_status_with_detail(resp)
261262
data = ClaimEpisodeResponse.model_validate(resp.json())
@@ -337,10 +338,9 @@ def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOut
337338
task_id=task_id
338339
)
339340

340-
resp = httpx.post(
341+
resp = self._http_client.post(
341342
f"{self.server_url}/end_episode",
342-
json=req_obj.model_dump(),
343-
timeout=GENERAL_TIMEOUT
343+
json=req_obj.model_dump()
344344
)
345345
raise_for_status_with_detail(resp)
346346
data = EndEpisodeResponse.model_validate(resp.json())
@@ -366,10 +366,9 @@ def abort_episode(self, episode_uuid: str):
366366
task_id=""
367367
)
368368

369-
resp = httpx.post(
369+
resp = self._http_client.post(
370370
f"{self.server_url}/abort_episode",
371-
json=req_obj.model_dump(),
372-
timeout=GENERAL_TIMEOUT
371+
json=req_obj.model_dump()
373372
)
374373
raise_for_status_with_detail(resp)
375374
data = EndEpisodeResponse.model_validate(resp.json())
@@ -399,10 +398,9 @@ def sync_train_config(self, agent_jet_job: AgentJetJob):
399398

400399
req_obj = SyncTrainConfigRequest(yaml_as_string=yaml_str)
401400

402-
resp = httpx.post(
401+
resp = self._http_client.post(
403402
f"{self.server_url}/sync_train_config",
404-
json=req_obj.model_dump(),
405-
timeout=GENERAL_TIMEOUT
403+
json=req_obj.model_dump()
406404
)
407405
raise_for_status_with_detail(resp)
408406
self.logger_info("Synced train config to Swarm server")
@@ -422,7 +420,7 @@ def start_engine(self):
422420
raise RuntimeError(f"Cannot start engine when engine is NOT ENGINE.OFFLINE. (current status: {current_status})")
423421

424422
# Send start engine request
425-
resp = httpx.post(
423+
resp = self._http_client.post(
426424
f"{self.server_url}/start_engine",
427425
json={},
428426
timeout=600
@@ -487,7 +485,7 @@ def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose=
487485
@cache_with_ttl(ttl=0.5)
488486
def get_engine_status(self) -> Tuple[str, dict]:
489487
try:
490-
resp = httpx.get(
488+
resp = self._http_client.get(
491489
f"{self.server_url}/get_engine_status",
492490
timeout=10
493491
)
@@ -512,7 +510,7 @@ def can_continue_episode(self, episode_uuid: str) -> bool:
512510
client_uuid=self.client_uuid,
513511
episode_uuid=episode_uuid
514512
)
515-
resp = httpx.post(
513+
resp = self._http_client.post(
516514
f"{self.server_url}/can_continue_episode",
517515
json=req_obj.model_dump(),
518516
timeout=10
@@ -526,7 +524,7 @@ def can_continue_episode(self, episode_uuid: str) -> bool:
526524

527525
def get_episode_buffer(self) -> List[EpisodeStatus]:
528526
try:
529-
resp = httpx.post(
527+
resp = self._http_client.post(
530528
f"{self.server_url}/get_episode_buffer",
531529
json={},
532530
timeout=10
@@ -585,7 +583,7 @@ def stop_engine(self):
585583
self.logger_info("Engine is already OFFLINE. No action needed.")
586584
return
587585

588-
resp = httpx.post(
586+
resp = self._http_client.post(
589587
f"{self.server_url}/stop_engine",
590588
json={},
591589
timeout=600
@@ -605,7 +603,7 @@ def get_rollout_stat(self) -> CurrentBatchRolloutPoolInformation:
605603
Returns statistics about completed episodes, tasks, and progress.
606604
"""
607605
try:
608-
resp = httpx.get(
606+
resp = self._http_client.get(
609607
f"{self.server_url}/get_current_batch_rollout_pool_information",
610608
timeout=10
611609
)

ajet/tuner_lib/experimental/as_swarm_server.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,8 @@ async def start_engine():
338338
experiment_dir = config_dict.get("ajet", {}).get("experiment_dir", DEFAULT_DIR)
339339
if experiment_dir == "auto":
340340
exp_base_dir = DEFAULT_DIR
341-
exp_base_dir = os.path.dirname(os.path.abspath(experiment_dir))
341+
else:
342+
exp_base_dir = os.path.dirname(os.path.abspath(experiment_dir))
342343

343344
# Save YAML to temporary file
344345
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".yaml") as temp_file:

ajet/tuner_lib/experimental/interchange_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ class UpdateEngineStatusRequest(BaseModel):
109109

110110
VERBOSE = True
111111

112+
shared_http_client = httpx.Client(timeout=10.0)
113+
112114
def get_interchange_server_url(config):
113115
port = os.getenv("AJET_DAT_INTERCHANGE_PORT")
114116
if isinstance(config, dict):
@@ -127,7 +129,7 @@ def http_change_engine_status(config, new_status: str, new_status_detail: str|No
127129
if new_status not in VALID_STATUSES:
128130
raise ValueError(f"Invalid engine status: {new_status}")
129131

130-
resp = httpx.post(
132+
resp = shared_http_client.post(
131133
f"{get_interchange_server_url(config)}/update_engine_status",
132134
json={"engine_status": new_status, "engine_status_detail": new_status_detail, "global_step": global_step},
133135
timeout=10
@@ -137,7 +139,7 @@ def http_change_engine_status(config, new_status: str, new_status_detail: str|No
137139

138140

139141
def is_episode_claimed(config, episode_uuid: str, unregister_if_not_claimed: bool) -> bool:
140-
resp = httpx.post(
142+
resp = shared_http_client.post(
141143
f"{get_interchange_server_url(config)}/is_episode_claimed",
142144
json={"episode_uuid": episode_uuid, "unregister_if_not_claimed": unregister_if_not_claimed},
143145
timeout=5
@@ -168,7 +170,7 @@ def http_register_episode(config,
168170
zmq_listen_result_addr=zmq_listen_result_addr,
169171
)
170172
# send http request to swarm server to register episode
171-
response = httpx.post(
173+
response = shared_http_client.post(
172174
f"{interchange_http_addr}/register_episode",
173175
json=rer.model_dump(), # 或者 rer.model_dump() 如果使用 Pydantic v2
174176
timeout=2

ajet/utils/swarm_overwatch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,12 @@ def __init__(self, server_url: str, refresh_interval: float = 2.0):
3737
self.last_update_time = None
3838
self.error_count = 0
3939
self.total_requests = 0
40+
self._httpx_client = httpx.Client(timeout=5.0)
4041

4142
def fetch_pool_info(self) -> Optional[CurrentBatchRolloutPoolInformation]:
4243
"""Fetch current batch rollout pool information from server"""
4344
try:
44-
response = httpx.get(
45+
response = self._httpx_client.get(
4546
f"{self.server_url}/get_current_batch_rollout_pool_information",
4647
timeout=5.0,
4748
)

ajet/utils/tokenizer.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def cleanup_messages(messages: List[Dict]) -> List[Dict]:
1919
pass
2020
return messages_copied
2121

22+
# Cache storage
23+
_cache = {}
24+
2225

2326
def ajet_apply_chat_template(
2427
tokenizer,
@@ -28,16 +31,39 @@ def ajet_apply_chat_template(
2831
tokenize: bool = True,
2932
):
3033
conversation = cleanup_messages(conversation)
34+
35+
# Create cache key by hashing all inputs
36+
cache_key = (
37+
id(tokenizer),
38+
hash(json.dumps(conversation, sort_keys=True)),
39+
hash(json.dumps(tools, sort_keys=True)) if tools else 0,
40+
add_generation_prompt,
41+
tokenize,
42+
)
43+
44+
# Check cache
45+
if cache_key in _cache:
46+
return _cache[cache_key]
47+
48+
# Compute result
3149
if tools:
32-
return tokenizer.apply_chat_template(
50+
result = tokenizer.apply_chat_template(
3351
conversation,
3452
tools,
3553
add_generation_prompt=add_generation_prompt,
3654
tokenize=tokenize,
3755
)
3856
else:
39-
return tokenizer.apply_chat_template(
57+
result = tokenizer.apply_chat_template(
4058
conversation,
4159
tokenize=tokenize,
4260
add_generation_prompt=add_generation_prompt,
4361
)
62+
63+
# Store in cache (implement LRU eviction if cache gets too large)
64+
if len(_cache) >= 1024:
65+
# Remove oldest item (first inserted)
66+
_cache.pop(next(iter(_cache)))
67+
68+
_cache[cache_key] = result
69+
return result

0 commit comments

Comments
 (0)