Skip to content

Commit ad17a4e

Browse files
committed
Resource subscriptions
Signed-off-by: Dariusz Jędrzejczyk <dariusz.jedrzejczyk@broadcom.com>
1 parent 46bacda commit ad17a4e

12 files changed

+538
-30
lines changed

mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java

Lines changed: 81 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
package io.modelcontextprotocol.server;
66

77
import java.time.Duration;
8+
import java.util.Collections;
89
import java.util.HashMap;
910
import java.util.List;
1011
import java.util.Map;
1112
import java.util.Optional;
13+
import java.util.Set;
1214
import java.util.UUID;
1315
import java.util.concurrent.ConcurrentHashMap;
1416
import java.util.concurrent.CopyOnWriteArrayList;
@@ -25,7 +27,6 @@
2527
import io.modelcontextprotocol.spec.McpSchema.CompleteResult.CompleteCompletion;
2628
import io.modelcontextprotocol.spec.McpSchema.ErrorCodes;
2729
import io.modelcontextprotocol.spec.McpSchema.LoggingLevel;
28-
import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification;
2930
import io.modelcontextprotocol.spec.McpSchema.PromptReference;
3031
import io.modelcontextprotocol.spec.McpSchema.ResourceReference;
3132
import io.modelcontextprotocol.spec.McpSchema.SetLevelRequest;
@@ -111,12 +112,10 @@ public class McpAsyncServer {
111112

112113
private final ConcurrentHashMap<String, McpServerFeatures.AsyncPromptSpecification> prompts = new ConcurrentHashMap<>();
113114

114-
// FIXME: this field is deprecated and should be remvoed together with the
115-
// broadcasting loggingNotification.
116-
private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG;
117-
118115
private final ConcurrentHashMap<McpSchema.CompleteReference, McpServerFeatures.AsyncCompletionSpecification> completions = new ConcurrentHashMap<>();
119116

117+
private final ConcurrentHashMap<String, Set<String>> resourceSubscriptions = new ConcurrentHashMap<>();
118+
120119
private List<String> protocolVersions;
121120

122121
private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory();
@@ -149,8 +148,11 @@ public class McpAsyncServer {
149148

150149
this.protocolVersions = mcpTransportProvider.protocolVersions();
151150

152-
mcpTransportProvider.setSessionFactory(transport -> new McpServerSession(UUID.randomUUID().toString(),
153-
requestTimeout, transport, this::asyncInitializeRequestHandler, requestHandlers, notificationHandlers));
151+
mcpTransportProvider.setSessionFactory(transport -> {
152+
String sessionId = UUID.randomUUID().toString();
153+
return new McpServerSession(sessionId, requestTimeout, transport, this::asyncInitializeRequestHandler,
154+
requestHandlers, notificationHandlers, () -> this.cleanupForSession(sessionId));
155+
});
154156
}
155157

156158
McpAsyncServer(McpStreamableServerTransportProvider mcpTransportProvider, McpJsonMapper jsonMapper,
@@ -174,8 +176,9 @@ public class McpAsyncServer {
174176

175177
this.protocolVersions = mcpTransportProvider.protocolVersions();
176178

177-
mcpTransportProvider.setSessionFactory(new DefaultMcpStreamableServerSessionFactory(requestTimeout,
178-
this::asyncInitializeRequestHandler, requestHandlers, notificationHandlers));
179+
mcpTransportProvider.setSessionFactory(
180+
new DefaultMcpStreamableServerSessionFactory(requestTimeout, this::asyncInitializeRequestHandler,
181+
requestHandlers, notificationHandlers, sessionId -> this.cleanupForSession(sessionId)));
179182
}
180183

181184
private Map<String, McpNotificationHandler> prepareNotificationHandlers(McpServerFeatures.Async features) {
@@ -215,6 +218,10 @@ private Map<String, McpRequestHandler<?>> prepareRequestHandlers() {
215218
requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler());
216219
requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler());
217220
requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler());
221+
if (Boolean.TRUE.equals(this.serverCapabilities.resources().subscribe())) {
222+
requestHandlers.put(McpSchema.METHOD_RESOURCES_SUBSCRIBE, resourcesSubscribeRequestHandler());
223+
requestHandlers.put(McpSchema.METHOD_RESOURCES_UNSUBSCRIBE, resourcesUnsubscribeRequestHandler());
224+
}
218225
}
219226

220227
// Add prompts API handlers if provider exists
@@ -685,12 +692,73 @@ public Mono<Void> notifyResourcesListChanged() {
685692
}
686693

687694
/**
688-
* Notifies clients that the resources have updated.
689-
* @return A Mono that completes when all clients have been notified
695+
* Notifies only the sessions that have subscribed to the updated resource URI.
696+
* @param resourcesUpdatedNotification the notification containing the updated
697+
* resource URI
698+
* @return A Mono that completes when all subscribed sessions have been notified
690699
*/
691700
public Mono<Void> notifyResourcesUpdated(McpSchema.ResourcesUpdatedNotification resourcesUpdatedNotification) {
692-
return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_UPDATED,
693-
resourcesUpdatedNotification);
701+
return Mono.defer(() -> {
702+
String uri = resourcesUpdatedNotification.uri();
703+
Set<String> subscribedSessions = this.resourceSubscriptions.get(uri);
704+
if (subscribedSessions == null || subscribedSessions.isEmpty()) {
705+
logger.debug("No sessions subscribed to resource URI: {}", uri);
706+
return Mono.empty();
707+
}
708+
return Flux.fromIterable(subscribedSessions)
709+
.flatMap(sessionId -> this.mcpTransportProvider
710+
.notifyClient(sessionId, McpSchema.METHOD_NOTIFICATION_RESOURCES_UPDATED,
711+
resourcesUpdatedNotification)
712+
.doOnError(e -> logger.error("Failed to notify session {} of resource update for {}", sessionId,
713+
uri, e))
714+
.onErrorComplete())
715+
.then();
716+
});
717+
}
718+
719+
private Mono<Void> cleanupForSession(String sessionId) {
720+
return Mono.fromRunnable(() -> {
721+
removeSessionSubscriptions(sessionId);
722+
});
723+
}
724+
725+
private void removeSessionSubscriptions(String sessionId) {
726+
this.resourceSubscriptions.forEach((uri, sessions) -> sessions.remove(sessionId));
727+
this.resourceSubscriptions.entrySet().removeIf(entry -> entry.getValue().isEmpty());
728+
}
729+
730+
private McpRequestHandler<Object> resourcesSubscribeRequestHandler() {
731+
return (exchange, params) -> Mono.defer(() -> {
732+
McpSchema.SubscribeRequest subscribeRequest = jsonMapper.convertValue(params,
733+
new TypeRef<McpSchema.SubscribeRequest>() {
734+
});
735+
String uri = subscribeRequest.uri();
736+
String sessionId = exchange.sessionId();
737+
this.resourceSubscriptions.computeIfAbsent(uri, k -> Collections.newSetFromMap(new ConcurrentHashMap<>()))
738+
.add(sessionId);
739+
logger.debug("Session {} subscribed to resource URI: {}", sessionId, uri);
740+
741+
return Mono.just(Map.of());
742+
});
743+
}
744+
745+
private McpRequestHandler<Object> resourcesUnsubscribeRequestHandler() {
746+
return (exchange, params) -> Mono.defer(() -> {
747+
McpSchema.UnsubscribeRequest unsubscribeRequest = jsonMapper.convertValue(params,
748+
new TypeRef<McpSchema.UnsubscribeRequest>() {
749+
});
750+
String uri = unsubscribeRequest.uri();
751+
String sessionId = exchange.sessionId();
752+
Set<String> sessions = this.resourceSubscriptions.get(uri);
753+
if (sessions != null) {
754+
sessions.remove(sessionId);
755+
if (sessions.isEmpty()) {
756+
this.resourceSubscriptions.remove(uri, sessions);
757+
}
758+
}
759+
logger.debug("Session {} unsubscribed from resource URI: {}", sessionId, uri);
760+
return Mono.just(Map.of());
761+
});
694762
}
695763

696764
private McpRequestHandler<McpSchema.ListResourcesResult> resourcesListRequestHandler() {
@@ -878,10 +946,6 @@ private McpRequestHandler<Object> setLoggerRequestHandler() {
878946

879947
exchange.setMinLoggingLevel(newMinLoggingLevel.level());
880948

881-
// FIXME: this field is deprecated and should be removed together
882-
// with the broadcasting loggingNotification.
883-
this.minLoggingLevel = newMinLoggingLevel.level();
884-
885949
return Mono.just(Map.of());
886950
});
887951
};

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,24 @@ public Mono<Void> notifyClients(String method, Object params) {
228228
.then();
229229
}
230230

231+
@Override
232+
public Mono<Void> notifyClient(String sessionId, String method, Object params) {
233+
return Mono.defer(() -> {
234+
// Need to iterate in O(n) because the transport session id
235+
// is different from the server-logical session id
236+
McpServerSession session = sessions.values()
237+
.stream()
238+
.filter(s -> sessionId.equals(s.getId()))
239+
.findFirst()
240+
.orElse(null);
241+
if (session == null) {
242+
logger.debug("Session {} not found", sessionId);
243+
return Mono.empty();
244+
}
245+
return session.sendNotification(method, params);
246+
});
247+
}
248+
231249
/**
232250
* Handles GET requests to establish SSE connections.
233251
* <p>

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,18 @@ public Mono<Void> notifyClients(String method, Object params) {
206206
});
207207
}
208208

209+
@Override
210+
public Mono<Void> notifyClient(String sessionId, String method, Object params) {
211+
return Mono.defer(() -> {
212+
McpStreamableServerSession session = this.sessions.get(sessionId);
213+
if (session == null) {
214+
logger.debug("Session {} not found", sessionId);
215+
return Mono.empty();
216+
}
217+
return session.sendNotification(method, params);
218+
});
219+
}
220+
209221
/**
210222
* Initiates a graceful shutdown of the transport.
211223
* @return A Mono that completes when all cleanup operations are finished

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,26 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) {
9898
@Override
9999
public Mono<Void> notifyClients(String method, Object params) {
100100
if (this.session == null) {
101-
return Mono.error(new IllegalStateException("No session to close"));
101+
return Mono.error(new IllegalStateException("No session to notify"));
102102
}
103103
return this.session.sendNotification(method, params)
104104
.doOnError(e -> logger.error("Failed to send notification: {}", e.getMessage()));
105105
}
106106

107+
@Override
108+
public Mono<Void> notifyClient(String sessionId, String method, Object params) {
109+
return Mono.defer(() -> {
110+
if (this.session == null) {
111+
return Mono.error(new IllegalStateException("No session to notify"));
112+
}
113+
if (!this.session.getId().equals(sessionId)) {
114+
return Mono.error(new IllegalStateException("Existing session id " + this.session.getId()
115+
+ " doesn't match the notification target: " + sessionId));
116+
}
117+
return this.session.sendNotification(method, params);
118+
});
119+
}
120+
107121
@Override
108122
public Mono<Void> closeGracefully() {
109123
if (this.session == null) {

mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
import java.time.Duration;
1111
import java.util.Map;
1212
import java.util.UUID;
13+
import java.util.function.Function;
14+
15+
import reactor.core.publisher.Mono;
1316

1417
/**
1518
* A default implementation of {@link McpStreamableServerSession.Factory}.
@@ -26,29 +29,53 @@ public class DefaultMcpStreamableServerSessionFactory implements McpStreamableSe
2629

2730
Map<String, McpNotificationHandler> notificationHandlers;
2831

32+
private final Function<String, Mono<Void>> onClose;
33+
2934
/**
30-
* Constructs an instance
35+
* Constructs an instance.
3136
* @param requestTimeout timeout for requests
3237
* @param initRequestHandler initialization request handler
3338
* @param requestHandlers map of MCP request handlers keyed by method name
3439
* @param notificationHandlers map of MCP notification handlers keyed by method name
40+
* @param onClose reactive callback invoked with the session ID when a session is
41+
* closed
3542
*/
3643
public DefaultMcpStreamableServerSessionFactory(Duration requestTimeout,
3744
McpStreamableServerSession.InitRequestHandler initRequestHandler,
38-
Map<String, McpRequestHandler<?>> requestHandlers,
39-
Map<String, McpNotificationHandler> notificationHandlers) {
45+
Map<String, McpRequestHandler<?>> requestHandlers, Map<String, McpNotificationHandler> notificationHandlers,
46+
Function<String, Mono<Void>> onClose) {
4047
this.requestTimeout = requestTimeout;
4148
this.initRequestHandler = initRequestHandler;
4249
this.requestHandlers = requestHandlers;
4350
this.notificationHandlers = notificationHandlers;
51+
this.onClose = onClose;
52+
}
53+
54+
/**
55+
* Constructs an instance.
56+
* @param requestTimeout timeout for requests
57+
* @param initRequestHandler initialization request handler
58+
* @param requestHandlers map of MCP request handlers keyed by method name
59+
* @param notificationHandlers map of MCP notification handlers keyed by method name
60+
* @deprecated Use
61+
* {@link #DefaultMcpStreamableServerSessionFactory(Duration, McpStreamableServerSession.InitRequestHandler, Map, Map, Function)}
62+
* instead
63+
*/
64+
@Deprecated
65+
public DefaultMcpStreamableServerSessionFactory(Duration requestTimeout,
66+
McpStreamableServerSession.InitRequestHandler initRequestHandler,
67+
Map<String, McpRequestHandler<?>> requestHandlers,
68+
Map<String, McpNotificationHandler> notificationHandlers) {
69+
this(requestTimeout, initRequestHandler, requestHandlers, notificationHandlers, sessionId -> Mono.empty());
4470
}
4571

4672
@Override
4773
public McpStreamableServerSession.McpStreamableServerSessionInit startSession(
4874
McpSchema.InitializeRequest initializeRequest) {
49-
return new McpStreamableServerSession.McpStreamableServerSessionInit(
50-
new McpStreamableServerSession(UUID.randomUUID().toString(), initializeRequest.capabilities(),
51-
initializeRequest.clientInfo(), requestTimeout, requestHandlers, notificationHandlers),
75+
String sessionId = UUID.randomUUID().toString();
76+
return new McpStreamableServerSession.McpStreamableServerSessionInit(new McpStreamableServerSession(sessionId,
77+
initializeRequest.capabilities(), initializeRequest.clientInfo(), requestTimeout, requestHandlers,
78+
notificationHandlers, () -> this.onClose.apply(sessionId)),
5279
this.initRequestHandler.handle(initializeRequest));
5380
}
5481

mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import java.time.Duration;
88
import java.util.Map;
99
import java.util.concurrent.ConcurrentHashMap;
10+
import java.util.function.Supplier;
1011
import java.util.concurrent.atomic.AtomicInteger;
1112
import java.util.concurrent.atomic.AtomicLong;
1213
import java.util.concurrent.atomic.AtomicReference;
@@ -65,25 +66,47 @@ public class McpServerSession implements McpLoggableSession {
6566

6667
private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO;
6768

69+
private final Supplier<Mono<Void>> onClose;
70+
6871
/**
6972
* Creates a new server session with the given parameters and the transport to use.
7073
* @param id session id
74+
* @param requestTimeout duration to wait for request responses before timing out
7175
* @param transport the transport to use
7276
* @param initHandler called when a
7377
* {@link io.modelcontextprotocol.spec.McpSchema.InitializeRequest} is received by the
7478
* server
7579
* @param requestHandlers map of request handlers to use
7680
* @param notificationHandlers map of notification handlers to use
81+
* @param onClose supplier of a reactive callback invoked when the session is closed
7782
*/
7883
public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport,
7984
McpInitRequestHandler initHandler, Map<String, McpRequestHandler<?>> requestHandlers,
80-
Map<String, McpNotificationHandler> notificationHandlers) {
85+
Map<String, McpNotificationHandler> notificationHandlers, Supplier<Mono<Void>> onClose) {
8186
this.id = id;
8287
this.requestTimeout = requestTimeout;
8388
this.transport = transport;
8489
this.initRequestHandler = initHandler;
8590
this.requestHandlers = requestHandlers;
8691
this.notificationHandlers = notificationHandlers;
92+
this.onClose = onClose;
93+
}
94+
95+
/**
96+
* Creates a new server session with the given parameters and the transport to use.
97+
* @param id session id
98+
* @param requestTimeout duration to wait for request responses before timing out
99+
* @param transport the transport to use
100+
* @param initHandler called when a
101+
* {@link io.modelcontextprotocol.spec.McpSchema.InitializeRequest} is received by the
102+
* server
103+
* @param requestHandlers map of request handlers to use
104+
* @param notificationHandlers map of notification handlers to use
105+
*/
106+
public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport,
107+
McpInitRequestHandler initHandler, Map<String, McpRequestHandler<?>> requestHandlers,
108+
Map<String, McpNotificationHandler> notificationHandlers) {
109+
this(id, requestTimeout, transport, initHandler, requestHandlers, notificationHandlers, Mono::empty);
87110
}
88111

89112
/**
@@ -318,12 +341,13 @@ private MethodNotFoundError getMethodNotFoundError(String method) {
318341
@Override
319342
public Mono<Void> closeGracefully() {
320343
// TODO: clear pendingResponses and emit errors?
321-
return this.transport.closeGracefully();
344+
return this.onClose.get().onErrorComplete().then(this.transport.closeGracefully());
322345
}
323346

324347
@Override
325348
public void close() {
326349
// TODO: clear pendingResponses and emit errors?
350+
this.onClose.get().onErrorComplete().subscribe();
327351
this.transport.close();
328352
}
329353

mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,23 @@ public interface McpServerTransportProviderBase {
4545
*/
4646
Mono<Void> notifyClients(String method, Object params);
4747

48+
/**
49+
* Sends a notification to a specific client session. Transport providers that support
50+
* resource subscriptions must override this method to enable per-session
51+
* notifications. The default implementation returns an error indicating that this
52+
* operation is not supported.
53+
* @param sessionId the id of the session to notify
54+
* @param method the name of the notification method to be called on the client
55+
* @param params parameters to be sent with the notification
56+
* @return a Mono that completes when the notification has been sent, or empty if the
57+
* session is not found
58+
*/
59+
default Mono<Void> notifyClient(String sessionId, String method, Object params) {
60+
return Mono.error(
61+
new UnsupportedOperationException("This transport provider does not support per-session notifications. "
62+
+ "Override notifyClient() to enable resource subscription support."));
63+
}
64+
4865
/**
4966
* Immediately closes all the transports with connected clients and releases any
5067
* associated resources.

0 commit comments

Comments
 (0)