1717from rich .text import Text
1818from loguru import logger
1919
20- from ajet .tuner_lib .experimental .swarm_overwatch_utils import CurrentBatchRolloutPoolInformation
20+ from ajet .tuner_lib .experimental .swarm_overwatch_utils import (
21+ CurrentBatchRolloutPoolInformation ,
22+ RewardHistoryResponse ,
23+ )
2124
2225
2326class SwarmOverwatch :
@@ -56,6 +59,20 @@ def fetch_pool_info(self) -> Optional[CurrentBatchRolloutPoolInformation]:
5659 # logger.error(f"Failed to fetch pool info: {e}")
5760 return None
5861
62+ def fetch_reward_history (self ) -> Optional [RewardHistoryResponse ]:
63+ """Fetch reward history from server for visualization"""
64+ try :
65+ response = self ._httpx_client .get (
66+ f"{ self .server_url } /get_reward_history" ,
67+ timeout = 5.0 ,
68+ )
69+ response .raise_for_status ()
70+ data = RewardHistoryResponse .model_validate (response .json ())
71+ return data
72+ except Exception as e :
73+ logger .error (f"Failed to fetch reward history: { e } " )
74+ return None
75+
5976 def create_header (
6077 self , info : Optional [CurrentBatchRolloutPoolInformation ] = None
6178 ) -> Panel :
@@ -450,6 +467,141 @@ def create_dashboard(
450467
451468 return layout
452469
470+ def display_reward_curve (self ):
471+ """Display ASCII reward curve in terminal"""
472+ self .console .clear ()
473+
474+ # Fetch reward history
475+ history = self .fetch_reward_history ()
476+ if history is None or not history .history :
477+ self .console .print ("[bold yellow]No reward history available yet.[/bold yellow]" )
478+ self .console .print ("[dim]Reward history is recorded when training completes batches with rewards.[/dim]" )
479+ self .console .print ("\n [dim]Press Enter to return to menu...[/dim]" )
480+ input ()
481+ return
482+
483+ # Get terminal size
484+ terminal_width = self .console .width or 80
485+ terminal_height = self .console .height or 24
486+
487+ # Reserve space for header, labels, and footer
488+ chart_width = min (terminal_width - 15 , 120 ) # Reserve space for y-axis labels
489+ chart_height = min (terminal_height - 10 , 30 ) # Reserve space for header and x-axis
490+
491+ # Extract data
492+ global_steps = [entry .global_step for entry in history .history ]
493+ mean_rewards = [entry .mean_reward for entry in history .history ]
494+
495+ # Calculate y-axis range with padding
496+ y_min = min (mean_rewards )
497+ y_max = max (mean_rewards )
498+ y_range = y_max - y_min
499+ if y_range == 0 :
500+ y_range = 1.0 # Avoid division by zero
501+ y_min -= 0.5
502+ y_max += 0.5
503+ else :
504+ # Add 10% padding
505+ y_min -= y_range * 0.1
506+ y_max += y_range * 0.1
507+ y_range = y_max - y_min
508+
509+ # Calculate x-axis range
510+ x_min = min (global_steps )
511+ x_max = max (global_steps )
512+ x_range = x_max - x_min
513+ if x_range == 0 :
514+ x_range = 1
515+
516+ # Create the chart grid
517+ chart = [[' ' for _ in range (chart_width )] for _ in range (chart_height )]
518+
519+ # Plot the data points
520+ for i , (step , reward ) in enumerate (zip (global_steps , mean_rewards )):
521+ # Map to chart coordinates
522+ x = int ((step - x_min ) / x_range * (chart_width - 1 )) if x_range > 0 else 0
523+ y = int ((reward - y_min ) / y_range * (chart_height - 1 )) if y_range > 0 else 0
524+
525+ # Invert y because terminal coordinates go top-down
526+ y = chart_height - 1 - y
527+
528+ # Clamp to valid range
529+ x = max (0 , min (chart_width - 1 , x ))
530+ y = max (0 , min (chart_height - 1 , y ))
531+
532+ # Draw point
533+ chart [y ][x ] = '*'
534+
535+ # Connect points with lines if there are multiple points
536+ if len (global_steps ) > 1 :
537+ for i in range (len (global_steps ) - 1 ):
538+ step1 , reward1 = global_steps [i ], mean_rewards [i ]
539+ step2 , reward2 = global_steps [i + 1 ], mean_rewards [i + 1 ]
540+
541+ x1 = int ((step1 - x_min ) / x_range * (chart_width - 1 )) if x_range > 0 else 0
542+ y1 = int ((reward1 - y_min ) / y_range * (chart_height - 1 )) if y_range > 0 else 0
543+ x2 = int ((step2 - x_min ) / x_range * (chart_width - 1 )) if x_range > 0 else 0
544+ y2 = int ((reward2 - y_min ) / y_range * (chart_height - 1 )) if y_range > 0 else 0
545+
546+ y1 = chart_height - 1 - y1
547+ y2 = chart_height - 1 - y2
548+
549+ # Simple line drawing between points
550+ steps_between = max (abs (x2 - x1 ), abs (y2 - y1 ))
551+ if steps_between > 0 :
552+ for s in range (1 , steps_between ):
553+ t = s / steps_between
554+ x = int (x1 + t * (x2 - x1 ))
555+ y = int (y1 + t * (y2 - y1 ))
556+ x = max (0 , min (chart_width - 1 , x ))
557+ y = max (0 , min (chart_height - 1 , y ))
558+ if chart [y ][x ] == ' ' :
559+ chart [y ][x ] = '.'
560+
561+ # Build the output
562+ output = Text ()
563+ output .append ("\n Reward Curve (Mean Reward vs Global Step)\n " , style = "bold cyan" )
564+ output .append (f" Server: { self .server_url } \n " , style = "dim" )
565+ output .append (f" Data points: { len (global_steps )} \n \n " , style = "dim" )
566+
567+ # Draw y-axis labels and chart
568+ y_labels = []
569+ for i in range (chart_height ):
570+ y_val = y_max - (i / (chart_height - 1 )) * y_range if chart_height > 1 else y_max
571+ y_labels .append (y_val )
572+
573+ for i , row in enumerate (chart ):
574+ # Y-axis label (only show a few)
575+ if i == 0 or i == chart_height - 1 or i == chart_height // 2 :
576+ label = f"{ y_labels [i ]:8.3f} |"
577+ else :
578+ label = " |"
579+ output .append (label , style = "dim" )
580+ output .append ('' .join (row ), style = "green" )
581+ output .append ("\n " )
582+
583+ # X-axis
584+ output .append (" +" + "-" * chart_width + "\n " , style = "dim" )
585+
586+ # X-axis labels
587+ x_label_line = " "
588+ x_label_line += f"{ x_min :<{chart_width // 3 }} "
589+ mid_step = x_min + x_range // 2
590+ x_label_line += f"{ mid_step :^{chart_width // 3 }} "
591+ x_label_line += f"{ x_max :>{chart_width // 3 }} "
592+ output .append (x_label_line [:chart_width + 10 ] + "\n " , style = "dim" )
593+ output .append (" " + " " * (chart_width // 2 - 5 ) + "Global Step\n " , style = "dim cyan" )
594+
595+ # Statistics
596+ output .append ("\n Statistics:\n " , style = "bold yellow" )
597+ output .append (f" Latest Global Step: { global_steps [- 1 ]} \n " , style = "green" )
598+ output .append (f" Latest Mean Reward: { mean_rewards [- 1 ]:.4f} \n " , style = "green" )
599+ output .append (f" Min Mean Reward: { min (mean_rewards ):.4f} (step { global_steps [mean_rewards .index (min (mean_rewards ))]} )\n " , style = "cyan" )
600+ output .append (f" Max Mean Reward: { max (mean_rewards ):.4f} (step { global_steps [mean_rewards .index (max (mean_rewards ))]} )\n " , style = "cyan" )
601+
602+ self .console .print (output )
603+ self .console .print ("\n [dim]Press Enter to return to menu...[/dim]" )
604+ input ()
453605
454606 def display_latest_llm_call (self ):
455607 while True :
@@ -515,6 +667,7 @@ def choose_run(self) -> str:
515667 self .console .print ("\n [bold]Choose action:[/bold]" )
516668 self .console .print (" [bold cyan]o[/bold cyan] - Return to overwatch" )
517669 self .console .print (" [bold cyan]t[/bold cyan] - Show replay_latest_llm_call" )
670+ self .console .print (" [bold cyan]c[/bold cyan] - Show reward curve" )
518671 self .console .print (" [bold cyan]ctrl+c[/bold cyan] - Exit" )
519672 choice = input ("\n > " ).strip ().lower ()
520673
@@ -526,8 +679,12 @@ def choose_run(self) -> str:
526679 mode = "replay_latest_llm_call"
527680 self .console .clear ()
528681 continue
682+ elif choice == "c" :
683+ self .display_reward_curve ()
684+ self .console .clear ()
685+ continue
529686 else :
530- self .console .print ("[yellow]Invalid choice. Please enter 'o' or 't '.[/yellow]" )
687+ self .console .print ("[yellow]Invalid choice. Please enter 'o', 't', or 'c '.[/yellow]" )
531688
532689 def run (self ):
533690 """Start the monitoring interface"""
0 commit comments