Skip to content

Commit 8648603

Browse files
committed
test: add long-running tool resume tests
1 parent d4ea2e9 commit 8648603

1 file changed

Lines changed: 172 additions & 0 deletions

File tree

crates/rmcp/tests/test_streamable_http_priming.rs

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,3 +396,175 @@ async fn test_priming_on_stream_close() -> anyhow::Result<()> {
396396

397397
Ok(())
398398
}
399+
400+
#[cfg(test)]
401+
mod test_priming_resume {
402+
use rmcp::{
403+
ServerHandler,
404+
handler::server::router::tool::ToolRouter,
405+
model::{CallToolResult, Content, ErrorData as McpError, ServerCapabilities, ServerInfo},
406+
service::{RoleClient, RunningService, serve_client},
407+
tool, tool_handler, tool_router,
408+
transport::{
409+
StreamableHttpClientTransport,
410+
streamable_http_client::StreamableHttpClientTransportConfig,
411+
streamable_http_server::{
412+
StreamableHttpServerConfig, StreamableHttpService,
413+
session::local::LocalSessionManager,
414+
},
415+
},
416+
};
417+
use tokio_util::sync::CancellationToken;
418+
419+
const CALL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(15);
420+
const REQWEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(3);
421+
const LONG_TASK_DURATION: std::time::Duration = std::time::Duration::from_secs(5);
422+
423+
async fn setup_server() -> anyhow::Result<(
424+
std::net::SocketAddr,
425+
CancellationToken,
426+
tokio::task::JoinHandle<()>,
427+
)> {
428+
let ct = CancellationToken::new();
429+
let service: StreamableHttpService<LongRunning, LocalSessionManager> =
430+
StreamableHttpService::new(
431+
|| Ok(LongRunning::new()),
432+
Default::default(),
433+
StreamableHttpServerConfig::default()
434+
.with_sse_keep_alive(None)
435+
.with_cancellation_token(ct.child_token()),
436+
);
437+
let router = axum::Router::new().nest_service("/mcp", service);
438+
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
439+
let addr = tcp_listener.local_addr()?;
440+
let server_handle = tokio::spawn({
441+
let ct = ct.clone();
442+
async move {
443+
let _ = axum::serve(tcp_listener, router)
444+
.with_graceful_shutdown(ct.cancelled_owned())
445+
.await;
446+
}
447+
});
448+
Ok((addr, ct, server_handle))
449+
}
450+
451+
async fn setup_client(
452+
addr: std::net::SocketAddr,
453+
) -> anyhow::Result<RunningService<RoleClient, ()>> {
454+
let reqwest_client = reqwest::Client::builder()
455+
.timeout(REQWEST_TIMEOUT)
456+
.connection_verbose(true)
457+
.build()?;
458+
let transport = StreamableHttpClientTransport::with_client(
459+
reqwest_client,
460+
StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")),
461+
);
462+
Ok(serve_client((), transport).await?)
463+
}
464+
465+
fn assert_tool_success(label: &str, result: &CallToolResult) {
466+
assert!(
467+
result.is_error != Some(true),
468+
"{label} call_tool expected success, got: {result:?}"
469+
);
470+
assert_eq!(
471+
result.content.len(),
472+
1,
473+
"{label} call_tool expected 1 content item"
474+
);
475+
assert_eq!(
476+
result.content[0].as_text().unwrap().text,
477+
"Long task completed"
478+
);
479+
}
480+
481+
#[derive(Debug, Clone, Default)]
482+
pub struct LongRunning {
483+
tool_router: ToolRouter<Self>,
484+
}
485+
486+
impl LongRunning {
487+
pub fn new() -> Self {
488+
Self {
489+
tool_router: Self::tool_router(),
490+
}
491+
}
492+
}
493+
494+
#[tool_router]
495+
impl LongRunning {
496+
#[tool(description = "Run a long running tool call")]
497+
async fn long_task(&self) -> Result<CallToolResult, McpError> {
498+
tokio::time::sleep(LONG_TASK_DURATION).await;
499+
Ok(CallToolResult::success(vec![Content::text(
500+
"Long task completed",
501+
)]))
502+
}
503+
}
504+
505+
#[tool_handler(router = self.tool_router)]
506+
impl ServerHandler for LongRunning {
507+
fn get_info(&self) -> ServerInfo {
508+
ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
509+
}
510+
}
511+
512+
#[tokio::test]
513+
async fn test_long_running_tool_single_via_mcp_client() -> anyhow::Result<()> {
514+
let (addr, ct, server_handle) = setup_server().await?;
515+
let client = setup_client(addr).await?;
516+
517+
let result = tokio::time::timeout(
518+
CALL_TIMEOUT,
519+
client.call_tool(rmcp::model::CallToolRequestParams::new("long_task")),
520+
)
521+
.await;
522+
523+
let _ = client.cancel().await;
524+
ct.cancel();
525+
server_handle.await?;
526+
527+
let result = result.expect("call_tool timed out - client may be stuck in endless loop")?;
528+
assert_tool_success("single", &result);
529+
530+
Ok(())
531+
}
532+
533+
#[tokio::test]
534+
async fn test_long_running_tool_parallel_via_mcp_client() -> anyhow::Result<()> {
535+
let (addr, ct, server_handle) = setup_server().await?;
536+
let client = setup_client(addr).await?;
537+
538+
let parallel_handle = tokio::spawn({
539+
let client = client.clone();
540+
async move {
541+
tokio::time::sleep(std::time::Duration::from_secs(4)).await;
542+
client
543+
.call_tool(rmcp::model::CallToolRequestParams::new("long_task"))
544+
.await
545+
}
546+
});
547+
548+
let (main_result, parallel_result) = tokio::join!(
549+
tokio::time::timeout(
550+
CALL_TIMEOUT,
551+
client.call_tool(rmcp::model::CallToolRequestParams::new("long_task")),
552+
),
553+
tokio::time::timeout(CALL_TIMEOUT, parallel_handle),
554+
);
555+
556+
let _ = client.cancel().await;
557+
ct.cancel();
558+
server_handle.await?;
559+
560+
let result =
561+
main_result.expect("main call_tool timed out - client may be stuck in endless loop")?;
562+
let parallel = parallel_result
563+
.expect("parallel call_tool timed out - client may be stuck in endless loop")?;
564+
565+
assert_tool_success("parallel", &parallel?);
566+
assert_tool_success("main", &result);
567+
568+
Ok(())
569+
}
570+
}

0 commit comments

Comments
 (0)