Skip to content

Commit 9523d45

Browse files
committed
Add session resume flow and configurable steering dimension
1 parent 77b3692 commit 9523d45

20 files changed

Lines changed: 530 additions & 134 deletions

app/core/config_yaml.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# updater: winner_average | winner_copy | linear_preference
1919
# feedback_mode: scalar_rating | pairwise | top_k | winner_only | approve_reject
2020
# seed_policy: fixed-per-round | fixed-per-candidate | fixed-per-candidate-role
21+
# steering_dimension: low-dimensional steering vector size, for example 3 or 5
2122
# image_size: WIDTHxHEIGHT, for example 512x512
2223
# guidance_scale: classifier-free guidance strength, for example 7.5
2324
# num_inference_steps: diffusion denoising steps, for example 15 or 30

app/core/jobs.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from concurrent.futures import ThreadPoolExecutor
44
from datetime import datetime
55
from enum import Enum
6+
from inspect import signature
67
from threading import Lock
78
from typing import Any, Callable
89
from uuid import uuid4
@@ -62,9 +63,16 @@ async def get(self, job_id: str) -> JobRecord | None:
6263
def _run_sync(self, job_id: str, fn: Callable[[], Any]) -> None:
6364
"""Execute one submitted job and persist its status transitions."""
6465

65-
self._update(job_id, state=JobState.running, progress=15, status_message="Running")
66+
self._update(job_id, state=JobState.running, progress=12, status_message="Starting work")
6667
try:
67-
result = fn()
68+
try:
69+
arity = len(signature(fn).parameters)
70+
except (TypeError, ValueError):
71+
arity = 0
72+
if arity >= 1:
73+
result = fn(lambda progress, message: self.update_progress(job_id, progress, message))
74+
else:
75+
result = fn()
6876
except Exception as exc:
6977
self._update(
7078
job_id,
@@ -80,11 +88,16 @@ def _run_sync(self, job_id: str, fn: Callable[[], Any]) -> None:
8088
job_id,
8189
state=JobState.succeeded,
8290
progress=100,
83-
status_message="Completed",
91+
status_message="Completed successfully",
8492
result=payload,
8593
error=None,
8694
)
8795

96+
def update_progress(self, job_id: str, progress: int, status_message: str) -> None:
97+
"""Expose safe phase-level progress updates to long-running jobs."""
98+
99+
self._update(job_id, state=JobState.running, progress=progress, status_message=status_message)
100+
88101
def _update(self, job_id: str, **changes: Any) -> None:
89102
"""Mutate one job record in place under the manager lock."""
90103

app/core/schema.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class StrategyConfig(BaseModel):
8585
feedback_mode: FeedbackType = FeedbackType.scalar_rating
8686
seed_policy: SeedPolicy = SeedPolicy.fixed_per_round
8787
steering_mode: SteeringMode = SteeringMode.low_dimensional
88+
steering_dimension: int = Field(default=3, ge=1, le=16)
8889
candidate_count: int = Field(default=5, ge=1, le=12)
8990
image_size: str = "512x512"
9091
trust_radius: float = Field(default=0.35, gt=0.0, le=1.0)
@@ -189,7 +190,7 @@ class Session(BaseModel):
189190
status: SessionStatus = SessionStatus.created
190191
basis_type: str = "random_orthonormal"
191192
current_round: int = 0
192-
current_z: list[float] = Field(default_factory=lambda: [0.0, 0.0, 0.0])
193+
current_z: list[float] = Field(default_factory=list)
193194
incumbent_candidate_id: str | None = None
194195
final_selected_candidate: str | None = None
195196
base_embedding_cache_key: str = Field(default_factory=lambda: new_id("emb"))

app/engine/orchestrator.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from copy import deepcopy
44
import hashlib
55
import math
6+
from typing import Callable
67

78
from app.core.config import settings
89
from app.core.schema import (
@@ -62,6 +63,13 @@ def __init__(
6263
"linear_preference": LinearPreferenceUpdater(),
6364
}
6465

66+
@staticmethod
67+
def _report_progress(progress_callback: Callable[[int, str], None] | None, progress: int, message: str) -> None:
68+
"""Emit a phase-level progress update when a callback is available."""
69+
70+
if progress_callback is not None:
71+
progress_callback(progress, message)
72+
6573
def create_experiment(self, request: ExperimentCreate) -> Experiment:
6674
"""Create and persist a reusable experiment definition."""
6775

@@ -102,6 +110,7 @@ def create_session(self, request: SessionCreate) -> Session:
102110
negative_prompt=request.negative_prompt,
103111
model_name=config.model_name,
104112
config=config,
113+
current_z=[0.0 for _ in range(config.steering_dimension)],
105114
status=SessionStatus.ready,
106115
)
107116
logger.info("Created session %s for experiment %s", session.id, session.experiment_id)
@@ -118,17 +127,28 @@ def get_session(self, session_id: str) -> Session | None:
118127

119128
return self.repository.get_session(session_id)
120129

130+
def list_sessions(self) -> list[Session]:
131+
"""Return all stored sessions ordered by recent activity."""
132+
133+
return self.repository.list_sessions()
134+
121135
def get_session_rounds(self, session_id: str) -> list[Round]:
122136
"""Return ordered rounds for a given session."""
123137

124138
return self.repository.list_rounds_for_session(session_id)
125139

126-
def generate_round(self, session_id: str) -> RoundResponse:
140+
def generate_round(
141+
self,
142+
session_id: str,
143+
progress_callback: Callable[[int, str], None] | None = None,
144+
) -> RoundResponse:
127145
"""Propose, render, persist, and return the next round of candidates."""
128146

147+
self._report_progress(progress_callback, 14, "Checking session readiness")
129148
session = self._assert_round_generation_allowed(session_id)
130149
sampler = self.samplers[session.config.sampler]
131150
round_index = session.current_round + 1
151+
self._report_progress(progress_callback, 24, "Preparing round state")
132152
round_obj = Round(
133153
session_id=session.id,
134154
round_index=round_index,
@@ -144,6 +164,7 @@ def generate_round(self, session_id: str) -> RoundResponse:
144164
carried_forward = self._build_carried_forward_candidate(session)
145165
baseline_candidate = self._build_baseline_prompt_candidate(session)
146166
sampler_seed = self._seed_token(session.id, round_index, "sampler")
167+
self._report_progress(progress_callback, 36, f"Sampling {session.config.candidate_count} candidate directions")
147168
proposed_candidates = sampler.propose(session, sampler_seed)
148169
proposed_candidates = self._widen_first_round_candidates(session, proposed_candidates)
149170
candidates = self._compose_round_candidates(
@@ -152,6 +173,7 @@ def generate_round(self, session_id: str) -> RoundResponse:
152173
candidate_count=session.config.candidate_count,
153174
)
154175
self._assign_candidate_seeds(session, round_index, candidates)
176+
self._report_progress(progress_callback, 52, "Rendering candidate images on the model backend")
155177
# Render each candidate independently so future versions can tolerate
156178
# partial round failures without changing the orchestration contract.
157179
for candidate in candidates:
@@ -164,6 +186,7 @@ def generate_round(self, session_id: str) -> RoundResponse:
164186
round_obj.candidates = candidates
165187
round_obj.render_status = RenderStatus.succeeded
166188
round_obj.latency_ms = 15 * len(candidates)
189+
self._report_progress(progress_callback, 76, "Saving rendered candidates and round state")
167190
session.current_round = round_index
168191
session.status = SessionStatus.awaiting_feedback
169192
session.updated_at = utc_now()
@@ -179,7 +202,9 @@ def generate_round(self, session_id: str) -> RoundResponse:
179202
"candidates": [self._candidate_trace_payload(candidate) for candidate in round_obj.candidates],
180203
},
181204
)
205+
self._report_progress(progress_callback, 90, "Refreshing trace report and replay data")
182206
self.generate_trace_report(session.id)
207+
self._report_progress(progress_callback, 98, "Round ready for review")
183208
return RoundResponse(
184209
round_id=round_obj.id,
185210
candidate_metadata=round_obj.candidates,
@@ -191,20 +216,29 @@ def generate_round(self, session_id: str) -> RoundResponse:
191216
},
192217
)
193218

194-
def submit_feedback(self, round_id: str, request: FeedbackRequest) -> FeedbackResponse:
219+
def submit_feedback(
220+
self,
221+
round_id: str,
222+
request: FeedbackRequest,
223+
progress_callback: Callable[[int, str], None] | None = None,
224+
) -> FeedbackResponse:
195225
"""Normalize feedback, update state, and persist the new incumbent."""
196226

227+
self._report_progress(progress_callback, 14, "Checking round readiness for feedback")
197228
round_obj, session = self._assert_feedback_submission_allowed(round_id, request)
229+
self._report_progress(progress_callback, 30, "Normalizing and validating user preferences")
198230
feedback = normalize_feedback(round_id, request)
199231
self._validate_feedback_against_round(round_obj, feedback)
200232
updater = self.updaters[session.config.updater]
233+
self._report_progress(progress_callback, 52, "Updating the steering model from your feedback")
201234
next_z, update_summary = updater.update(session, round_obj.candidates, feedback)
202235
round_obj.feedback_events.append(feedback)
203236
round_obj.update_summary = update_summary
204237
session.current_z = next_z
205238
session.incumbent_candidate_id = update_summary["winner_candidate_id"]
206239
session.status = SessionStatus.ready
207240
session.updated_at = utc_now()
241+
self._report_progress(progress_callback, 72, "Saving updated session state")
208242
self.repository.save_round(round_obj)
209243
self.repository.save_session(session)
210244
logger.info("Applied feedback to round %s for session %s", round_obj.id, session.id)
@@ -221,7 +255,9 @@ def submit_feedback(self, round_id: str, request: FeedbackRequest) -> FeedbackRe
221255
"next_incumbent_state": next_z,
222256
},
223257
)
258+
self._report_progress(progress_callback, 90, "Refreshing trace report with the new preference outcome")
224259
self.generate_trace_report(session.id)
260+
self._report_progress(progress_callback, 98, "Feedback applied and next round unlocked")
225261
return FeedbackResponse(update_summary=update_summary, next_incumbent_state=next_z)
226262

227263
def export_replay(self, session_id: str) -> dict:

app/frontend/static/styles.css

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ h1, h2, h3, p { margin-top: 0; }
4444
.card { background: var(--paper); border: 1px solid var(--line); border-radius: 20px; padding: 22px; margin-bottom: 20px; box-shadow: 0 10px 30px rgba(31,26,23,0.05); }
4545
.section-head { display: flex; align-items: baseline; justify-content: space-between; gap: 12px; margin-bottom: 16px; }
4646
.actions { display: flex; gap: 12px; flex-wrap: wrap; margin-top: 16px; }
47+
.inline-actions {
48+
margin-top: 0;
49+
gap: 8px;
50+
}
4751
.setup-steps {
4852
display: grid;
4953
grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
@@ -218,6 +222,38 @@ textarea { min-height: 100px; resize: vertical; }
218222
color: var(--accent);
219223
margin-bottom: 10px;
220224
}
225+
.round-block {
226+
margin-top: 14px;
227+
border: 1px solid var(--line);
228+
border-radius: 18px;
229+
background: white;
230+
overflow: hidden;
231+
}
232+
.round-summary {
233+
display: flex;
234+
justify-content: space-between;
235+
gap: 12px;
236+
padding: 14px 16px;
237+
cursor: pointer;
238+
list-style: none;
239+
font-weight: 600;
240+
}
241+
.round-summary::-webkit-details-marker {
242+
display: none;
243+
}
244+
.round-summary-title {
245+
color: var(--ink);
246+
}
247+
.round-summary-meta {
248+
color: var(--muted);
249+
font-weight: 400;
250+
}
251+
.round-body {
252+
padding: 0 16px 16px;
253+
}
254+
.round-footer-actions {
255+
margin-top: 18px;
256+
}
221257
code { background: #f2ebdf; padding: 2px 6px; border-radius: 8px; }
222258
.trace-card { background: linear-gradient(180deg, #fffaf2, #f4efe7); }
223259
.trace-log {

app/frontend/templates/index.html

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,41 @@ <h1>Interactive prompt-embedding steering research prototype</h1>
1717
<a class="button secondary" href="/diagnostics/view">Runtime diagnostics</a>
1818
</div>
1919
</section>
20+
<section class="card">
21+
<div class="section-head">
22+
<h2>Resume sessions</h2>
23+
<span>{{ sessions|length }} recent</span>
24+
</div>
25+
{% if sessions %}
26+
<table class="table">
27+
<thead>
28+
<tr>
29+
<th>Prompt</th>
30+
<th>Status</th>
31+
<th>Round</th>
32+
<th>Continue</th>
33+
</tr>
34+
</thead>
35+
<tbody>
36+
{% for session in sessions %}
37+
<tr>
38+
<td>{{ session.prompt }}</td>
39+
<td>{{ session.status }}</td>
40+
<td>{{ session.current_round }}</td>
41+
<td>
42+
<div class="actions inline-actions">
43+
<a class="button secondary" href="/sessions/{{ session.id }}/view">Resume session</a>
44+
<a class="button secondary" href="/sessions/{{ session.id }}/replay-view">Replay</a>
45+
</div>
46+
</td>
47+
</tr>
48+
{% endfor %}
49+
</tbody>
50+
</table>
51+
{% else %}
52+
<p>No sessions yet. Create your first session to make it resumable from this dashboard.</p>
53+
{% endif %}
54+
</section>
2055
<section class="card">
2156
<div class="section-head">
2257
<h2>Experiments</h2>

0 commit comments

Comments
 (0)