diff --git a/AGENTS.md b/AGENTS.md index 38ecb09f..04f2dd03 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -19,7 +19,6 @@ All paths in the protocol should be absolute - Add constants for the method names - Add variants to {Agent|Client}{Request|Response} enums -- Handle the new method in the `Side::decode_request`/`Side::decode_notification` implementation - Add the method to markdown_generator.rs SideDocs functions - Run `npm run generate` and fix any issues that appear - Run `npm run check` diff --git a/src/bin/generate.rs b/src/bin/generate.rs index 3bf8b5d9..777cd87e 100644 --- a/src/bin/generate.rs +++ b/src/bin/generate.rs @@ -1,6 +1,7 @@ use agent_client_protocol_schema::{ - AGENT_METHOD_NAMES, AgentSide, CLIENT_METHOD_NAMES, ClientSide, JsonRpcMessage, - OutgoingMessage, ProtocolVersion, + AGENT_METHOD_NAMES, AgentNotification, AgentRequest, AgentResponse, CLIENT_METHOD_NAMES, + ClientNotification, ClientRequest, ClientResponse, JsonRpcMessage, Notification, + ProtocolVersion, Request, Response, }; #[cfg(feature = "unstable_cancel_request")] use agent_client_protocol_schema::{PROTOCOL_LEVEL_METHOD_NAMES, ProtocolLevelNotification}; @@ -9,19 +10,32 @@ use schemars::{ generate::SchemaSettings, transform::{RemoveRefSiblings, ReplaceBoolSchemas}, }; +use serde::{Deserialize, Serialize}; use std::{env, fs, path::Path}; use markdown_generator::MarkdownGenerator; -#[expect(dead_code)] -#[derive(JsonSchema)] +/// All messages that an agent can send to a client. +#[derive(Serialize, Deserialize, JsonSchema)] +#[serde(untagged)] #[schemars(inline)] -struct AgentOutgoingMessage(JsonRpcMessage>); +#[allow(clippy::large_enum_variant)] +enum AgentOutgoingMessage { + Request(Request), + Response(Response), + Notification(Notification), +} -#[expect(dead_code)] -#[derive(JsonSchema)] +/// All messages that a client can send to an agent. +#[derive(Serialize, Deserialize, JsonSchema)] +#[serde(untagged)] #[schemars(inline)] -struct ClientOutgoingMessage(JsonRpcMessage>); +#[allow(clippy::large_enum_variant)] +enum ClientOutgoingMessage { + Request(Request), + Response(Response), + Notification(Notification), +} #[expect(dead_code)] #[derive(JsonSchema)] @@ -29,8 +43,8 @@ struct ClientOutgoingMessage(JsonRpcMessage), + Client(JsonRpcMessage), #[cfg(feature = "unstable_cancel_request")] ProtocolLevel(ProtocolLevelNotification), } diff --git a/src/rpc.rs b/src/rpc.rs index c7552689..95c27220 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -2,14 +2,10 @@ use std::sync::Arc; use derive_more::{Display, From}; use schemars::JsonSchema; -use serde::{Deserialize, Serialize, de::DeserializeOwned}; -use serde_json::value::RawValue; +use serde::{Deserialize, Serialize}; use serde_with::skip_serializing_none; -use crate::{ - AGENT_METHOD_NAMES, AgentNotification, AgentRequest, AgentResponse, CLIENT_METHOD_NAMES, - ClientNotification, ClientRequest, ClientResponse, Error, ExtNotification, ExtRequest, Result, -}; +use crate::{Error, Result}; /// JSON RPC Request Id /// @@ -100,19 +96,6 @@ pub struct Notification { pub params: Option, } -#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)] -#[serde(untagged)] -#[schemars(inline)] -#[allow( - clippy::exhaustive_enums, - reason = "This comes from the JSON-RPC specification itself" -)] -pub enum OutgoingMessage { - Request(Request), - Response(Response), - Notification(Notification), -} - #[derive(Debug, Serialize, Deserialize, JsonSchema)] #[schemars(inline)] enum JsonRpcVersion { @@ -133,8 +116,7 @@ pub struct JsonRpcMessage { } impl JsonRpcMessage { - /// Wraps the provided [`OutgoingMessage`] or [`IncomingMessage`] into a versioned - /// [`JsonRpcMessage`]. + /// Wraps the provided message into a versioned [`JsonRpcMessage`]. #[must_use] pub fn wrap(message: M) -> Self { Self { @@ -144,283 +126,15 @@ impl JsonRpcMessage { } } -pub trait Side: Clone { - type InRequest: Clone + Serialize + DeserializeOwned + JsonSchema + 'static; - type InNotification: Clone + Serialize + DeserializeOwned + JsonSchema + 'static; - type OutResponse: Clone + Serialize + DeserializeOwned + JsonSchema + 'static; - - /// Decode a request for a given method. This will encapsulate the knowledge of mapping which - /// serialization struct to use for each method. - /// - /// # Errors - /// - /// This function will return an error if the method is not recognized or if the parameters - /// cannot be deserialized into the expected type. - fn decode_request(method: &str, params: Option<&RawValue>) -> Result; - - /// Decode a notification for a given method. This will encapsulate the knowledge of mapping which - /// serialization struct to use for each method. - /// - /// # Errors - /// - /// This function will return an error if the method is not recognized or if the parameters - /// cannot be deserialized into the expected type. - fn decode_notification(method: &str, params: Option<&RawValue>) - -> Result; -} - -/// Marker type representing the client side of an ACP connection. -/// -/// This type is used by the RPC layer to determine which messages -/// are incoming vs outgoing from the client's perspective. -/// -/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model) -#[derive(Clone, Default, Debug, JsonSchema)] -#[non_exhaustive] -pub struct ClientSide; - -impl Side for ClientSide { - type InRequest = AgentRequest; - type InNotification = AgentNotification; - type OutResponse = ClientResponse; - - fn decode_request(method: &str, params: Option<&RawValue>) -> Result { - let params = params.ok_or_else(Error::invalid_params)?; - - match method { - m if m == CLIENT_METHOD_NAMES.session_request_permission => { - serde_json::from_str(params.get()) - .map(AgentRequest::RequestPermissionRequest) - .map_err(Into::into) - } - m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get()) - .map(AgentRequest::WriteTextFileRequest) - .map_err(Into::into), - m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get()) - .map(AgentRequest::ReadTextFileRequest) - .map_err(Into::into), - m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get()) - .map(AgentRequest::CreateTerminalRequest) - .map_err(Into::into), - m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get()) - .map(AgentRequest::TerminalOutputRequest) - .map_err(Into::into), - m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get()) - .map(AgentRequest::KillTerminalRequest) - .map_err(Into::into), - m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get()) - .map(AgentRequest::ReleaseTerminalRequest) - .map_err(Into::into), - m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => { - serde_json::from_str(params.get()) - .map(AgentRequest::WaitForTerminalExitRequest) - .map_err(Into::into) - } - #[cfg(feature = "unstable_elicitation")] - m if m == CLIENT_METHOD_NAMES.elicitation_create => serde_json::from_str(params.get()) - .map(AgentRequest::CreateElicitationRequest) - .map_err(Into::into), - _ => { - if is_valid_ext_method(method) { - Ok(AgentRequest::ExtMethodRequest(ExtRequest::new( - method, - params.to_owned().into(), - ))) - } else { - Err(Error::method_not_found()) - } - } - } - } - - fn decode_notification(method: &str, params: Option<&RawValue>) -> Result { - let params = params.ok_or_else(Error::invalid_params)?; - - match method { - m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get()) - .map(AgentNotification::SessionNotification) - .map_err(Into::into), - #[cfg(feature = "unstable_elicitation")] - m if m == CLIENT_METHOD_NAMES.elicitation_complete => { - serde_json::from_str(params.get()) - .map(AgentNotification::CompleteElicitationNotification) - .map_err(Into::into) - } - _ => { - if is_valid_ext_method(method) { - Ok(AgentNotification::ExtNotification(ExtNotification::new( - method, - params.to_owned().into(), - ))) - } else { - Err(Error::method_not_found()) - } - } - } - } -} - -/// Marker type representing the agent side of an ACP connection. -/// -/// This type is used by the RPC layer to determine which messages -/// are incoming vs outgoing from the agent's perspective. -/// -/// See protocol docs: [Communication Model](https://agentclientprotocol.com/protocol/overview#communication-model) -#[derive(Clone, Default, Debug, JsonSchema)] -#[non_exhaustive] -pub struct AgentSide; - -impl Side for AgentSide { - type InRequest = ClientRequest; - type InNotification = ClientNotification; - type OutResponse = AgentResponse; - - fn decode_request(method: &str, params: Option<&RawValue>) -> Result { - let params = params.ok_or_else(Error::invalid_params)?; - - match method { - m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get()) - .map(ClientRequest::InitializeRequest) - .map_err(Into::into), - m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get()) - .map(ClientRequest::AuthenticateRequest) - .map_err(Into::into), - #[cfg(feature = "unstable_llm_providers")] - m if m == AGENT_METHOD_NAMES.providers_list => serde_json::from_str(params.get()) - .map(ClientRequest::ListProvidersRequest) - .map_err(Into::into), - #[cfg(feature = "unstable_llm_providers")] - m if m == AGENT_METHOD_NAMES.providers_set => serde_json::from_str(params.get()) - .map(ClientRequest::SetProvidersRequest) - .map_err(Into::into), - #[cfg(feature = "unstable_llm_providers")] - m if m == AGENT_METHOD_NAMES.providers_disable => serde_json::from_str(params.get()) - .map(ClientRequest::DisableProvidersRequest) - .map_err(Into::into), - #[cfg(feature = "unstable_logout")] - m if m == AGENT_METHOD_NAMES.logout => serde_json::from_str(params.get()) - .map(ClientRequest::LogoutRequest) - .map_err(Into::into), - m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get()) - .map(ClientRequest::NewSessionRequest) - .map_err(Into::into), - m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get()) - .map(ClientRequest::LoadSessionRequest) - .map_err(Into::into), - m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get()) - .map(ClientRequest::ListSessionsRequest) - .map_err(Into::into), - #[cfg(feature = "unstable_session_fork")] - m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get()) - .map(ClientRequest::ForkSessionRequest) - .map_err(Into::into), - #[cfg(feature = "unstable_session_resume")] - m if m == AGENT_METHOD_NAMES.session_resume => serde_json::from_str(params.get()) - .map(ClientRequest::ResumeSessionRequest) - .map_err(Into::into), - #[cfg(feature = "unstable_session_close")] - m if m == AGENT_METHOD_NAMES.session_close => serde_json::from_str(params.get()) - .map(ClientRequest::CloseSessionRequest) - .map_err(Into::into), - m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get()) - .map(ClientRequest::SetSessionModeRequest) - .map_err(Into::into), - m if m == AGENT_METHOD_NAMES.session_set_config_option => { - serde_json::from_str(params.get()) - .map(ClientRequest::SetSessionConfigOptionRequest) - .map_err(Into::into) - } - #[cfg(feature = "unstable_session_model")] - m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get()) - .map(ClientRequest::SetSessionModelRequest) - .map_err(Into::into), - m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get()) - .map(ClientRequest::PromptRequest) - .map_err(Into::into), - #[cfg(feature = "unstable_nes")] - m if m == AGENT_METHOD_NAMES.nes_start => serde_json::from_str(params.get()) - .map(ClientRequest::StartNesRequest) - .map_err(Into::into), - #[cfg(feature = "unstable_nes")] - m if m == AGENT_METHOD_NAMES.nes_suggest => serde_json::from_str(params.get()) - .map(ClientRequest::SuggestNesRequest) - .map_err(Into::into), - #[cfg(feature = "unstable_nes")] - m if m == AGENT_METHOD_NAMES.nes_close => serde_json::from_str(params.get()) - .map(ClientRequest::CloseNesRequest) - .map_err(Into::into), - _ => { - if is_valid_ext_method(method) { - Ok(ClientRequest::ExtMethodRequest(ExtRequest::new( - method, - params.to_owned().into(), - ))) - } else { - Err(Error::method_not_found()) - } - } - } - } - - fn decode_notification(method: &str, params: Option<&RawValue>) -> Result { - let params = params.ok_or_else(Error::invalid_params)?; - - match method { - m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get()) - .map(ClientNotification::CancelNotification) - .map_err(Into::into), - #[cfg(feature = "unstable_nes")] - m if m == AGENT_METHOD_NAMES.document_did_open => serde_json::from_str(params.get()) - .map(ClientNotification::DidOpenDocumentNotification) - .map_err(Into::into), - #[cfg(feature = "unstable_nes")] - m if m == AGENT_METHOD_NAMES.document_did_change => serde_json::from_str(params.get()) - .map(ClientNotification::DidChangeDocumentNotification) - .map_err(Into::into), - #[cfg(feature = "unstable_nes")] - m if m == AGENT_METHOD_NAMES.document_did_close => serde_json::from_str(params.get()) - .map(ClientNotification::DidCloseDocumentNotification) - .map_err(Into::into), - #[cfg(feature = "unstable_nes")] - m if m == AGENT_METHOD_NAMES.document_did_save => serde_json::from_str(params.get()) - .map(ClientNotification::DidSaveDocumentNotification) - .map_err(Into::into), - #[cfg(feature = "unstable_nes")] - m if m == AGENT_METHOD_NAMES.document_did_focus => serde_json::from_str(params.get()) - .map(ClientNotification::DidFocusDocumentNotification) - .map_err(Into::into), - #[cfg(feature = "unstable_nes")] - m if m == AGENT_METHOD_NAMES.nes_accept => serde_json::from_str(params.get()) - .map(ClientNotification::AcceptNesNotification) - .map_err(Into::into), - #[cfg(feature = "unstable_nes")] - m if m == AGENT_METHOD_NAMES.nes_reject => serde_json::from_str(params.get()) - .map(ClientNotification::RejectNesNotification) - .map_err(Into::into), - _ => { - if is_valid_ext_method(method) { - Ok(ClientNotification::ExtNotification(ExtNotification::new( - method, - params.to_owned().into(), - ))) - } else { - Err(Error::method_not_found()) - } - } - } - } -} - -fn is_valid_ext_method(method: &str) -> bool { - method.starts_with('_') && method.len() > 1 -} - #[cfg(test)] mod tests { use super::*; - use crate::ErrorCode; - use serde_json::{Number, Value}; + use crate::{ + AgentNotification, CancelNotification, ClientNotification, ContentBlock, ContentChunk, + SessionId, SessionNotification, SessionUpdate, TextContent, + }; + use serde_json::{Number, Value, json}; #[test] fn id_deserialization() { @@ -470,305 +184,30 @@ mod tests { } #[test] - fn decode_ext_request_preserves_prefix_on_client_side() { - let raw = serde_json::value::RawValue::from_string(r#"{"x":1}"#.to_string()).unwrap(); - let request = ClientSide::decode_request("_vendor/custom_request", Some(&raw)).unwrap(); - assert_eq!(request.method(), "_vendor/custom_request"); - } - - #[test] - fn decode_ext_request_preserves_prefix_on_agent_side() { - let raw = serde_json::value::RawValue::from_string(r#"{"x":1}"#.to_string()).unwrap(); - let request = AgentSide::decode_request("_vendor/custom_request", Some(&raw)).unwrap(); - assert_eq!(request.method(), "_vendor/custom_request"); - } - - #[test] - fn decode_ext_notification_preserves_prefix_on_client_side() { - let raw = serde_json::value::RawValue::from_string(r#"{"x":1}"#.to_string()).unwrap(); - let notification = - ClientSide::decode_notification("_vendor/custom_notification", Some(&raw)).unwrap(); - assert_eq!(notification.method(), "_vendor/custom_notification"); - } - - #[test] - fn decode_ext_notification_preserves_prefix_on_agent_side() { - let raw = serde_json::value::RawValue::from_string(r#"{"x":1}"#.to_string()).unwrap(); - let notification = - AgentSide::decode_notification("_vendor/custom_notification", Some(&raw)).unwrap(); - assert_eq!(notification.method(), "_vendor/custom_notification"); - } - - #[test] - fn decode_rejects_empty_ext_method_name() { - let raw = serde_json::value::RawValue::from_string(r"{}".to_string()).unwrap(); - - let err = ClientSide::decode_request("_", Some(&raw)).unwrap_err(); - assert_eq!(err.code, ErrorCode::MethodNotFound); - - let err = ClientSide::decode_notification("_", Some(&raw)).unwrap_err(); - assert_eq!(err.code, ErrorCode::MethodNotFound); - - let err = AgentSide::decode_request("_", Some(&raw)).unwrap_err(); - assert_eq!(err.code, ErrorCode::MethodNotFound); - - let err = AgentSide::decode_notification("_", Some(&raw)).unwrap_err(); - assert_eq!(err.code, ErrorCode::MethodNotFound); - } -} - -#[cfg(feature = "unstable_nes")] -#[cfg(test)] -mod nes_rpc_tests { - use super::*; - use serde_json::json; - - #[test] - fn test_decode_nes_start_request() { - let params = serde_json::to_string(&json!({ - "workspaceUri": "file:///Users/alice/projects/my-app", - "workspaceFolders": [ - { "uri": "file:///Users/alice/projects/my-app", "name": "my-app" } - ] - })) - .unwrap(); - let raw = serde_json::value::RawValue::from_string(params).unwrap(); - let request = AgentSide::decode_request("nes/start", Some(&raw)).unwrap(); - assert!(matches!(request, ClientRequest::StartNesRequest(_))); - } - - #[test] - fn test_decode_nes_suggest_request() { - let params = serde_json::to_string(&json!({ - "sessionId": "session_123", - "uri": "file:///path/to/file.rs", - "version": 2, - "position": { "line": 5, "character": 12 }, - "triggerKind": "automatic" - })) - .unwrap(); - let raw = serde_json::value::RawValue::from_string(params).unwrap(); - let request = AgentSide::decode_request("nes/suggest", Some(&raw)).unwrap(); - assert!(matches!(request, ClientRequest::SuggestNesRequest(_))); - } - - #[test] - fn test_decode_nes_close_request() { - let params = serde_json::to_string(&json!({ - "sessionId": "session_123" - })) - .unwrap(); - let raw = serde_json::value::RawValue::from_string(params).unwrap(); - let request = AgentSide::decode_request("nes/close", Some(&raw)).unwrap(); - assert!(matches!(request, ClientRequest::CloseNesRequest(_))); - } - - #[test] - fn test_decode_document_did_open_notification() { - let params = serde_json::to_string(&json!({ - "sessionId": "session_123", - "uri": "file:///path/to/file.rs", - "languageId": "rust", - "version": 1, - "text": "fn main() {}" - })) - .unwrap(); - let raw = serde_json::value::RawValue::from_string(params).unwrap(); - let notification = AgentSide::decode_notification("document/didOpen", Some(&raw)).unwrap(); - assert!(matches!( - notification, - ClientNotification::DidOpenDocumentNotification(_) - )); - } - - #[test] - fn test_decode_document_did_change_notification() { - let params = serde_json::to_string(&json!({ - "sessionId": "session_123", - "uri": "file:///path/to/file.rs", - "version": 2, - "contentChanges": [{ "text": "fn main() { let x = 1; }" }] - })) - .unwrap(); - let raw = serde_json::value::RawValue::from_string(params).unwrap(); - let notification = - AgentSide::decode_notification("document/didChange", Some(&raw)).unwrap(); - assert!(matches!( - notification, - ClientNotification::DidChangeDocumentNotification(_) - )); - } - - #[test] - fn test_decode_document_did_close_notification() { - let params = serde_json::to_string(&json!({ - "sessionId": "session_123", - "uri": "file:///path/to/file.rs" - })) - .unwrap(); - let raw = serde_json::value::RawValue::from_string(params).unwrap(); - let notification = AgentSide::decode_notification("document/didClose", Some(&raw)).unwrap(); - assert!(matches!( - notification, - ClientNotification::DidCloseDocumentNotification(_) - )); - } - - #[test] - fn test_decode_document_did_save_notification() { - let params = serde_json::to_string(&json!({ - "sessionId": "session_123", - "uri": "file:///path/to/file.rs" - })) - .unwrap(); - let raw = serde_json::value::RawValue::from_string(params).unwrap(); - let notification = AgentSide::decode_notification("document/didSave", Some(&raw)).unwrap(); - assert!(matches!( - notification, - ClientNotification::DidSaveDocumentNotification(_) - )); - } - - #[test] - fn test_decode_document_did_focus_notification() { - let params = serde_json::to_string(&json!({ - "sessionId": "session_123", - "uri": "file:///path/to/file.rs", - "version": 2, - "position": { "line": 5, "character": 12 }, - "visibleRange": { - "start": { "line": 0, "character": 0 }, - "end": { "line": 45, "character": 0 } - } - })) - .unwrap(); - let raw = serde_json::value::RawValue::from_string(params).unwrap(); - let notification = AgentSide::decode_notification("document/didFocus", Some(&raw)).unwrap(); - assert!(matches!( - notification, - ClientNotification::DidFocusDocumentNotification(_) - )); - } - - #[test] - fn test_decode_nes_accept_notification() { - let params = serde_json::to_string(&json!({ - "sessionId": "session_123", - "id": "sugg_001" - })) - .unwrap(); - let raw = serde_json::value::RawValue::from_string(params).unwrap(); - let notification = AgentSide::decode_notification("nes/accept", Some(&raw)).unwrap(); - assert!(matches!( - notification, - ClientNotification::AcceptNesNotification(_) - )); - } - - #[test] - fn test_decode_nes_reject_notification() { - let params = serde_json::to_string(&json!({ - "sessionId": "session_123", - "id": "sugg_001", - "reason": "rejected" - })) - .unwrap(); - let raw = serde_json::value::RawValue::from_string(params).unwrap(); - let notification = AgentSide::decode_notification("nes/reject", Some(&raw)).unwrap(); - assert!(matches!( - notification, - ClientNotification::RejectNesNotification(_) - )); - } -} - -#[cfg(feature = "unstable_llm_providers")] -#[cfg(test)] -mod providers_rpc_tests { - use super::*; - use serde_json::json; - - #[test] - fn test_decode_providers_list_request() { - let params = serde_json::to_string(&json!({})).unwrap(); - let raw = serde_json::value::RawValue::from_string(params).unwrap(); - let request = AgentSide::decode_request("providers/list", Some(&raw)).unwrap(); - assert!(matches!(request, ClientRequest::ListProvidersRequest(_))); - } - - #[test] - fn test_decode_providers_set_request() { - let params = serde_json::to_string(&json!({ - "id": "main", - "apiType": "anthropic", - "baseUrl": "https://api.anthropic.com" - })) - .unwrap(); - let raw = serde_json::value::RawValue::from_string(params).unwrap(); - let request = AgentSide::decode_request("providers/set", Some(&raw)).unwrap(); - assert!(matches!(request, ClientRequest::SetProvidersRequest(_))); - } - - #[test] - fn test_decode_providers_set_request_with_headers() { - let params = serde_json::to_string(&json!({ - "id": "main", - "apiType": "openai", - "baseUrl": "https://api.openai.com/v1", - "headers": { - "Authorization": "Bearer sk-test" - } - })) - .unwrap(); - let raw = serde_json::value::RawValue::from_string(params).unwrap(); - let request = AgentSide::decode_request("providers/set", Some(&raw)).unwrap(); - assert!(matches!(request, ClientRequest::SetProvidersRequest(_))); - } - - #[test] - fn test_decode_providers_disable_request() { - let params = serde_json::to_string(&json!({ - "id": "secondary" - })) - .unwrap(); - let raw = serde_json::value::RawValue::from_string(params).unwrap(); - let request = AgentSide::decode_request("providers/disable", Some(&raw)).unwrap(); - assert!(matches!(request, ClientRequest::DisableProvidersRequest(_))); - } -} - -#[test] -fn test_notification_wire_format() { - use super::*; - - use serde_json::{Value, json}; - - // Test client -> agent notification wire format - let outgoing_msg = JsonRpcMessage::wrap( - OutgoingMessage::::Notification(Notification { + fn notification_wire_format() { + // Test client -> agent notification wire format + let outgoing_msg = JsonRpcMessage::wrap(Notification { method: "cancel".into(), params: Some(ClientNotification::CancelNotification(CancelNotification { session_id: SessionId("test-123".into()), meta: None, })), - }), - ); - - let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap(); - assert_eq!( - serialized, - json!({ - "jsonrpc": "2.0", - "method": "cancel", - "params": { - "sessionId": "test-123" - }, - }) - ); + }); + + let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap(); + assert_eq!( + serialized, + json!({ + "jsonrpc": "2.0", + "method": "cancel", + "params": { + "sessionId": "test-123" + }, + }) + ); - // Test agent -> client notification wire format - let outgoing_msg = JsonRpcMessage::wrap( - OutgoingMessage::::Notification(Notification { + // Test agent -> client notification wire format + let outgoing_msg = JsonRpcMessage::wrap(Notification { method: "sessionUpdate".into(), params: Some(AgentNotification::SessionNotification( SessionNotification { @@ -786,25 +225,25 @@ fn test_notification_wire_format() { meta: None, }, )), - }), - ); - - let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap(); - assert_eq!( - serialized, - json!({ - "jsonrpc": "2.0", - "method": "sessionUpdate", - "params": { - "sessionId": "test-456", - "update": { - "sessionUpdate": "agent_message_chunk", - "content": { - "type": "text", - "text": "Hello" + }); + + let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap(); + assert_eq!( + serialized, + json!({ + "jsonrpc": "2.0", + "method": "sessionUpdate", + "params": { + "sessionId": "test-456", + "update": { + "sessionUpdate": "agent_message_chunk", + "content": { + "type": "text", + "text": "Hello" + } } } - } - }) - ); + }) + ); + } }