@@ -396,3 +396,175 @@ async fn test_priming_on_stream_close() -> anyhow::Result<()> {
396396
397397 Ok ( ( ) )
398398}
399+
400+ #[ cfg( test) ]
401+ mod test_priming_resume {
402+ use rmcp:: {
403+ ServerHandler ,
404+ handler:: server:: router:: tool:: ToolRouter ,
405+ model:: { CallToolResult , Content , ErrorData as McpError , ServerCapabilities , ServerInfo } ,
406+ service:: { RoleClient , RunningService , serve_client} ,
407+ tool, tool_handler, tool_router,
408+ transport:: {
409+ StreamableHttpClientTransport ,
410+ streamable_http_client:: StreamableHttpClientTransportConfig ,
411+ streamable_http_server:: {
412+ StreamableHttpServerConfig , StreamableHttpService ,
413+ session:: local:: LocalSessionManager ,
414+ } ,
415+ } ,
416+ } ;
417+ use tokio_util:: sync:: CancellationToken ;
418+
419+ const CALL_TIMEOUT : std:: time:: Duration = std:: time:: Duration :: from_secs ( 15 ) ;
420+ const REQWEST_TIMEOUT : std:: time:: Duration = std:: time:: Duration :: from_secs ( 3 ) ;
421+ const LONG_TASK_DURATION : std:: time:: Duration = std:: time:: Duration :: from_secs ( 5 ) ;
422+
423+ async fn setup_server ( ) -> anyhow:: Result < (
424+ std:: net:: SocketAddr ,
425+ CancellationToken ,
426+ tokio:: task:: JoinHandle < ( ) > ,
427+ ) > {
428+ let ct = CancellationToken :: new ( ) ;
429+ let service: StreamableHttpService < LongRunning , LocalSessionManager > =
430+ StreamableHttpService :: new (
431+ || Ok ( LongRunning :: new ( ) ) ,
432+ Default :: default ( ) ,
433+ StreamableHttpServerConfig :: default ( )
434+ . with_sse_keep_alive ( None )
435+ . with_cancellation_token ( ct. child_token ( ) ) ,
436+ ) ;
437+ let router = axum:: Router :: new ( ) . nest_service ( "/mcp" , service) ;
438+ let tcp_listener = tokio:: net:: TcpListener :: bind ( "127.0.0.1:0" ) . await ?;
439+ let addr = tcp_listener. local_addr ( ) ?;
440+ let server_handle = tokio:: spawn ( {
441+ let ct = ct. clone ( ) ;
442+ async move {
443+ let _ = axum:: serve ( tcp_listener, router)
444+ . with_graceful_shutdown ( ct. cancelled_owned ( ) )
445+ . await ;
446+ }
447+ } ) ;
448+ Ok ( ( addr, ct, server_handle) )
449+ }
450+
451+ async fn setup_client (
452+ addr : std:: net:: SocketAddr ,
453+ ) -> anyhow:: Result < RunningService < RoleClient , ( ) > > {
454+ let reqwest_client = reqwest:: Client :: builder ( )
455+ . timeout ( REQWEST_TIMEOUT )
456+ . connection_verbose ( true )
457+ . build ( ) ?;
458+ let transport = StreamableHttpClientTransport :: with_client (
459+ reqwest_client,
460+ StreamableHttpClientTransportConfig :: with_uri ( format ! ( "http://{addr}/mcp" ) ) ,
461+ ) ;
462+ Ok ( serve_client ( ( ) , transport) . await ?)
463+ }
464+
465+ fn assert_tool_success ( label : & str , result : & CallToolResult ) {
466+ assert ! (
467+ result. is_error != Some ( true ) ,
468+ "{label} call_tool expected success, got: {result:?}"
469+ ) ;
470+ assert_eq ! (
471+ result. content. len( ) ,
472+ 1 ,
473+ "{label} call_tool expected 1 content item"
474+ ) ;
475+ assert_eq ! (
476+ result. content[ 0 ] . as_text( ) . unwrap( ) . text,
477+ "Long task completed"
478+ ) ;
479+ }
480+
481+ #[ derive( Debug , Clone , Default ) ]
482+ pub struct LongRunning {
483+ tool_router : ToolRouter < Self > ,
484+ }
485+
486+ impl LongRunning {
487+ pub fn new ( ) -> Self {
488+ Self {
489+ tool_router : Self :: tool_router ( ) ,
490+ }
491+ }
492+ }
493+
494+ #[ tool_router]
495+ impl LongRunning {
496+ #[ tool( description = "Run a long running tool call" ) ]
497+ async fn long_task ( & self ) -> Result < CallToolResult , McpError > {
498+ tokio:: time:: sleep ( LONG_TASK_DURATION ) . await ;
499+ Ok ( CallToolResult :: success ( vec ! [ Content :: text(
500+ "Long task completed" ,
501+ ) ] ) )
502+ }
503+ }
504+
505+ #[ tool_handler( router = self . tool_router) ]
506+ impl ServerHandler for LongRunning {
507+ fn get_info ( & self ) -> ServerInfo {
508+ ServerInfo :: new ( ServerCapabilities :: builder ( ) . enable_tools ( ) . build ( ) )
509+ }
510+ }
511+
512+ #[ tokio:: test]
513+ async fn test_long_running_tool_single_via_mcp_client ( ) -> anyhow:: Result < ( ) > {
514+ let ( addr, ct, server_handle) = setup_server ( ) . await ?;
515+ let client = setup_client ( addr) . await ?;
516+
517+ let result = tokio:: time:: timeout (
518+ CALL_TIMEOUT ,
519+ client. call_tool ( rmcp:: model:: CallToolRequestParams :: new ( "long_task" ) ) ,
520+ )
521+ . await ;
522+
523+ let _ = client. cancel ( ) . await ;
524+ ct. cancel ( ) ;
525+ server_handle. await ?;
526+
527+ let result = result. expect ( "call_tool timed out - client may be stuck in endless loop" ) ?;
528+ assert_tool_success ( "single" , & result) ;
529+
530+ Ok ( ( ) )
531+ }
532+
533+ #[ tokio:: test]
534+ async fn test_long_running_tool_parallel_via_mcp_client ( ) -> anyhow:: Result < ( ) > {
535+ let ( addr, ct, server_handle) = setup_server ( ) . await ?;
536+ let client = setup_client ( addr) . await ?;
537+
538+ let parallel_handle = tokio:: spawn ( {
539+ let client = client. clone ( ) ;
540+ async move {
541+ tokio:: time:: sleep ( std:: time:: Duration :: from_secs ( 4 ) ) . await ;
542+ client
543+ . call_tool ( rmcp:: model:: CallToolRequestParams :: new ( "long_task" ) )
544+ . await
545+ }
546+ } ) ;
547+
548+ let ( main_result, parallel_result) = tokio:: join!(
549+ tokio:: time:: timeout(
550+ CALL_TIMEOUT ,
551+ client. call_tool( rmcp:: model:: CallToolRequestParams :: new( "long_task" ) ) ,
552+ ) ,
553+ tokio:: time:: timeout( CALL_TIMEOUT , parallel_handle) ,
554+ ) ;
555+
556+ let _ = client. cancel ( ) . await ;
557+ ct. cancel ( ) ;
558+ server_handle. await ?;
559+
560+ let result =
561+ main_result. expect ( "main call_tool timed out - client may be stuck in endless loop" ) ?;
562+ let parallel = parallel_result
563+ . expect ( "parallel call_tool timed out - client may be stuck in endless loop" ) ?;
564+
565+ assert_tool_success ( "parallel" , & parallel?) ;
566+ assert_tool_success ( "main" , & result) ;
567+
568+ Ok ( ( ) )
569+ }
570+ }
0 commit comments