Skip to content

Commit 2beb044

Browse files
committed
display reward in swarm overwatch
1 parent a0a020e commit 2beb044

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

ajet/task_rollout/native_parallel_worker.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,13 +318,17 @@ def stop_condition_callback(completed_task_id_map_ct):
318318
def update_rollout_result_array_preview(observation_window, completed_task_id_map_ct: Dict[str, List[SingleAgentContextTracker]]):
319319
buffer = ""
320320
completed_tasks_details = {}
321+
completed_tasks_rewards = {}
321322
for task_id, tracker_arr in completed_task_id_map_ct.items():
322323
buffer += f"Task {task_id} (completed {len(tracker_arr)} episodes):\n"
323324
episode_uuids = []
325+
rewards = []
324326
for ct in tracker_arr:
325327
buffer += f"\tEpisode: {ct.episode_uuid}\tTimelines: {len(ct.saved_timelines)}\tLLM_Calls: {ct.llm_call_cnt}\tReward: {ct.reward_structure.performance_reward}\n"
326328
episode_uuids.append(ct.episode_uuid)
329+
rewards.append(float(ct.reward_structure.performance_reward))
327330
completed_tasks_details[task_id] = episode_uuids
331+
completed_tasks_rewards[task_id] = rewards
328332
buffer += f"\n"
329333
buffer += f"\n"
330334
counts = count_tasks(completed_task_id_map_ct)
@@ -345,6 +349,7 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma
345349
completed_non_dummy_task_target=n_batch_task,
346350
task_expected_num_repeat=rollout_n,
347351
completed_tasks_details=completed_tasks_details,
352+
completed_tasks_rewards=completed_tasks_rewards,
348353
)
349354
http_update_rollout_pool_information(self.config, pool_info)
350355
return

ajet/tuner_lib/experimental/swarm_overwatch_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class CurrentBatchRolloutPoolInformation(BaseModel):
1212
completed_non_dummy_task_target: int = 0
1313
task_expected_num_repeat: int = 0
1414
completed_tasks_details: Dict[str, List[str]] = {} # task_id -> list of episode_uuids
15+
completed_tasks_rewards: Dict[str, List[float]] = {} # task_id -> list of rewards (one per episode)
1516
running_episode_details: Dict[str, Dict[str, str]] | None = None # episode_uuid -> { "episode_status": ..., "time_since_last_activity": ..., "discard_episode_timeout": ..., "llm_call_count": ...}
1617
engine_status: str | None = None
1718
global_step: int | None = None

ajet/utils/swarm_overwatch.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Optional
99

1010
import httpx
11+
import numpy as np
1112
from rich.console import Console
1213
from rich.live import Live
1314
from rich.table import Table
@@ -296,10 +297,11 @@ def create_task_details_table(
296297

297298
table.add_column("Task ID", style="cyan", no_wrap=True, overflow="ellipsis")
298299
table.add_column("Episodes", justify="right", style="green")
300+
table.add_column("Reward", justify="right", style="yellow")
299301
table.add_column("Episode UUIDs (first 3)", style="dim", overflow="fold")
300302

301303
if not info.completed_tasks_details:
302-
table.add_row("[dim]No task details available[/dim]", "", "")
304+
table.add_row("[dim]No task details available[/dim]", "", "", "")
303305
return table
304306

305307
# Sort tasks by number of completed episodes (descending)
@@ -315,15 +317,25 @@ def create_task_details_table(
315317
if len(episode_uuids) > 3:
316318
uuid_str += f" (+{len(episode_uuids) - 3} more)"
317319

320+
# Calculate reward statistics
321+
reward_str = "-"
322+
if info.completed_tasks_rewards and task_id in info.completed_tasks_rewards:
323+
rewards = info.completed_tasks_rewards[task_id]
324+
if rewards:
325+
mean_reward = np.mean(rewards)
326+
std_reward = np.std(rewards)
327+
reward_str = f"{mean_reward:.3f} ± {std_reward:.3f}"
328+
318329
table.add_row(
319330
task_id[:40] if len(task_id) > 40 else task_id,
320331
f"{len(episode_uuids):,}",
332+
reward_str,
321333
uuid_str,
322334
)
323335

324336
if len(sorted_tasks) > 30:
325337
table.add_row(
326-
f"[dim]... and {len(sorted_tasks) - 30} more tasks[/dim]", "", ""
338+
f"[dim]... and {len(sorted_tasks) - 30} more tasks[/dim]", "", "", ""
327339
)
328340

329341
return table

0 commit comments

Comments
 (0)