Skip to content

Commit 14df32f

Browse files
committed
Forwards all MCP non-transport headers, to downstream methods
1 parent 3ee9a44 commit 14df32f

File tree

2 files changed

+235
-0
lines changed

2 files changed

+235
-0
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
*
3+
* *
4+
* * *
5+
* * * *
6+
* * * * *
7+
* * * * * * Copyright 2019-2026 the original author or authors.
8+
* * * * * *
9+
* * * * * * Licensed under the Apache License, Version 2.0 (the "License");
10+
* * * * * * you may not use this file except in compliance with the License.
11+
* * * * * * You may obtain a copy of the License at
12+
* * * * * *
13+
* * * * * * https://www.apache.org/licenses/LICENSE-2.0
14+
* * * * * *
15+
* * * * * * Unless required by applicable law or agreed to in writing, software
16+
* * * * * * distributed under the License is distributed on an "AS IS" BASIS,
17+
* * * * * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
* * * * * * See the License for the specific language governing permissions and
19+
* * * * * * limitations under the License.
20+
* * * * *
21+
* * * *
22+
* * *
23+
* *
24+
*
25+
*/
26+
27+
package org.springdoc.ai.mcp;
28+
29+
import java.util.Map;
30+
import java.util.Set;
31+
32+
/**
33+
* Thread-local holder for HTTP headers captured from inbound MCP requests. Platform-specific
34+
* filters ({@code McpAuditMdcFilter} for WebMVC, {@code McpAuditMdcWebFilter} for WebFlux)
35+
* populate this holder so that {@link OpenApiToolCallback} can propagate headers
36+
* (e.g. {@code Authorization}) to downstream REST API calls.
37+
*
38+
* <p>Uses a plain {@link ThreadLocal} (not {@link InheritableThreadLocal}) to avoid
39+
* leaking credentials to child threads.
40+
*
41+
* @author bnasslahsen
42+
*/
43+
public final class McpRequestContextHolder {
44+
45+
/**
46+
* Headers to skip when forwarding from inbound requests to downstream API calls.
47+
* These are standard HTTP transport/browser headers that should not be propagated.
48+
*/
49+
public static final Set<String> SKIP_HEADERS = Set.of("content-type", "content-length", "host", "connection",
50+
"accept", "accept-encoding", "accept-language", "user-agent", "origin", "referer", "cookie",
51+
"sec-fetch-dest", "sec-fetch-mode", "sec-fetch-site", "sec-ch-ua", "sec-ch-ua-mobile",
52+
"sec-ch-ua-platform");
53+
54+
private static final ThreadLocal<Map<String, String>> HEADERS = new ThreadLocal<>();
55+
56+
private McpRequestContextHolder() {
57+
}
58+
59+
/**
60+
* Stores the forwardable headers from the current MCP request.
61+
* @param headers the filtered headers to propagate
62+
*/
63+
public static void setHeaders(Map<String, String> headers) {
64+
HEADERS.set(headers);
65+
}
66+
67+
/**
68+
* Returns the forwardable headers captured from the current MCP request,
69+
* or {@code null} if none were set.
70+
* @return the captured headers, or {@code null}
71+
*/
72+
public static Map<String, String> getHeaders() {
73+
return HEADERS.get();
74+
}
75+
76+
/**
77+
* Removes the stored headers. Must be called in a {@code finally} block to
78+
* prevent leaking values across thread-pool reuse.
79+
*/
80+
public static void clear() {
81+
HEADERS.remove();
82+
}
83+
84+
}
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
/*
2+
*
3+
* *
4+
* * *
5+
* * * *
6+
* * * * *
7+
* * * * * * Copyright 2019-2026 the original author or authors.
8+
* * * * * *
9+
* * * * * * Licensed under the Apache License, Version 2.0 (the "License");
10+
* * * * * * you may not use this file except in compliance with the License.
11+
* * * * * * You may obtain a copy of the License at
12+
* * * * * *
13+
* * * * * * https://www.apache.org/licenses/LICENSE-2.0
14+
* * * * * *
15+
* * * * * * Unless required by applicable law or agreed to in writing, software
16+
* * * * * * distributed under the License is distributed on an "AS IS" BASIS,
17+
* * * * * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
* * * * * * See the License for the specific language governing permissions and
19+
* * * * * * limitations under the License.
20+
* * * * *
21+
* * * *
22+
* * *
23+
* *
24+
*
25+
*/
26+
27+
package org.springdoc.webflux.ai;
28+
29+
import java.net.InetAddress;
30+
import java.net.InetSocketAddress;
31+
import java.util.HashMap;
32+
import java.util.Map;
33+
34+
import org.slf4j.MDC;
35+
import org.springdoc.ai.mcp.McpRequestContextHolder;
36+
import reactor.core.publisher.Mono;
37+
38+
import org.springframework.http.HttpHeaders;
39+
import org.springframework.http.server.reactive.ServerHttpRequest;
40+
import org.springframework.web.server.ServerWebExchange;
41+
import org.springframework.web.server.WebFilter;
42+
import org.springframework.web.server.WebFilterChain;
43+
44+
/**
45+
* Reactive WebFilter that populates SLF4J MDC with the caller's IP address and MCP session
46+
* ID, and stores forwardable headers in {@link McpRequestContextHolder} for propagation
47+
* to downstream REST API calls.
48+
*
49+
* <p>This is the WebFlux equivalent of
50+
* {@code org.springdoc.webmvc.ai.McpAuditMdcFilter}.
51+
*
52+
* @author bnasslahsen
53+
*/
54+
public class McpAuditMdcWebFilter implements WebFilter {
55+
56+
/**
57+
* MDC key used by {@link org.springdoc.ai.mcp.McpAuditLogger} to read the client IP
58+
* address.
59+
*/
60+
private static final String MDC_CLIENT_IP = "clientIp";
61+
62+
/**
63+
* MDC key used by {@link org.springdoc.ai.mcp.McpAuditLogger} to read the MCP
64+
* session ID.
65+
*/
66+
private static final String MDC_SESSION_ID = "sessionId";
67+
68+
/**
69+
* HTTP header set by reverse proxies carrying the originating client IP.
70+
*/
71+
private static final String HEADER_X_FORWARDED_FOR = "X-Forwarded-For";
72+
73+
/**
74+
* MCP Streamable HTTP transport header carrying the session identifier.
75+
*/
76+
private static final String HEADER_MCP_SESSION_ID = "mcp-session-id";
77+
78+
/**
79+
* The MCP endpoint path to match against.
80+
*/
81+
private final String mcpEndpoint;
82+
83+
/**
84+
* Creates a new filter scoped to the given MCP endpoint path.
85+
* @param mcpEndpoint the MCP endpoint path (e.g. {@code /mcp})
86+
*/
87+
public McpAuditMdcWebFilter(String mcpEndpoint) {
88+
this.mcpEndpoint = mcpEndpoint;
89+
}
90+
91+
@Override
92+
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
93+
String path = exchange.getRequest().getPath().value();
94+
if (!path.startsWith(mcpEndpoint)) {
95+
return chain.filter(exchange);
96+
}
97+
ServerHttpRequest request = exchange.getRequest();
98+
MDC.put(MDC_CLIENT_IP, resolveClientIp(request));
99+
String sessionId = request.getHeaders().getFirst(HEADER_MCP_SESSION_ID);
100+
if (sessionId != null && !sessionId.isBlank()) {
101+
MDC.put(MDC_SESSION_ID, sessionId);
102+
}
103+
McpRequestContextHolder.setHeaders(extractForwardableHeaders(request.getHeaders()));
104+
return chain.filter(exchange)
105+
.doFinally(signal -> {
106+
MDC.remove(MDC_CLIENT_IP);
107+
MDC.remove(MDC_SESSION_ID);
108+
McpRequestContextHolder.clear();
109+
});
110+
}
111+
112+
/**
113+
* Extracts headers that should be forwarded to downstream REST API calls,
114+
* filtering out standard HTTP transport/browser headers.
115+
* @param httpHeaders the reactive request headers
116+
* @return the forwardable headers map
117+
*/
118+
private Map<String, String> extractForwardableHeaders(HttpHeaders httpHeaders) {
119+
Map<String, String> headers = new HashMap<>();
120+
httpHeaders.forEach((name, values) -> {
121+
if (!McpRequestContextHolder.SKIP_HEADERS.contains(name.toLowerCase()) && !values.isEmpty()) {
122+
String value = values.get(0);
123+
if (value != null && !value.isEmpty()) {
124+
headers.put(name, value);
125+
}
126+
}
127+
});
128+
return headers;
129+
}
130+
131+
/**
132+
* Resolves the originating client IP from the request. Prefers the first value in
133+
* {@code X-Forwarded-For} when set, otherwise falls back to the remote address.
134+
* @param request the reactive server request
135+
* @return the resolved client IP address string
136+
*/
137+
private String resolveClientIp(ServerHttpRequest request) {
138+
String forwarded = request.getHeaders().getFirst(HEADER_X_FORWARDED_FOR);
139+
if (forwarded != null && !forwarded.isBlank()) {
140+
int comma = forwarded.indexOf(',');
141+
return (comma >= 0 ? forwarded.substring(0, comma) : forwarded).strip();
142+
}
143+
InetSocketAddress remoteAddress = request.getRemoteAddress();
144+
if (remoteAddress != null) {
145+
InetAddress address = remoteAddress.getAddress();
146+
return address != null ? address.getHostAddress() : remoteAddress.getHostString();
147+
}
148+
return "unknown";
149+
}
150+
151+
}

0 commit comments

Comments
 (0)