forked from paiml/rust-mcp-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrefactored_server_example.rs
More file actions
137 lines (116 loc) · 4.11 KB
/
refactored_server_example.rs
File metadata and controls
137 lines (116 loc) · 4.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
//! Example demonstrating the refactored server architecture with protocol/transport split.
//!
//! This example shows how to use the new transport-independent ServerCore
//! with different transport adapters.
use async_trait::async_trait;
use pmcp::server::adapters::{StdioAdapter, TransportAdapter};
use pmcp::server::builder::ServerCoreBuilder;
use pmcp::server::cancellation::RequestHandlerExtra;
use pmcp::server::core::ProtocolHandler;
use pmcp::server::ToolHandler;
use pmcp::Result;
use serde_json::{json, Value};
use std::sync::Arc;
/// Example tool that echoes input back
struct EchoTool;
#[async_trait]
impl ToolHandler for EchoTool {
async fn handle(&self, args: Value, _extra: RequestHandlerExtra) -> Result<Value> {
Ok(json!({
"echo": args,
"timestamp": chrono::Utc::now().to_rfc3339()
}))
}
}
/// Example tool that performs calculations
struct CalculatorTool;
#[async_trait]
impl ToolHandler for CalculatorTool {
async fn handle(&self, args: Value, _extra: RequestHandlerExtra) -> Result<Value> {
let operation = args
.get("operation")
.and_then(|v| v.as_str())
.unwrap_or("add");
let a = args.get("a").and_then(|v| v.as_f64()).unwrap_or(0.0);
let b = args.get("b").and_then(|v| v.as_f64()).unwrap_or(0.0);
let result = match operation {
"add" => a + b,
"subtract" => a - b,
"multiply" => a * b,
"divide" => {
if b != 0.0 {
a / b
} else {
return Ok(json!({"error": "Division by zero"}));
}
},
_ => return Ok(json!({"error": "Unknown operation"})),
};
Ok(json!({ "result": result }))
}
}
#[tokio::main]
async fn main() -> Result<()> {
// Initialize logging
env_logger::init();
// Build the server core using the builder pattern
let server_core = ServerCoreBuilder::new()
.name("refactored-example-server")
.version("0.1.0")
.tool("echo", EchoTool)
.tool("calculator", CalculatorTool)
.build()?;
// Convert to Arc for sharing with transport adapter
let handler: Arc<dyn ProtocolHandler> = Arc::new(server_core);
// Choose transport adapter based on environment
#[cfg(not(target_arch = "wasm32"))]
{
println!("Starting server with STDIO transport...");
let adapter = StdioAdapter::new();
adapter.serve(handler).await?;
}
#[cfg(target_arch = "wasm32")]
{
println!("WASM environment detected - use WASI HTTP adapter");
// In WASM, you would use the WasiHttpAdapter
// This would typically be integrated with the WASI HTTP world
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use pmcp::server::adapters::MockAdapter;
use pmcp::types::{ClientRequest, Implementation, InitializeParams, Request, RequestId};
#[tokio::test]
async fn test_refactored_server() {
// Create server
let server = ServerCoreBuilder::new()
.name("test-server")
.version("1.0.0")
.tool("echo", EchoTool)
.build()
.unwrap();
let handler: Arc<dyn ProtocolHandler> = Arc::new(server);
// Create mock adapter for testing
let adapter = MockAdapter::new();
// Add initialization request
let init_request = Request::Client(Box::new(ClientRequest::Initialize(InitializeParams {
protocol_version: "2024-11-05".to_string(),
capabilities: pmcp::types::ClientCapabilities::default(),
client_info: Implementation {
name: "test-client".to_string(),
version: "1.0.0".to_string(),
},
})));
adapter
.add_request(RequestId::from(1i64), init_request)
.await;
// Serve the requests
adapter.serve(handler).await.unwrap();
// Check responses
let responses = adapter.get_responses().await;
assert_eq!(responses.len(), 1);
assert_eq!(responses[0].id, RequestId::from(1i64));
}
}