diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d0f4e5..82eaeef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## 0.4.8 (2025-10-16) + +- Export `acp::Result` for easier indication of ACP errors. + ## 0.4.7 (2025-10-13) - Depend on `agent-client-protocol-schema` for schema types diff --git a/Cargo.lock b/Cargo.lock index afc1373..def472a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,7 +19,7 @@ checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" [[package]] name = "agent-client-protocol" -version = "0.4.7" +version = "0.4.8" dependencies = [ "agent-client-protocol-schema", "anyhow", @@ -33,7 +33,6 @@ dependencies = [ "piper", "pretty_assertions", "rustyline", - "schemars", "serde", "serde_json", "tokio", @@ -42,9 +41,9 @@ dependencies = [ [[package]] name = "agent-client-protocol-schema" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db625c1cb83729cb7d04080fff35cabe7f68870f31e003546aa99565beccb11a" +checksum = "36bd1f574102fc8611d1018e06246307c3d16ed862d4259cab55bbf4be804e84" dependencies = [ "anyhow", "schemars", diff --git a/Cargo.toml b/Cargo.toml index 33e6367..c30bc19 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "agent-client-protocol" authors = ["Zed "] -version = "0.4.7" +version = "0.4.8" edition = "2024" license = "Apache-2.0" description = "A protocol for standardizing communication between code editors and AI coding agents" @@ -17,14 +17,13 @@ include = ["/src/**/*.rs", "/README.md", "/LICENSE", "/Cargo.toml"] unstable = ["agent-client-protocol-schema/unstable"] [dependencies] -agent-client-protocol-schema = "0.4.9" +agent-client-protocol-schema = "0.4.10" anyhow = "1" async-broadcast = "0.7" async-trait = "0.1" futures = { version = "0.3" } log = "0.4" parking_lot = "0.12" -schemars = { version = "1" } serde = { version = "1", features = ["derive", "rc"] } serde_json = { version = "1", features = ["raw_value"] } diff --git a/examples/agent.rs b/examples/agent.rs index 6434dd9..51d98f4 100644 --- a/examples/agent.rs +++ b/examples/agent.rs @@ -154,7 +154,7 @@ impl acp::Agent for ExampleAgent { } #[tokio::main(flavor = "current_thread")] -async fn main() -> anyhow::Result<()> { +async fn main() -> acp::Result<()> { env_logger::init(); let outgoing = tokio::io::stdout().compat_write(); diff --git a/examples/client.rs b/examples/client.rs index da5a0d0..f3e9998 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -13,7 +13,6 @@ //! ``` use agent_client_protocol::{self as acp, Agent as _}; -use anyhow::bail; use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; struct ExampleClient {} @@ -23,21 +22,21 @@ impl acp::Client for ExampleClient { async fn request_permission( &self, _args: acp::RequestPermissionRequest, - ) -> anyhow::Result { + ) -> acp::Result { Err(acp::Error::method_not_found()) } async fn write_text_file( &self, _args: acp::WriteTextFileRequest, - ) -> anyhow::Result { + ) -> acp::Result { Err(acp::Error::method_not_found()) } async fn read_text_file( &self, _args: acp::ReadTextFileRequest, - ) -> anyhow::Result { + ) -> acp::Result { Err(acp::Error::method_not_found()) } @@ -51,35 +50,35 @@ impl acp::Client for ExampleClient { async fn terminal_output( &self, _args: acp::TerminalOutputRequest, - ) -> anyhow::Result { + ) -> acp::Result { Err(acp::Error::method_not_found()) } async fn release_terminal( &self, _args: acp::ReleaseTerminalRequest, - ) -> anyhow::Result { + ) -> acp::Result { Err(acp::Error::method_not_found()) } async fn wait_for_terminal_exit( &self, _args: acp::WaitForTerminalExitRequest, - ) -> anyhow::Result { + ) -> acp::Result { Err(acp::Error::method_not_found()) } async fn kill_terminal_command( &self, _args: acp::KillTerminalCommandRequest, - ) -> anyhow::Result { + ) -> acp::Result { Err(acp::Error::method_not_found()) } async fn session_notification( &self, args: acp::SessionNotification, - ) -> anyhow::Result<(), acp::Error> { + ) -> acp::Result<(), acp::Error> { match args.update { acp::SessionUpdate::AgentMessageChunk { content } => { let text = match content { @@ -102,11 +101,11 @@ impl acp::Client for ExampleClient { Ok(()) } - async fn ext_method(&self, _args: acp::ExtRequest) -> Result { + async fn ext_method(&self, _args: acp::ExtRequest) -> acp::Result { Err(acp::Error::method_not_found()) } - async fn ext_notification(&self, _args: acp::ExtNotification) -> Result<(), acp::Error> { + async fn ext_notification(&self, _args: acp::ExtNotification) -> acp::Result<()> { Err(acp::Error::method_not_found()) } } @@ -130,7 +129,7 @@ async fn main() -> anyhow::Result<()> { child, ) } - _ => bail!("Usage: client AGENT_PROGRAM AGENT_ARG..."), + _ => anyhow::bail!("Usage: client AGENT_PROGRAM AGENT_ARG..."), }; // The ClientSideConnection will spawn futures onto our Tokio runtime. diff --git a/src/agent.rs b/src/agent.rs index 4775a3f..19e5c16 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -1,15 +1,14 @@ use std::{rc::Rc, sync::Arc}; -use serde_json::value::RawValue; - use agent_client_protocol_schema::{ AuthenticateRequest, AuthenticateResponse, CancelNotification, Error, ExtNotification, ExtRequest, ExtResponse, InitializeRequest, InitializeResponse, LoadSessionRequest, LoadSessionResponse, NewSessionRequest, NewSessionResponse, PromptRequest, PromptResponse, - SetSessionModeRequest, SetSessionModeResponse, + Result, SetSessionModeRequest, SetSessionModeResponse, }; #[cfg(feature = "unstable")] use agent_client_protocol_schema::{SetSessionModelRequest, SetSessionModelResponse}; +use serde_json::value::RawValue; /// Defines the interface that all ACP-compliant agents must implement. /// @@ -27,7 +26,7 @@ pub trait Agent { /// The agent should respond with its supported protocol version and capabilities. /// /// See protocol docs: [Initialization](https://agentclientprotocol.com/protocol/initialization) - async fn initialize(&self, args: InitializeRequest) -> Result; + async fn initialize(&self, args: InitializeRequest) -> Result; /// Authenticates the client using the specified authentication method. /// @@ -38,7 +37,7 @@ pub trait Agent { /// `new_session` without receiving an `auth_required` error. /// /// See protocol docs: [Initialization](https://agentclientprotocol.com/protocol/initialization) - async fn authenticate(&self, args: AuthenticateRequest) -> Result; + async fn authenticate(&self, args: AuthenticateRequest) -> Result; /// Creates a new conversation session with the agent. /// @@ -52,7 +51,7 @@ pub trait Agent { /// May return an `auth_required` error if the agent requires authentication. /// /// See protocol docs: [Session Setup](https://agentclientprotocol.com/protocol/session-setup) - async fn new_session(&self, args: NewSessionRequest) -> Result; + async fn new_session(&self, args: NewSessionRequest) -> Result; /// Processes a user prompt within a session. /// @@ -65,7 +64,7 @@ pub trait Agent { /// - Returns when the turn is complete with a stop reason /// /// See protocol docs: [Prompt Turn](https://agentclientprotocol.com/protocol/prompt-turn) - async fn prompt(&self, args: PromptRequest) -> Result; + async fn prompt(&self, args: PromptRequest) -> Result; /// Cancels ongoing operations for a session. /// @@ -78,7 +77,7 @@ pub trait Agent { /// - Respond to the original `session/prompt` request with `StopReason::Cancelled` /// /// See protocol docs: [Cancellation](https://agentclientprotocol.com/protocol/prompt-turn#cancellation) - async fn cancel(&self, args: CancelNotification) -> Result<(), Error>; + async fn cancel(&self, args: CancelNotification) -> Result<()>; /// Loads an existing session to resume a previous conversation. /// @@ -90,7 +89,7 @@ pub trait Agent { /// - Stream the entire conversation history back to the client via notifications /// /// See protocol docs: [Loading Sessions](https://agentclientprotocol.com/protocol/session-setup#loading-sessions) - async fn load_session(&self, _args: LoadSessionRequest) -> Result { + async fn load_session(&self, _args: LoadSessionRequest) -> Result { Err(Error::method_not_found()) } @@ -110,7 +109,7 @@ pub trait Agent { async fn set_session_mode( &self, _args: SetSessionModeRequest, - ) -> Result { + ) -> Result { Err(Error::method_not_found()) } @@ -123,7 +122,7 @@ pub trait Agent { async fn set_session_model( &self, _args: SetSessionModelRequest, - ) -> Result { + ) -> Result { Err(Error::method_not_found()) } @@ -133,7 +132,7 @@ pub trait Agent { /// protocol compatibility. /// /// See protocol docs: [Extensibility](https://agentclientprotocol.com/protocol/extensibility) - async fn ext_method(&self, _args: ExtRequest) -> Result { + async fn ext_method(&self, _args: ExtRequest) -> Result { Ok(RawValue::NULL.to_owned().into()) } @@ -143,89 +142,89 @@ pub trait Agent { /// while maintaining protocol compatibility. /// /// See protocol docs: [Extensibility](https://agentclientprotocol.com/protocol/extensibility) - async fn ext_notification(&self, _args: ExtNotification) -> Result<(), Error> { + async fn ext_notification(&self, _args: ExtNotification) -> Result<()> { Ok(()) } } #[async_trait::async_trait(?Send)] impl Agent for Rc { - async fn initialize(&self, args: InitializeRequest) -> Result { + async fn initialize(&self, args: InitializeRequest) -> Result { self.as_ref().initialize(args).await } - async fn authenticate(&self, args: AuthenticateRequest) -> Result { + async fn authenticate(&self, args: AuthenticateRequest) -> Result { self.as_ref().authenticate(args).await } - async fn new_session(&self, args: NewSessionRequest) -> Result { + async fn new_session(&self, args: NewSessionRequest) -> Result { self.as_ref().new_session(args).await } - async fn load_session(&self, args: LoadSessionRequest) -> Result { + async fn load_session(&self, args: LoadSessionRequest) -> Result { self.as_ref().load_session(args).await } async fn set_session_mode( &self, args: SetSessionModeRequest, - ) -> Result { + ) -> Result { self.as_ref().set_session_mode(args).await } - async fn prompt(&self, args: PromptRequest) -> Result { + async fn prompt(&self, args: PromptRequest) -> Result { self.as_ref().prompt(args).await } - async fn cancel(&self, args: CancelNotification) -> Result<(), Error> { + async fn cancel(&self, args: CancelNotification) -> Result<()> { self.as_ref().cancel(args).await } #[cfg(feature = "unstable")] async fn set_session_model( &self, args: SetSessionModelRequest, - ) -> Result { + ) -> Result { self.as_ref().set_session_model(args).await } - async fn ext_method(&self, args: ExtRequest) -> Result { + async fn ext_method(&self, args: ExtRequest) -> Result { self.as_ref().ext_method(args).await } - async fn ext_notification(&self, args: ExtNotification) -> Result<(), Error> { + async fn ext_notification(&self, args: ExtNotification) -> Result<()> { self.as_ref().ext_notification(args).await } } #[async_trait::async_trait(?Send)] impl Agent for Arc { - async fn initialize(&self, args: InitializeRequest) -> Result { + async fn initialize(&self, args: InitializeRequest) -> Result { self.as_ref().initialize(args).await } - async fn authenticate(&self, args: AuthenticateRequest) -> Result { + async fn authenticate(&self, args: AuthenticateRequest) -> Result { self.as_ref().authenticate(args).await } - async fn new_session(&self, args: NewSessionRequest) -> Result { + async fn new_session(&self, args: NewSessionRequest) -> Result { self.as_ref().new_session(args).await } - async fn load_session(&self, args: LoadSessionRequest) -> Result { + async fn load_session(&self, args: LoadSessionRequest) -> Result { self.as_ref().load_session(args).await } async fn set_session_mode( &self, args: SetSessionModeRequest, - ) -> Result { + ) -> Result { self.as_ref().set_session_mode(args).await } - async fn prompt(&self, args: PromptRequest) -> Result { + async fn prompt(&self, args: PromptRequest) -> Result { self.as_ref().prompt(args).await } - async fn cancel(&self, args: CancelNotification) -> Result<(), Error> { + async fn cancel(&self, args: CancelNotification) -> Result<()> { self.as_ref().cancel(args).await } #[cfg(feature = "unstable")] async fn set_session_model( &self, args: SetSessionModelRequest, - ) -> Result { + ) -> Result { self.as_ref().set_session_model(args).await } - async fn ext_method(&self, args: ExtRequest) -> Result { + async fn ext_method(&self, args: ExtRequest) -> Result { self.as_ref().ext_method(args).await } - async fn ext_notification(&self, args: ExtNotification) -> Result<(), Error> { + async fn ext_notification(&self, args: ExtNotification) -> Result<()> { self.as_ref().ext_notification(args).await } } diff --git a/src/client.rs b/src/client.rs index a5aae76..66ff187 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,15 +1,14 @@ use std::{rc::Rc, sync::Arc}; -use serde_json::value::RawValue; - use agent_client_protocol_schema::{ CreateTerminalRequest, CreateTerminalResponse, Error, ExtNotification, ExtRequest, ExtResponse, KillTerminalCommandRequest, KillTerminalCommandResponse, ReadTextFileRequest, ReadTextFileResponse, ReleaseTerminalRequest, ReleaseTerminalResponse, - RequestPermissionRequest, RequestPermissionResponse, SessionNotification, + RequestPermissionRequest, RequestPermissionResponse, Result, SessionNotification, TerminalOutputRequest, TerminalOutputResponse, WaitForTerminalExitRequest, WaitForTerminalExitResponse, WriteTextFileRequest, WriteTextFileResponse, }; +use serde_json::value::RawValue; /// Defines the interface that ACP-compliant clients must implement. /// @@ -31,7 +30,7 @@ pub trait Client { async fn request_permission( &self, args: RequestPermissionRequest, - ) -> Result; + ) -> Result; /// Handles session update notifications from the agent. /// @@ -44,7 +43,7 @@ pub trait Client { /// updates before responding with the cancelled stop reason. /// /// See protocol docs: [Agent Reports Output](https://agentclientprotocol.com/protocol/prompt-turn#3-agent-reports-output) - async fn session_notification(&self, args: SessionNotification) -> Result<(), Error>; + async fn session_notification(&self, args: SessionNotification) -> Result<()>; /// Writes content to a text file in the client's file system. /// @@ -52,10 +51,7 @@ pub trait Client { /// Allows the agent to create or modify files within the client's environment. /// /// See protocol docs: [Client](https://agentclientprotocol.com/protocol/overview#client) - async fn write_text_file( - &self, - _args: WriteTextFileRequest, - ) -> Result { + async fn write_text_file(&self, _args: WriteTextFileRequest) -> Result { Err(Error::method_not_found()) } @@ -65,10 +61,7 @@ pub trait Client { /// Allows the agent to access file contents within the client's environment. /// /// See protocol docs: [Client](https://agentclientprotocol.com/protocol/overview#client) - async fn read_text_file( - &self, - _args: ReadTextFileRequest, - ) -> Result { + async fn read_text_file(&self, _args: ReadTextFileRequest) -> Result { Err(Error::method_not_found()) } @@ -89,7 +82,7 @@ pub trait Client { async fn create_terminal( &self, _args: CreateTerminalRequest, - ) -> Result { + ) -> Result { Err(Error::method_not_found()) } @@ -102,7 +95,7 @@ pub trait Client { async fn terminal_output( &self, _args: TerminalOutputRequest, - ) -> Result { + ) -> Result { Err(Error::method_not_found()) } @@ -121,7 +114,7 @@ pub trait Client { async fn release_terminal( &self, _args: ReleaseTerminalRequest, - ) -> Result { + ) -> Result { Err(Error::method_not_found()) } @@ -131,7 +124,7 @@ pub trait Client { async fn wait_for_terminal_exit( &self, _args: WaitForTerminalExitRequest, - ) -> Result { + ) -> Result { Err(Error::method_not_found()) } @@ -150,7 +143,7 @@ pub trait Client { async fn kill_terminal_command( &self, _args: KillTerminalCommandRequest, - ) -> Result { + ) -> Result { Err(Error::method_not_found()) } @@ -161,7 +154,7 @@ pub trait Client { /// protocol compatibility. /// /// See protocol docs: [Extensibility](https://agentclientprotocol.com/protocol/extensibility) - async fn ext_method(&self, _args: ExtRequest) -> Result { + async fn ext_method(&self, _args: ExtRequest) -> Result { Ok(RawValue::NULL.to_owned().into()) } @@ -172,7 +165,7 @@ pub trait Client { /// while maintaining protocol compatibility. /// /// See protocol docs: [Extensibility](https://agentclientprotocol.com/protocol/extensibility) - async fn ext_notification(&self, _args: ExtNotification) -> Result<(), Error> { + async fn ext_notification(&self, _args: ExtNotification) -> Result<()> { Ok(()) } } @@ -182,58 +175,46 @@ impl Client for Rc { async fn request_permission( &self, args: RequestPermissionRequest, - ) -> Result { + ) -> Result { self.as_ref().request_permission(args).await } - async fn write_text_file( - &self, - args: WriteTextFileRequest, - ) -> Result { + async fn write_text_file(&self, args: WriteTextFileRequest) -> Result { self.as_ref().write_text_file(args).await } - async fn read_text_file( - &self, - args: ReadTextFileRequest, - ) -> Result { + async fn read_text_file(&self, args: ReadTextFileRequest) -> Result { self.as_ref().read_text_file(args).await } - async fn session_notification(&self, args: SessionNotification) -> Result<(), Error> { + async fn session_notification(&self, args: SessionNotification) -> Result<()> { self.as_ref().session_notification(args).await } - async fn create_terminal( - &self, - args: CreateTerminalRequest, - ) -> Result { + async fn create_terminal(&self, args: CreateTerminalRequest) -> Result { self.as_ref().create_terminal(args).await } - async fn terminal_output( - &self, - args: TerminalOutputRequest, - ) -> Result { + async fn terminal_output(&self, args: TerminalOutputRequest) -> Result { self.as_ref().terminal_output(args).await } async fn release_terminal( &self, args: ReleaseTerminalRequest, - ) -> Result { + ) -> Result { self.as_ref().release_terminal(args).await } async fn wait_for_terminal_exit( &self, args: WaitForTerminalExitRequest, - ) -> Result { + ) -> Result { self.as_ref().wait_for_terminal_exit(args).await } async fn kill_terminal_command( &self, args: KillTerminalCommandRequest, - ) -> Result { + ) -> Result { self.as_ref().kill_terminal_command(args).await } - async fn ext_method(&self, args: ExtRequest) -> Result { + async fn ext_method(&self, args: ExtRequest) -> Result { self.as_ref().ext_method(args).await } - async fn ext_notification(&self, args: ExtNotification) -> Result<(), Error> { + async fn ext_notification(&self, args: ExtNotification) -> Result<()> { self.as_ref().ext_notification(args).await } } @@ -243,58 +224,46 @@ impl Client for Arc { async fn request_permission( &self, args: RequestPermissionRequest, - ) -> Result { + ) -> Result { self.as_ref().request_permission(args).await } - async fn write_text_file( - &self, - args: WriteTextFileRequest, - ) -> Result { + async fn write_text_file(&self, args: WriteTextFileRequest) -> Result { self.as_ref().write_text_file(args).await } - async fn read_text_file( - &self, - args: ReadTextFileRequest, - ) -> Result { + async fn read_text_file(&self, args: ReadTextFileRequest) -> Result { self.as_ref().read_text_file(args).await } - async fn session_notification(&self, args: SessionNotification) -> Result<(), Error> { + async fn session_notification(&self, args: SessionNotification) -> Result<()> { self.as_ref().session_notification(args).await } - async fn create_terminal( - &self, - args: CreateTerminalRequest, - ) -> Result { + async fn create_terminal(&self, args: CreateTerminalRequest) -> Result { self.as_ref().create_terminal(args).await } - async fn terminal_output( - &self, - args: TerminalOutputRequest, - ) -> Result { + async fn terminal_output(&self, args: TerminalOutputRequest) -> Result { self.as_ref().terminal_output(args).await } async fn release_terminal( &self, args: ReleaseTerminalRequest, - ) -> Result { + ) -> Result { self.as_ref().release_terminal(args).await } async fn wait_for_terminal_exit( &self, args: WaitForTerminalExitRequest, - ) -> Result { + ) -> Result { self.as_ref().wait_for_terminal_exit(args).await } async fn kill_terminal_command( &self, args: KillTerminalCommandRequest, - ) -> Result { + ) -> Result { self.as_ref().kill_terminal_command(args).await } - async fn ext_method(&self, args: ExtRequest) -> Result { + async fn ext_method(&self, args: ExtRequest) -> Result { self.as_ref().ext_method(args).await } - async fn ext_notification(&self, args: ExtNotification) -> Result<(), Error> { + async fn ext_notification(&self, args: ExtNotification) -> Result<()> { self.as_ref().ext_notification(args).await } } diff --git a/src/lib.rs b/src/lib.rs index 8322541..a0ca495 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,3 @@ -use anyhow::Result; use futures::{AsyncRead, AsyncWrite, future::LocalBoxFuture}; use rpc::{MessageHandler, RpcConnection, Side}; @@ -75,7 +74,7 @@ impl ClientSideConnection { #[async_trait::async_trait(?Send)] impl Agent for ClientSideConnection { - async fn initialize(&self, args: InitializeRequest) -> Result { + async fn initialize(&self, args: InitializeRequest) -> Result { self.conn .request( AGENT_METHOD_NAMES.initialize, @@ -84,7 +83,7 @@ impl Agent for ClientSideConnection { .await } - async fn authenticate(&self, args: AuthenticateRequest) -> Result { + async fn authenticate(&self, args: AuthenticateRequest) -> Result { self.conn .request::>( AGENT_METHOD_NAMES.authenticate, @@ -94,7 +93,7 @@ impl Agent for ClientSideConnection { .map(Option::unwrap_or_default) } - async fn new_session(&self, args: NewSessionRequest) -> Result { + async fn new_session(&self, args: NewSessionRequest) -> Result { self.conn .request( AGENT_METHOD_NAMES.session_new, @@ -103,7 +102,7 @@ impl Agent for ClientSideConnection { .await } - async fn load_session(&self, args: LoadSessionRequest) -> Result { + async fn load_session(&self, args: LoadSessionRequest) -> Result { self.conn .request::>( AGENT_METHOD_NAMES.session_load, @@ -116,7 +115,7 @@ impl Agent for ClientSideConnection { async fn set_session_mode( &self, args: SetSessionModeRequest, - ) -> Result { + ) -> Result { self.conn .request( AGENT_METHOD_NAMES.session_set_mode, @@ -125,7 +124,7 @@ impl Agent for ClientSideConnection { .await } - async fn prompt(&self, args: PromptRequest) -> Result { + async fn prompt(&self, args: PromptRequest) -> Result { self.conn .request( AGENT_METHOD_NAMES.session_prompt, @@ -134,7 +133,7 @@ impl Agent for ClientSideConnection { .await } - async fn cancel(&self, args: CancelNotification) -> Result<(), Error> { + async fn cancel(&self, args: CancelNotification) -> Result<()> { self.conn.notify( AGENT_METHOD_NAMES.session_cancel, Some(ClientNotification::CancelNotification(args)), @@ -145,7 +144,7 @@ impl Agent for ClientSideConnection { async fn set_session_model( &self, args: SetSessionModelRequest, - ) -> Result { + ) -> Result { self.conn .request( AGENT_METHOD_NAMES.session_set_model, @@ -154,7 +153,7 @@ impl Agent for ClientSideConnection { .await } - async fn ext_method(&self, args: ExtRequest) -> Result { + async fn ext_method(&self, args: ExtRequest) -> Result { self.conn .request( format!("_{}", args.method), @@ -163,7 +162,7 @@ impl Agent for ClientSideConnection { .await } - async fn ext_notification(&self, args: ExtNotification) -> Result<(), Error> { + async fn ext_notification(&self, args: ExtNotification) -> Result<()> { self.conn.notify( format!("_{}", args.method), Some(ClientNotification::ExtNotification(args)), @@ -185,7 +184,7 @@ impl Side for ClientSide { type InRequest = AgentRequest; type OutResponse = ClientResponse; - fn decode_request(method: &str, params: Option<&RawValue>) -> Result { + fn decode_request(method: &str, params: Option<&RawValue>) -> Result { let params = params.ok_or_else(Error::invalid_params)?; match method { @@ -230,10 +229,7 @@ impl Side for ClientSide { } } - fn decode_notification( - method: &str, - params: Option<&RawValue>, - ) -> Result { + fn decode_notification(method: &str, params: Option<&RawValue>) -> Result { let params = params.ok_or_else(Error::invalid_params)?; match method { @@ -255,7 +251,7 @@ impl Side for ClientSide { } impl MessageHandler for T { - async fn handle_request(&self, request: AgentRequest) -> Result { + async fn handle_request(&self, request: AgentRequest) -> Result { match request { AgentRequest::RequestPermissionRequest(args) => { let response = self.request_permission(args).await?; @@ -296,7 +292,7 @@ impl MessageHandler for T { } } - async fn handle_notification(&self, notification: AgentNotification) -> Result<(), Error> { + async fn handle_notification(&self, notification: AgentNotification) -> Result<()> { match notification { AgentNotification::SessionNotification(args) => { self.session_notification(args).await?; @@ -371,7 +367,7 @@ impl Client for AgentSideConnection { async fn request_permission( &self, args: RequestPermissionRequest, - ) -> Result { + ) -> Result { self.conn .request( CLIENT_METHOD_NAMES.session_request_permission, @@ -380,10 +376,7 @@ impl Client for AgentSideConnection { .await } - async fn write_text_file( - &self, - args: WriteTextFileRequest, - ) -> Result { + async fn write_text_file(&self, args: WriteTextFileRequest) -> Result { self.conn .request::>( CLIENT_METHOD_NAMES.fs_write_text_file, @@ -393,10 +386,7 @@ impl Client for AgentSideConnection { .map(Option::unwrap_or_default) } - async fn read_text_file( - &self, - args: ReadTextFileRequest, - ) -> Result { + async fn read_text_file(&self, args: ReadTextFileRequest) -> Result { self.conn .request( CLIENT_METHOD_NAMES.fs_read_text_file, @@ -405,10 +395,7 @@ impl Client for AgentSideConnection { .await } - async fn create_terminal( - &self, - args: CreateTerminalRequest, - ) -> Result { + async fn create_terminal(&self, args: CreateTerminalRequest) -> Result { self.conn .request( CLIENT_METHOD_NAMES.terminal_create, @@ -417,10 +404,7 @@ impl Client for AgentSideConnection { .await } - async fn terminal_output( - &self, - args: TerminalOutputRequest, - ) -> Result { + async fn terminal_output(&self, args: TerminalOutputRequest) -> Result { self.conn .request( CLIENT_METHOD_NAMES.terminal_output, @@ -432,7 +416,7 @@ impl Client for AgentSideConnection { async fn release_terminal( &self, args: ReleaseTerminalRequest, - ) -> Result { + ) -> Result { self.conn .request::>( CLIENT_METHOD_NAMES.terminal_release, @@ -445,7 +429,7 @@ impl Client for AgentSideConnection { async fn wait_for_terminal_exit( &self, args: WaitForTerminalExitRequest, - ) -> Result { + ) -> Result { self.conn .request( CLIENT_METHOD_NAMES.terminal_wait_for_exit, @@ -457,7 +441,7 @@ impl Client for AgentSideConnection { async fn kill_terminal_command( &self, args: KillTerminalCommandRequest, - ) -> Result { + ) -> Result { self.conn .request::>( CLIENT_METHOD_NAMES.terminal_kill, @@ -467,14 +451,14 @@ impl Client for AgentSideConnection { .map(Option::unwrap_or_default) } - async fn session_notification(&self, args: SessionNotification) -> Result<(), Error> { + async fn session_notification(&self, args: SessionNotification) -> Result<()> { self.conn.notify( CLIENT_METHOD_NAMES.session_update, Some(AgentNotification::SessionNotification(args)), ) } - async fn ext_method(&self, args: ExtRequest) -> Result { + async fn ext_method(&self, args: ExtRequest) -> Result { self.conn .request( format!("_{}", args.method), @@ -483,7 +467,7 @@ impl Client for AgentSideConnection { .await } - async fn ext_notification(&self, args: ExtNotification) -> Result<(), Error> { + async fn ext_notification(&self, args: ExtNotification) -> Result<()> { self.conn.notify( format!("_{}", args.method), Some(AgentNotification::ExtNotification(args)), @@ -505,7 +489,7 @@ impl Side for AgentSide { type InNotification = ClientNotification; type OutResponse = AgentResponse; - fn decode_request(method: &str, params: Option<&RawValue>) -> Result { + fn decode_request(method: &str, params: Option<&RawValue>) -> Result { let params = params.ok_or_else(Error::invalid_params)?; match method { @@ -544,10 +528,7 @@ impl Side for AgentSide { } } - fn decode_notification( - method: &str, - params: Option<&RawValue>, - ) -> Result { + fn decode_notification(method: &str, params: Option<&RawValue>) -> Result { let params = params.ok_or_else(Error::invalid_params)?; match method { @@ -569,7 +550,7 @@ impl Side for AgentSide { } impl MessageHandler for T { - async fn handle_request(&self, request: ClientRequest) -> Result { + async fn handle_request(&self, request: ClientRequest) -> Result { match request { ClientRequest::InitializeRequest(args) => { let response = self.initialize(args).await?; @@ -607,7 +588,7 @@ impl MessageHandler for T { } } - async fn handle_notification(&self, notification: ClientNotification) -> Result<(), Error> { + async fn handle_notification(&self, notification: ClientNotification) -> Result<()> { match notification { ClientNotification::CancelNotification(args) => { self.cancel(args).await?; diff --git a/src/rpc.rs b/src/rpc.rs index a500302..edb8f86 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -8,8 +8,7 @@ use std::{ }, }; -use agent_client_protocol_schema::Error; -use anyhow::Result; +use agent_client_protocol_schema::{Error, Result}; use futures::{ AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _, StreamExt as _, @@ -35,8 +34,8 @@ pub struct RpcConnection { } struct PendingResponse { - deserialize: fn(&serde_json::value::RawValue) -> Result, Error>, - respond: oneshot::Sender, Error>>, + deserialize: fn(&serde_json::value::RawValue) -> Result>, + respond: oneshot::Sender>>, } impl RpcConnection @@ -96,7 +95,7 @@ where &self, method: impl Into>, params: Option, - ) -> Result<(), Error> { + ) -> Result<()> { self.outgoing_tx .unbounded_send(OutgoingMessage::Notification { method: method.into(), @@ -109,7 +108,7 @@ where &self, method: impl Into>, params: Option, - ) -> impl Future> { + ) -> impl Future> { let (tx, rx) = oneshot::channel(); let id = self.next_id.fetch_add(1, Ordering::SeqCst); let id = Id::Number(id); @@ -378,8 +377,8 @@ pub enum ResponseResult { Error(Error), } -impl From> for ResponseResult { - fn from(result: Result) -> Self { +impl From> for ResponseResult { + fn from(result: Result) -> Self { match result { Ok(value) => ResponseResult::Result(value), Err(error) => ResponseResult::Error(error), @@ -392,24 +391,22 @@ pub trait Side: Clone { type OutResponse: Clone + Serialize + DeserializeOwned + 'static; type InNotification: Clone + Serialize + DeserializeOwned + 'static; - fn decode_request(method: &str, params: Option<&RawValue>) -> Result; + fn decode_request(method: &str, params: Option<&RawValue>) -> Result; - fn decode_notification( - method: &str, - params: Option<&RawValue>, - ) -> Result; + fn decode_notification(method: &str, params: Option<&RawValue>) + -> Result; } pub trait MessageHandler { fn handle_request( &self, request: Local::InRequest, - ) -> impl Future>; + ) -> impl Future>; fn handle_notification( &self, notification: Local::InNotification, - ) -> impl Future>; + ) -> impl Future>; } #[cfg(test)] diff --git a/src/rpc_tests.rs b/src/rpc_tests.rs index 720e63e..43a11c5 100644 --- a/src/rpc_tests.rs +++ b/src/rpc_tests.rs @@ -13,7 +13,6 @@ use agent_client_protocol_schema::{ VERSION, WaitForTerminalExitRequest, WaitForTerminalExitResponse, WriteTextFileRequest, WriteTextFileResponse, }; -use anyhow::Result; use serde_json::json; use std::sync::{Arc, Mutex}; @@ -60,7 +59,7 @@ impl Client for TestClient { async fn request_permission( &self, _arguments: RequestPermissionRequest, - ) -> Result { + ) -> Result { let responses = self.permission_responses.clone(); let mut responses = responses.lock().unwrap(); let outcome = responses @@ -75,7 +74,7 @@ impl Client for TestClient { async fn write_text_file( &self, arguments: WriteTextFileRequest, - ) -> Result { + ) -> Result { self.written_files .lock() .unwrap() @@ -83,10 +82,7 @@ impl Client for TestClient { Ok(WriteTextFileResponse::default()) } - async fn read_text_file( - &self, - arguments: ReadTextFileRequest, - ) -> Result { + async fn read_text_file(&self, arguments: ReadTextFileRequest) -> Result { let contents = self.file_contents.lock().unwrap(); let content = contents .get(&arguments.path) @@ -98,7 +94,7 @@ impl Client for TestClient { }) } - async fn session_notification(&self, args: SessionNotification) -> Result<(), Error> { + async fn session_notification(&self, args: SessionNotification) -> Result<()> { self.session_notifications.lock().unwrap().push(args); Ok(()) } @@ -106,39 +102,39 @@ impl Client for TestClient { async fn create_terminal( &self, _args: CreateTerminalRequest, - ) -> Result { + ) -> Result { unimplemented!() } async fn terminal_output( &self, _args: TerminalOutputRequest, - ) -> Result { + ) -> Result { unimplemented!() } async fn kill_terminal_command( &self, _args: KillTerminalCommandRequest, - ) -> Result { + ) -> Result { unimplemented!() } async fn release_terminal( &self, _args: ReleaseTerminalRequest, - ) -> Result { + ) -> Result { unimplemented!() } async fn wait_for_terminal_exit( &self, _args: WaitForTerminalExitRequest, - ) -> Result { + ) -> Result { unimplemented!() } - async fn ext_method(&self, args: ExtRequest) -> Result { + async fn ext_method(&self, args: ExtRequest) -> Result { match dbg!(args.method.as_ref()) { "example.com/ping" => Ok(raw_json!({ "response": "pong", @@ -148,7 +144,7 @@ impl Client for TestClient { } } - async fn ext_notification(&self, args: ExtNotification) -> Result<(), Error> { + async fn ext_notification(&self, args: ExtNotification) -> Result<()> { self.extension_notifications .lock() .unwrap() @@ -180,7 +176,7 @@ impl TestAgent { #[async_trait::async_trait(?Send)] impl Agent for TestAgent { - async fn initialize(&self, arguments: InitializeRequest) -> Result { + async fn initialize(&self, arguments: InitializeRequest) -> Result { Ok(InitializeResponse { protocol_version: arguments.protocol_version, agent_capabilities: AgentCapabilities::default(), @@ -189,17 +185,11 @@ impl Agent for TestAgent { }) } - async fn authenticate( - &self, - _arguments: AuthenticateRequest, - ) -> Result { + async fn authenticate(&self, _arguments: AuthenticateRequest) -> Result { Ok(AuthenticateResponse::default()) } - async fn new_session( - &self, - _arguments: NewSessionRequest, - ) -> Result { + async fn new_session(&self, _arguments: NewSessionRequest) -> Result { let session_id = SessionId(Arc::from("test-session-123")); self.sessions.lock().unwrap().insert(session_id.clone()); Ok(NewSessionResponse { @@ -211,7 +201,7 @@ impl Agent for TestAgent { }) } - async fn load_session(&self, _: LoadSessionRequest) -> Result { + async fn load_session(&self, _: LoadSessionRequest) -> Result { Ok(LoadSessionResponse { modes: None, #[cfg(feature = "unstable")] @@ -223,11 +213,11 @@ impl Agent for TestAgent { async fn set_session_mode( &self, _arguments: SetSessionModeRequest, - ) -> Result { + ) -> Result { Ok(SetSessionModeResponse { meta: None }) } - async fn prompt(&self, arguments: PromptRequest) -> Result { + async fn prompt(&self, arguments: PromptRequest) -> Result { self.prompts_received .lock() .unwrap() @@ -238,7 +228,7 @@ impl Agent for TestAgent { }) } - async fn cancel(&self, args: CancelNotification) -> Result<(), Error> { + async fn cancel(&self, args: CancelNotification) -> Result<()> { self.cancellations_received .lock() .unwrap() @@ -250,12 +240,12 @@ impl Agent for TestAgent { async fn set_session_model( &self, args: agent_client_protocol_schema::SetSessionModelRequest, - ) -> Result { + ) -> Result { log::info!("Received select model request {args:?}"); Ok(agent_client_protocol_schema::SetSessionModelResponse::default()) } - async fn ext_method(&self, args: ExtRequest) -> Result { + async fn ext_method(&self, args: ExtRequest) -> Result { dbg!(); match dbg!(args.method.as_ref()) { "example.com/echo" => { @@ -268,7 +258,7 @@ impl Agent for TestAgent { } } - async fn ext_notification(&self, args: ExtNotification) -> Result<(), Error> { + async fn ext_notification(&self, args: ExtNotification) -> Result<()> { self.extension_notifications .lock() .unwrap() diff --git a/src/stream_broadcast.rs b/src/stream_broadcast.rs index c639bb2..5d52496 100644 --- a/src/stream_broadcast.rs +++ b/src/stream_broadcast.rs @@ -6,8 +6,7 @@ use std::sync::Arc; -use agent_client_protocol_schema::Error; -use anyhow::Result; +use agent_client_protocol_schema::{Error, Result}; use serde::Serialize; use serde_json::value::RawValue; @@ -59,7 +58,7 @@ pub enum StreamMessageContent { /// The ID of the request this response is for. id: Id, /// The result of the request (success or error). - result: Result, Error>, + result: Result>, }, /// A JSON-RPC notification message. Notification { @@ -101,7 +100,10 @@ impl StreamReceiver { /// - `Ok(StreamMessage)` when a message is received /// - `Err` when the sender is dropped or the receiver is lagged pub async fn recv(&mut self) -> Result { - Ok(self.0.recv().await?) + self.0 + .recv() + .await + .map_err(|e| Error::internal_error().with_data(e.to_string())) } }