Skip to content

Commit 3ee9a44

Browse files
committed
Forwards all MCP non-transport headers, to downstream methods
1 parent df99408 commit 3ee9a44

File tree

4 files changed

+47
-9
lines changed

4 files changed

+47
-9
lines changed

springdoc-openapi-starter-common-mcp/src/main/java/org/springdoc/ai/dashboard/McpDashboardController.java

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
import java.util.LinkedHashMap;
3131
import java.util.List;
3232
import java.util.Map;
33-
import java.util.Set;
3433

3534
import io.swagger.v3.oas.annotations.Operation;
35+
import org.springdoc.ai.mcp.McpRequestContextHolder;
3636

3737
import org.springframework.http.ResponseEntity;
3838
import org.springframework.web.bind.annotation.DeleteMapping;
@@ -89,13 +89,6 @@ public List<McpToolInfo> listTools() {
8989
return List.copyOf(toolsByName.values());
9090
}
9191

92-
/**
93-
* Headers to skip when forwarding to tool execution.
94-
*/
95-
private static final Set<String> SKIP_HEADERS = Set.of("content-type", "content-length", "host", "connection",
96-
"accept", "accept-encoding", "accept-language", "user-agent", "origin", "referer", "cookie",
97-
"sec-fetch-dest", "sec-fetch-mode", "sec-fetch-site", "sec-ch-ua", "sec-ch-ua-mobile",
98-
"sec-ch-ua-platform");
9992

10093
/**
10194
* Executes an MCP tool by name, forwarding any authentication headers (Authorization,
@@ -110,7 +103,7 @@ public ResponseEntity<McpToolExecutionResponse> executeTool(@RequestBody McpTool
110103
@RequestHeader Map<String, String> allHeaders) {
111104
Map<String, String> extraHeaders = new HashMap<>();
112105
for (Map.Entry<String, String> entry : allHeaders.entrySet()) {
113-
if (!SKIP_HEADERS.contains(entry.getKey().toLowerCase()) && entry.getValue() != null
106+
if (!McpRequestContextHolder.SKIP_HEADERS.contains(entry.getKey().toLowerCase()) && entry.getValue() != null
114107
&& !entry.getValue().isEmpty()) {
115108
extraHeaders.put(entry.getKey(), entry.getValue());
116109
}

springdoc-openapi-starter-common-mcp/src/main/java/org/springdoc/ai/mcp/OpenApiToolCallback.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,12 @@ private HttpResponse<String> executeHttp(String toolInput, Map<String, String> e
402402
if (extraHeaders != null) {
403403
extraHeaders.forEach(requestBuilder::header);
404404
}
405+
else {
406+
Map<String, String> contextHeaders = McpRequestContextHolder.getHeaders();
407+
if (contextHeaders != null) {
408+
contextHeaders.forEach(requestBuilder::header);
409+
}
410+
}
405411

406412
String bodyString = buildBodyString(input);
407413
HttpRequest.BodyPublisher bodyPublisher = bodyString != null

springdoc-openapi-starter-webflux-mcp/src/main/java/org/springdoc/webflux/ai/McpWebFluxAiAutoConfiguration.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,18 @@
6666
@ConditionalOnWebApplication(type = Type.REACTIVE)
6767
public class McpWebFluxAiAutoConfiguration {
6868

69+
/**
70+
* Creates the {@link McpAuditMdcWebFilter} that populates MDC and
71+
* {@link org.springdoc.ai.mcp.McpRequestContextHolder} for MCP requests.
72+
* @param aiProperties the AI properties (used to scope the filter to the MCP path)
73+
* @return the WebFilter bean
74+
*/
75+
@Bean
76+
@ConditionalOnMissingBean(McpAuditMdcWebFilter.class)
77+
McpAuditMdcWebFilter mcpAuditMdcWebFilter(SpringDocAiProperties aiProperties) {
78+
return new McpAuditMdcWebFilter(aiProperties.getMcpEndpoint());
79+
}
80+
6981
/**
7082
* Creates the {@link McpToolDescriptionCustomizer} bean that scans WebFlux handler
7183
* methods for {@link org.springdoc.ai.annotations.McpToolDescription} annotations.

springdoc-openapi-starter-webmvc-mcp/src/main/java/org/springdoc/webmvc/ai/McpAuditMdcFilter.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,16 @@
2727
package org.springdoc.webmvc.ai;
2828

2929
import java.io.IOException;
30+
import java.util.Enumeration;
31+
import java.util.HashMap;
32+
import java.util.Map;
3033

3134
import jakarta.servlet.FilterChain;
3235
import jakarta.servlet.ServletException;
3336
import jakarta.servlet.http.HttpServletRequest;
3437
import jakarta.servlet.http.HttpServletResponse;
3538
import org.slf4j.MDC;
39+
import org.springdoc.ai.mcp.McpRequestContextHolder;
3640

3741
import org.springframework.web.filter.OncePerRequestFilter;
3842

@@ -90,15 +94,38 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
9094
if (sessionId != null && !sessionId.isBlank()) {
9195
MDC.put(MDC_SESSION_ID, sessionId);
9296
}
97+
McpRequestContextHolder.setHeaders(extractForwardableHeaders(request));
9398
try {
9499
filterChain.doFilter(request, response);
95100
}
96101
finally {
97102
MDC.remove(MDC_CLIENT_IP);
98103
MDC.remove(MDC_SESSION_ID);
104+
McpRequestContextHolder.clear();
99105
}
100106
}
101107

108+
/**
109+
* Extracts headers from the servlet request that should be forwarded to downstream
110+
* REST API calls, filtering out standard HTTP transport/browser headers.
111+
* @param request the HTTP servlet request
112+
* @return the forwardable headers map
113+
*/
114+
private Map<String, String> extractForwardableHeaders(HttpServletRequest request) {
115+
Map<String, String> headers = new HashMap<>();
116+
Enumeration<String> headerNames = request.getHeaderNames();
117+
while (headerNames.hasMoreElements()) {
118+
String name = headerNames.nextElement();
119+
if (!McpRequestContextHolder.SKIP_HEADERS.contains(name.toLowerCase())) {
120+
String value = request.getHeader(name);
121+
if (value != null && !value.isEmpty()) {
122+
headers.put(name, value);
123+
}
124+
}
125+
}
126+
return headers;
127+
}
128+
102129
/**
103130
* Resolves the originating client IP from the request. Prefers the first value in
104131
* {@code X-Forwarded-For} when set, otherwise falls back to the direct remote

0 commit comments

Comments
 (0)