33from copy import deepcopy
44import hashlib
55import math
6+ from typing import Callable
67
78from app .core .config import settings
89from app .core .schema import (
@@ -62,6 +63,13 @@ def __init__(
6263 "linear_preference" : LinearPreferenceUpdater (),
6364 }
6465
66+ @staticmethod
67+ def _report_progress (progress_callback : Callable [[int , str ], None ] | None , progress : int , message : str ) -> None :
68+ """Emit a phase-level progress update when a callback is available."""
69+
70+ if progress_callback is not None :
71+ progress_callback (progress , message )
72+
6573 def create_experiment (self , request : ExperimentCreate ) -> Experiment :
6674 """Create and persist a reusable experiment definition."""
6775
@@ -102,6 +110,7 @@ def create_session(self, request: SessionCreate) -> Session:
102110 negative_prompt = request .negative_prompt ,
103111 model_name = config .model_name ,
104112 config = config ,
113+ current_z = [0.0 for _ in range (config .steering_dimension )],
105114 status = SessionStatus .ready ,
106115 )
107116 logger .info ("Created session %s for experiment %s" , session .id , session .experiment_id )
@@ -118,17 +127,28 @@ def get_session(self, session_id: str) -> Session | None:
118127
119128 return self .repository .get_session (session_id )
120129
130+ def list_sessions (self ) -> list [Session ]:
131+ """Return all stored sessions ordered by recent activity."""
132+
133+ return self .repository .list_sessions ()
134+
121135 def get_session_rounds (self , session_id : str ) -> list [Round ]:
122136 """Return ordered rounds for a given session."""
123137
124138 return self .repository .list_rounds_for_session (session_id )
125139
126- def generate_round (self , session_id : str ) -> RoundResponse :
140+ def generate_round (
141+ self ,
142+ session_id : str ,
143+ progress_callback : Callable [[int , str ], None ] | None = None ,
144+ ) -> RoundResponse :
127145 """Propose, render, persist, and return the next round of candidates."""
128146
147+ self ._report_progress (progress_callback , 14 , "Checking session readiness" )
129148 session = self ._assert_round_generation_allowed (session_id )
130149 sampler = self .samplers [session .config .sampler ]
131150 round_index = session .current_round + 1
151+ self ._report_progress (progress_callback , 24 , "Preparing round state" )
132152 round_obj = Round (
133153 session_id = session .id ,
134154 round_index = round_index ,
@@ -144,6 +164,7 @@ def generate_round(self, session_id: str) -> RoundResponse:
144164 carried_forward = self ._build_carried_forward_candidate (session )
145165 baseline_candidate = self ._build_baseline_prompt_candidate (session )
146166 sampler_seed = self ._seed_token (session .id , round_index , "sampler" )
167+ self ._report_progress (progress_callback , 36 , f"Sampling { session .config .candidate_count } candidate directions" )
147168 proposed_candidates = sampler .propose (session , sampler_seed )
148169 proposed_candidates = self ._widen_first_round_candidates (session , proposed_candidates )
149170 candidates = self ._compose_round_candidates (
@@ -152,6 +173,7 @@ def generate_round(self, session_id: str) -> RoundResponse:
152173 candidate_count = session .config .candidate_count ,
153174 )
154175 self ._assign_candidate_seeds (session , round_index , candidates )
176+ self ._report_progress (progress_callback , 52 , "Rendering candidate images on the model backend" )
155177 # Render each candidate independently so future versions can tolerate
156178 # partial round failures without changing the orchestration contract.
157179 for candidate in candidates :
@@ -164,6 +186,7 @@ def generate_round(self, session_id: str) -> RoundResponse:
164186 round_obj .candidates = candidates
165187 round_obj .render_status = RenderStatus .succeeded
166188 round_obj .latency_ms = 15 * len (candidates )
189+ self ._report_progress (progress_callback , 76 , "Saving rendered candidates and round state" )
167190 session .current_round = round_index
168191 session .status = SessionStatus .awaiting_feedback
169192 session .updated_at = utc_now ()
@@ -179,7 +202,9 @@ def generate_round(self, session_id: str) -> RoundResponse:
179202 "candidates" : [self ._candidate_trace_payload (candidate ) for candidate in round_obj .candidates ],
180203 },
181204 )
205+ self ._report_progress (progress_callback , 90 , "Refreshing trace report and replay data" )
182206 self .generate_trace_report (session .id )
207+ self ._report_progress (progress_callback , 98 , "Round ready for review" )
183208 return RoundResponse (
184209 round_id = round_obj .id ,
185210 candidate_metadata = round_obj .candidates ,
@@ -191,20 +216,29 @@ def generate_round(self, session_id: str) -> RoundResponse:
191216 },
192217 )
193218
194- def submit_feedback (self , round_id : str , request : FeedbackRequest ) -> FeedbackResponse :
219+ def submit_feedback (
220+ self ,
221+ round_id : str ,
222+ request : FeedbackRequest ,
223+ progress_callback : Callable [[int , str ], None ] | None = None ,
224+ ) -> FeedbackResponse :
195225 """Normalize feedback, update state, and persist the new incumbent."""
196226
227+ self ._report_progress (progress_callback , 14 , "Checking round readiness for feedback" )
197228 round_obj , session = self ._assert_feedback_submission_allowed (round_id , request )
229+ self ._report_progress (progress_callback , 30 , "Normalizing and validating user preferences" )
198230 feedback = normalize_feedback (round_id , request )
199231 self ._validate_feedback_against_round (round_obj , feedback )
200232 updater = self .updaters [session .config .updater ]
233+ self ._report_progress (progress_callback , 52 , "Updating the steering model from your feedback" )
201234 next_z , update_summary = updater .update (session , round_obj .candidates , feedback )
202235 round_obj .feedback_events .append (feedback )
203236 round_obj .update_summary = update_summary
204237 session .current_z = next_z
205238 session .incumbent_candidate_id = update_summary ["winner_candidate_id" ]
206239 session .status = SessionStatus .ready
207240 session .updated_at = utc_now ()
241+ self ._report_progress (progress_callback , 72 , "Saving updated session state" )
208242 self .repository .save_round (round_obj )
209243 self .repository .save_session (session )
210244 logger .info ("Applied feedback to round %s for session %s" , round_obj .id , session .id )
@@ -221,7 +255,9 @@ def submit_feedback(self, round_id: str, request: FeedbackRequest) -> FeedbackRe
221255 "next_incumbent_state" : next_z ,
222256 },
223257 )
258+ self ._report_progress (progress_callback , 90 , "Refreshing trace report with the new preference outcome" )
224259 self .generate_trace_report (session .id )
260+ self ._report_progress (progress_callback , 98 , "Feedback applied and next round unlocked" )
225261 return FeedbackResponse (update_summary = update_summary , next_incumbent_state = next_z )
226262
227263 def export_replay (self , session_id : str ) -> dict :
0 commit comments