Skip to content

Commit 4658ce0

Browse files
committed
Add reward history tracking and visualization for swarm overwatch
- Add RewardHistoryEntry and RewardHistoryResponse models for reward data - Implement reward collection and history finalization in swarm_server - Add /get_reward_history API endpoint for visualization - Add ASCII reward curve display in swarm_overwatch UI - Fix typo in config_utils parameter name (convertion_json_fg -> convertion_json_fp)
1 parent 8ec210a commit 4658ce0

File tree

4 files changed

+255
-6
lines changed

4 files changed

+255
-6
lines changed

ajet/tuner_lib/experimental/swarm_overwatch_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,19 @@
22
from pydantic import BaseModel
33

44

5+
class RewardHistoryEntry(BaseModel):
6+
"""A single entry in the reward history."""
7+
global_step: int
8+
mean_reward: float
9+
std_reward: float
10+
timestamp: float # Unix timestamp when this entry was recorded
11+
12+
13+
class RewardHistoryResponse(BaseModel):
14+
"""Response containing the reward history for visualization."""
15+
history: List[RewardHistoryEntry] = []
16+
17+
518
class CurrentBatchRolloutPoolInformation(BaseModel):
619
sample_collection_method: str = ""
720
completed_episodes: int = 0

ajet/tuner_lib/experimental/swarm_server.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
from multiprocessing.managers import DictProxy
1212
from typing import Coroutine, Optional, Tuple, List
1313
from ajet.utils.process_killer import kill_process_tree
14-
from ajet.tuner_lib.experimental.swarm_overwatch_utils import CurrentBatchRolloutPoolInformation
14+
from ajet.tuner_lib.experimental.swarm_overwatch_utils import (
15+
CurrentBatchRolloutPoolInformation,
16+
RewardHistoryEntry,
17+
RewardHistoryResponse,
18+
)
1519
from ajet.tuner_lib.experimental.interchange_utils import DEBUG, VERBOSE
1620
from ajet.tuner_lib.experimental.interchange_utils import (
1721
SyncTrainConfigRequest,
@@ -63,6 +67,14 @@ def register_enable_swarm_mode_routes(
6367
if "current_batch_rollout_pool_information" not in shared_mem_dict:
6468
shared_mem_dict["current_batch_rollout_pool_information"] = CurrentBatchRolloutPoolInformation()
6569

70+
# Initialize reward history storage for visualization
71+
if "reward_history" not in shared_mem_dict:
72+
shared_mem_dict["reward_history"] = [] # List of RewardHistoryEntry dicts
73+
74+
# Initialize reward accumulator for collecting rewards of current global step
75+
if "current_rewards" not in shared_mem_dict:
76+
shared_mem_dict["current_rewards"] = [] # [rewards...]
77+
6678
# ------------------------------------------------------------------------------------------------
6779
# ------ Recycle claimed episodes that client failed to complete in (promised) time --------------
6880
# --------------------------------- claimed -> unclaimed ----------------------------------------
@@ -166,6 +178,35 @@ def _delete_episode_record(episode_uuid: str, shared_mem_dict, shared_mem_dict_l
166178
if episode_uuid in shared_mem_dict["unclaimed_episodes"]:
167179
shared_mem_dict["unclaimed_episodes"].remove(episode_uuid)
168180

181+
# --------------------------------------------------------------------------------------
182+
# -------------------------- reward history management ---------------------------------
183+
# --------------------------------------------------------------------------------------
184+
185+
def _finalize_reward_history_for_step(global_step, shared_mem_dict, shared_mem_dict_lock):
186+
"""Finalize reward statistics for a given global step and add to reward_history."""
187+
import numpy as np
188+
189+
rewards = shared_mem_dict.get("current_rewards", [])
190+
if rewards:
191+
rewards = list(rewards) # Convert proxy to list if needed
192+
mean_reward = float(np.mean(rewards))
193+
std_reward = float(np.std(rewards))
194+
195+
history = shared_mem_dict.get("reward_history", [])
196+
history = list(history) # Convert proxy to list if needed
197+
198+
entry = RewardHistoryEntry(
199+
global_step=global_step,
200+
mean_reward=mean_reward,
201+
std_reward=std_reward,
202+
timestamp=time.time(),
203+
)
204+
history.append(entry.model_dump())
205+
shared_mem_dict["reward_history"] = history
206+
207+
# Clear current rewards for next step
208+
shared_mem_dict["current_rewards"] = []
209+
169210
# --------------------------------------------------------------------------------------
170211
# -------------------------- return workflow output ------------------------------------
171212
# --------------------------------------------------------------------------------------
@@ -272,6 +313,10 @@ def _clean_up_engine_status(shared_mem_dict_lock, shared_mem_dict):
272313
shared_mem_dict["unclaimed_episodes"] = []
273314
logger.info(f"[_clean_up_engine_status] Cleared {num_unclaimed} unclaimed episodes")
274315

316+
# clear reward tracking
317+
shared_mem_dict["current_rewards"] = []
318+
shared_mem_dict["reward_history"] = []
319+
275320
# --------------------------------------------------------------------------------------
276321
# -------------------------- fastapi routes --------------------------------------------
277322
# --------------------------------------------------------------------------------------
@@ -446,7 +491,12 @@ async def update_engine_status(req: UpdateEngineStatusRequest):
446491
engine_status_detail = req.engine_status_detail
447492
global_step = req.global_step
448493
if global_step is not None:
494+
previous_global_step = shared_mem_dict.get("global_step", None)
449495
shared_mem_dict["global_step"] = global_step
496+
# When global_step changes, finalize reward statistics for the previous step
497+
if previous_global_step is not None and previous_global_step != global_step:
498+
_finalize_reward_history_for_step(previous_global_step, shared_mem_dict, shared_mem_dict_lock)
499+
450500
if engine_status_detail is not None:
451501
shared_mem_dict["engine_status_detail"] = engine_status_detail
452502
logger.info(f"[update_engine_status] Engine status set to {req.engine_status}")
@@ -636,6 +686,21 @@ async def end_episode(req: EndEpisodeRequest):
636686
shared_mem_dict_lock,
637687
)
638688

689+
# Record reward to current_rewards
690+
if workflow_output.reward is not None:
691+
reward_value = workflow_output.reward
692+
# Handle both single reward and list of rewards
693+
if isinstance(reward_value, list):
694+
rewards_to_record = reward_value
695+
else:
696+
rewards_to_record = [reward_value]
697+
698+
with shared_mem_dict_lock:
699+
current_rewards = shared_mem_dict.get("current_rewards", [])
700+
current_rewards = list(current_rewards) # Convert proxy to list if needed
701+
current_rewards.extend(rewards_to_record)
702+
shared_mem_dict["current_rewards"] = current_rewards
703+
639704
elif episode_type == "eval":
640705
if engine_status in ["ENGINE.ROLLING"]:
641706
await _revert_episode_to_unclaimed(episode_uuid, shared_mem_dict, shared_mem_dict_lock)
@@ -779,6 +844,20 @@ async def get_current_batch_rollout_pool_information():
779844
logger.error(f"Error getting current batch rollout pool information: {e}")
780845
return CurrentBatchRolloutPoolInformation()
781846

847+
# --------------------------------------------------------------------
848+
# ------------ get reward history for visualization ------------------
849+
# --------------------------------------------------------------------
850+
@app.get("/get_reward_history", response_model=RewardHistoryResponse)
851+
async def get_reward_history():
852+
"""Get the reward history for visualization (reward curves)."""
853+
try:
854+
history = shared_mem_dict.get("reward_history", [])
855+
entries = [RewardHistoryEntry(**entry) for entry in history]
856+
return RewardHistoryResponse(history=entries)
857+
except Exception as e:
858+
logger.error(f"Error getting reward history: {e}")
859+
return RewardHistoryResponse(history=[])
860+
782861
# --------------------------------------------------------------------
783862
# ------------ bring engine back to ENGINE.OFFLINE -------------------
784863
# --------------------------------------------------------------------

ajet/utils/config_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _dive_to_set_value(config, dotted_key, value):
9898
sub_config[keys[-1]] = value
9999

100100

101-
def align_parameters(from_config_fp, to_config_fp, convertion_json_fg, backbone):
101+
def align_parameters(from_config_fp, to_config_fp, convertion_json_fp, backbone):
102102
"""Align configuration values based on a conversion map.
103103
104104
Parameters
@@ -107,7 +107,7 @@ def align_parameters(from_config_fp, to_config_fp, convertion_json_fg, backbone)
107107
Source YAML path to read values from.
108108
to_config_fp : str
109109
Destination YAML path that is updated in place.
110-
convertion_json_fg : str
110+
convertion_json_fp : str
111111
JSON path mapping dotted keys between configs.
112112
backbone : str
113113
Backbone identifier used for framework-specific alignment.
@@ -121,7 +121,7 @@ def align_parameters(from_config_fp, to_config_fp, convertion_json_fg, backbone)
121121
# read convertion json
122122
import json
123123

124-
with open(convertion_json_fg, "r", encoding="utf-8") as file:
124+
with open(convertion_json_fp, "r", encoding="utf-8") as file:
125125
convertion_json = json.load(file)
126126

127127
logger.success("----------------------------------------------------")

ajet/utils/swarm_overwatch.py

Lines changed: 159 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
from rich.text import Text
1818
from loguru import logger
1919

20-
from ajet.tuner_lib.experimental.swarm_overwatch_utils import CurrentBatchRolloutPoolInformation
20+
from ajet.tuner_lib.experimental.swarm_overwatch_utils import (
21+
CurrentBatchRolloutPoolInformation,
22+
RewardHistoryResponse,
23+
)
2124

2225

2326
class SwarmOverwatch:
@@ -56,6 +59,20 @@ def fetch_pool_info(self) -> Optional[CurrentBatchRolloutPoolInformation]:
5659
# logger.error(f"Failed to fetch pool info: {e}")
5760
return None
5861

62+
def fetch_reward_history(self) -> Optional[RewardHistoryResponse]:
63+
"""Fetch reward history from server for visualization"""
64+
try:
65+
response = self._httpx_client.get(
66+
f"{self.server_url}/get_reward_history",
67+
timeout=5.0,
68+
)
69+
response.raise_for_status()
70+
data = RewardHistoryResponse.model_validate(response.json())
71+
return data
72+
except Exception as e:
73+
logger.error(f"Failed to fetch reward history: {e}")
74+
return None
75+
5976
def create_header(
6077
self, info: Optional[CurrentBatchRolloutPoolInformation] = None
6178
) -> Panel:
@@ -450,6 +467,141 @@ def create_dashboard(
450467

451468
return layout
452469

470+
def display_reward_curve(self):
471+
"""Display ASCII reward curve in terminal"""
472+
self.console.clear()
473+
474+
# Fetch reward history
475+
history = self.fetch_reward_history()
476+
if history is None or not history.history:
477+
self.console.print("[bold yellow]No reward history available yet.[/bold yellow]")
478+
self.console.print("[dim]Reward history is recorded when training completes batches with rewards.[/dim]")
479+
self.console.print("\n[dim]Press Enter to return to menu...[/dim]")
480+
input()
481+
return
482+
483+
# Get terminal size
484+
terminal_width = self.console.width or 80
485+
terminal_height = self.console.height or 24
486+
487+
# Reserve space for header, labels, and footer
488+
chart_width = min(terminal_width - 15, 120) # Reserve space for y-axis labels
489+
chart_height = min(terminal_height - 10, 30) # Reserve space for header and x-axis
490+
491+
# Extract data
492+
global_steps = [entry.global_step for entry in history.history]
493+
mean_rewards = [entry.mean_reward for entry in history.history]
494+
495+
# Calculate y-axis range with padding
496+
y_min = min(mean_rewards)
497+
y_max = max(mean_rewards)
498+
y_range = y_max - y_min
499+
if y_range == 0:
500+
y_range = 1.0 # Avoid division by zero
501+
y_min -= 0.5
502+
y_max += 0.5
503+
else:
504+
# Add 10% padding
505+
y_min -= y_range * 0.1
506+
y_max += y_range * 0.1
507+
y_range = y_max - y_min
508+
509+
# Calculate x-axis range
510+
x_min = min(global_steps)
511+
x_max = max(global_steps)
512+
x_range = x_max - x_min
513+
if x_range == 0:
514+
x_range = 1
515+
516+
# Create the chart grid
517+
chart = [[' ' for _ in range(chart_width)] for _ in range(chart_height)]
518+
519+
# Plot the data points
520+
for i, (step, reward) in enumerate(zip(global_steps, mean_rewards)):
521+
# Map to chart coordinates
522+
x = int((step - x_min) / x_range * (chart_width - 1)) if x_range > 0 else 0
523+
y = int((reward - y_min) / y_range * (chart_height - 1)) if y_range > 0 else 0
524+
525+
# Invert y because terminal coordinates go top-down
526+
y = chart_height - 1 - y
527+
528+
# Clamp to valid range
529+
x = max(0, min(chart_width - 1, x))
530+
y = max(0, min(chart_height - 1, y))
531+
532+
# Draw point
533+
chart[y][x] = '*'
534+
535+
# Connect points with lines if there are multiple points
536+
if len(global_steps) > 1:
537+
for i in range(len(global_steps) - 1):
538+
step1, reward1 = global_steps[i], mean_rewards[i]
539+
step2, reward2 = global_steps[i + 1], mean_rewards[i + 1]
540+
541+
x1 = int((step1 - x_min) / x_range * (chart_width - 1)) if x_range > 0 else 0
542+
y1 = int((reward1 - y_min) / y_range * (chart_height - 1)) if y_range > 0 else 0
543+
x2 = int((step2 - x_min) / x_range * (chart_width - 1)) if x_range > 0 else 0
544+
y2 = int((reward2 - y_min) / y_range * (chart_height - 1)) if y_range > 0 else 0
545+
546+
y1 = chart_height - 1 - y1
547+
y2 = chart_height - 1 - y2
548+
549+
# Simple line drawing between points
550+
steps_between = max(abs(x2 - x1), abs(y2 - y1))
551+
if steps_between > 0:
552+
for s in range(1, steps_between):
553+
t = s / steps_between
554+
x = int(x1 + t * (x2 - x1))
555+
y = int(y1 + t * (y2 - y1))
556+
x = max(0, min(chart_width - 1, x))
557+
y = max(0, min(chart_height - 1, y))
558+
if chart[y][x] == ' ':
559+
chart[y][x] = '.'
560+
561+
# Build the output
562+
output = Text()
563+
output.append("\n Reward Curve (Mean Reward vs Global Step)\n", style="bold cyan")
564+
output.append(f" Server: {self.server_url}\n", style="dim")
565+
output.append(f" Data points: {len(global_steps)}\n\n", style="dim")
566+
567+
# Draw y-axis labels and chart
568+
y_labels = []
569+
for i in range(chart_height):
570+
y_val = y_max - (i / (chart_height - 1)) * y_range if chart_height > 1 else y_max
571+
y_labels.append(y_val)
572+
573+
for i, row in enumerate(chart):
574+
# Y-axis label (only show a few)
575+
if i == 0 or i == chart_height - 1 or i == chart_height // 2:
576+
label = f"{y_labels[i]:8.3f} |"
577+
else:
578+
label = " |"
579+
output.append(label, style="dim")
580+
output.append(''.join(row), style="green")
581+
output.append("\n")
582+
583+
# X-axis
584+
output.append(" +" + "-" * chart_width + "\n", style="dim")
585+
586+
# X-axis labels
587+
x_label_line = " "
588+
x_label_line += f"{x_min:<{chart_width // 3}}"
589+
mid_step = x_min + x_range // 2
590+
x_label_line += f"{mid_step:^{chart_width // 3}}"
591+
x_label_line += f"{x_max:>{chart_width // 3}}"
592+
output.append(x_label_line[:chart_width + 10] + "\n", style="dim")
593+
output.append(" " + " " * (chart_width // 2 - 5) + "Global Step\n", style="dim cyan")
594+
595+
# Statistics
596+
output.append("\n Statistics:\n", style="bold yellow")
597+
output.append(f" Latest Global Step: {global_steps[-1]}\n", style="green")
598+
output.append(f" Latest Mean Reward: {mean_rewards[-1]:.4f}\n", style="green")
599+
output.append(f" Min Mean Reward: {min(mean_rewards):.4f} (step {global_steps[mean_rewards.index(min(mean_rewards))]})\n", style="cyan")
600+
output.append(f" Max Mean Reward: {max(mean_rewards):.4f} (step {global_steps[mean_rewards.index(max(mean_rewards))]})\n", style="cyan")
601+
602+
self.console.print(output)
603+
self.console.print("\n[dim]Press Enter to return to menu...[/dim]")
604+
input()
453605

454606
def display_latest_llm_call(self):
455607
while True:
@@ -515,6 +667,7 @@ def choose_run(self) -> str:
515667
self.console.print("\n[bold]Choose action:[/bold]")
516668
self.console.print(" [bold cyan]o[/bold cyan] - Return to overwatch")
517669
self.console.print(" [bold cyan]t[/bold cyan] - Show replay_latest_llm_call")
670+
self.console.print(" [bold cyan]c[/bold cyan] - Show reward curve")
518671
self.console.print(" [bold cyan]ctrl+c[/bold cyan] - Exit")
519672
choice = input("\n> ").strip().lower()
520673

@@ -526,8 +679,12 @@ def choose_run(self) -> str:
526679
mode = "replay_latest_llm_call"
527680
self.console.clear()
528681
continue
682+
elif choice == "c":
683+
self.display_reward_curve()
684+
self.console.clear()
685+
continue
529686
else:
530-
self.console.print("[yellow]Invalid choice. Please enter 'o' or 't'.[/yellow]")
687+
self.console.print("[yellow]Invalid choice. Please enter 'o', 't', or 'c'.[/yellow]")
531688

532689
def run(self):
533690
"""Start the monitoring interface"""

0 commit comments

Comments
 (0)