Skip to content

Commit 3581363

Browse files
committed
add client side rollout status monitor
1 parent b825b21 commit 3581363

7 files changed

Lines changed: 107 additions & 18 deletions

File tree

ajet/task_rollout/native_parallel_worker.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@
2020
from ajet.schema.trajectory import Sample
2121
from ajet.task_rollout.single_worker import BaseRolloutManager
2222
from ajet.context_tracker.basic_tracker import BaseContextTracker
23-
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status
24-
25-
DEBUG = True
23+
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import (
24+
http_change_engine_status,
25+
http_update_rollout_pool_information,
26+
CurrentBatchRolloutPoolInformation,
27+
DEBUG,
28+
)
2629

2730

2831
def spawn_thread_shared_observation_window(n_threads) -> Dict[str, List[int | bool | str]]:
@@ -302,10 +305,14 @@ def stop_all_threads_hard():
302305

303306
def update_rollout_result_array_preview(observation_window, completed_task_id_map_ct: Dict[str, List[BaseContextTracker]]):
304307
buffer = ""
308+
completed_tasks_details = {}
305309
for task_id, tracker_arr in completed_task_id_map_ct.items():
306310
buffer += f"Task {task_id} (completed {len(tracker_arr)} episodes):\n"
311+
episode_uuids = []
307312
for ct in tracker_arr:
308313
buffer += f"\tEpisode: {ct.episode_uuid}\tTimelines: {len(ct.saved_timelines)}\tLLM_Calls: {ct.llm_call_cnt}\tReward: {ct.reward_structure.performance_reward}\n"
314+
episode_uuids.append(ct.episode_uuid)
315+
completed_tasks_details[task_id] = episode_uuids
309316
buffer += f"\n"
310317
buffer += f"\n"
311318
counts = count_tasks(completed_task_id_map_ct)
@@ -314,6 +321,20 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma
314321
buffer += f"Total completed non-dummy tasks: {counts['total_completed_non_dummy_tasks']} (target {n_batch_task})\n"
315322
buffer += f"Current stop condition: {self.config.ajet.swarm_mode_sample_collection_method}\n"
316323
observation_window["info"][-1] = buffer
324+
325+
# Update rollout pool information via API
326+
pool_info = CurrentBatchRolloutPoolInformation(
327+
completed_episodes=counts['total_completed_episodes'],
328+
completed_episode_target=n_batch_task * rollout_n,
329+
completed_tasks=counts['total_completed_tasks'],
330+
completed_task_target=n_batch_task,
331+
completed_non_dummy_tasks=counts['total_completed_non_dummy_tasks'],
332+
completed_non_dummy_task_target=n_batch_task,
333+
task_expected_num_repeat=rollout_n,
334+
completed_tasks_details=completed_tasks_details,
335+
)
336+
http_update_rollout_pool_information(self.config, pool_info)
337+
317338
return
318339

319340
# loop and wait until stop condition is met, then stop threads and collect results
@@ -360,9 +381,6 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma
360381
update_rollout_result_array_preview(observation_window, completed_task_id_map_ct)
361382
self._write_swarm_rollout_dynamic_log(observation_window)
362383

363-
time.sleep(10)
364-
raise RuntimeError("DEBUG")
365-
# return all trackers
366384
return tracker_array
367385

368386

ajet/task_runner/swarm_runner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212
from ajet.task_runner.base_runner import BaseAgentRunner
1313
from ajet.utils.retry import SwarmReceiveAbortException
1414
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_register_episode, get_zmq_socket, is_episode_claimed
15+
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import DEBUG
1516
from loguru import logger
1617
from ajet import Workflow
1718
from typing import Callable
1819

19-
DEBUG = True
20-
# DEBUG = False
2120

2221
context = zmq.Context()
2322
atexit.register(context.term)

ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,14 @@
1515
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import InterchangeCompletionRequest, API_KEY_PREFIX
1616
from ajet.utils.thread_executors import SharedInferenceTrackerThreadExecutor, SharedInterchangeThreadExecutor
1717
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import get_zmq_socket, is_episode_claimed
18+
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import DEBUG
1819

1920
context = zmq.Context()
2021
atexit.register(context.term)
2122

2223
if TYPE_CHECKING:
2324
from ajet.context_tracker.multiagent_tracking import MultiAgentContextTracker
2425

25-
# DEBUG = False
26-
DEBUG = True
2726

2827
def generate_auth_token(agent_name, target_tag, episode_uuid, episode_address):
2928
"""

ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
from ajet.utils.networking import find_free_port, get_host_ip
3737
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import EpisodeStatus
38+
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import DEBUG
3839

3940
API_KEY_PREFIX = "sk-ajet-"
4041

@@ -54,8 +55,6 @@ class HealthCheckRequest(BaseModel):
5455

5556
# Create FastAPI app
5657
SERVER_SHUTDOWN_EVENT = threading.Event()
57-
# DEBUG = False
58-
DEBUG = True
5958

6059
context = zmq.Context()
6160
atexit.register(context.term)

ajet/tuner_lib/weight_tuner/experimental/as_swarm_client.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
EndEpisodeResponse,
1818
EpisodeStatus,
1919
EpisodeBufferResponse,
20+
CurrentBatchRolloutPoolInformation,
2021
)
2122

2223

@@ -379,3 +380,20 @@ def stop_engine(self):
379380
self._wait_until_status_change_to(desired_status="ENGINE.OFFLINE")
380381
except Exception as e:
381382
logger.error(f"Error stopping engine: {e}")
383+
384+
def get_rollout_stat(self) -> CurrentBatchRolloutPoolInformation:
385+
"""
386+
Get the current batch rollout pool information from the Swarm server.
387+
Returns statistics about completed episodes, tasks, and progress.
388+
"""
389+
try:
390+
resp = httpx.get(
391+
f"{self.server_url}/get_current_batch_rollout_pool_information",
392+
timeout=10
393+
)
394+
resp.raise_for_status()
395+
data = CurrentBatchRolloutPoolInformation.model_validate(resp.json())
396+
return data
397+
except Exception as e:
398+
logger.error(f"Error getting rollout statistics: {e}")
399+
return CurrentBatchRolloutPoolInformation()

ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from fastapi import FastAPI, HTTPException
1111
from multiprocessing.managers import DictProxy
1212
from typing import Coroutine, Optional, Tuple, List
13+
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import DEBUG
1314
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import (
1415
SyncTrainConfigRequest,
1516
ClaimEpisodeRequest,
@@ -23,11 +24,10 @@
2324
BoolResponse,
2425
RegisterEpisodeRequest,
2526
UpdateEngineStatusRequest,
27+
CurrentBatchRolloutPoolInformation,
2628
VALID_STATUSES,
2729
)
2830

29-
DEBUG = True
30-
# DEBUG = False
3131
RCVTIMEO = 2 * 1000
3232
RCVTIMEO_OUT = 300 * 1000
3333
RCVTIMEO_WAIT_N = RCVTIMEO_OUT // RCVTIMEO
@@ -55,6 +55,9 @@ def register_enable_swarm_mode_routes(
5555
if 'unclaimed_episodes' not in shared_mem_dict:
5656
shared_mem_dict['unclaimed_episodes'] = []
5757

58+
if 'current_batch_rollout_pool_information' not in shared_mem_dict:
59+
shared_mem_dict['current_batch_rollout_pool_information'] = CurrentBatchRolloutPoolInformation()
60+
5861
# ------------------------------------------------------------------------------------------------
5962
# ------ Recycle claimed episodes that client failed to complete in (promised) time --------------
6063
# --------------------------------- claimed -> unclaimed ----------------------------------------
@@ -613,6 +616,28 @@ async def get_episode_buffer():
613616
return EpisodeBufferResponse(buffer=result)
614617

615618

619+
@app.post("/update_current_batch_rollout_pool_information", response_model=BoolResponse)
620+
async def update_current_batch_rollout_pool_information(req: CurrentBatchRolloutPoolInformation):
621+
"""Update the current batch rollout pool information."""
622+
try:
623+
with shared_mem_dict_lock:
624+
shared_mem_dict['current_batch_rollout_pool_information'] = req
625+
return BoolResponse(success=True)
626+
except Exception as e:
627+
logger.error(f"Error updating current batch rollout pool information: {e}")
628+
return BoolResponse(success=False, failure_reason=str(e))
629+
630+
631+
@app.get("/get_current_batch_rollout_pool_information", response_model=CurrentBatchRolloutPoolInformation)
632+
async def get_current_batch_rollout_pool_information():
633+
"""Get the current batch rollout pool information."""
634+
try:
635+
return shared_mem_dict.get('current_batch_rollout_pool_information', CurrentBatchRolloutPoolInformation())
636+
except Exception as e:
637+
logger.error(f"Error getting current batch rollout pool information: {e}")
638+
return CurrentBatchRolloutPoolInformation()
639+
640+
616641

617642

618643
# --------------------------------------------------------------------
@@ -751,4 +776,4 @@ def kill_process_tree(shared_mem_dict_lock=None, shared_mem_dict=None):
751776
except:
752777
pass
753778

754-
return {"success": False, "error": str(e)}
779+
return {"success": False, "error": str(e)}

ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import time
33
import httpx
4-
from typing import List
4+
from typing import List, Dict
55
from pydantic import BaseModel
66
from loguru import logger
77
from ajet.schema.task import WorkflowOutput
@@ -81,8 +81,19 @@ class UpdateEngineStatusRequest(BaseModel):
8181
engine_status: str = ""
8282

8383

84-
# DEBUG = False
85-
DEBUG = True
84+
class CurrentBatchRolloutPoolInformation(BaseModel):
85+
completed_episodes: int = 0
86+
completed_episode_target: int = 0
87+
completed_tasks: int = 0
88+
completed_task_target: int = 0
89+
completed_non_dummy_tasks: int = 0
90+
completed_non_dummy_task_target: int = 0
91+
task_expected_num_repeat: int = 0
92+
completed_tasks_details: Dict[str, List[str]] = {} # task_id -> list of episode_uuids
93+
94+
95+
DEBUG = False
96+
# DEBUG = True
8697

8798
def get_interchange_server_url(config):
8899
port = os.getenv("AJET_DAT_INTERCHANGE_PORT")
@@ -155,6 +166,26 @@ def http_register_episode(config,
155166
return True
156167

157168

169+
def http_update_rollout_pool_information(config, pool_info: CurrentBatchRolloutPoolInformation):
170+
"""
171+
Update the rollout pool information on the interchange server.
172+
173+
Args:
174+
config: The configuration object
175+
pool_info: CurrentBatchRolloutPoolInformation object with rollout statistics
176+
"""
177+
try:
178+
resp = httpx.post(
179+
f"{get_interchange_server_url(config)}/update_current_batch_rollout_pool_information",
180+
json=pool_info.model_dump(),
181+
timeout=5
182+
)
183+
resp.raise_for_status()
184+
except Exception as e:
185+
if DEBUG:
186+
logger.warning(f"Failed to update rollout pool information: {e}")
187+
188+
158189
def get_zmq_socket(config, episode_uuid: str, tag: str = ""):
159190
interchange_method = config.ajet.interchange_server.interchange_method
160191
if interchange_method == 'tcp':

0 commit comments

Comments
 (0)