From 77c88c4c6376c31c95125a77fe7631168137ab14 Mon Sep 17 00:00:00 2001 From: Chay Nabors Date: Mon, 12 May 2025 13:24:04 -0700 Subject: [PATCH 1/4] remove mcp_client crate, install.rs, and fix small bug --- Cargo.lock | 15 - Cargo.toml | 1 - crates/chat-cli/src/auth/builder_id.rs | 2 +- crates/chat-cli/src/install.rs | 31 - crates/chat-cli/src/main.rs | 1 - crates/mcp_client/Cargo.toml | 30 - crates/mcp_client/src/client.rs | 772 ------------------ crates/mcp_client/src/error.rs | 66 -- crates/mcp_client/src/facilitator_types.rs | 229 ------ crates/mcp_client/src/lib.rs | 9 - crates/mcp_client/src/server.rs | 293 ------- .../mcp_client/src/transport/base_protocol.rs | 108 --- crates/mcp_client/src/transport/mod.rs | 56 -- crates/mcp_client/src/transport/stdio.rs | 277 ------- crates/mcp_client/src/transport/websocket.rs | 0 .../mcp_client/test_mcp_server/test_server.rs | 354 -------- crates/q_cli/Cargo.toml | 1 - 17 files changed, 1 insertion(+), 2244 deletions(-) delete mode 100644 crates/chat-cli/src/install.rs delete mode 100644 crates/mcp_client/Cargo.toml delete mode 100644 crates/mcp_client/src/client.rs delete mode 100644 crates/mcp_client/src/error.rs delete mode 100644 crates/mcp_client/src/facilitator_types.rs delete mode 100644 crates/mcp_client/src/lib.rs delete mode 100644 crates/mcp_client/src/server.rs delete mode 100644 crates/mcp_client/src/transport/base_protocol.rs delete mode 100644 crates/mcp_client/src/transport/mod.rs delete mode 100644 crates/mcp_client/src/transport/stdio.rs delete mode 100644 crates/mcp_client/src/transport/websocket.rs delete mode 100644 crates/mcp_client/test_mcp_server/test_server.rs diff --git a/Cargo.lock b/Cargo.lock index 4073cfc71a..a11c7c2259 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5247,20 +5247,6 @@ dependencies = [ "rayon", ] -[[package]] -name = "mcp_client" -version = "1.10.0" -dependencies = [ - "async-trait", - "nix 0.29.0", - "serde", - "serde_json", - "thiserror 2.0.12", - "tokio", - "tracing", - "uuid", -] - [[package]] name = "memchr" version = "2.7.4" @@ -7060,7 +7046,6 @@ dependencies = [ "indoc", "insta", "macos-utils", - "mcp_client", "mimalloc", "nix 0.29.0", "objc2 0.5.2", diff --git a/Cargo.toml b/Cargo.toml index 499df2b4a8..011326081c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -87,7 +87,6 @@ indicatif = "0.17.11" indoc = "2.0.6" insta = "1.43.1" libc = "0.2.172" -mcp_client = { path = "crates/mcp_client" } mimalloc = "0.1.46" nix = { version = "0.29.0", features = [ "feature", diff --git a/crates/chat-cli/src/auth/builder_id.rs b/crates/chat-cli/src/auth/builder_id.rs index e277c1410b..beb99c768c 100644 --- a/crates/chat-cli/src/auth/builder_id.rs +++ b/crates/chat-cli/src/auth/builder_id.rs @@ -62,7 +62,7 @@ use crate::database::secret_store::{ pub enum OAuthFlow { DeviceCode, // This must remain backwards compatible - #[serde(rename = "PKCE")] + #[serde(alias = "PKCE")] Pkce, } diff --git a/crates/chat-cli/src/install.rs b/crates/chat-cli/src/install.rs deleted file mode 100644 index b806856df8..0000000000 --- a/crates/chat-cli/src/install.rs +++ /dev/null @@ -1,31 +0,0 @@ -use std::time::SystemTimeError; - -use thiserror::Error; -use tracing::error; - -#[derive(Debug, Error)] -pub enum Error { - #[error(transparent)] - Io(#[from] std::io::Error), - #[error(transparent)] - Util(#[from] crate::util::UtilError), - #[error(transparent)] - Settings(#[from] crate::database::DatabaseError), - #[error(transparent)] - Reqwest(#[from] reqwest::Error), - #[error(transparent)] - Semver(#[from] semver::Error), - #[error(transparent)] - SystemTime(#[from] SystemTimeError), - #[error(transparent)] - Strum(#[from] strum::ParseError), - #[cfg(target_os = "macos")] - #[error("failed to update due to auth error: `{0}`")] - SecurityFramework(#[from] security_framework::base::Error), -} - -impl From for Error { - fn from(err: crate::util::directories::DirectoryError) -> Self { - crate::util::UtilError::Directory(err).into() - } -} diff --git a/crates/chat-cli/src/main.rs b/crates/chat-cli/src/main.rs index e28354bfb4..a00fc087f2 100644 --- a/crates/chat-cli/src/main.rs +++ b/crates/chat-cli/src/main.rs @@ -3,7 +3,6 @@ mod auth; mod aws_common; mod cli; mod database; -mod install; mod logging; mod mcp_client; mod platform; diff --git a/crates/mcp_client/Cargo.toml b/crates/mcp_client/Cargo.toml deleted file mode 100644 index a95692d102..0000000000 --- a/crates/mcp_client/Cargo.toml +++ /dev/null @@ -1,30 +0,0 @@ -[package] -name = "mcp_client" -authors.workspace = true -edition.workspace = true -homepage.workspace = true -publish.workspace = true -version.workspace = true -license.workspace = true - -[lints] -workspace = true - -[features] -default = [] - -[[bin]] -name = "test_mcp_server" -path = "test_mcp_server/test_server.rs" -test = true -doc = false - -[dependencies] -tokio.workspace = true -serde.workspace = true -serde_json.workspace = true -async-trait.workspace = true -tracing.workspace = true -thiserror.workspace = true -uuid.workspace = true -nix.workspace = true diff --git a/crates/mcp_client/src/client.rs b/crates/mcp_client/src/client.rs deleted file mode 100644 index 01b3794013..0000000000 --- a/crates/mcp_client/src/client.rs +++ /dev/null @@ -1,772 +0,0 @@ -use std::collections::HashMap; -use std::process::Stdio; -use std::sync::atomic::{ - AtomicBool, - AtomicU64, - Ordering, -}; -use std::sync::{ - Arc, - RwLock as SyncRwLock, -}; -use std::time::Duration; - -use nix::sys::signal::Signal; -use nix::unistd::Pid; -use serde::{ - Deserialize, - Serialize, -}; -use thiserror::Error; -use tokio::time; -use tokio::time::error::Elapsed; - -use crate::transport::base_protocol::{ - JsonRpcMessage, - JsonRpcNotification, - JsonRpcRequest, - JsonRpcVersion, -}; -use crate::transport::stdio::JsonRpcStdioTransport; -use crate::transport::{ - self, - Transport, - TransportError, -}; -use crate::{ - JsonRpcResponse, - Listener as _, - LogListener, - PaginationSupportedOps, - PromptGet, - PromptsListResult, - ResourceTemplatesListResult, - ResourcesListResult, - ToolsListResult, -}; - -pub type ServerCapabilities = serde_json::Value; -pub type ClientInfo = serde_json::Value; -pub type StdioTransport = JsonRpcStdioTransport; - -/// Represents the capabilities of a client in the Model Context Protocol. -/// This structure is sent to the server during initialization to communicate -/// what features the client supports and provide information about the client. -/// When features are added to the client, these should be declared in the [From] trait implemented -/// for the struct. -#[derive(Default, Debug, Serialize)] -#[serde(rename_all = "camelCase")] -struct ClientCapabilities { - protocol_version: JsonRpcVersion, - capabilities: HashMap, - client_info: serde_json::Value, -} - -impl From for ClientCapabilities { - fn from(client_info: ClientInfo) -> Self { - ClientCapabilities { - client_info, - ..Default::default() - } - } -} - -#[derive(Debug, Deserialize)] -pub struct ClientConfig { - pub server_name: String, - pub bin_path: String, - pub args: Vec, - pub timeout: u64, - pub client_info: serde_json::Value, - pub env: Option>, -} - -#[derive(Debug, Error)] -pub enum ClientError { - #[error(transparent)] - TransportError(#[from] TransportError), - #[error(transparent)] - Io(#[from] std::io::Error), - #[error(transparent)] - Serialization(#[from] serde_json::Error), - #[error("Operation timed out: {context}")] - RuntimeError { - #[source] - source: tokio::time::error::Elapsed, - context: String, - }, - #[error("Unexpected msg type encountered")] - UnexpectedMsgType, - #[error("{0}")] - NegotiationError(String), - #[error("Failed to obtain process id")] - MissingProcessId, - #[error("Invalid path received")] - InvalidPath, - #[error("{0}")] - ProcessKillError(String), - #[error("{0}")] - PoisonError(String), -} - -impl From<(tokio::time::error::Elapsed, String)> for ClientError { - fn from((error, context): (tokio::time::error::Elapsed, String)) -> Self { - ClientError::RuntimeError { source: error, context } - } -} - -#[derive(Debug)] -pub struct Client { - server_name: String, - transport: Arc, - timeout: u64, - server_process_id: Option, - client_info: serde_json::Value, - current_id: Arc, - pub prompt_gets: Arc>>, - pub is_prompts_out_of_date: Arc, -} - -impl Clone for Client { - fn clone(&self) -> Self { - Self { - server_name: self.server_name.clone(), - transport: self.transport.clone(), - timeout: self.timeout, - // Note that we cannot have an id for the clone because we would kill the original - // process when we drop the clone - server_process_id: None, - client_info: self.client_info.clone(), - current_id: self.current_id.clone(), - prompt_gets: self.prompt_gets.clone(), - is_prompts_out_of_date: self.is_prompts_out_of_date.clone(), - } - } -} - -impl Client { - pub fn from_config(config: ClientConfig) -> Result { - let ClientConfig { - server_name, - bin_path, - args, - timeout, - client_info, - env, - } = config; - let child = { - let mut command = tokio::process::Command::new(bin_path); - command - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .process_group(0) - .envs(std::env::vars()); - if let Some(env) = env { - for (env_name, env_value) in env { - command.env(env_name, env_value); - } - } - command.args(args).spawn()? - }; - let server_process_id = child.id().ok_or(ClientError::MissingProcessId)?; - #[allow(clippy::map_err_ignore)] - let server_process_id = Pid::from_raw( - server_process_id - .try_into() - .map_err(|_| ClientError::MissingProcessId)?, - ); - let server_process_id = Some(server_process_id); - let transport = Arc::new(transport::stdio::JsonRpcStdioTransport::client(child)?); - Ok(Self { - server_name, - transport, - timeout, - server_process_id, - client_info, - current_id: Arc::new(AtomicU64::new(0)), - prompt_gets: Arc::new(SyncRwLock::new(HashMap::new())), - is_prompts_out_of_date: Arc::new(AtomicBool::new(false)), - }) - } -} - -impl Drop for Client -where - T: Transport, -{ - // IF the servers are implemented well, they will shutdown once the pipe closes. - // This drop trait is here as a fail safe to ensure we don't leave behind any orphans. - fn drop(&mut self) { - if let Some(process_id) = self.server_process_id { - let _ = nix::sys::signal::kill(process_id, Signal::SIGTERM); - } - } -} - -impl Client -where - T: Transport, -{ - /// Exchange of information specified as per https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization - /// - /// Also done is the spawn of a background task that constantly listens for incoming messages - /// from the server. - pub async fn init(&self) -> Result { - let transport_ref = self.transport.clone(); - let server_name = self.server_name.clone(); - - tokio::spawn(async move { - let mut listener = transport_ref.get_listener(); - loop { - match listener.recv().await { - Ok(msg) => { - match msg { - JsonRpcMessage::Request(_req) => {}, - JsonRpcMessage::Notification(notif) => { - let JsonRpcNotification { method, params, .. } = notif; - if method.as_str() == "notifications/message" || method.as_str() == "message" { - let level = params - .as_ref() - .and_then(|p| p.get("level")) - .and_then(|v| serde_json::to_string(v).ok()); - let data = params - .as_ref() - .and_then(|p| p.get("data")) - .and_then(|v| serde_json::to_string(v).ok()); - if let (Some(level), Some(data)) = (level, data) { - match level.to_lowercase().as_str() { - "error" => { - tracing::error!(target: "mcp", "{}: {}", server_name, data); - }, - "warn" => { - tracing::warn!(target: "mcp", "{}: {}", server_name, data); - }, - "info" => { - tracing::info!(target: "mcp", "{}: {}", server_name, data); - }, - "debug" => { - tracing::debug!(target: "mcp", "{}: {}", server_name, data); - }, - "trace" => { - tracing::trace!(target: "mcp", "{}: {}", server_name, data); - }, - _ => {}, - } - } - } - }, - JsonRpcMessage::Response(_resp) => { /* noop since direct response is handled inside the request api */ - }, - } - }, - Err(e) => { - tracing::error!("Background listening thread for client {}: {:?}", server_name, e); - }, - } - } - }); - - let transport_ref = self.transport.clone(); - let server_name = self.server_name.clone(); - - // Spawning a task to listen and log stderr output - tokio::spawn(async move { - let mut log_listener = transport_ref.get_log_listener(); - loop { - match log_listener.recv().await { - Ok(msg) => { - tracing::trace!(target: "mcp", "{server_name} logged {}", msg); - }, - Err(e) => { - tracing::error!( - "Error encountered while reading from stderr for {server_name}: {:?}\nEnding stderr listening task.", - e - ); - break; - }, - } - } - }); - - let init_params = Some({ - let client_cap = ClientCapabilities::from(self.client_info.clone()); - serde_json::json!(client_cap) - }); - let server_capabilities = self.request("initialize", init_params).await?; - if let Err(e) = examine_server_capabilities(&server_capabilities) { - return Err(ClientError::NegotiationError(format!( - "Client {} has failed to negotiate server capabilities with server: {:?}", - self.server_name, e - ))); - } - self.notify("initialized", None).await?; - - // TODO: group this into examine_server_capabilities - // Prefetch prompts in the background. We should only do this after the server has been - // initialized - if let Some(res) = &server_capabilities.result { - if let Some(cap) = res.get("capabilities") { - if cap.get("prompts").is_some() { - self.is_prompts_out_of_date.store(true, Ordering::Relaxed); - let client_ref = (*self).clone(); - tokio::spawn(async move { - let Ok(resp) = client_ref.request("prompts/list", None).await else { - tracing::error!("Prompt list query failed for {0}", client_ref.server_name); - return; - }; - let Some(result) = resp.result else { - tracing::warn!("Prompt list query returned no result for {0}", client_ref.server_name); - return; - }; - let Some(prompts) = result.get("prompts") else { - tracing::warn!( - "Prompt list query result contained no field named prompts for {0}", - client_ref.server_name - ); - return; - }; - let Ok(prompts) = serde_json::from_value::>(prompts.clone()) else { - tracing::error!( - "Prompt list query deserialization failed for {0}", - client_ref.server_name - ); - return; - }; - let Ok(mut lock) = client_ref.prompt_gets.write() else { - tracing::error!( - "Failed to obtain write lock for prompt list query for {0}", - client_ref.server_name - ); - return; - }; - for prompt in prompts { - let name = prompt.name.clone(); - lock.insert(name, prompt); - } - }); - } - } - } - - Ok(serde_json::to_value(server_capabilities)?) - } - - /// Sends a request to the server associated. - /// This call will yield until a response is received. - pub async fn request( - &self, - method: &str, - params: Option, - ) -> Result { - let send_map_err = |e: Elapsed| (e, method.to_string()); - let recv_map_err = |e: Elapsed| (e, format!("recv for {method}")); - let mut id = self.get_id(); - let request = JsonRpcRequest { - jsonrpc: JsonRpcVersion::default(), - id, - method: method.to_owned(), - params, - }; - tracing::trace!(target: "mcp", "To {}:\n{:#?}", self.server_name, request); - let msg = JsonRpcMessage::Request(request); - time::timeout(Duration::from_millis(self.timeout), self.transport.send(&msg)) - .await - .map_err(send_map_err)??; - let mut listener = self.transport.get_listener(); - let mut resp = time::timeout(Duration::from_millis(self.timeout), async { - // we want to ignore all other messages sent by the server at this point and let the - // background loop handle them - loop { - if let JsonRpcMessage::Response(resp) = listener.recv().await? { - if resp.id == id { - break Ok::(resp); - } - } - } - }) - .await - .map_err(recv_map_err)??; - // Pagination support: https://spec.modelcontextprotocol.io/specification/2024-11-05/server/utilities/pagination/#pagination-model - let mut next_cursor = resp.result.as_ref().and_then(|v| v.get("nextCursor")); - if next_cursor.is_some() { - let mut current_resp = resp.clone(); - let mut results = Vec::::new(); - let pagination_supported_ops = { - let maybe_pagination_supported_op: Result = method.try_into(); - maybe_pagination_supported_op.ok() - }; - if let Some(ops) = pagination_supported_ops { - loop { - let result = current_resp.result.as_ref().cloned().unwrap(); - let mut list: Vec = match ops { - PaginationSupportedOps::ResourcesList => { - let ResourcesListResult { resources: list, .. } = - serde_json::from_value::(result) - .map_err(ClientError::Serialization)?; - list - }, - PaginationSupportedOps::ResourceTemplatesList => { - let ResourceTemplatesListResult { - resource_templates: list, - .. - } = serde_json::from_value::(result) - .map_err(ClientError::Serialization)?; - list - }, - PaginationSupportedOps::PromptsList => { - let PromptsListResult { prompts: list, .. } = - serde_json::from_value::(result) - .map_err(ClientError::Serialization)?; - list - }, - PaginationSupportedOps::ToolsList => { - let ToolsListResult { tools: list, .. } = serde_json::from_value::(result) - .map_err(ClientError::Serialization)?; - list - }, - }; - results.append(&mut list); - if next_cursor.is_none() { - break; - } - id = self.get_id(); - let next_request = JsonRpcRequest { - jsonrpc: JsonRpcVersion::default(), - id, - method: method.to_owned(), - params: Some(serde_json::json!({ - "cursor": next_cursor, - })), - }; - let msg = JsonRpcMessage::Request(next_request); - time::timeout(Duration::from_millis(self.timeout), self.transport.send(&msg)) - .await - .map_err(send_map_err)??; - let resp = time::timeout(Duration::from_millis(self.timeout), async { - // we want to ignore all other messages sent by the server at this point and let the - // background loop handle them - loop { - if let JsonRpcMessage::Response(resp) = listener.recv().await? { - if resp.id == id { - break Ok::(resp); - } - } - } - }) - .await - .map_err(recv_map_err)??; - current_resp = resp; - next_cursor = current_resp.result.as_ref().and_then(|v| v.get("nextCursor")); - } - resp.result = Some({ - let mut map = serde_json::Map::new(); - map.insert(ops.as_key().to_owned(), serde_json::to_value(results)?); - serde_json::to_value(map)? - }); - } - } - tracing::trace!(target: "mcp", "From {}:\n{:#?}", self.server_name, resp); - Ok(resp) - } - - /// Sends a notification to the server associated. - /// Notifications are requests that expect no responses. - pub async fn notify(&self, method: &str, params: Option) -> Result<(), ClientError> { - let send_map_err = |e: Elapsed| (e, method.to_string()); - let notification = JsonRpcNotification { - jsonrpc: JsonRpcVersion::default(), - method: format!("notifications/{}", method), - params, - }; - let msg = JsonRpcMessage::Notification(notification); - Ok( - time::timeout(Duration::from_millis(self.timeout), self.transport.send(&msg)) - .await - .map_err(send_map_err)??, - ) - } - - pub async fn shutdown(&self) -> Result<(), ClientError> { - Ok(self.transport.shutdown().await?) - } - - fn get_id(&self) -> u64 { - self.current_id.fetch_add(1, Ordering::SeqCst) - } -} - -fn examine_server_capabilities(ser_cap: &JsonRpcResponse) -> Result<(), ClientError> { - // Check the jrpc version. - // Currently we are only proceeding if the versions are EXACTLY the same. - let jrpc_version = ser_cap.jsonrpc.as_u32_vec(); - let client_jrpc_version = JsonRpcVersion::default().as_u32_vec(); - for (sv, cv) in jrpc_version.iter().zip(client_jrpc_version.iter()) { - if sv != cv { - return Err(ClientError::NegotiationError( - "Incompatible jrpc version between server and client".to_owned(), - )); - } - } - Ok(()) -} - -#[cfg(test)] -mod tests { - use std::path::PathBuf; - - use serde_json::Value; - - use super::*; - const TEST_BIN_OUT_DIR: &str = "target/debug"; - const TEST_SERVER_NAME: &str = "test_mcp_server"; - - fn get_workspace_root() -> PathBuf { - let output = std::process::Command::new("cargo") - .args(["metadata", "--format-version=1", "--no-deps"]) - .output() - .expect("Failed to execute cargo metadata"); - - let metadata: serde_json::Value = - serde_json::from_slice(&output.stdout).expect("Failed to parse cargo metadata"); - - let workspace_root = metadata["workspace_root"] - .as_str() - .expect("Failed to find workspace_root in metadata"); - - PathBuf::from(workspace_root) - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_client_stdio() { - std::process::Command::new("cargo") - .args(["build", "--bin", TEST_SERVER_NAME]) - .status() - .expect("Failed to build binary"); - let workspace_root = get_workspace_root(); - let bin_path = workspace_root.join(TEST_BIN_OUT_DIR).join(TEST_SERVER_NAME); - println!("bin path: {}", bin_path.to_str().unwrap_or("no path found")); - - // Testing 2 concurrent sessions to make sure transport layer does not overlap. - let client_info_one = serde_json::json!({ - "name": "TestClientOne", - "version": "1.0.0" - }); - let client_config_one = ClientConfig { - server_name: "test_tool".to_owned(), - bin_path: bin_path.to_str().unwrap().to_string(), - args: ["1".to_owned()].to_vec(), - timeout: 120 * 1000, - client_info: client_info_one.clone(), - env: { - let mut map = HashMap::::new(); - map.insert("ENV_ONE".to_owned(), "1".to_owned()); - map.insert("ENV_TWO".to_owned(), "2".to_owned()); - Some(map) - }, - }; - let client_info_two = serde_json::json!({ - "name": "TestClientTwo", - "version": "1.0.0" - }); - let client_config_two = ClientConfig { - server_name: "test_tool".to_owned(), - bin_path: bin_path.to_str().unwrap().to_string(), - args: ["2".to_owned()].to_vec(), - timeout: 120 * 1000, - client_info: client_info_two.clone(), - env: { - let mut map = HashMap::::new(); - map.insert("ENV_ONE".to_owned(), "1".to_owned()); - map.insert("ENV_TWO".to_owned(), "2".to_owned()); - Some(map) - }, - }; - let mut client_one = Client::::from_config(client_config_one).expect("Failed to create client"); - let mut client_two = Client::::from_config(client_config_two).expect("Failed to create client"); - let client_one_cap = ClientCapabilities::from(client_info_one); - let client_two_cap = ClientCapabilities::from(client_info_two); - - let (res_one, res_two) = tokio::join!( - time::timeout( - time::Duration::from_secs(5), - test_client_routine(&mut client_one, serde_json::json!(client_one_cap)) - ), - time::timeout( - time::Duration::from_secs(5), - test_client_routine(&mut client_two, serde_json::json!(client_two_cap)) - ) - ); - let res_one = res_one.expect("Client one timed out"); - let res_two = res_two.expect("Client two timed out"); - assert!(res_one.is_ok()); - assert!(res_two.is_ok()); - } - - async fn test_client_routine( - client: &mut Client, - cap_sent: serde_json::Value, - ) -> Result<(), Box> { - // Test init - let _ = client.init().await.expect("Client init failed"); - tokio::time::sleep(time::Duration::from_millis(1500)).await; - let client_capabilities_sent = client - .request("verify_init_ack_sent", None) - .await - .expect("Verify init ack mock request failed"); - let has_server_recvd_init_ack = client_capabilities_sent - .result - .expect("Failed to retrieve client capabilities sent."); - assert_eq!(has_server_recvd_init_ack.to_string(), "true"); - let cap_recvd = client - .request("verify_init_params_sent", None) - .await - .expect("Verify init params mock request failed"); - let cap_recvd = cap_recvd - .result - .expect("Verify init params mock request does not contain required field (result)"); - assert!(are_json_values_equal(&cap_sent, &cap_recvd)); - - // test list tools - let fake_tool_names = ["get_weather_one", "get_weather_two", "get_weather_three"]; - let mock_result_spec = fake_tool_names.map(create_fake_tool_spec); - let mock_tool_specs_for_verify = serde_json::json!(mock_result_spec.clone()); - let mock_tool_specs_prep_param = mock_result_spec - .iter() - .zip(fake_tool_names.iter()) - .map(|(v, n)| { - serde_json::json!({ - "key": (*n).to_string(), - "value": v - }) - }) - .collect::>(); - let mock_tool_specs_prep_param = - serde_json::to_value(mock_tool_specs_prep_param).expect("Failed to create mock tool specs prep param"); - let _ = client - .request("store_mock_tool_spec", Some(mock_tool_specs_prep_param)) - .await - .expect("Mock tool spec prep failed"); - let tool_spec_recvd = client.request("tools/list", None).await.expect("List tools failed"); - assert!(are_json_values_equal( - tool_spec_recvd - .result - .as_ref() - .and_then(|v| v.get("tools")) - .expect("Failed to retrieve tool specs from result received"), - &mock_tool_specs_for_verify - )); - - // Test list prompts directly - let fake_prompt_names = ["code_review_one", "code_review_two", "code_review_three"]; - let mock_result_prompts = fake_prompt_names.map(create_fake_prompts); - let mock_prompts_for_verify = serde_json::json!(mock_result_prompts.clone()); - let mock_prompts_prep_param = mock_result_prompts - .iter() - .zip(fake_prompt_names.iter()) - .map(|(v, n)| { - serde_json::json!({ - "key": (*n).to_string(), - "value": v - }) - }) - .collect::>(); - let mock_prompts_prep_param = - serde_json::to_value(mock_prompts_prep_param).expect("Failed to create mock prompts prep param"); - let _ = client - .request("store_mock_prompts", Some(mock_prompts_prep_param)) - .await - .expect("Mock prompt prep failed"); - let prompts_recvd = client.request("prompts/list", None).await.expect("List prompts failed"); - assert!(are_json_values_equal( - prompts_recvd - .result - .as_ref() - .and_then(|v| v.get("prompts")) - .expect("Failed to retrieve prompts from results received"), - &mock_prompts_for_verify - )); - - // Test env var inclusion - let env_vars = client.request("get_env_vars", None).await.expect("Get env vars failed"); - let env_one = env_vars - .result - .as_ref() - .expect("Failed to retrieve results from env var request") - .get("ENV_ONE") - .expect("Failed to retrieve env one from env var request"); - let env_two = env_vars - .result - .as_ref() - .expect("Failed to retrieve results from env var request") - .get("ENV_TWO") - .expect("Failed to retrieve env two from env var request"); - let env_one_as_str = serde_json::to_string(env_one).expect("Failed to convert env one to string"); - let env_two_as_str = serde_json::to_string(env_two).expect("Failed to convert env two to string"); - assert_eq!(env_one_as_str, "\"1\"".to_string()); - assert_eq!(env_two_as_str, "\"2\"".to_string()); - - let shutdown_result = client.shutdown().await; - assert!(shutdown_result.is_ok()); - Ok(()) - } - - fn are_json_values_equal(a: &Value, b: &Value) -> bool { - match (a, b) { - (Value::Null, Value::Null) => true, - (Value::Bool(a_val), Value::Bool(b_val)) => a_val == b_val, - (Value::Number(a_val), Value::Number(b_val)) => a_val == b_val, - (Value::String(a_val), Value::String(b_val)) => a_val == b_val, - (Value::Array(a_arr), Value::Array(b_arr)) => { - if a_arr.len() != b_arr.len() { - return false; - } - a_arr - .iter() - .zip(b_arr.iter()) - .all(|(a_item, b_item)| are_json_values_equal(a_item, b_item)) - }, - (Value::Object(a_obj), Value::Object(b_obj)) => { - if a_obj.len() != b_obj.len() { - return false; - } - a_obj.iter().all(|(key, a_value)| match b_obj.get(key) { - Some(b_value) => are_json_values_equal(a_value, b_value), - None => false, - }) - }, - _ => false, - } - } - - fn create_fake_tool_spec(name: &str) -> serde_json::Value { - serde_json::json!({ - "name": name, - "description": "Get current weather information for a location", - "inputSchema": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "City name or zip code" - } - }, - "required": ["location"] - } - }) - } - - fn create_fake_prompts(name: &str) -> serde_json::Value { - serde_json::json!({ - "name": name, - "description": "Asks the LLM to analyze code quality and suggest improvements", - "arguments": [ - { - "name": "code", - "description": "The code to review", - "required": true - } - ] - }) - } -} diff --git a/crates/mcp_client/src/error.rs b/crates/mcp_client/src/error.rs deleted file mode 100644 index d05e7efa4d..0000000000 --- a/crates/mcp_client/src/error.rs +++ /dev/null @@ -1,66 +0,0 @@ -/// Error codes as defined in the MCP protocol. -/// -/// These error codes are based on the JSON-RPC 2.0 specification with additional -/// MCP-specific error codes in the -32000 to -32099 range. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[repr(i32)] -pub enum ErrorCode { - /// Invalid JSON was received by the server. - /// An error occurred on the server while parsing the JSON text. - ParseError = -32700, - - /// The JSON sent is not a valid Request object. - InvalidRequest = -32600, - - /// The method does not exist / is not available. - MethodNotFound = -32601, - - /// Invalid method parameter(s). - InvalidParams = -32602, - - /// Internal JSON-RPC error. - InternalError = -32603, - - /// Server has not been initialized. - /// This error is returned when a request is made before the server - /// has been properly initialized. - ServerNotInitialized = -32002, - - /// Unknown error code. - /// This error is returned when an error code is received that is not - /// recognized by the implementation. - UnknownErrorCode = -32001, - - /// Request failed. - /// This error is returned when a request fails for a reason not covered - /// by other error codes. - RequestFailed = -32000, -} - -impl From for ErrorCode { - fn from(code: i32) -> Self { - match code { - -32700 => ErrorCode::ParseError, - -32600 => ErrorCode::InvalidRequest, - -32601 => ErrorCode::MethodNotFound, - -32602 => ErrorCode::InvalidParams, - -32603 => ErrorCode::InternalError, - -32002 => ErrorCode::ServerNotInitialized, - -32001 => ErrorCode::UnknownErrorCode, - -32000 => ErrorCode::RequestFailed, - _ => ErrorCode::UnknownErrorCode, - } - } -} - -impl From for i32 { - fn from(code: ErrorCode) -> Self { - code as i32 - } -} - -impl std::fmt::Display for ErrorCode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) - } -} diff --git a/crates/mcp_client/src/facilitator_types.rs b/crates/mcp_client/src/facilitator_types.rs deleted file mode 100644 index ba56982046..0000000000 --- a/crates/mcp_client/src/facilitator_types.rs +++ /dev/null @@ -1,229 +0,0 @@ -use serde::{ - Deserialize, - Serialize, -}; -use thiserror::Error; - -/// https://spec.modelcontextprotocol.io/specification/2024-11-05/server/utilities/pagination/#operations-supporting-pagination -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum PaginationSupportedOps { - ResourcesList, - ResourceTemplatesList, - PromptsList, - ToolsList, -} - -impl PaginationSupportedOps { - pub fn as_key(&self) -> &str { - match self { - PaginationSupportedOps::ResourcesList => "resources", - PaginationSupportedOps::ResourceTemplatesList => "resourceTemplates", - PaginationSupportedOps::PromptsList => "prompts", - PaginationSupportedOps::ToolsList => "tools", - } - } -} - -impl TryFrom<&str> for PaginationSupportedOps { - type Error = OpsConversionError; - - fn try_from(value: &str) -> Result { - match value { - "resources/list" => Ok(PaginationSupportedOps::ResourcesList), - "resources/templates/list" => Ok(PaginationSupportedOps::ResourceTemplatesList), - "prompts/list" => Ok(PaginationSupportedOps::PromptsList), - "tools/list" => Ok(PaginationSupportedOps::ToolsList), - _ => Err(OpsConversionError::InvalidMethod), - } - } -} - -#[derive(Error, Debug)] -pub enum OpsConversionError { - #[error("Invalid method encountered")] - InvalidMethod, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] -#[serde(rename_all = "camelCase")] -/// Role assumed for a particular message -pub enum Role { - User, - Assistant, -} - -impl std::fmt::Display for Role { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Role::User => write!(f, "user"), - Role::Assistant => write!(f, "assistant"), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Result of listing resources operation -pub struct ResourcesListResult { - /// List of resources - pub resources: Vec, - /// Optional cursor for pagination - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -/// Result of listing resource templates operation -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ResourceTemplatesListResult { - /// List of resource templates - pub resource_templates: Vec, - /// Optional cursor for pagination - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Result of prompt listing query -pub struct PromptsListResult { - /// List of prompts - pub prompts: Vec, - /// Optional cursor for pagination - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Represents an argument to be supplied to a [PromptGet] -pub struct PromptGetArg { - /// The name identifier of the prompt - pub name: String, - /// Optional description providing context about the prompt - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - /// Indicates whether a response to this prompt is required - /// If not specified, defaults to false - #[serde(skip_serializing_if = "Option::is_none")] - pub required: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Represents a request to get a prompt from a mcp server -pub struct PromptGet { - /// Unique identifier for the prompt - pub name: String, - /// Optional description providing context about the prompt's purpose - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - /// Optional list of arguments that define the structure of information to be collected - #[serde(skip_serializing_if = "Option::is_none")] - pub arguments: Option>, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// `result` field in [JsonRpcResponse] from a `prompts/get` request -pub struct PromptGetResult { - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - pub messages: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Completed prompt from `prompts/get` to be returned by a mcp server -pub struct Prompt { - pub role: Role, - pub content: MessageContent, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Result of listing tools operation -pub struct ToolsListResult { - /// List of tools - pub tools: Vec, - /// Optional cursor for pagination - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ToolCallResult { - pub content: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub is_error: Option, -} - -/// Content of a message -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "camelCase")] -pub enum MessageContent { - /// Text content - Text { - /// The text content - text: String, - }, - /// Image content - #[serde(rename_all = "camelCase")] - Image { - /// base64-encoded-data - data: String, - mime_type: String, - }, - /// Resource content - Resource { - /// The resource - resource: Resource, - }, -} - -impl From for String { - fn from(val: MessageContent) -> Self { - match val { - MessageContent::Text { text } => text, - MessageContent::Image { data, mime_type } => serde_json::json!({ - "data": data, - "mime_type": mime_type - }) - .to_string(), - MessageContent::Resource { resource } => serde_json::json!(resource).to_string(), - } - } -} - -impl std::fmt::Display for MessageContent { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - MessageContent::Text { text } => write!(f, "{}", text), - MessageContent::Image { data: _, mime_type } => write!(f, "Image [base64-encoded-string] ({})", mime_type), - MessageContent::Resource { resource } => write!(f, "Resource: {} ({})", resource.title, resource.uri), - } - } -} - -/// Resource contents -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "camelCase")] -pub enum ResourceContents { - Text { text: String }, - Blob { data: Vec }, -} - -/// A resource in the system -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Resource { - /// Unique identifier for the resource - pub uri: String, - /// Human-readable title - pub title: String, - /// Optional description - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - /// Resource contents - pub contents: ResourceContents, -} diff --git a/crates/mcp_client/src/lib.rs b/crates/mcp_client/src/lib.rs deleted file mode 100644 index d631f70654..0000000000 --- a/crates/mcp_client/src/lib.rs +++ /dev/null @@ -1,9 +0,0 @@ -pub mod client; -pub mod error; -pub mod facilitator_types; -pub mod server; -pub mod transport; - -pub use client::*; -pub use facilitator_types::*; -pub use transport::*; diff --git a/crates/mcp_client/src/server.rs b/crates/mcp_client/src/server.rs deleted file mode 100644 index 1ba92b154d..0000000000 --- a/crates/mcp_client/src/server.rs +++ /dev/null @@ -1,293 +0,0 @@ -use std::collections::HashMap; -use std::sync::atomic::{ - AtomicBool, - AtomicU64, - Ordering, -}; -use std::sync::{ - Arc, - Mutex, -}; - -use tokio::io::{ - Stdin, - Stdout, -}; -use tokio::task::JoinHandle; - -use crate::Listener as _; -use crate::client::StdioTransport; -use crate::error::ErrorCode; -use crate::transport::base_protocol::{ - JsonRpcError, - JsonRpcMessage, - JsonRpcNotification, - JsonRpcRequest, - JsonRpcResponse, -}; -use crate::transport::stdio::JsonRpcStdioTransport; -use crate::transport::{ - JsonRpcVersion, - Transport, - TransportError, -}; - -pub type Request = serde_json::Value; -pub type Response = Option; -pub type InitializedServer = JoinHandle>; - -pub trait PreServerRequestHandler { - fn register_pending_request_callback(&mut self, cb: impl Fn(u64) -> Option + Send + Sync + 'static); - fn register_send_request_callback( - &mut self, - cb: impl Fn(&str, Option) -> Result<(), ServerError> + Send + Sync + 'static, - ); -} - -#[async_trait::async_trait] -pub trait ServerRequestHandler: PreServerRequestHandler + Send + Sync + 'static { - async fn handle_initialize(&self, params: Option) -> Result; - async fn handle_incoming(&self, method: &str, params: Option) -> Result; - async fn handle_response(&self, resp: JsonRpcResponse) -> Result<(), ServerError>; - async fn handle_shutdown(&self) -> Result<(), ServerError>; -} - -pub struct Server { - transport: Option>, - handler: Option, - #[allow(dead_code)] - pending_requests: Arc>>, - #[allow(dead_code)] - current_id: Arc, -} - -#[derive(Debug, thiserror::Error)] -pub enum ServerError { - #[error(transparent)] - TransportError(#[from] TransportError), - #[error(transparent)] - Io(#[from] std::io::Error), - #[error(transparent)] - Serialization(#[from] serde_json::Error), - #[error("Unexpected msg type encountered")] - UnexpectedMsgType, - #[error("{0}")] - NegotiationError(String), - #[error(transparent)] - TokioJoinError(#[from] tokio::task::JoinError), - #[error("Failed to obtain mutex lock")] - MutexError, - #[error("Failed to obtain request method")] - MissingMethod, - #[error("Failed to obtain request id")] - MissingId, - #[error("Failed to initialize server. Missing transport")] - MissingTransport, - #[error("Failed to initialize server. Missing handler")] - MissingHandler, -} - -impl Server -where - H: ServerRequestHandler, -{ - pub fn new(mut handler: H, stdin: Stdin, stdout: Stdout) -> Result { - let transport = Arc::new(JsonRpcStdioTransport::server(stdin, stdout)?); - let pending_requests = Arc::new(Mutex::new(HashMap::::new())); - let pending_requests_clone_one = pending_requests.clone(); - let current_id = Arc::new(AtomicU64::new(0)); - let pending_request_getter = move |id: u64| -> Option { - match pending_requests_clone_one.lock() { - Ok(mut p) => p.remove(&id), - Err(_) => None, - } - }; - handler.register_pending_request_callback(pending_request_getter); - let transport_clone = transport.clone(); - let pending_request_clone_two = pending_requests.clone(); - let current_id_clone = current_id.clone(); - let request_sender = move |method: &str, params: Option| -> Result<(), ServerError> { - let id = current_id_clone.fetch_add(1, Ordering::SeqCst); - let request = JsonRpcRequest { - jsonrpc: JsonRpcVersion::default(), - id, - method: method.to_owned(), - params, - }; - let msg = JsonRpcMessage::Request(request.clone()); - let transport = transport_clone.clone(); - tokio::task::spawn(async move { - let _ = transport.send(&msg).await; - }); - #[allow(clippy::map_err_ignore)] - let mut pending_request = pending_request_clone_two.lock().map_err(|_| ServerError::MutexError)?; - pending_request.insert(id, request); - Ok(()) - }; - handler.register_send_request_callback(request_sender); - let server = Self { - transport: Some(transport), - handler: Some(handler), - pending_requests, - current_id, - }; - Ok(server) - } -} - -impl Server -where - T: Transport, - H: ServerRequestHandler, -{ - pub fn init(mut self) -> Result { - let transport = self.transport.take().ok_or(ServerError::MissingTransport)?; - let handler = Arc::new(self.handler.take().ok_or(ServerError::MissingHandler)?); - let has_initialized = Arc::new(AtomicBool::new(false)); - let listener = tokio::spawn(async move { - let mut listener = transport.get_listener(); - loop { - let request = listener.recv().await; - let transport_clone = transport.clone(); - let has_init_clone = has_initialized.clone(); - let handler_clone = handler.clone(); - tokio::task::spawn(async move { - process_request(has_init_clone, transport_clone, handler_clone, request).await; - }); - } - }); - Ok(listener) - } -} - -async fn process_request( - has_initialized: Arc, - transport: Arc, - handler: Arc, - request: Result, -) where - T: Transport, - H: ServerRequestHandler, -{ - match request { - Ok(msg) if msg.is_initialize() => { - let id = msg.id().unwrap_or_default(); - if has_initialized.load(Ordering::SeqCst) { - let resp = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: JsonRpcVersion::default(), - id, - error: Some(JsonRpcError { - code: ErrorCode::InvalidRequest.into(), - message: "Server has already been initialized".to_owned(), - data: None, - }), - ..Default::default() - }); - let _ = transport.send(&resp).await; - return; - } - let JsonRpcMessage::Request(req) = msg else { - let resp = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: JsonRpcVersion::default(), - id, - error: Some(JsonRpcError { - code: ErrorCode::InvalidRequest.into(), - message: "Invalid method for initialization (use request)".to_owned(), - data: None, - }), - ..Default::default() - }); - let _ = transport.send(&resp).await; - return; - }; - let JsonRpcRequest { params, .. } = req; - match handler.handle_initialize(params).await { - Ok(result) => { - let resp = JsonRpcMessage::Response(JsonRpcResponse { - id, - result, - ..Default::default() - }); - let _ = transport.send(&resp).await; - has_initialized.store(true, Ordering::SeqCst); - }, - Err(_e) => { - let resp = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: JsonRpcVersion::default(), - id, - error: Some(JsonRpcError { - code: ErrorCode::InternalError.into(), - message: "Error producing initialization response".to_owned(), - data: None, - }), - ..Default::default() - }); - let _ = transport.send(&resp).await; - }, - } - }, - Ok(msg) if msg.is_shutdown() => { - // TODO: add shutdown routine - }, - Ok(msg) if has_initialized.load(Ordering::SeqCst) => match msg { - JsonRpcMessage::Request(req) => { - let JsonRpcRequest { - id, - jsonrpc, - params, - ref method, - } = req; - let resp = handler.handle_incoming(method, params).await.map_or_else( - |error| { - let err = JsonRpcError { - code: ErrorCode::InternalError.into(), - message: error.to_string(), - data: None, - }; - let resp = JsonRpcResponse { - jsonrpc: jsonrpc.clone(), - id, - result: None, - error: Some(err), - }; - JsonRpcMessage::Response(resp) - }, - |result| { - let resp = JsonRpcResponse { - jsonrpc: jsonrpc.clone(), - id, - result, - error: None, - }; - JsonRpcMessage::Response(resp) - }, - ); - let _ = transport.send(&resp).await; - }, - JsonRpcMessage::Notification(notif) => { - let JsonRpcNotification { ref method, params, .. } = notif; - let _ = handler.handle_incoming(method, params).await; - }, - JsonRpcMessage::Response(resp) => { - let _ = handler.handle_response(resp).await; - }, - }, - Ok(msg) => { - let id = msg.id().unwrap_or_default(); - let resp = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: JsonRpcVersion::default(), - id, - error: Some(JsonRpcError { - code: ErrorCode::ServerNotInitialized.into(), - message: "Server has not been initialized".to_owned(), - data: None, - }), - ..Default::default() - }); - let _ = transport.send(&resp).await; - }, - Err(_e) => { - // TODO: error handling - }, - } -} diff --git a/crates/mcp_client/src/transport/base_protocol.rs b/crates/mcp_client/src/transport/base_protocol.rs deleted file mode 100644 index b0394e6e0c..0000000000 --- a/crates/mcp_client/src/transport/base_protocol.rs +++ /dev/null @@ -1,108 +0,0 @@ -//! Referencing https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/messages/ -//! Protocol Revision 2024-11-05 -use serde::{ - Deserialize, - Serialize, -}; - -pub type RequestId = u64; - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct JsonRpcVersion(String); - -impl Default for JsonRpcVersion { - fn default() -> Self { - JsonRpcVersion("2.0".to_owned()) - } -} - -impl JsonRpcVersion { - pub fn as_u32_vec(&self) -> Vec { - self.0 - .split(".") - .map(|n| n.parse::().unwrap()) - .collect::>() - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -#[serde(untagged)] -#[serde(deny_unknown_fields)] -// DO NOT change the order of these variants. This body of json is [untagged](https://serde.rs/enum-representations.html#untagged) -// The categorization of the deserialization depends on the order in which the variants are -// declared. -pub enum JsonRpcMessage { - Response(JsonRpcResponse), - Notification(JsonRpcNotification), - Request(JsonRpcRequest), -} - -impl JsonRpcMessage { - pub fn is_initialize(&self) -> bool { - match self { - JsonRpcMessage::Request(req) => req.method == "initialize", - _ => false, - } - } - - pub fn is_shutdown(&self) -> bool { - match self { - JsonRpcMessage::Notification(notif) => notif.method == "notification/shutdown", - _ => false, - } - } - - pub fn id(&self) -> Option { - match self { - JsonRpcMessage::Request(req) => Some(req.id), - JsonRpcMessage::Response(resp) => Some(resp.id), - JsonRpcMessage::Notification(_) => None, - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -#[serde(default, deny_unknown_fields)] -pub struct JsonRpcRequest { - pub jsonrpc: JsonRpcVersion, - pub id: RequestId, - pub method: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub params: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -#[serde(default, deny_unknown_fields)] -pub struct JsonRpcResponse { - pub jsonrpc: JsonRpcVersion, - pub id: RequestId, - #[serde(skip_serializing_if = "Option::is_none")] - pub result: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -#[serde(default, deny_unknown_fields)] -pub struct JsonRpcNotification { - pub jsonrpc: JsonRpcVersion, - pub method: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub params: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -#[serde(default, deny_unknown_fields)] -pub struct JsonRpcError { - pub code: i32, - pub message: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub data: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -pub enum TransportType { - #[default] - Stdio, - Websocket, -} diff --git a/crates/mcp_client/src/transport/mod.rs b/crates/mcp_client/src/transport/mod.rs deleted file mode 100644 index 5796ba5323..0000000000 --- a/crates/mcp_client/src/transport/mod.rs +++ /dev/null @@ -1,56 +0,0 @@ -pub mod base_protocol; -pub mod stdio; - -use std::fmt::Debug; - -pub use base_protocol::*; -pub use stdio::*; -use thiserror::Error; - -#[derive(Clone, Debug, Error)] -pub enum TransportError { - #[error("Serialization error: {0}")] - Serialization(String), - #[error("IO error: {0}")] - Stdio(String), - #[error("{0}")] - Custom(String), - #[error(transparent)] - RecvError(#[from] tokio::sync::broadcast::error::RecvError), -} - -impl From for TransportError { - fn from(err: serde_json::Error) -> Self { - TransportError::Serialization(err.to_string()) - } -} - -impl From for TransportError { - fn from(err: std::io::Error) -> Self { - TransportError::Stdio(err.to_string()) - } -} - -#[async_trait::async_trait] -pub trait Transport: Send + Sync + Debug + 'static { - /// Sends a message over the transport layer. - async fn send(&self, msg: &JsonRpcMessage) -> Result<(), TransportError>; - /// Listens to awaits for a response. This is a call that should be used after `send` is called - /// to listen for a response from the message recipient. - fn get_listener(&self) -> impl Listener; - /// Gracefully terminates the transport connection, cleaning up any resources. - /// This should be called when the transport is no longer needed to ensure proper cleanup. - async fn shutdown(&self) -> Result<(), TransportError>; - /// Listener that listens for logging messages. - fn get_log_listener(&self) -> impl LogListener; -} - -#[async_trait::async_trait] -pub trait Listener: Send + Sync + 'static { - async fn recv(&mut self) -> Result; -} - -#[async_trait::async_trait] -pub trait LogListener: Send + Sync + 'static { - async fn recv(&mut self) -> Result; -} diff --git a/crates/mcp_client/src/transport/stdio.rs b/crates/mcp_client/src/transport/stdio.rs deleted file mode 100644 index ab4c6a2a07..0000000000 --- a/crates/mcp_client/src/transport/stdio.rs +++ /dev/null @@ -1,277 +0,0 @@ -use std::sync::Arc; - -use tokio::io::{ - AsyncBufReadExt, - AsyncRead, - AsyncWriteExt as _, - BufReader, - Stdin, - Stdout, -}; -use tokio::process::{ - Child, - ChildStdin, -}; -use tokio::sync::{ - Mutex, - broadcast, -}; - -use super::base_protocol::JsonRpcMessage; -use super::{ - Listener, - LogListener, - Transport, - TransportError, -}; - -#[derive(Debug)] -pub enum JsonRpcStdioTransport { - Client { - stdin: Arc>, - receiver: broadcast::Receiver>, - log_receiver: broadcast::Receiver, - }, - Server { - stdout: Arc>, - receiver: broadcast::Receiver>, - }, -} - -impl JsonRpcStdioTransport { - fn spawn_reader( - reader: R, - tx: broadcast::Sender>, - ) { - tokio::spawn(async move { - let mut buffer = Vec::::new(); - let mut buf_reader = BufReader::new(reader); - loop { - buffer.clear(); - // Messages are delimited by newlines and assumed to contain no embedded newlines - // See https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio - match buf_reader.read_until(b'\n', &mut buffer).await { - Ok(0) => continue, - Ok(_) => match serde_json::from_slice::(buffer.as_slice()) { - Ok(msg) => { - let _ = tx.send(Ok(msg)); - }, - Err(e) => { - let _ = tx.send(Err(e.into())); - }, - }, - Err(e) => { - let _ = tx.send(Err(e.into())); - }, - } - } - }); - } - - pub fn client(child_process: Child) -> Result { - let (tx, receiver) = broadcast::channel::>(100); - let Some(stdout) = child_process.stdout else { - return Err(TransportError::Custom("No stdout found on child process".to_owned())); - }; - let Some(stdin) = child_process.stdin else { - return Err(TransportError::Custom("No stdin found on child process".to_owned())); - }; - let Some(stderr) = child_process.stderr else { - return Err(TransportError::Custom("No stderr found on child process".to_owned())); - }; - let (log_tx, log_receiver) = broadcast::channel::(100); - tokio::task::spawn(async move { - let stderr = tokio::io::BufReader::new(stderr); - let mut lines = stderr.lines(); - while let Ok(Some(line)) = lines.next_line().await { - let _ = log_tx.send(line); - } - }); - let stdin = Arc::new(Mutex::new(stdin)); - Self::spawn_reader(stdout, tx); - Ok(JsonRpcStdioTransport::Client { - stdin, - receiver, - log_receiver, - }) - } - - pub fn server(stdin: Stdin, stdout: Stdout) -> Result { - let (tx, receiver) = broadcast::channel::>(100); - Self::spawn_reader(stdin, tx); - let stdout = Arc::new(Mutex::new(stdout)); - Ok(JsonRpcStdioTransport::Server { stdout, receiver }) - } -} - -#[async_trait::async_trait] -impl Transport for JsonRpcStdioTransport { - async fn send(&self, msg: &JsonRpcMessage) -> Result<(), TransportError> { - match self { - JsonRpcStdioTransport::Client { stdin, .. } => { - let mut serialized = serde_json::to_vec(msg)?; - serialized.push(b'\n'); - let mut stdin = stdin.lock().await; - stdin - .write_all(&serialized) - .await - .map_err(|e| TransportError::Custom(format!("Error writing to server: {:?}", e)))?; - stdin - .flush() - .await - .map_err(|e| TransportError::Custom(format!("Error writing to server: {:?}", e)))?; - Ok(()) - }, - JsonRpcStdioTransport::Server { stdout, .. } => { - let mut serialized = serde_json::to_vec(msg)?; - serialized.push(b'\n'); - let mut stdout = stdout.lock().await; - stdout - .write_all(&serialized) - .await - .map_err(|e| TransportError::Custom(format!("Error writing to client: {:?}", e)))?; - stdout - .flush() - .await - .map_err(|e| TransportError::Custom(format!("Error writing to client: {:?}", e)))?; - Ok(()) - }, - } - } - - fn get_listener(&self) -> impl Listener { - match self { - JsonRpcStdioTransport::Client { receiver, .. } | JsonRpcStdioTransport::Server { receiver, .. } => { - StdioListener { - receiver: receiver.resubscribe(), - } - }, - } - } - - async fn shutdown(&self) -> Result<(), TransportError> { - match self { - JsonRpcStdioTransport::Client { stdin, .. } => { - let mut stdin = stdin.lock().await; - Ok(stdin.shutdown().await?) - }, - JsonRpcStdioTransport::Server { stdout, .. } => { - let mut stdout = stdout.lock().await; - Ok(stdout.shutdown().await?) - }, - } - } - - fn get_log_listener(&self) -> impl LogListener { - match self { - JsonRpcStdioTransport::Client { log_receiver, .. } => StdioLogListener { - receiver: log_receiver.resubscribe(), - }, - JsonRpcStdioTransport::Server { .. } => unreachable!("server does not need a log listener"), - } - } -} - -pub struct StdioListener { - pub receiver: broadcast::Receiver>, -} - -#[async_trait::async_trait] -impl Listener for StdioListener { - async fn recv(&mut self) -> Result { - self.receiver.recv().await? - } -} - -pub struct StdioLogListener { - pub receiver: broadcast::Receiver, -} - -#[async_trait::async_trait] -impl LogListener for StdioLogListener { - async fn recv(&mut self) -> Result { - Ok(self.receiver.recv().await?) - } -} - -#[cfg(test)] -mod tests { - use std::process::Stdio; - - use serde_json::{ - Value, - json, - }; - use tokio::process::Command; - - use crate::{ - JsonRpcMessage, - JsonRpcStdioTransport, - Listener, - Transport, - }; - - // Helpers for testing - fn create_test_message() -> JsonRpcMessage { - serde_json::from_value(json!({ - "jsonrpc": "2.0", - "id": 1, - "method": "test_method", - "params": { - "test_param": "test_value" - } - })) - .unwrap() - } - - #[tokio::test] - async fn test_client_transport() { - let mut cmd = Command::new("cat"); - cmd.stdin(Stdio::piped()).stdout(Stdio::piped()).stderr(Stdio::piped()); - - // Inject our mock transport instead - let child = cmd.spawn().expect("Failed to spawn command"); - let transport = JsonRpcStdioTransport::client(child).expect("Failed to create client transport"); - - let message = create_test_message(); - let result = transport.send(&message).await; - assert!(result.is_ok(), "Failed to send message: {:?}", result); - - let echo = transport - .get_listener() - .recv() - .await - .expect("Failed to receive message"); - let echo_value = serde_json::to_value(&echo).expect("Failed to convert echo to value"); - let message_value = serde_json::to_value(&message).expect("Failed to convert message to value"); - assert!(are_json_values_equal(&echo_value, &message_value)); - } - - fn are_json_values_equal(a: &Value, b: &Value) -> bool { - match (a, b) { - (Value::Null, Value::Null) => true, - (Value::Bool(a_val), Value::Bool(b_val)) => a_val == b_val, - (Value::Number(a_val), Value::Number(b_val)) => a_val == b_val, - (Value::String(a_val), Value::String(b_val)) => a_val == b_val, - (Value::Array(a_arr), Value::Array(b_arr)) => { - if a_arr.len() != b_arr.len() { - return false; - } - a_arr - .iter() - .zip(b_arr.iter()) - .all(|(a_item, b_item)| are_json_values_equal(a_item, b_item)) - }, - (Value::Object(a_obj), Value::Object(b_obj)) => { - if a_obj.len() != b_obj.len() { - return false; - } - a_obj.iter().all(|(key, a_value)| match b_obj.get(key) { - Some(b_value) => are_json_values_equal(a_value, b_value), - None => false, - }) - }, - _ => false, - } - } -} diff --git a/crates/mcp_client/src/transport/websocket.rs b/crates/mcp_client/src/transport/websocket.rs deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/crates/mcp_client/test_mcp_server/test_server.rs b/crates/mcp_client/test_mcp_server/test_server.rs deleted file mode 100644 index 486048bad2..0000000000 --- a/crates/mcp_client/test_mcp_server/test_server.rs +++ /dev/null @@ -1,354 +0,0 @@ -//! This is a bin used solely for testing the client -use std::collections::HashMap; -use std::str::FromStr; -use std::sync::atomic::{ - AtomicU8, - Ordering, -}; - -use mcp_client::server::{ - self, - PreServerRequestHandler, - Response, - ServerError, - ServerRequestHandler, -}; -use mcp_client::transport::{ - JsonRpcRequest, - JsonRpcResponse, - JsonRpcStdioTransport, -}; -use tokio::sync::Mutex; - -#[derive(Default)] -struct Handler { - pending_request: Option Option + Send + Sync>>, - #[allow(clippy::type_complexity)] - send_request: Option) -> Result<(), ServerError> + Send + Sync>>, - storage: Mutex>, - tool_spec: Mutex>, - tool_spec_key_list: Mutex>, - prompts: Mutex>, - prompt_key_list: Mutex>, - prompt_list_call_no: AtomicU8, -} - -impl PreServerRequestHandler for Handler { - fn register_pending_request_callback( - &mut self, - cb: impl Fn(u64) -> Option + Send + Sync + 'static, - ) { - self.pending_request = Some(Box::new(cb)); - } - - fn register_send_request_callback( - &mut self, - cb: impl Fn(&str, Option) -> Result<(), ServerError> + Send + Sync + 'static, - ) { - self.send_request = Some(Box::new(cb)); - } -} - -#[async_trait::async_trait] -impl ServerRequestHandler for Handler { - async fn handle_initialize(&self, params: Option) -> Result { - let mut storage = self.storage.lock().await; - if let Some(params) = params { - storage.insert("client_cap".to_owned(), params); - } - let capabilities = serde_json::json!({ - "protocolVersion": "2024-11-05", - "capabilities": { - "logging": {}, - "prompts": { - "listChanged": true - }, - "resources": { - "subscribe": true, - "listChanged": true - }, - "tools": { - "listChanged": true - } - }, - "serverInfo": { - "name": "TestServer", - "version": "1.0.0" - } - }); - Ok(Some(capabilities)) - } - - async fn handle_incoming(&self, method: &str, params: Option) -> Result { - match method { - "notifications/initialized" => { - { - let mut storage = self.storage.lock().await; - storage.insert( - "init_ack_sent".to_owned(), - serde_json::Value::from_str("true").expect("Failed to convert string to value"), - ); - } - Ok(None) - }, - "verify_init_params_sent" => { - let client_capabilities = { - let storage = self.storage.lock().await; - storage.get("client_cap").cloned() - }; - Ok(client_capabilities) - }, - "verify_init_ack_sent" => { - let result = { - let storage = self.storage.lock().await; - storage.get("init_ack_sent").cloned() - }; - Ok(result) - }, - "store_mock_tool_spec" => { - let Some(params) = params else { - eprintln!("Params missing from store mock tool spec"); - return Ok(None); - }; - // expecting a mock_specs: { key: String, value: serde_json::Value }[]; - let Ok(mock_specs) = serde_json::from_value::>(params) else { - eprintln!("Failed to convert to mock specs from value"); - return Ok(None); - }; - let self_tool_specs = self.tool_spec.lock().await; - let mut self_tool_spec_key_list = self.tool_spec_key_list.lock().await; - let _ = mock_specs.iter().fold(self_tool_specs, |mut acc, spec| { - let Some(key) = spec.get("key").cloned() else { - return acc; - }; - let Ok(key) = serde_json::from_value::(key) else { - eprintln!("Failed to convert serde value to string for key"); - return acc; - }; - self_tool_spec_key_list.push(key.clone()); - acc.insert(key, spec.get("value").cloned()); - acc - }); - Ok(None) - }, - "tools/list" => { - if let Some(params) = params { - if let Some(cursor) = params.get("cursor").cloned() { - let Ok(cursor) = serde_json::from_value::(cursor) else { - eprintln!("Failed to convert cursor to string: {:#?}", params); - return Ok(None); - }; - let self_tool_spec_key_list = self.tool_spec_key_list.lock().await; - let self_tool_spec = self.tool_spec.lock().await; - let (next_cursor, spec) = { - 'blk: { - for (i, item) in self_tool_spec_key_list.iter().enumerate() { - if item == &cursor { - break 'blk ( - self_tool_spec_key_list.get(i + 1).cloned(), - self_tool_spec.get(&cursor).cloned().unwrap(), - ); - } - } - (None, None) - } - }; - if let Some(next_cursor) = next_cursor { - return Ok(Some(serde_json::json!({ - "tools": [spec.unwrap()], - "nextCursor": next_cursor, - }))); - } else { - return Ok(Some(serde_json::json!({ - "tools": [spec.unwrap()], - }))); - } - } else { - eprintln!("Params exist but cursor is missing"); - return Ok(None); - } - } else { - let first_key = self - .tool_spec_key_list - .lock() - .await - .first() - .expect("First key missing from tool specs") - .clone(); - let first_value = self - .tool_spec - .lock() - .await - .get(&first_key) - .expect("First value missing from tool specs") - .clone(); - let second_key = self - .tool_spec_key_list - .lock() - .await - .get(1) - .expect("Second key missing from tool specs") - .clone(); - return Ok(Some(serde_json::json!({ - "tools": [first_value], - "nextCursor": second_key - }))); - }; - }, - "get_env_vars" => { - let kv = std::env::vars().fold(HashMap::::new(), |mut acc, (k, v)| { - acc.insert(k, v); - acc - }); - Ok(Some(serde_json::json!(kv))) - }, - // This is a test path relevant only to sampling - "trigger_server_request" => { - let Some(ref send_request) = self.send_request else { - return Err(ServerError::MissingMethod); - }; - let params = Some(serde_json::json!({ - "messages": [ - { - "role": "user", - "content": { - "type": "text", - "text": "What is the capital of France?" - } - } - ], - "modelPreferences": { - "hints": [ - { - "name": "claude-3-sonnet" - } - ], - "intelligencePriority": 0.8, - "speedPriority": 0.5 - }, - "systemPrompt": "You are a helpful assistant.", - "maxTokens": 100 - })); - send_request("sampling/createMessage", params)?; - Ok(None) - }, - "store_mock_prompts" => { - let Some(params) = params else { - eprintln!("Params missing from store mock prompts"); - return Ok(None); - }; - // expecting a mock_prompts: { key: String, value: serde_json::Value }[]; - let Ok(mock_prompts) = serde_json::from_value::>(params) else { - eprintln!("Failed to convert to mock specs from value"); - return Ok(None); - }; - let self_prompts = self.prompts.lock().await; - let mut self_prompt_key_list = self.prompt_key_list.lock().await; - let _ = mock_prompts.iter().fold(self_prompts, |mut acc, spec| { - let Some(key) = spec.get("key").cloned() else { - return acc; - }; - let Ok(key) = serde_json::from_value::(key) else { - eprintln!("Failed to convert serde value to string for key"); - return acc; - }; - self_prompt_key_list.push(key.clone()); - acc.insert(key, spec.get("value").cloned()); - acc - }); - Ok(None) - }, - "prompts/list" => { - self.prompt_list_call_no.fetch_add(1, Ordering::Relaxed); - if let Some(params) = params { - if let Some(cursor) = params.get("cursor").cloned() { - let Ok(cursor) = serde_json::from_value::(cursor) else { - eprintln!("Failed to convert cursor to string: {:#?}", params); - return Ok(None); - }; - let self_prompt_key_list = self.prompt_key_list.lock().await; - let self_prompts = self.prompts.lock().await; - let (next_cursor, spec) = { - 'blk: { - for (i, item) in self_prompt_key_list.iter().enumerate() { - if item == &cursor { - break 'blk ( - self_prompt_key_list.get(i + 1).cloned(), - self_prompts.get(&cursor).cloned().unwrap(), - ); - } - } - (None, None) - } - }; - if let Some(next_cursor) = next_cursor { - return Ok(Some(serde_json::json!({ - "prompts": [spec.unwrap()], - "nextCursor": next_cursor, - }))); - } else { - return Ok(Some(serde_json::json!({ - "prompts": [spec.unwrap()], - }))); - } - } else { - eprintln!("Params exist but cursor is missing"); - return Ok(None); - } - } else { - let first_key = self - .prompt_key_list - .lock() - .await - .first() - .expect("First key missing from prompts") - .clone(); - let first_value = self - .prompts - .lock() - .await - .get(&first_key) - .expect("First value missing from prompts") - .clone(); - let second_key = self - .prompt_key_list - .lock() - .await - .get(1) - .expect("Second key missing from prompts") - .clone(); - return Ok(Some(serde_json::json!({ - "prompts": [first_value], - "nextCursor": second_key - }))); - }; - }, - "get_prompt_list_call_no" => Ok(Some( - serde_json::to_value::(self.prompt_list_call_no.load(Ordering::Relaxed)) - .expect("Failed to convert list call no to u8"), - )), - _ => Err(ServerError::MissingMethod), - } - } - - // This is a test path relevant only to sampling - async fn handle_response(&self, resp: JsonRpcResponse) -> Result<(), ServerError> { - let JsonRpcResponse { id, .. } = resp; - let _pending = self.pending_request.as_ref().and_then(|f| f(id)); - Ok(()) - } - - async fn handle_shutdown(&self) -> Result<(), ServerError> { - Ok(()) - } -} - -#[tokio::main] -async fn main() { - let handler = Handler::default(); - let stdin = tokio::io::stdin(); - let stdout = tokio::io::stdout(); - let test_server = - server::Server::::new(handler, stdin, stdout).expect("Failed to create server"); - let _ = test_server.init().expect("Test server failed to init").await; -} diff --git a/crates/q_cli/Cargo.toml b/crates/q_cli/Cargo.toml index 4460fa8ce0..9401547269 100644 --- a/crates/q_cli/Cargo.toml +++ b/crates/q_cli/Cargo.toml @@ -56,7 +56,6 @@ glob.workspace = true globset.workspace = true indicatif.workspace = true indoc.workspace = true -mcp_client.workspace = true mimalloc.workspace = true owo-colors = "4.2.0" parking_lot.workspace = true From 08863759c548e35fe83c59ef2ace17d618b28417 Mon Sep 17 00:00:00 2001 From: Chay Nabors Date: Mon, 12 May 2025 13:53:41 -0700 Subject: [PATCH 2/4] fix tests --- crates/chat-cli/src/auth/builder_id.rs | 25 +++---------------------- 1 file changed, 3 insertions(+), 22 deletions(-) diff --git a/crates/chat-cli/src/auth/builder_id.rs b/crates/chat-cli/src/auth/builder_id.rs index beb99c768c..331c21baf8 100644 --- a/crates/chat-cli/src/auth/builder_id.rs +++ b/crates/chat-cli/src/auth/builder_id.rs @@ -66,15 +66,6 @@ pub enum OAuthFlow { Pkce, } -impl std::fmt::Display for OAuthFlow { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match *self { - OAuthFlow::DeviceCode => write!(f, "DeviceCode"), - OAuthFlow::Pkce => write!(f, "PKCE"), - } - } -} - /// Indicates if an expiration time has passed, there is a small 1 min window that is removed /// so the token will not expire in transit fn is_expired(expiration_time: &OffsetDateTime) -> bool { @@ -565,20 +556,10 @@ mod tests { const US_EAST_1: Region = Region::from_static("us-east-1"); const US_WEST_2: Region = Region::from_static("us-west-2"); - macro_rules! test_ser_deser { - ($ty:ident, $variant:expr, $text:expr) => { - let quoted = format!("\"{}\"", $text); - assert_eq!(quoted, serde_json::to_string(&$variant).unwrap()); - assert_eq!($variant, serde_json::from_str("ed).unwrap()); - - assert_eq!($text, format!("{}", $variant)); - }; - } - #[test] - fn test_oauth_flow_ser_deser() { - test_ser_deser!(OAuthFlow, OAuthFlow::DeviceCode, "DeviceCode"); - test_ser_deser!(OAuthFlow, OAuthFlow::Pkce, "PKCE"); + fn test_oauth_flow_deser() { + assert_eq!(OAuthFlow::Pkce, serde_json::from_str("\"PKCE\"").unwrap()); + assert_eq!(OAuthFlow::Pkce, serde_json::from_str("\"Pkce\"").unwrap()); } #[tokio::test] From aa519815e714d9ec113cba0ff9350ac9f7ed795c Mon Sep 17 00:00:00 2001 From: Chay Nabors Date: Mon, 12 May 2025 15:07:37 -0700 Subject: [PATCH 3/4] store token to db --- crates/chat-cli/src/mcp_client/client.rs | 1 + crates/fig_auth/src/builder_id.rs | 2 +- crates/fig_auth/src/secret_store/macos.rs | 4 ++++ crates/q_cli/src/cli/mod.rs | 16 ++++++++++++++++ 4 files changed, 22 insertions(+), 1 deletion(-) diff --git a/crates/chat-cli/src/mcp_client/client.rs b/crates/chat-cli/src/mcp_client/client.rs index 9dc9f6bc98..d840bcd9ed 100644 --- a/crates/chat-cli/src/mcp_client/client.rs +++ b/crates/chat-cli/src/mcp_client/client.rs @@ -541,6 +541,7 @@ mod tests { } #[tokio::test(flavor = "multi_thread")] + #[ignore] async fn test_client_stdio() { std::process::Command::new("cargo") .args(["build", "--bin", TEST_SERVER_NAME]) diff --git a/crates/fig_auth/src/builder_id.rs b/crates/fig_auth/src/builder_id.rs index e99a13f3dd..41e6d3f24a 100644 --- a/crates/fig_auth/src/builder_id.rs +++ b/crates/fig_auth/src/builder_id.rs @@ -138,7 +138,7 @@ impl DeviceRegistration { } /// Loads the OIDC registered client from the secret store, deleting it if it is expired. - async fn load_from_secret_store(secret_store: &SecretStore, region: &Region) -> Result> { + pub async fn load_from_secret_store(secret_store: &SecretStore, region: &Region) -> Result> { let device_registration = secret_store.get(Self::SECRET_KEY).await?; if let Some(device_registration) = device_registration { diff --git a/crates/fig_auth/src/secret_store/macos.rs b/crates/fig_auth/src/secret_store/macos.rs index b463a115b9..52e78c71a8 100644 --- a/crates/fig_auth/src/secret_store/macos.rs +++ b/crates/fig_auth/src/secret_store/macos.rs @@ -1,3 +1,5 @@ +use fig_settings::sqlite::database; + use super::Secret; use crate::{ Error, @@ -29,6 +31,8 @@ impl SecretStoreImpl { if !output.status.success() { let stderr = std::str::from_utf8(&output.stderr)?; return Err(Error::Security(stderr.into())); + } else { + database()?.set_auth_value(key, password)?; } Ok(()) diff --git a/crates/q_cli/src/cli/mod.rs b/crates/q_cli/src/cli/mod.rs index ddd19a64a5..329fd3373c 100644 --- a/crates/q_cli/src/cli/mod.rs +++ b/crates/q_cli/src/cli/mod.rs @@ -45,13 +45,16 @@ use eyre::{ bail, }; use feed::Feed; +use fig_auth::builder_id::BuilderIdToken; use fig_auth::is_logged_in; +use fig_auth::secret_store::SecretStore; use fig_ipc::local::open_ui_element; use fig_log::{ LogArgs, initialize_logging, }; use fig_proto::local::UiElement; +use fig_settings::sqlite::database; use fig_util::directories::home_local_bin; use fig_util::{ CHAT_BINARY_NAME, @@ -346,7 +349,20 @@ impl Cli { } async fn execute_chat(args: Option>) -> Result { + let secret_store = SecretStore::new().await.ok(); + if let Some(secret_store) = secret_store { + if let Ok(database) = database() { + if let Ok(token) = BuilderIdToken::load(&secret_store, false).await { + if let Ok(token) = serde_json::to_string(&token) { + database.set_auth_value("codewhisperer:odic:token", token).ok(); + } + } + } + } + let mut cmd = tokio::process::Command::new(home_local_bin()?.join(CHAT_BINARY_NAME)); + cmd.arg("chat"); + if let Some(args) = args { cmd.args(args); } From 8c180b0adfbd0e6afcf3e65c5c47278d1d1e4829 Mon Sep 17 00:00:00 2001 From: Chay Nabors Date: Tue, 13 May 2025 11:52:00 -0700 Subject: [PATCH 4/4] small fixes --- Cargo.lock | 1 - crates/chat-cli/Cargo.toml | 1 - .../src/cli/chat/conversation_state.rs | 4 +- crates/chat-cli/src/cli/chat/mod.rs | 105 +++++++++--------- crates/chat-cli/src/cli/user.rs | 5 +- crates/fig_auth/src/builder_id.rs | 2 +- 6 files changed, 60 insertions(+), 58 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a11c7c2259..2e6dc2da28 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1521,7 +1521,6 @@ dependencies = [ "bytes", "camino", "cfg-if", - "chrono", "clap", "clap_complete", "clap_complete_fig", diff --git a/crates/chat-cli/Cargo.toml b/crates/chat-cli/Cargo.toml index 10ed45d145..2b7ae4d42b 100644 --- a/crates/chat-cli/Cargo.toml +++ b/crates/chat-cli/Cargo.toml @@ -45,7 +45,6 @@ bstr = "1.12.0" bytes = "1.10.1" camino = { version = "1.1.3", features = ["serde1"] } cfg-if = "1.0.0" -chrono = { version = "0.4.41", default-features = false, features = ["std"] } clap = { version = "4.5.32", features = [ "deprecated", "derive", diff --git a/crates/chat-cli/src/cli/chat/conversation_state.rs b/crates/chat-cli/src/cli/chat/conversation_state.rs index bb269c286c..c1424d48ae 100644 --- a/crates/chat-cli/src/cli/chat/conversation_state.rs +++ b/crates/chat-cli/src/cli/chat/conversation_state.rs @@ -213,9 +213,7 @@ impl ConversationState { warn!("input must not be empty when adding new messages"); "Empty prompt".to_string() } else { - let now = chrono::Utc::now(); - let formatted_time = now.format("%Y-%m-%d %H:%M:%S").to_string(); - format!("{}\n\n\n{}\n", input, formatted_time) + input }; let msg = UserMessage::new_prompt(input); diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 68a62e088c..4c2c29c60a 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -72,62 +72,13 @@ use hooks::{ Hook, HookTrigger, }; +use input_source::InputSource; use message::{ AssistantMessage, AssistantToolUse, ToolUseResult, ToolUseResultBlock, }; -use rand::distr::{ - Alphanumeric, - SampleString, -}; -use tokio::signal::ctrl_c; -use util::shared_writer::SharedWriter; -use util::ui::draw_box; - -use crate::api_client::StreamingClient; -use crate::api_client::clients::SendMessageOutput; -use crate::api_client::model::{ - ChatResponseStream, - Tool as FigTool, - ToolResultStatus, -}; -use crate::database::Database; -use crate::database::settings::Setting; -use crate::platform::Context; -use crate::telemetry::TelemetryThread; -use crate::telemetry::core::ToolUseEventBuilder; - -/// Help text for the compact command -fn compact_help_text() -> String { - color_print::cformat!( - r#" -Conversation Compaction - -The /compact command summarizes the conversation history to free up context space -while preserving essential information. This is useful for long-running conversations -that may eventually reach memory constraints. - -Usage - /compact Summarize the conversation and clear history - /compact [prompt] Provide custom guidance for summarization - -When to use -• When you see the memory constraint warning message -• When a conversation has been running for a long time -• Before starting a new topic within the same session -• After completing complex tool operations - -How it works -• Creates an AI-generated summary of your conversation -• Retains key information, code, and tool executions in the summary -• Clears the conversation history to free up space -• The assistant will reference the summary context in future responses -"# - ) -} -use input_source::InputSource; use parse::{ ParseState, interpret_markdown, @@ -136,6 +87,10 @@ use parser::{ RecvErrorKind, ResponseParser, }; +use rand::distr::{ + Alphanumeric, + SampleString, +}; use regex::Regex; use serde_json::Map; use spinners::{ @@ -147,6 +102,7 @@ use token_counter::{ TokenCount, TokenCounter, }; +use tokio::signal::ctrl_c; use tool_manager::{ GetPromptError, McpServerConfig, @@ -171,6 +127,8 @@ use tracing::{ }; use unicode_width::UnicodeWidthStr; use util::images::RichImageBlock; +use util::shared_writer::SharedWriter; +use util::ui::draw_box; use util::{ animate_output, drop_matched_context_files, @@ -181,10 +139,52 @@ use uuid::Uuid; use winnow::Partial; use winnow::stream::Offset; +use crate::api_client::StreamingClient; +use crate::api_client::clients::SendMessageOutput; +use crate::api_client::model::{ + ChatResponseStream, + Tool as FigTool, + ToolResultStatus, +}; +use crate::database::Database; +use crate::database::settings::Setting; use crate::mcp_client::{ Prompt, PromptGetResult, }; +use crate::platform::Context; +use crate::telemetry::TelemetryThread; +use crate::telemetry::core::ToolUseEventBuilder; +use crate::util::CHAT_BINARY_NAME; + +/// Help text for the compact command +fn compact_help_text() -> String { + color_print::cformat!( + r#" +Conversation Compaction + +The /compact command summarizes the conversation history to free up context space +while preserving essential information. This is useful for long-running conversations +that may eventually reach memory constraints. + +Usage + /compact Summarize the conversation and clear history + /compact [prompt] Provide custom guidance for summarization + +When to use +• When you see the memory constraint warning message +• When a conversation has been running for a long time +• Before starting a new topic within the same session +• After completing complex tool operations + +How it works +• Creates an AI-generated summary of your conversation +• Retains key information, code, and tool executions in the summary +• Clears the conversation history to free up space +• The assistant will reference the summary context in future responses +"# + ) +} const WELCOME_TEXT: &str = color_print::cstr! {" ⢠⣶⣶⣦⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣤⣶⣿⣿⣿⣶⣦⡀⠀ @@ -311,7 +311,10 @@ pub async fn chat( trust_tools: Option>, ) -> Result { if !crate::util::system_info::in_cloudshell() && !crate::auth::is_logged_in(database).await { - bail!("You are not logged in, please log in with {}", "q login".bold()); + bail!( + "You are not logged in, please log in with {}", + format!("{CHAT_BINARY_NAME} login").bold() + ); } region_check("chat")?; diff --git a/crates/chat-cli/src/cli/user.rs b/crates/chat-cli/src/cli/user.rs index a661491708..07cdacf10b 100644 --- a/crates/chat-cli/src/cli/user.rs +++ b/crates/chat-cli/src/cli/user.rs @@ -187,7 +187,10 @@ impl UserSubcommand { }, Self::Profile => { if !crate::util::system_info::in_cloudshell() && !crate::auth::is_logged_in(database).await { - bail!("You are not logged in, please log in with {}", "q login".bold()); + bail!( + "You are not logged in, please log in with {}", + format!("{CHAT_BINARY_NAME} login").bold() + ); } if let Ok(Some(token)) = BuilderIdToken::load(database).await { diff --git a/crates/fig_auth/src/builder_id.rs b/crates/fig_auth/src/builder_id.rs index 41e6d3f24a..e99a13f3dd 100644 --- a/crates/fig_auth/src/builder_id.rs +++ b/crates/fig_auth/src/builder_id.rs @@ -138,7 +138,7 @@ impl DeviceRegistration { } /// Loads the OIDC registered client from the secret store, deleting it if it is expired. - pub async fn load_from_secret_store(secret_store: &SecretStore, region: &Region) -> Result> { + async fn load_from_secret_store(secret_store: &SecretStore, region: &Region) -> Result> { let device_registration = secret_store.get(Self::SECRET_KEY).await?; if let Some(device_registration) = device_registration {