|
1 | 1 | use std::collections::HashMap; |
2 | 2 | use std::future::Future; |
3 | 3 | use std::path::PathBuf; |
4 | | -use std::sync::OnceLock; |
5 | 4 | use std::sync::Arc; |
| 5 | +use std::sync::OnceLock; |
6 | 6 |
|
7 | 7 | use tokio::sync::Mutex; |
8 | 8 |
|
@@ -99,6 +99,104 @@ where |
99 | 99 | Ok(()) |
100 | 100 | } |
101 | 101 |
|
| 102 | +pub(crate) async fn restart_workspace_session_core<F, Fut>( |
| 103 | + workspace_id: String, |
| 104 | + workspaces: &Mutex<HashMap<String, WorkspaceEntry>>, |
| 105 | + sessions: &Mutex<HashMap<String, Arc<WorkspaceSession>>>, |
| 106 | + app_settings: &Mutex<AppSettings>, |
| 107 | + spawn_session: F, |
| 108 | +) -> Result<Vec<String>, String> |
| 109 | +where |
| 110 | + F: Fn(WorkspaceEntry, Option<String>, Option<String>, Option<PathBuf>) -> Fut, |
| 111 | + Fut: Future<Output = Result<Arc<WorkspaceSession>, String>>, |
| 112 | +{ |
| 113 | + let _ = resolve_entry_and_parent(workspaces, &workspace_id).await?; |
| 114 | + let _spawn_guard = workspace_session_spawn_lock().lock().await; |
| 115 | + |
| 116 | + let current_session = { |
| 117 | + let sessions = sessions.lock().await; |
| 118 | + sessions.get(&workspace_id).cloned() |
| 119 | + }; |
| 120 | + |
| 121 | + let (spawn_entry_id, affected_ids, old_session) = if let Some(current_session) = current_session |
| 122 | + { |
| 123 | + let owner_workspace_id = current_session.owner_workspace_id.clone(); |
| 124 | + let affected_ids = { |
| 125 | + let sessions = sessions.lock().await; |
| 126 | + sessions |
| 127 | + .iter() |
| 128 | + .filter_map(|(id, session)| { |
| 129 | + if Arc::ptr_eq(session, ¤t_session) { |
| 130 | + Some(id.clone()) |
| 131 | + } else { |
| 132 | + None |
| 133 | + } |
| 134 | + }) |
| 135 | + .collect::<Vec<_>>() |
| 136 | + }; |
| 137 | + let spawn_entry_id = if affected_ids.iter().any(|id| id == &owner_workspace_id) { |
| 138 | + owner_workspace_id |
| 139 | + } else { |
| 140 | + workspace_id.clone() |
| 141 | + }; |
| 142 | + (spawn_entry_id, affected_ids, Some(current_session)) |
| 143 | + } else { |
| 144 | + (workspace_id.clone(), vec![workspace_id.clone()], None) |
| 145 | + }; |
| 146 | + |
| 147 | + let (spawn_entry, parent_entry, affected_entries) = { |
| 148 | + let workspaces = workspaces.lock().await; |
| 149 | + let spawn_entry = workspaces |
| 150 | + .get(&spawn_entry_id) |
| 151 | + .cloned() |
| 152 | + .or_else(|| workspaces.get(&workspace_id).cloned()) |
| 153 | + .ok_or_else(|| "workspace not found".to_string())?; |
| 154 | + let parent_entry = spawn_entry |
| 155 | + .parent_id |
| 156 | + .as_ref() |
| 157 | + .and_then(|parent_id| workspaces.get(parent_id)) |
| 158 | + .cloned(); |
| 159 | + let affected_entries = affected_ids |
| 160 | + .iter() |
| 161 | + .filter_map(|id| workspaces.get(id).cloned()) |
| 162 | + .collect::<Vec<_>>(); |
| 163 | + (spawn_entry, parent_entry, affected_entries) |
| 164 | + }; |
| 165 | + |
| 166 | + let (default_bin, codex_args) = { |
| 167 | + let settings = app_settings.lock().await; |
| 168 | + ( |
| 169 | + settings.codex_bin.clone(), |
| 170 | + resolve_workspace_codex_args(&spawn_entry, parent_entry.as_ref(), Some(&settings)), |
| 171 | + ) |
| 172 | + }; |
| 173 | + let codex_home = resolve_workspace_codex_home(&spawn_entry, parent_entry.as_ref()); |
| 174 | + let new_session = spawn_session(spawn_entry, default_bin, codex_args, codex_home).await?; |
| 175 | + |
| 176 | + for entry in &affected_entries { |
| 177 | + new_session |
| 178 | + .register_workspace_with_path(&entry.id, Some(&entry.path)) |
| 179 | + .await; |
| 180 | + } |
| 181 | + |
| 182 | + { |
| 183 | + let mut sessions = sessions.lock().await; |
| 184 | + if let Some(old_session) = old_session.as_ref() { |
| 185 | + sessions.retain(|_, session| !Arc::ptr_eq(session, old_session)); |
| 186 | + } |
| 187 | + for entry in &affected_entries { |
| 188 | + sessions.insert(entry.id.clone(), Arc::clone(&new_session)); |
| 189 | + } |
| 190 | + } |
| 191 | + |
| 192 | + if let Some(old_session) = old_session { |
| 193 | + let mut child = old_session.child.lock().await; |
| 194 | + kill_child_process_tree(&mut child).await; |
| 195 | + } |
| 196 | + |
| 197 | + Ok(affected_entries.into_iter().map(|entry| entry.id).collect()) |
| 198 | +} |
| 199 | + |
102 | 200 | pub(super) async fn kill_session_by_id( |
103 | 201 | sessions: &Mutex<HashMap<String, Arc<WorkspaceSession>>>, |
104 | 202 | id: &str, |
@@ -149,7 +247,14 @@ mod tests { |
149 | 247 | } |
150 | 248 | } |
151 | 249 |
|
152 | | - fn make_session(_entry: WorkspaceEntry) -> Arc<WorkspaceSession> { |
| 250 | + fn make_session(entry: WorkspaceEntry) -> Arc<WorkspaceSession> { |
| 251 | + make_session_with_owner(entry.clone(), &entry.id) |
| 252 | + } |
| 253 | + |
| 254 | + fn make_session_with_owner( |
| 255 | + _entry: WorkspaceEntry, |
| 256 | + owner_workspace_id: &str, |
| 257 | + ) -> Arc<WorkspaceSession> { |
153 | 258 | let mut cmd = if cfg!(windows) { |
154 | 259 | let mut cmd = Command::new("cmd"); |
155 | 260 | cmd.args(["/C", "more"]); |
@@ -177,8 +282,8 @@ mod tests { |
177 | 282 | hidden_thread_ids: Mutex::new(HashSet::new()), |
178 | 283 | next_id: AtomicU64::new(0), |
179 | 284 | background_thread_callbacks: Mutex::new(HashMap::new()), |
180 | | - owner_workspace_id: "test-owner".to_string(), |
181 | | - workspace_ids: Mutex::new(HashSet::from(["test-owner".to_string()])), |
| 285 | + owner_workspace_id: owner_workspace_id.to_string(), |
| 286 | + workspace_ids: Mutex::new(HashSet::from([owner_workspace_id.to_string()])), |
182 | 287 | workspace_roots: Mutex::new(HashMap::new()), |
183 | 288 | }) |
184 | 289 | } |
@@ -250,4 +355,97 @@ mod tests { |
250 | 355 | kill_session_by_id(&sessions, &entry.id).await; |
251 | 356 | }); |
252 | 357 | } |
| 358 | + |
| 359 | + #[test] |
| 360 | + fn restart_workspace_session_spawns_when_not_connected() { |
| 361 | + tokio::runtime::Runtime::new().unwrap().block_on(async { |
| 362 | + let entry = make_workspace_entry("ws-restart"); |
| 363 | + let workspaces = Mutex::new(HashMap::from([(entry.id.clone(), entry.clone())])); |
| 364 | + let sessions = Mutex::new(HashMap::<String, Arc<WorkspaceSession>>::new()); |
| 365 | + let app_settings = Mutex::new(AppSettings::default()); |
| 366 | + let spawn_calls = Arc::new(AtomicUsize::new(0)); |
| 367 | + let spawn_calls_ref = spawn_calls.clone(); |
| 368 | + let entry_for_spawn = entry.clone(); |
| 369 | + |
| 370 | + let restarted = restart_workspace_session_core( |
| 371 | + entry.id.clone(), |
| 372 | + &workspaces, |
| 373 | + &sessions, |
| 374 | + &app_settings, |
| 375 | + move |_entry, _default_bin, _codex_args, _codex_home| { |
| 376 | + let spawn_calls_ref = spawn_calls_ref.clone(); |
| 377 | + let entry_for_spawn = entry_for_spawn.clone(); |
| 378 | + async move { |
| 379 | + spawn_calls_ref.fetch_add(1, Ordering::SeqCst); |
| 380 | + Ok(make_session(entry_for_spawn)) |
| 381 | + } |
| 382 | + }, |
| 383 | + ) |
| 384 | + .await |
| 385 | + .expect("restart should spawn"); |
| 386 | + |
| 387 | + assert_eq!(restarted, vec![entry.id.clone()]); |
| 388 | + assert_eq!(spawn_calls.load(Ordering::SeqCst), 1); |
| 389 | + assert!(sessions.lock().await.contains_key(&entry.id)); |
| 390 | + kill_session_by_id(&sessions, &entry.id).await; |
| 391 | + }); |
| 392 | + } |
| 393 | + |
| 394 | + #[test] |
| 395 | + fn restart_workspace_session_replaces_only_shared_session_group() { |
| 396 | + tokio::runtime::Runtime::new().unwrap().block_on(async { |
| 397 | + let entry_a = make_workspace_entry("ws-a"); |
| 398 | + let entry_b = make_workspace_entry("ws-b"); |
| 399 | + let entry_c = make_workspace_entry("ws-c"); |
| 400 | + let workspaces = Mutex::new(HashMap::from([ |
| 401 | + (entry_a.id.clone(), entry_a.clone()), |
| 402 | + (entry_b.id.clone(), entry_b.clone()), |
| 403 | + (entry_c.id.clone(), entry_c.clone()), |
| 404 | + ])); |
| 405 | + let shared_session = make_session_with_owner(entry_a.clone(), &entry_a.id); |
| 406 | + shared_session |
| 407 | + .register_workspace_with_path(&entry_b.id, Some(&entry_b.path)) |
| 408 | + .await; |
| 409 | + let unrelated_session = make_session_with_owner(entry_c.clone(), &entry_c.id); |
| 410 | + let sessions = Mutex::new(HashMap::from([ |
| 411 | + (entry_a.id.clone(), Arc::clone(&shared_session)), |
| 412 | + (entry_b.id.clone(), Arc::clone(&shared_session)), |
| 413 | + (entry_c.id.clone(), Arc::clone(&unrelated_session)), |
| 414 | + ])); |
| 415 | + let app_settings = Mutex::new(AppSettings::default()); |
| 416 | + let spawn_calls = Arc::new(AtomicUsize::new(0)); |
| 417 | + let spawn_calls_ref = spawn_calls.clone(); |
| 418 | + let entry_for_spawn = entry_a.clone(); |
| 419 | + |
| 420 | + let restarted = restart_workspace_session_core( |
| 421 | + entry_b.id.clone(), |
| 422 | + &workspaces, |
| 423 | + &sessions, |
| 424 | + &app_settings, |
| 425 | + move |_entry, _default_bin, _codex_args, _codex_home| { |
| 426 | + let spawn_calls_ref = spawn_calls_ref.clone(); |
| 427 | + let entry_for_spawn = entry_for_spawn.clone(); |
| 428 | + async move { |
| 429 | + spawn_calls_ref.fetch_add(1, Ordering::SeqCst); |
| 430 | + Ok(make_session_with_owner(entry_for_spawn, "ws-a")) |
| 431 | + } |
| 432 | + }, |
| 433 | + ) |
| 434 | + .await |
| 435 | + .expect("restart should replace shared session"); |
| 436 | + |
| 437 | + assert_eq!(spawn_calls.load(Ordering::SeqCst), 1); |
| 438 | + assert_eq!(restarted.len(), 2); |
| 439 | + assert!(restarted.contains(&entry_a.id)); |
| 440 | + assert!(restarted.contains(&entry_b.id)); |
| 441 | + |
| 442 | + let sessions = sessions.lock().await; |
| 443 | + let next_a = sessions.get(&entry_a.id).expect("ws-a session"); |
| 444 | + let next_b = sessions.get(&entry_b.id).expect("ws-b session"); |
| 445 | + let next_c = sessions.get(&entry_c.id).expect("ws-c session"); |
| 446 | + assert!(Arc::ptr_eq(next_a, next_b)); |
| 447 | + assert!(!Arc::ptr_eq(next_a, &shared_session)); |
| 448 | + assert!(Arc::ptr_eq(next_c, &unrelated_session)); |
| 449 | + }); |
| 450 | + } |
253 | 451 | } |
0 commit comments