diff --git a/crates/goose/src/acp/server.rs b/crates/goose/src/acp/server.rs index 6d06a0b5f7db..b95010df96df 100644 --- a/crates/goose/src/acp/server.rs +++ b/crates/goose/src/acp/server.rs @@ -2401,6 +2401,32 @@ fn replay_message_meta(message: &Message) -> Meta { meta } +fn replay_audience_annotations(audience: &[Role]) -> Annotations { + Annotations::new().audience( + audience + .iter() + .map(|role| match role { + Role::Assistant => agent_client_protocol::schema::Role::Assistant, + Role::User => agent_client_protocol::schema::Role::User, + }) + .collect::>(), + ) +} + +fn send_replay_content_chunk( + cx: &ConnectionTo, + session_id: &SessionId, + message: &Message, + content: ContentBlock, +) -> std::result::Result<(), agent_client_protocol::Error> { + let chunk = ContentChunk::new(content).meta(replay_message_meta(message)); + let update = match message.role { + Role::User => SessionUpdate::UserMessageChunk(chunk), + Role::Assistant => SessionUpdate::AgentMessageChunk(chunk), + }; + cx.send_notification(SessionNotification::new(session_id.clone(), update)) +} + fn replay_message_goose_meta(message: &Message) -> serde_json::Map { let mut goose = serde_json::Map::new(); goose.insert("created".to_string(), serde_json::json!(message.created)); @@ -2824,30 +2850,28 @@ impl GooseAcpAgent { MessageContent::Text(text) => { let mut tc = TextContent::new(text.text.clone()); if let Some(audience) = text.audience() { - tc = tc.annotations( - Annotations::new().audience( - audience - .iter() - .map(|r| match r { - Role::Assistant => { - agent_client_protocol::schema::Role::Assistant - } - Role::User => agent_client_protocol::schema::Role::User, - }) - .collect::>(), - ), - ); + tc = tc.annotations(replay_audience_annotations(audience)); } - let chunk = ContentChunk::new(ContentBlock::Text(tc)) - .meta(replay_message_meta(message)); - let update = match message.role { - Role::User => SessionUpdate::UserMessageChunk(chunk), - Role::Assistant => SessionUpdate::AgentMessageChunk(chunk), - }; - cx.send_notification(SessionNotification::new( - args.session_id.clone(), - update, - ))?; + send_replay_content_chunk( + cx, + &args.session_id, + message, + ContentBlock::Text(tc), + )?; + } + MessageContent::Image(image) => { + let mut image_content = + ImageContent::new(image.data.clone(), image.mime_type.clone()); + if let Some(audience) = image.audience() { + image_content = + image_content.annotations(replay_audience_annotations(audience)); + } + send_replay_content_chunk( + cx, + &args.session_id, + message, + ContentBlock::Image(image_content), + )?; } MessageContent::ToolRequest(tool_request) => { // Replay-only: emit the ToolCall notification and diff --git a/crates/goose/tests/acp_common_tests/mod.rs b/crates/goose/tests/acp_common_tests/mod.rs index 90c225d834dd..18976ab7511e 100644 --- a/crates/goose/tests/acp_common_tests/mod.rs +++ b/crates/goose/tests/acp_common_tests/mod.rs @@ -5,8 +5,8 @@ #[path = "../acp_fixtures/mod.rs"] pub mod fixtures; use agent_client_protocol::schema::{ - ListSessionsResponse, McpServer, McpServerHttp, ModelId, SessionInfo, SessionModeId, - ToolCallStatus, ToolKind, + ContentBlock, ListSessionsResponse, McpServer, McpServerHttp, ModelId, SessionInfo, + SessionModeId, SessionUpdate, ToolCallStatus, ToolKind, }; use fixtures::{ assert_notifications, Connection, FsFixture, Notification, OpenAiFixture, PermissionDecision, @@ -622,6 +622,57 @@ pub async fn run_load_session_mcp() { assert_eq!(output.text, FAKE_CODE, "tool call failed in loaded session"); } +pub async fn run_load_session_replays_image_attachment() { + let expected_session_id = C::expected_session_id(); + let openai = OpenAiFixture::new( + vec![( + r#""type":"image_url""#.into(), + include_str!("../acp_test_data/openai_image_attachment.txt"), + )], + expected_session_id.clone(), + ) + .await; + + let mut conn = C::new(TestConnectionConfig::default(), openai).await; + let SessionData { mut session, .. } = conn.new_session().await.unwrap(); + expected_session_id.set(&session.session_id().0); + let session_id = session.session_id().0.to_string(); + + let output = session + .prompt_with_image( + "Describe what you see in this image", + TEST_IMAGE_B64, + "image/png", + PermissionDecision::Cancel, + ) + .await + .unwrap(); + assert!(output.text.contains("Hello Goose!")); + session.session_updates(); + + let SessionData { session, .. } = conn.load_session(&session_id, vec![]).await.unwrap(); + let replayed_images = session + .session_updates() + .into_iter() + .filter_map(|update| match update { + SessionUpdate::UserMessageChunk(chunk) => match chunk.content { + ContentBlock::Image(image) => Some(image), + _ => None, + }, + _ => None, + }) + .collect::>(); + + assert_eq!( + replayed_images.len(), + 1, + "expected load_session to replay the user image attachment exactly once" + ); + let replayed_image = &replayed_images[0]; + assert_eq!(replayed_image.data, TEST_IMAGE_B64); + assert_eq!(replayed_image.mime_type, "image/png"); +} + pub async fn run_load_session_error() { let openai = OpenAiFixture::new(vec![], C::expected_session_id()).await; let mut conn = C::new(TestConnectionConfig::default(), openai).await; diff --git a/crates/goose/tests/acp_fixtures/mod.rs b/crates/goose/tests/acp_fixtures/mod.rs index 17a307d27d43..6ba0b05e1ad7 100644 --- a/crates/goose/tests/acp_fixtures/mod.rs +++ b/crates/goose/tests/acp_fixtures/mod.rs @@ -562,6 +562,9 @@ pub trait Connection: Sized { pub trait Session: std::fmt::Debug { fn session_id(&self) -> &agent_client_protocol::schema::SessionId; fn work_dir(&self) -> std::path::PathBuf; + /// Drains and returns raw session updates collected by the fixture. + fn session_updates(&self) -> Vec; + /// Drains and returns simplified notifications collected by the fixture. fn notifications(&self) -> Vec; async fn prompt( &mut self, diff --git a/crates/goose/tests/acp_fixtures/provider.rs b/crates/goose/tests/acp_fixtures/provider.rs index ed4f9a3ab90b..4767125b27ca 100644 --- a/crates/goose/tests/acp_fixtures/provider.rs +++ b/crates/goose/tests/acp_fixtures/provider.rs @@ -325,9 +325,12 @@ impl Session for AcpProviderSession { self.work_dir.clone() } + fn session_updates(&self) -> Vec { + self.notification_sink.lock().unwrap().drain(..).collect() + } + fn notifications(&self) -> Vec { - let updates: Vec<_> = self.notification_sink.lock().unwrap().drain(..).collect(); - super::to_notifications(&updates) + super::to_notifications(&self.session_updates()) } async fn prompt( diff --git a/crates/goose/tests/acp_fixtures/server.rs b/crates/goose/tests/acp_fixtures/server.rs index 5050dafbeee8..f7a5eec96ba8 100644 --- a/crates/goose/tests/acp_fixtures/server.rs +++ b/crates/goose/tests/acp_fixtures/server.rs @@ -53,6 +53,15 @@ impl std::fmt::Debug for AcpServerSession { } impl AcpServerSession { + pub fn session_updates(&self) -> Vec { + self.updates + .lock() + .unwrap() + .drain(..) + .map(|n| n.update) + .collect() + } + async fn send_prompt( &mut self, content: Vec, @@ -464,15 +473,12 @@ impl Session for AcpServerSession { self._work_dir.path().to_path_buf() } + fn session_updates(&self) -> Vec { + AcpServerSession::session_updates(self) + } + fn notifications(&self) -> Vec { - let updates: Vec<_> = self - .updates - .lock() - .unwrap() - .drain(..) - .map(|n| n.update) - .collect(); - super::to_notifications(&updates) + super::to_notifications(&self.session_updates()) } async fn prompt( diff --git a/crates/goose/tests/acp_server_test.rs b/crates/goose/tests/acp_server_test.rs index b540b32605a8..6faca600bf8f 100644 --- a/crates/goose/tests/acp_server_test.rs +++ b/crates/goose/tests/acp_server_test.rs @@ -11,12 +11,12 @@ use common_tests::{ run_close_session, run_config_mcp, run_config_option_mode_set, run_config_option_model_set, run_delete_session, run_fs_read_text_file_true, run_fs_write_text_file_false, run_fs_write_text_file_true, run_initialize_doesnt_hit_provider, run_list_sessions, - run_load_mode, run_load_model, run_load_session_error, run_load_session_mcp, run_mode_set, - run_model_list, run_model_set, run_model_set_error_session_not_found, - run_new_session_returns_initial_config, run_permission_persistence, run_prompt_basic, - run_prompt_error, run_prompt_image, run_prompt_image_attachment, run_prompt_mcp, - run_prompt_model_mismatch, run_prompt_skill, run_session_name_update_notification, - run_shell_terminal_false, run_shell_terminal_true, + run_load_mode, run_load_model, run_load_session_error, run_load_session_mcp, + run_load_session_replays_image_attachment, run_mode_set, run_model_list, run_model_set, + run_model_set_error_session_not_found, run_new_session_returns_initial_config, + run_permission_persistence, run_prompt_basic, run_prompt_error, run_prompt_image, + run_prompt_image_attachment, run_prompt_mcp, run_prompt_model_mismatch, run_prompt_skill, + run_session_name_update_notification, run_shell_terminal_false, run_shell_terminal_true, }; use goose::config::GooseMode; use goose::conversation::message::Message; @@ -220,6 +220,11 @@ fn test_load_session_mcp() { run_test(async { run_load_session_mcp::().await }); } +#[test] +fn test_load_session_replays_image_attachment() { + run_test(async { run_load_session_replays_image_attachment::().await }); +} + #[test] fn test_mode_set() { run_test(async { run_mode_set::().await });