-
Notifications
You must be signed in to change notification settings - Fork 436
Expand file tree
/
Copy pathcustom_tool.rs
More file actions
226 lines (203 loc) · 7.11 KB
/
Copy pathcustom_tool.rs
File metadata and controls
226 lines (203 loc) · 7.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
use std::borrow::Cow;
use std::collections::HashMap;
use std::io::Write;
use crossterm::{
queue,
style,
};
use eyre::Result;
use rmcp::model::CallToolRequestParam;
use schemars::JsonSchema;
use serde::{
Deserialize,
Serialize,
};
use tracing::warn;
use super::InvokeOutput;
use crate::cli::agent::{
Agent,
PermissionEvalResult,
};
use crate::cli::chat::CONTINUATION_LINE;
use crate::cli::chat::token_counter::TokenCounter;
use crate::mcp_client::{
RunningService,
oauth_util,
};
use crate::os::Os;
use crate::theme::StyledText;
use crate::util::MCP_SERVER_TOOL_DELIMITER;
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub enum TransportType {
/// Standard input/output transport (default)
Stdio,
/// HTTP transport for web-based communication
Http,
}
impl Default for TransportType {
fn default() -> Self {
Self::Stdio
}
}
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct OAuthConfig {
/// Custom redirect URI for OAuth flow (e.g., "127.0.0.1:7778")
/// If not specified, a random available port will be assigned by the OS
#[serde(skip_serializing_if = "Option::is_none")]
pub redirect_uri: Option<String>,
}
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct CustomToolConfig {
/// The transport type to use for communication with the MCP server
#[serde(default)]
pub r#type: TransportType,
/// The URL for HTTP-based MCP server communication
#[serde(default)]
pub url: String,
/// HTTP headers to include when communicating with HTTP-based MCP servers
#[serde(default)]
pub headers: HashMap<String, String>,
/// Scopes with which oauth is done
#[serde(default = "get_default_scopes")]
pub oauth_scopes: Vec<String>,
/// OAuth configuration for this server
#[serde(skip_serializing_if = "Option::is_none")]
pub oauth: Option<OAuthConfig>,
/// The command string used to initialize the mcp server
#[serde(default)]
pub command: String,
/// A list of arguments to be used to run the command with
#[serde(default)]
pub args: Vec<String>,
/// A list of environment variables to run the command with
#[serde(skip_serializing_if = "Option::is_none")]
pub env: Option<HashMap<String, String>>,
/// Timeout for each mcp request in ms
#[serde(default = "default_timeout")]
pub timeout: u64,
/// A boolean flag to denote whether or not to load this mcp server
#[serde(default)]
pub disabled: bool,
/// A flag to denote whether this is a server from the legacy mcp.json
#[serde(skip)]
pub is_from_legacy_mcp_json: bool,
}
pub fn get_default_scopes() -> Vec<String> {
oauth_util::get_default_scopes()
.iter()
.map(|s| (*s).to_string())
.collect::<Vec<_>>()
}
pub fn default_timeout() -> u64 {
120 * 1000
}
/// Represents a custom tool that can be invoked through the Model Context Protocol (MCP).
#[derive(Clone, Debug)]
pub struct CustomTool {
/// Actual tool name as recognized by its MCP server. This differs from the tool names as they
/// are seen by the model since they are not prefixed by its MCP server name.
pub name: String,
/// The name of the MCP (Model Context Protocol) server that hosts this tool.
/// This is used to identify which server instance the tool belongs to and is
/// prefixed to the tool name when presented to the model for disambiguation.
pub server_name: String,
/// Reference to the client that manages communication with the tool's server process.
pub client: RunningService,
/// Optional parameters to pass to the tool when invoking the method.
/// Structured as a JSON value to accommodate various parameter types and structures.
pub params: Option<serde_json::Map<String, serde_json::Value>>,
}
impl CustomTool {
/// Returns the full tool name with server prefix in the format @server_name/tool_name
pub fn namespaced_tool_name(&self) -> String {
format!("@{}{}{}", self.server_name, MCP_SERVER_TOOL_DELIMITER, self.name)
}
pub async fn invoke(&self, _os: &Os, _updates: &mut impl Write) -> Result<InvokeOutput> {
let params = CallToolRequestParam {
name: Cow::from(self.name.clone()),
arguments: self.params.clone(),
};
let resp = self.client.call_tool(params.clone()).await?;
if resp.is_error.is_none_or(|v| !v) {
Ok(InvokeOutput {
output: super::OutputKind::Json(serde_json::json!(resp)),
})
} else {
warn!("Tool call for {} failed", self.name);
Ok(InvokeOutput {
output: super::OutputKind::Json(serde_json::json!(resp)),
})
}
}
pub fn queue_description(&self, output: &mut impl Write) -> Result<()> {
queue!(
output,
style::Print("Running "),
StyledText::success_fg(),
style::Print(&self.name),
StyledText::reset(),
)?;
if let Some(params) = &self.params {
let params = match serde_json::to_string_pretty(params) {
Ok(params) => params
.split("\n")
.map(|p| format!("{CONTINUATION_LINE} {p}"))
.collect::<Vec<_>>()
.join("\n"),
_ => format!("{:?}", params),
};
queue!(
output,
style::Print(" with the param:\n"),
style::Print(params),
style::Print("\n"),
StyledText::reset(),
)?;
} else {
queue!(output, style::Print("\n"))?;
}
Ok(())
}
pub async fn validate(&mut self, _os: &Os) -> Result<()> {
Ok(())
}
pub fn get_input_token_size(&self) -> usize {
TokenCounter::count_tokens(
&serde_json::to_string(self.params.as_ref().unwrap_or(&serde_json::Map::new())).unwrap_or_default(),
)
}
pub fn eval_perm(&self, _os: &Os, agent: &Agent) -> PermissionEvalResult {
use crate::util::tool_permission_checker::is_tool_in_allowlist;
if is_tool_in_allowlist(&agent.allowed_tools, &self.name, Some(&self.server_name)) {
PermissionEvalResult::Allow
} else {
PermissionEvalResult::Ask
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_timeout_is_two_minutes_in_ms() {
assert_eq!(default_timeout(), 120 * 1000);
}
#[test]
fn get_default_scopes_returns_non_empty() {
let scopes = get_default_scopes();
assert!(!scopes.is_empty(), "default OAuth scopes must not be empty");
}
#[test]
fn get_default_scopes_are_strings() {
for scope in get_default_scopes() {
assert!(!scope.is_empty(), "each scope must be a non-empty string");
}
}
#[test]
fn transport_type_default_is_stdio() {
assert!(matches!(TransportType::default(), TransportType::Stdio));
}
}