Skip to content

Commit 602c930

Browse files
committed
feat: implement support for elicitation
1 parent 07e7b8f commit 602c930

File tree

13 files changed

+1525
-31
lines changed

13 files changed

+1525
-31
lines changed

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.junit.jupiter.api.BeforeEach;
3434
import org.junit.jupiter.params.ParameterizedTest;
3535
import org.junit.jupiter.params.provider.ValueSource;
36+
import reactor.core.publisher.Mono;
3637
import reactor.netty.DisposableServer;
3738
import reactor.netty.http.server.HttpServer;
3839

@@ -41,6 +42,7 @@
4142
import org.springframework.web.client.RestClient;
4243
import org.springframework.web.reactive.function.client.WebClient;
4344
import org.springframework.web.reactive.function.server.RouterFunctions;
45+
import reactor.test.StepVerifier;
4446

4547
import static org.assertj.core.api.Assertions.assertThat;
4648
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@@ -331,6 +333,229 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt
331333
mcpServer.closeGracefully().block();
332334
}
333335

336+
// ---------------------------------------
337+
// Elicitation Tests
338+
// ---------------------------------------
339+
@ParameterizedTest(name = "{0} : {displayName} ")
340+
@ValueSource(strings = { "httpclient", "webflux" })
341+
void testCreateElicitationWithoutElicitationCapabilities(String clientType) {
342+
343+
var clientBuilder = clientBuilders.get(clientType);
344+
345+
McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
346+
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {
347+
348+
exchange.createElicitation(mock(ElicitRequest.class)).block();
349+
350+
return Mono.just(mock(CallToolResult.class));
351+
});
352+
353+
var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build();
354+
355+
try (
356+
// Create client without sampling capabilities
357+
var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) {
358+
359+
assertThat(client.initialize()).isNotNull();
360+
361+
try {
362+
client.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
363+
}
364+
catch (McpError e) {
365+
assertThat(e).isInstanceOf(McpError.class)
366+
.hasMessage("Client must be configured with elicitation capabilities");
367+
}
368+
}
369+
server.closeGracefully().block();
370+
}
371+
372+
@ParameterizedTest(name = "{0} : {displayName} ")
373+
@ValueSource(strings = { "httpclient", "webflux" })
374+
void testCreateElicitationSuccess(String clientType) {
375+
376+
var clientBuilder = clientBuilders.get(clientType);
377+
378+
Function<ElicitRequest, ElicitResult> elicitationHandler = request -> {
379+
assertThat(request.message()).isNotEmpty();
380+
assertThat(request.requestedSchema()).isNotNull();
381+
382+
return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message()));
383+
};
384+
385+
CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
386+
null);
387+
388+
McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
389+
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {
390+
391+
var elicitationRequest = ElicitRequest.builder()
392+
.message("Test message")
393+
.requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder()
394+
.properties(Map.of("message", McpSchema.StringSchema.builder().build()))
395+
.build())
396+
.build();
397+
398+
StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> {
399+
assertThat(result).isNotNull();
400+
assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT);
401+
assertThat(result.content().get("message")).isEqualTo("Test message");
402+
}).verifyComplete();
403+
404+
return Mono.just(callResponse);
405+
});
406+
407+
var mcpServer = McpServer.async(mcpServerTransportProvider)
408+
.serverInfo("test-server", "1.0.0")
409+
.tools(tool)
410+
.build();
411+
412+
try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
413+
.capabilities(ClientCapabilities.builder().elicitation().build())
414+
.elicitation(elicitationHandler)
415+
.build()) {
416+
417+
InitializeResult initResult = mcpClient.initialize();
418+
assertThat(initResult).isNotNull();
419+
420+
CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
421+
422+
assertThat(response).isNotNull();
423+
assertThat(response).isEqualTo(callResponse);
424+
}
425+
mcpServer.closeGracefully().block();
426+
}
427+
428+
@ParameterizedTest(name = "{0} : {displayName} ")
429+
@ValueSource(strings = { "httpclient", "webflux" })
430+
void testCreateElicitationWithRequestTimeoutSuccess(String clientType) {
431+
432+
// Client
433+
var clientBuilder = clientBuilders.get(clientType);
434+
435+
Function<ElicitRequest, ElicitResult> elicitationHandler = request -> {
436+
assertThat(request.message()).isNotEmpty();
437+
assertThat(request.requestedSchema()).isNotNull();
438+
try {
439+
TimeUnit.SECONDS.sleep(2);
440+
}
441+
catch (InterruptedException e) {
442+
throw new RuntimeException(e);
443+
}
444+
return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message()));
445+
};
446+
447+
var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
448+
.capabilities(ClientCapabilities.builder().elicitation().build())
449+
.elicitation(elicitationHandler)
450+
.build();
451+
452+
// Server
453+
454+
CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
455+
null);
456+
457+
McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
458+
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {
459+
460+
var elicitationRequest = ElicitRequest.builder()
461+
.message("Test message")
462+
.requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder()
463+
.properties(Map.of("message", McpSchema.StringSchema.builder().build()))
464+
.build())
465+
.build();
466+
467+
StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> {
468+
assertThat(result).isNotNull();
469+
assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT);
470+
assertThat(result.content().get("message")).isEqualTo("Test message");
471+
}).verifyComplete();
472+
473+
return Mono.just(callResponse);
474+
});
475+
476+
var mcpServer = McpServer.async(mcpServerTransportProvider)
477+
.serverInfo("test-server", "1.0.0")
478+
.requestTimeout(Duration.ofSeconds(3))
479+
.tools(tool)
480+
.build();
481+
482+
InitializeResult initResult = mcpClient.initialize();
483+
assertThat(initResult).isNotNull();
484+
485+
CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
486+
487+
assertThat(response).isNotNull();
488+
assertThat(response).isEqualTo(callResponse);
489+
490+
mcpClient.closeGracefully();
491+
mcpServer.closeGracefully().block();
492+
}
493+
494+
@ParameterizedTest(name = "{0} : {displayName} ")
495+
@ValueSource(strings = { "httpclient", "webflux" })
496+
void testCreateElicitationWithRequestTimeoutFail(String clientType) {
497+
498+
// Client
499+
var clientBuilder = clientBuilders.get(clientType);
500+
501+
Function<ElicitRequest, ElicitResult> elicitationHandler = request -> {
502+
assertThat(request.message()).isNotEmpty();
503+
assertThat(request.requestedSchema()).isNotNull();
504+
try {
505+
TimeUnit.SECONDS.sleep(2);
506+
}
507+
catch (InterruptedException e) {
508+
throw new RuntimeException(e);
509+
}
510+
return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message()));
511+
};
512+
513+
var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
514+
.capabilities(ClientCapabilities.builder().elicitation().build())
515+
.elicitation(elicitationHandler)
516+
.build();
517+
518+
// Server
519+
520+
CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
521+
null);
522+
523+
McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
524+
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {
525+
526+
var elicitationRequest = ElicitRequest.builder()
527+
.message("Test message")
528+
.requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder()
529+
.properties(Map.of("message", McpSchema.StringSchema.builder().build()))
530+
.build())
531+
.build();
532+
533+
StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> {
534+
assertThat(result).isNotNull();
535+
assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT);
536+
assertThat(result.content().get("message")).isEqualTo("Test message");
537+
}).verifyComplete();
538+
539+
return Mono.just(callResponse);
540+
});
541+
542+
var mcpServer = McpServer.async(mcpServerTransportProvider)
543+
.serverInfo("test-server", "1.0.0")
544+
.requestTimeout(Duration.ofSeconds(1))
545+
.tools(tool)
546+
.build();
547+
548+
InitializeResult initResult = mcpClient.initialize();
549+
assertThat(initResult).isNotNull();
550+
551+
assertThatExceptionOfType(McpError.class).isThrownBy(() -> {
552+
mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
553+
}).withMessageContaining("within 1000ms");
554+
555+
mcpClient.closeGracefully();
556+
mcpServer.closeGracefully().block();
557+
}
558+
334559
// ---------------------------------------
335560
// Roots Tests
336561
// ---------------------------------------

0 commit comments

Comments
 (0)