Skip to content

Commit b91da74

Browse files
committed
refactor: enhance HTTP client management with automatic refresh on connection errors
1 parent f6a0f86 commit b91da74

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

ajet/tuner_lib/experimental/as_swarm_client.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,25 @@ def logger_info(self, message):
9494

9595
return
9696

97+
def _refresh_http_client(self):
98+
"""Refresh the HTTP client by closing the old one and creating a new one."""
99+
try:
100+
self._http_client.close()
101+
except Exception:
102+
pass # Ignore errors when closing
103+
self._http_client = httpx.Client(timeout=GENERAL_TIMEOUT)
104+
logger.info("HTTP client refreshed due to connection error")
105+
106+
def _should_refresh_client_on_error(self, error: Exception) -> bool:
107+
"""Check if an error suggests the HTTP client should be refreshed."""
108+
error_msg = str(error).lower()
109+
return any(keyword in error_msg for keyword in [
110+
"disconnected",
111+
"connection reset",
112+
"connection closed",
113+
"broken pipe",
114+
"connection aborted"
115+
])
97116

98117
def _clean_up_expired_records(self):
99118
# remove records that have expired and expired at least CLEAN_RECORD_TIMEOUT seconds ago
@@ -498,6 +517,8 @@ def get_engine_status(self) -> Tuple[str, dict]:
498517
logger.warning("get_engine_status: " + str(resp_json))
499518
return result, resp_json
500519
except Exception as e:
520+
if self._should_refresh_client_on_error(e):
521+
self._refresh_http_client()
501522
logger.error(f"Error getting engine status: {e}")
502523
return "ENGINE.CANNOT_CONNECT", {}
503524

tutorial/example_werewolves_swarm/agent_roll.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def rollout(task):
4747
return
4848

4949

50-
executor = PeriodicDrainThreadPoolExecutor(workers=1, max_parallel=64, auto_retry=True, block_first_run=True)
50+
executor = PeriodicDrainThreadPoolExecutor(workers=GRPO_N * REMOTE_BATCH_SIZE, max_parallel=64, auto_retry=True)
5151
for _ in range(NUM_EPOCH):
5252
for _, task in enumerate(dataset.generate_training_tasks()):
5353
for _ in range(GRPO_N):

0 commit comments

Comments
 (0)