Skip to content

Commit 504de1d

Browse files
committed
refactor: update is_episode_claimed function to include unregister_if_not_claimed argument and adjust related logic
1 parent a6db01d commit 504de1d

4 files changed

Lines changed: 33 additions & 19 deletions

File tree

ajet/task_runner/base_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def should_interrupt_hard_fn() -> bool:
5959
return True
6060
if (observation_window["stop"] is not None) and observation_window["stop"][task_thread_index]: # check soft condition
6161
# if soft condition met, check if episode is claimed
62-
has_claimed = is_episode_claimed(self.config, workflow_task.episode_uuid)
62+
has_claimed = is_episode_claimed(self.config, workflow_task.episode_uuid, unregister_if_not_claimed=True)
6363
if not has_claimed:
6464
# if not claimed by now (ENGINE.ROLLING_POST), this episode will never be claimed again, so we can hard stop
6565
return True

ajet/tuner_lib/experimental/as_swarm_client.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ajet.schema.task import WorkflowOutput, Task
1212
from ajet.copilot.job import AgentJetJob
1313
from ajet.utils.thread_executors import BoundedThreadPoolExecutor
14+
from ajet.utils.cache import cache_with_ttl
1415
from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey
1516
from ajet.tuner_lib.experimental.swarm_overwatch_utils import CurrentBatchRolloutPoolInformation
1617
from ajet.tuner_lib.experimental.interchange_utils import (
@@ -31,6 +32,8 @@
3132
# To prevent stale records from accumulating, do not need to be changed
3233
CLEAN_RECORD_TIMEOUT = 10
3334
START_EPISODE_RETRY_DELAY = 15
35+
TROTTLE_EPISODE_RETRY_DELAY = 2
36+
WAIT_MORE_AVAIL_EPISODE_RETRY_DELAY = 2
3437

3538
def raise_for_status_with_detail(resp):
3639
try:
@@ -237,7 +240,7 @@ def _begin_episode_auto_retry(self, discard_episode_timeout=600, episode_type="t
237240
should_throttle = self._should_throttle(throttle_policy, pool_info)
238241
if should_throttle:
239242
self.logger_info(f"Throttle policy is active, delaying episode ...")
240-
retry_delay = START_EPISODE_RETRY_DELAY
243+
retry_delay = TROTTLE_EPISODE_RETRY_DELAY
241244
continue
242245

243246
# connect remote server to claim an episode
@@ -273,18 +276,22 @@ def _begin_episode_auto_retry(self, discard_episode_timeout=600, episode_type="t
273276
episode_uuid=episode_uuid
274277
)
275278
else:
276-
need_wait_scenarios =[
279+
need_snap_scenarios =[
277280
"Engine is syncing weights",
278281
"Engine is in post-rolling phase",
282+
]
283+
need_wait_scenarios =[
279284
"No available episodes to claim.",
280-
"SwarmThrottlePolicy",
281285
]
282-
if any(scenario in data.fail_cause for scenario in need_wait_scenarios):
286+
if any(scenario in data.fail_cause for scenario in need_snap_scenarios):
283287
if time.time() - self.previous_warning_time > 60:
284288
self.logger_info(f"{data.fail_cause}. Retrying ...")
285289
self.previous_warning_time = time.time()
286290
retry_delay = START_EPISODE_RETRY_DELAY
287291
continue
292+
elif any(scenario in data.fail_cause for scenario in need_wait_scenarios):
293+
retry_delay = WAIT_MORE_AVAIL_EPISODE_RETRY_DELAY
294+
continue
288295
else:
289296
logger.warning(f"Failed to claim episode: {data.fail_cause}. Retrying ...")
290297
retry_delay = START_EPISODE_RETRY_DELAY
@@ -464,6 +471,7 @@ def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose=
464471
logger.error(f"Error polling engine status: {e}")
465472
time.sleep(5)
466473

474+
@cache_with_ttl(ttl=0.5)
467475
def get_engine_status(self) -> Tuple[str, dict]:
468476
try:
469477
resp = httpx.get(

ajet/tuner_lib/experimental/as_swarm_server.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,14 @@ async def _revert_episode_to_unclaimed(episode_uuid: str, shared_mem_dict, share
127127
if ep_key(episode_uuid) not in shared_mem_dict:
128128
logger.warning(f"Episode record for {episode_uuid} not found in shared memory. It may have been already processed by another thread. Skipping unclaim.")
129129
return
130-
if shared_mem_dict[ep_key(episode_uuid)].episode_status != "claimed":
131-
if episode_uuid in shared_mem_dict["unclaimed_episodes"]:
132-
pass
133-
else:
134-
shared_mem_dict["unclaimed_episodes"] += [episode_uuid]
135-
return
130+
131+
with shared_mem_dict_lock:
132+
if shared_mem_dict[ep_key(episode_uuid)].episode_status != "claimed":
133+
if episode_uuid in shared_mem_dict["unclaimed_episodes"]:
134+
pass
135+
else:
136+
shared_mem_dict["unclaimed_episodes"] += [episode_uuid]
137+
return
136138

137139
# reset context tracker
138140
# _context_tracker_reset_blocking(episode_uuid, shared_mem_dict) # must async
@@ -533,6 +535,15 @@ async def claim_episode(req: ClaimEpisodeRequest):
533535
shared_mem_dict["unclaimed_episodes"] = shared_mem_dict["unclaimed_episodes"][1:]
534536

535537
# get episode
538+
if ep_key(episode_uuid) not in shared_mem_dict:
539+
return ClaimEpisodeResponse(
540+
success=False,
541+
client_uuid=req.client_uuid,
542+
episode_uuid="",
543+
openai_base_url="",
544+
openai_api_key="",
545+
fail_cause="No available episodes to claim. Try again (maybe 2 minutes) later.",
546+
)
536547
es: EpisodeStatus = shared_mem_dict[ep_key(episode_uuid)]
537548
es.episode_status = "claimed"
538549
es.episode_type = req.episode_type
@@ -684,12 +695,7 @@ async def is_episode_claimed(req: CheckWhetherEpisodeClaimedRequest):
684695
return BoolResponse(success=True)
685696
else:
686697
if req.unregister_if_not_claimed:
687-
# remove from shared memory to avoid stale records
688-
with shared_mem_dict_lock:
689-
if ep_key(req.episode_uuid) in shared_mem_dict:
690-
del shared_mem_dict[ep_key(req.episode_uuid)]
691-
if req.episode_uuid in shared_mem_dict["unclaimed_episodes"]:
692-
shared_mem_dict["unclaimed_episodes"].remove(req.episode_uuid)
698+
_delete_episode_record(req.episode_uuid, shared_mem_dict, shared_mem_dict_lock)
693699
return BoolResponse(success=False)
694700

695701
@app.post("/get_episode_buffer", response_model=EpisodeBufferResponse)

ajet/tuner_lib/experimental/interchange_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,10 @@ def http_change_engine_status(config, new_status: str, new_status_detail: str|No
132132
logger.success(f"Changed engine status to {new_status}")
133133

134134

135-
def is_episode_claimed(config, episode_uuid: str) -> bool:
135+
def is_episode_claimed(config, episode_uuid: str, unregister_if_not_claimed: bool) -> bool:
136136
resp = httpx.post(
137137
f"{get_interchange_server_url(config)}/is_episode_claimed",
138-
json={"episode_uuid": episode_uuid, "unregister_if_not_claimed": True},
138+
json={"episode_uuid": episode_uuid, "unregister_if_not_claimed": unregister_if_not_claimed},
139139
timeout=5
140140
)
141141
resp.raise_for_status()

0 commit comments

Comments
 (0)