forked from modelcontextprotocol/java-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathWebMvcSseSyncServerTransportTests.java
More file actions
115 lines (90 loc) · 3.34 KB
/
WebMvcSseSyncServerTransportTests.java
File metadata and controls
115 lines (90 loc) · 3.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
/*
* Copyright 2024-2024 the original author or authors.
*/
package io.modelcontextprotocol.server;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider;
import org.apache.catalina.Context;
import org.apache.catalina.LifecycleException;
import org.apache.catalina.startup.Tomcat;
import org.junit.jupiter.api.Timeout;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
import org.springframework.web.servlet.function.RouterFunction;
import org.springframework.web.servlet.function.ServerResponse;
@Timeout(15)
class WebMvcSseSyncServerTransportTests extends AbstractMcpSyncServerTests {
private static final String MESSAGE_ENDPOINT = "/mcp/message";
private static final int PORT = TomcatTestUtil.findAvailablePort();
private Tomcat tomcat;
private WebMvcSseServerTransportProvider transportProvider;
@Configuration
@EnableWebMvc
static class TestConfig {
@Bean
public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() {
return new WebMvcSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT);
}
@Bean
public RouterFunction<ServerResponse> routerFunction(WebMvcSseServerTransportProvider transportProvider) {
return transportProvider.getRouterFunction();
}
}
private AnnotationConfigWebApplicationContext appContext;
@Override
protected WebMvcSseServerTransportProvider createMcpTransportProvider() {
// Set up Tomcat first
tomcat = new Tomcat();
tomcat.setPort(PORT);
// Set Tomcat base directory to java.io.tmpdir to avoid permission issues
String baseDir = System.getProperty("java.io.tmpdir");
tomcat.setBaseDir(baseDir);
// Use the same directory for document base
Context context = tomcat.addContext("", baseDir);
// Create and configure Spring WebMvc context
appContext = new AnnotationConfigWebApplicationContext();
appContext.register(TestConfig.class);
appContext.setServletContext(context.getServletContext());
appContext.refresh();
// Get the transport from Spring context
transportProvider = appContext.getBean(WebMvcSseServerTransportProvider.class);
// Create DispatcherServlet with our Spring context
DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext);
// Add servlet to Tomcat and get the wrapper
var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet);
wrapper.setLoadOnStartup(1);
context.addServletMappingDecoded("/*", "dispatcherServlet");
try {
tomcat.start();
tomcat.getConnector(); // Create and start the connector
}
catch (LifecycleException e) {
throw new RuntimeException("Failed to start Tomcat", e);
}
return transportProvider;
}
@Override
protected void onStart() {
}
@Override
protected void onClose() {
if (transportProvider != null) {
transportProvider.closeGracefully().block();
}
if (appContext != null) {
appContext.close();
}
if (tomcat != null) {
try {
tomcat.stop();
tomcat.destroy();
}
catch (LifecycleException e) {
throw new RuntimeException("Failed to stop Tomcat", e);
}
}
}
}