Skip to content

Commit 91bc653

Browse files
author
jack
committed
feat: Add real-time model display updates for lead-worker mode
## Summary Implements real-time model display updates in the GUI when using lead-worker mode, showing which model (lead or worker) is currently active. ## Changes ### Backend - Add ModelChange event type to AgentEvent enum for tracking model switches - Emit ModelChange events when the lead-worker provider switches between models - Add /config/current-model API endpoint to fetch the currently active model - Update global current model store when providers complete requests - Fix runtime panic by avoiding block_on in async context (use global store instead) ### Frontend - Create LeadWorkerSettings UI component for configuring lead-worker mode - Add fields for lead model, worker model, lead turns, failure threshold, and fallback turns - Update useMessageStream hook to capture ModelChange events from SSE stream - Use React Context to share current model info instead of window.appConfig (which is read-only) - Update ModelsBottomBar to display active model with mode indicator (lead/worker) - Prioritize active model display when lead-worker mode is enabled ### Event Handling - Handle ModelChange events in all AgentEvent match statements: - CLI session handler - Web command handler - FFI library - Scheduler - Scheduler executor - Agent tests ## Testing - Fixed agent tests to handle the new ModelChange event - All tests pass - Clippy checks pass - Frontend linting and type checking pass
1 parent 3bec469 commit 91bc653

16 files changed

Lines changed: 475 additions & 9 deletions

File tree

crates/goose-cli/src/commands/web.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,10 @@ async fn process_message_streaming(
589589
// For now, we'll just log them
590590
tracing::info!("Received MCP notification in web interface");
591591
}
592+
Ok(AgentEvent::ModelChange { model, mode }) => {
593+
// Log model change
594+
tracing::info!("Model changed to {} in {} mode", model, mode);
595+
}
592596
Err(e) => {
593597
error!("Error in message stream: {}", e);
594598
let mut sender = sender.lock().await;

crates/goose-cli/src/session/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,12 @@ impl Session {
928928
}
929929
}
930930
}
931+
Some(Ok(AgentEvent::ModelChange { model, mode })) => {
932+
// Log model change if in debug mode
933+
if self.debug {
934+
eprintln!("Model changed to {} in {} mode", model, mode);
935+
}
936+
}
931937
Some(Err(e)) => {
932938
eprintln!("Error: {}", e);
933939
drop(stream);

crates/goose-ffi/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,9 @@ pub unsafe extern "C" fn goose_agent_send_message(
266266
Ok(AgentEvent::McpNotification(_)) => {
267267
// TODO: Handle MCP notifications.
268268
}
269+
Ok(AgentEvent::ModelChange { .. }) => {
270+
// Model change events are informational, just continue
271+
}
269272
Err(e) => {
270273
full_response.push_str(&format!("\nError in message stream: {}", e));
271274
}

crates/goose-scheduler-executor/src/main.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,9 @@ async fn execute_recipe(job_id: &str, recipe_path: &str) -> Result<String> {
165165
Ok(AgentEvent::McpNotification(_)) => {
166166
// Handle notifications if needed
167167
}
168+
Ok(AgentEvent::ModelChange { .. }) => {
169+
// Model change events are informational, just continue
170+
}
168171
Err(e) => {
169172
return Err(anyhow!("Error receiving message from agent: {}", e));
170173
}

crates/goose-server/src/routes/config_management.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,26 @@ pub async fn backup_config(
441441
}
442442
}
443443

444+
#[utoipa::path(
445+
get,
446+
path = "/config/current-model",
447+
responses(
448+
(status = 200, description = "Current model retrieved successfully", body = String),
449+
)
450+
)]
451+
pub async fn get_current_model(
452+
State(state): State<Arc<AppState>>,
453+
headers: HeaderMap,
454+
) -> Result<Json<Value>, StatusCode> {
455+
verify_secret_key(&headers, &state)?;
456+
457+
let current_model = goose::providers::base::get_current_model();
458+
459+
Ok(Json(serde_json::json!({
460+
"model": current_model
461+
})))
462+
}
463+
444464
pub fn routes(state: Arc<AppState>) -> Router {
445465
Router::new()
446466
.route("/config", get(read_all_config))
@@ -454,6 +474,7 @@ pub fn routes(state: Arc<AppState>) -> Router {
454474
.route("/config/init", post(init_config))
455475
.route("/config/backup", post(backup_config))
456476
.route("/config/permissions", post(upsert_permissions))
477+
.route("/config/current-model", get(get_current_model))
457478
.with_state(state)
458479
}
459480

crates/goose-server/src/routes/reply.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ enum MessageEvent {
8888
Finish {
8989
reason: String,
9090
},
91+
ModelChange {
92+
model: String,
93+
mode: String,
94+
},
9195
Notification {
9296
request_id: String,
9397
message: JsonRpcMessage,
@@ -233,6 +237,17 @@ async fn handler(
233237
}
234238
});
235239
}
240+
Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => {
241+
if let Err(e) = stream_event(MessageEvent::ModelChange { model, mode }, &tx).await {
242+
tracing::error!("Error sending model change through channel: {}", e);
243+
let _ = stream_event(
244+
MessageEvent::Error {
245+
error: e.to_string(),
246+
},
247+
&tx,
248+
).await;
249+
}
250+
}
236251
Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => {
237252
if let Err(e) = stream_event(MessageEvent::Notification{
238253
request_id: request_id.clone(),
@@ -352,6 +367,10 @@ async fn ask_handler(
352367
}
353368
}
354369
}
370+
Ok(AgentEvent::ModelChange { model, mode }) => {
371+
// Log model change for non-streaming
372+
tracing::info!("Model changed to {} in {} mode", model, mode);
373+
}
355374
Ok(AgentEvent::McpNotification(n)) => {
356375
// Handle notifications if needed
357376
tracing::info!("Received notification: {:?}", n);

crates/goose/src/agents/agent.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ pub struct Agent {
6565
pub enum AgentEvent {
6666
Message(Message),
6767
McpNotification((String, JsonRpcMessage)),
68+
ModelChange { model: String, mode: String },
6869
}
6970

7071
impl Agent {
@@ -582,6 +583,26 @@ impl Agent {
582583
&toolshim_tools,
583584
).await {
584585
Ok((response, usage)) => {
586+
// Emit model change event if provider is lead-worker
587+
let provider = self.provider().await?;
588+
if let Some(lead_worker) = provider.as_lead_worker() {
589+
// The actual model used is in the usage
590+
let active_model = usage.model.clone();
591+
let (lead_model, worker_model) = lead_worker.get_model_info();
592+
let mode = if active_model == lead_model {
593+
"lead"
594+
} else if active_model == worker_model {
595+
"worker"
596+
} else {
597+
"unknown"
598+
};
599+
600+
yield AgentEvent::ModelChange {
601+
model: active_model,
602+
mode: mode.to_string(),
603+
};
604+
}
605+
585606
// record usage for the session in the session file
586607
if let Some(session_config) = session.clone() {
587608
Self::update_session_metrics(session_config, &usage, messages.len()).await?;

crates/goose/src/providers/base.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ use async_trait::async_trait;
152152
pub trait LeadWorkerProviderTrait {
153153
/// Get information about the lead and worker models for logging
154154
fn get_model_info(&self) -> (String, String);
155+
156+
/// Get the currently active model name
157+
fn get_active_model(&self) -> String;
155158
}
156159

157160
/// Base trait for AI providers (OpenAI, Anthropic, etc)
@@ -207,6 +210,17 @@ pub trait Provider: Send + Sync {
207210
fn as_lead_worker(&self) -> Option<&dyn LeadWorkerProviderTrait> {
208211
None
209212
}
213+
214+
/// Get the currently active model name
215+
/// For regular providers, this returns the configured model
216+
/// For LeadWorkerProvider, this returns the currently active model (lead or worker)
217+
fn get_active_model_name(&self) -> String {
218+
if let Some(lead_worker) = self.as_lead_worker() {
219+
lead_worker.get_active_model()
220+
} else {
221+
self.get_model_config().model_name
222+
}
223+
}
210224
}
211225

212226
#[cfg(test)]

crates/goose/src/providers/lead_worker.rs

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,16 @@ impl LeadWorkerProviderTrait for LeadWorkerProvider {
291291
let worker_model = self.worker_provider.get_model_config().model_name;
292292
(lead_model, worker_model)
293293
}
294+
295+
/// Get the currently active model name
296+
fn get_active_model(&self) -> String {
297+
// Read from the global store which was set during complete()
298+
use super::base::get_current_model;
299+
get_current_model().unwrap_or_else(|| {
300+
// Fallback to lead model if no current model is set
301+
self.lead_provider.get_model_config().model_name
302+
})
303+
}
294304
}
295305

296306
#[async_trait]
@@ -336,19 +346,31 @@ impl Provider for LeadWorkerProvider {
336346
"worker"
337347
};
338348

349+
// Get the active model name and update the global store
350+
let active_model_name = if turn_count < self.lead_turns || in_fallback {
351+
self.lead_provider.get_model_config().model_name.clone()
352+
} else {
353+
self.worker_provider.get_model_config().model_name.clone()
354+
};
355+
356+
// Update the global current model store
357+
super::base::set_current_model(&active_model_name);
358+
339359
if in_fallback {
340360
tracing::info!(
341-
"🔄 Using {} provider for turn {} (FALLBACK MODE: {} turns remaining)",
361+
"🔄 Using {} provider for turn {} (FALLBACK MODE: {} turns remaining) - Model: {}",
342362
provider_type,
343363
turn_count + 1,
344-
fallback_remaining
364+
fallback_remaining,
365+
active_model_name
345366
);
346367
} else {
347368
tracing::info!(
348-
"Using {} provider for turn {} (lead_turns: {})",
369+
"Using {} provider for turn {} (lead_turns: {}) - Model: {}",
349370
provider_type,
350371
turn_count + 1,
351-
self.lead_turns
372+
self.lead_turns,
373+
active_model_name
352374
);
353375
}
354376

crates/goose/src/scheduler.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,6 +1114,9 @@ async fn run_scheduled_job_internal(
11141114
Ok(AgentEvent::McpNotification(_)) => {
11151115
// Handle notifications if needed
11161116
}
1117+
Ok(AgentEvent::ModelChange { .. }) => {
1118+
// Model change events are informational, just continue
1119+
}
11171120
Err(e) => {
11181121
tracing::error!(
11191122
"[Job {}] Error receiving message from agent: {}",

0 commit comments

Comments
 (0)