Skip to content

Commit 806a1a6

Browse files
committed
test: add long-running tool resume tests
1 parent d4ea2e9 commit 806a1a6

1 file changed

Lines changed: 175 additions & 1 deletion

File tree

crates/rmcp/tests/test_streamable_http_priming.rs

Lines changed: 175 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ async fn test_resume_after_request_wise_channel_completed() -> anyhow::Result<()
278278
.header("mcp-session-id", session_id.to_string())
279279
.header("Mcp-Protocol-Version", "2025-06-18")
280280
.header("last-event-id", "0/0")
281-
.timeout(std::time::Duration::from_secs(5))
281+
.timeout(std::time::Duration::from_secs(10))
282282
.send()
283283
.await?;
284284
assert_eq!(resume_response.status(), 200);
@@ -396,3 +396,177 @@ async fn test_priming_on_stream_close() -> anyhow::Result<()> {
396396

397397
Ok(())
398398
}
399+
400+
#[cfg(test)]
401+
mod test_priming_resume {
402+
use rmcp::{
403+
ServerHandler,
404+
handler::server::router::tool::ToolRouter,
405+
model::{CallToolResult, Content, ErrorData as McpError, ServerCapabilities, ServerInfo},
406+
service::{RoleClient, RunningService, serve_client},
407+
tool, tool_handler, tool_router,
408+
transport::{
409+
StreamableHttpClientTransport,
410+
streamable_http_client::StreamableHttpClientTransportConfig,
411+
streamable_http_server::{
412+
StreamableHttpServerConfig, StreamableHttpService,
413+
session::local::LocalSessionManager,
414+
},
415+
},
416+
};
417+
use tokio_util::sync::CancellationToken;
418+
419+
const CALL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(15);
420+
const REQWEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(3);
421+
const LONG_TASK_DURATION: std::time::Duration = std::time::Duration::from_secs(5);
422+
423+
async fn setup_server() -> anyhow::Result<(
424+
std::net::SocketAddr,
425+
CancellationToken,
426+
tokio::task::JoinHandle<()>,
427+
)> {
428+
let ct = CancellationToken::new();
429+
let service: StreamableHttpService<LongRunning, LocalSessionManager> =
430+
StreamableHttpService::new(
431+
|| Ok(LongRunning::new()),
432+
Default::default(),
433+
StreamableHttpServerConfig::default()
434+
.with_sse_keep_alive(None)
435+
.with_cancellation_token(ct.child_token()),
436+
);
437+
let router = axum::Router::new().nest_service("/mcp", service);
438+
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
439+
let addr = tcp_listener.local_addr()?;
440+
let server_handle = tokio::spawn({
441+
let ct = ct.clone();
442+
async move {
443+
let _ = axum::serve(tcp_listener, router)
444+
.with_graceful_shutdown(ct.cancelled_owned())
445+
.await;
446+
}
447+
});
448+
Ok((addr, ct, server_handle))
449+
}
450+
451+
async fn setup_client(
452+
addr: std::net::SocketAddr,
453+
) -> anyhow::Result<RunningService<RoleClient, ()>> {
454+
let reqwest_client = reqwest::Client::builder()
455+
.timeout(REQWEST_TIMEOUT)
456+
.connection_verbose(true)
457+
.build()?;
458+
let transport = StreamableHttpClientTransport::with_client(
459+
reqwest_client,
460+
StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")),
461+
);
462+
Ok(serve_client((), transport).await?)
463+
}
464+
465+
fn assert_tool_success(label: &str, result: &CallToolResult) {
466+
assert!(
467+
result.is_error != Some(true),
468+
"{label} call_tool expected success, got: {result:?}"
469+
);
470+
assert_eq!(
471+
result.content.len(),
472+
1,
473+
"{label} call_tool expected 1 content item"
474+
);
475+
assert_eq!(
476+
result.content[0].as_text().unwrap().text,
477+
"Long task completed"
478+
);
479+
}
480+
481+
#[derive(Debug, Clone, Default)]
482+
pub struct LongRunning {
483+
tool_router: ToolRouter<Self>,
484+
}
485+
486+
impl LongRunning {
487+
pub fn new() -> Self {
488+
Self {
489+
tool_router: Self::tool_router(),
490+
}
491+
}
492+
}
493+
494+
#[tool_router]
495+
impl LongRunning {
496+
#[tool(description = "Run a long running tool call")]
497+
async fn long_task(&self) -> Result<CallToolResult, McpError> {
498+
tokio::time::sleep(LONG_TASK_DURATION).await;
499+
Ok(CallToolResult::success(vec![Content::text(
500+
"Long task completed",
501+
)]))
502+
}
503+
}
504+
505+
#[tool_handler(router = self.tool_router)]
506+
impl ServerHandler for LongRunning {
507+
fn get_info(&self) -> ServerInfo {
508+
ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
509+
}
510+
}
511+
512+
#[tokio::test]
513+
#[ignore = "timing-sensitive: requires reqwest timeout shorter than tool duration"]
514+
async fn test_long_running_tool_single_via_mcp_client() -> anyhow::Result<()> {
515+
let (addr, ct, server_handle) = setup_server().await?;
516+
let client = setup_client(addr).await?;
517+
518+
let result = tokio::time::timeout(
519+
CALL_TIMEOUT,
520+
client.call_tool(rmcp::model::CallToolRequestParams::new("long_task")),
521+
)
522+
.await;
523+
524+
let _ = client.cancel().await;
525+
ct.cancel();
526+
server_handle.await?;
527+
528+
let result = result.expect("call_tool timed out - client may be stuck in endless loop")?;
529+
assert_tool_success("single", &result);
530+
531+
Ok(())
532+
}
533+
534+
#[tokio::test]
535+
#[ignore = "timing-sensitive: requires reqwest timeout shorter than tool duration"]
536+
async fn test_long_running_tool_parallel_via_mcp_client() -> anyhow::Result<()> {
537+
let (addr, ct, server_handle) = setup_server().await?;
538+
let client = setup_client(addr).await?;
539+
540+
let parallel_handle = tokio::spawn({
541+
let client = client.clone();
542+
async move {
543+
tokio::time::sleep(std::time::Duration::from_secs(4)).await;
544+
client
545+
.call_tool(rmcp::model::CallToolRequestParams::new("long_task"))
546+
.await
547+
}
548+
});
549+
550+
let (main_result, parallel_result) = tokio::join!(
551+
tokio::time::timeout(
552+
CALL_TIMEOUT,
553+
client.call_tool(rmcp::model::CallToolRequestParams::new("long_task")),
554+
),
555+
tokio::time::timeout(CALL_TIMEOUT, parallel_handle),
556+
);
557+
558+
let _ = client.cancel().await;
559+
ct.cancel();
560+
server_handle.await?;
561+
562+
let result =
563+
main_result.expect("main call_tool timed out - client may be stuck in endless loop")?;
564+
let parallel = parallel_result
565+
.expect("parallel call_tool timed out - client may be stuck in endless loop")?;
566+
567+
assert_tool_success("parallel", &parallel?);
568+
assert_tool_success("main", &result);
569+
570+
Ok(())
571+
}
572+
}

0 commit comments

Comments
 (0)