Skip to content

Commit 92fc86c

Browse files
committed
feat: add support for reactive hook in MCP client initialization
- Add connectHook to McpClient Sync and Async builders to allow custom reactive transformations of the connection lifecycle. - Update McpClientSession to propagate connection errors to pending requests, preventing 'onErrorDropped' scenarios. - Add McpClientInitializationTests to verify error propagation and custom hook functionality. Closes #712
1 parent cbb235f commit 92fc86c

File tree

5 files changed

+128
-13
lines changed

5 files changed

+128
-13
lines changed

mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import io.modelcontextprotocol.spec.McpSchema.Root;
3939
import io.modelcontextprotocol.util.Assert;
4040
import io.modelcontextprotocol.util.Utils;
41+
import org.reactivestreams.Publisher;
4142
import org.slf4j.Logger;
4243
import org.slf4j.LoggerFactory;
4344
import reactor.core.publisher.Flux;
@@ -317,9 +318,17 @@ public class McpAsyncClient {
317318
};
318319

319320
this.initializer = new LifecycleInitializer(clientCapabilities, clientInfo, transport.protocolVersions(),
320-
initializationTimeout, ctx -> new McpClientSession(requestTimeout, transport, requestHandlers,
321-
notificationHandlers, con -> con.contextWrite(ctx)),
322-
postInitializationHook);
321+
initializationTimeout, ctx -> {
322+
Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook;
323+
if (features.connectHook() != null) {
324+
connectHook = con -> features.connectHook().apply(con.contextWrite(ctx));
325+
}
326+
else {
327+
connectHook = con -> con.contextWrite(ctx);
328+
}
329+
return new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers,
330+
connectHook);
331+
}, postInitializationHook);
323332

324333
this.transport.setExceptionHandler(this.initializer::handleException);
325334
}

mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import io.modelcontextprotocol.spec.McpSchema.Root;
1919
import io.modelcontextprotocol.spec.McpTransport;
2020
import io.modelcontextprotocol.util.Assert;
21+
import org.reactivestreams.Publisher;
2122
import reactor.core.publisher.Mono;
2223

2324
import java.time.Duration;
@@ -195,6 +196,8 @@ class SyncSpec {
195196

196197
private boolean enableCallToolSchemaCaching = false; // Default to false
197198

199+
private Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook;
200+
198201
private SyncSpec(McpClientTransport transport) {
199202
Assert.notNull(transport, "Transport must not be null");
200203
this.transport = transport;
@@ -479,6 +482,17 @@ public SyncSpec enableCallToolSchemaCaching(boolean enableCallToolSchemaCaching)
479482
return this;
480483
}
481484

485+
/**
486+
* Allows to add a reactive hook to the connection lifecycle. This hook can be
487+
* used to intercept connection events, add retry logic, or handle errors.
488+
* @param connectHook the connection hook.
489+
* @return this builder instance for method chaining
490+
*/
491+
public SyncSpec connectHook(Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook) {
492+
this.connectHook = connectHook;
493+
return this;
494+
}
495+
482496
/**
483497
* Create an instance of {@link McpSyncClient} with the provided configurations or
484498
* sensible defaults.
@@ -488,7 +502,7 @@ public McpSyncClient build() {
488502
McpClientFeatures.Sync syncFeatures = new McpClientFeatures.Sync(this.clientInfo, this.capabilities,
489503
this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers,
490504
this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, this.samplingHandler,
491-
this.elicitationHandler, this.enableCallToolSchemaCaching);
505+
this.elicitationHandler, this.enableCallToolSchemaCaching, this.connectHook);
492506

493507
McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures);
494508

@@ -549,6 +563,8 @@ class AsyncSpec {
549563

550564
private boolean enableCallToolSchemaCaching = false; // Default to false
551565

566+
private Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook;
567+
552568
private AsyncSpec(McpClientTransport transport) {
553569
Assert.notNull(transport, "Transport must not be null");
554570
this.transport = transport;
@@ -820,6 +836,17 @@ public AsyncSpec enableCallToolSchemaCaching(boolean enableCallToolSchemaCaching
820836
return this;
821837
}
822838

839+
/**
840+
* Allows to add a reactive hook to the connection lifecycle. This hook can be
841+
* used to intercept connection events, add retry logic, or handle errors.
842+
* @param connectHook the connection hook.
843+
* @return this builder instance for method chaining
844+
*/
845+
public AsyncSpec connectHook(Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook) {
846+
this.connectHook = connectHook;
847+
return this;
848+
}
849+
823850
/**
824851
* Create an instance of {@link McpAsyncClient} with the provided configurations
825852
* or sensible defaults.
@@ -833,7 +860,8 @@ public McpAsyncClient build() {
833860
new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots,
834861
this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers,
835862
this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers,
836-
this.samplingHandler, this.elicitationHandler, this.enableCallToolSchemaCaching));
863+
this.samplingHandler, this.elicitationHandler, this.enableCallToolSchemaCaching,
864+
this.connectHook));
837865
}
838866

839867
}

mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import io.modelcontextprotocol.spec.McpSchema;
1616
import io.modelcontextprotocol.util.Assert;
1717
import io.modelcontextprotocol.util.Utils;
18+
import org.reactivestreams.Publisher;
1819
import reactor.core.publisher.Mono;
1920
import reactor.core.scheduler.Schedulers;
2021

@@ -57,12 +58,14 @@ class McpClientFeatures {
5758
* @param roots the roots.
5859
* @param toolsChangeConsumers the tools change consumers.
5960
* @param resourcesChangeConsumers the resources change consumers.
61+
* @param resourcesUpdateConsumers the resources update consumers.
6062
* @param promptsChangeConsumers the prompts change consumers.
6163
* @param loggingConsumers the logging consumers.
6264
* @param progressConsumers the progress consumers.
6365
* @param samplingHandler the sampling handler.
6466
* @param elicitationHandler the elicitation handler.
6567
* @param enableCallToolSchemaCaching whether to enable call tool schema caching.
68+
* @param connectHook the connection hook.
6669
*/
6770
record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities,
6871
Map<String, McpSchema.Root> roots, List<Function<List<McpSchema.Tool>, Mono<Void>>> toolsChangeConsumers,
@@ -73,20 +76,23 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
7376
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers,
7477
Function<McpSchema.CreateMessageRequest, Mono<McpSchema.CreateMessageResult>> samplingHandler,
7578
Function<McpSchema.ElicitRequest, Mono<McpSchema.ElicitResult>> elicitationHandler,
76-
boolean enableCallToolSchemaCaching) {
79+
boolean enableCallToolSchemaCaching, Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook) {
7780

7881
/**
7982
* Create an instance and validate the arguments.
83+
* @param clientInfo the client implementation information.
8084
* @param clientCapabilities the client capabilities.
8185
* @param roots the roots.
8286
* @param toolsChangeConsumers the tools change consumers.
8387
* @param resourcesChangeConsumers the resources change consumers.
88+
* @param resourcesUpdateConsumers the resources update consumers.
8489
* @param promptsChangeConsumers the prompts change consumers.
8590
* @param loggingConsumers the logging consumers.
8691
* @param progressConsumers the progress consumers.
8792
* @param samplingHandler the sampling handler.
8893
* @param elicitationHandler the elicitation handler.
8994
* @param enableCallToolSchemaCaching whether to enable call tool schema caching.
95+
* @param connectHook the connection hook.
9096
*/
9197
public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities,
9298
Map<String, McpSchema.Root> roots,
@@ -98,7 +104,8 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
98104
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers,
99105
Function<McpSchema.CreateMessageRequest, Mono<McpSchema.CreateMessageResult>> samplingHandler,
100106
Function<McpSchema.ElicitRequest, Mono<McpSchema.ElicitResult>> elicitationHandler,
101-
boolean enableCallToolSchemaCaching) {
107+
boolean enableCallToolSchemaCaching,
108+
Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook) {
102109

103110
Assert.notNull(clientInfo, "Client info must not be null");
104111
this.clientInfo = clientInfo;
@@ -118,6 +125,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
118125
this.samplingHandler = samplingHandler;
119126
this.elicitationHandler = elicitationHandler;
120127
this.enableCallToolSchemaCaching = enableCallToolSchemaCaching;
128+
this.connectHook = connectHook;
121129
}
122130

123131
/**
@@ -134,7 +142,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
134142
Function<McpSchema.ElicitRequest, Mono<McpSchema.ElicitResult>> elicitationHandler) {
135143
this(clientInfo, clientCapabilities, roots, toolsChangeConsumers, resourcesChangeConsumers,
136144
resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), samplingHandler,
137-
elicitationHandler, false);
145+
elicitationHandler, false, null);
138146
}
139147

140148
/**
@@ -193,7 +201,7 @@ public static Async fromSync(Sync syncSpec) {
193201
return new Async(syncSpec.clientInfo(), syncSpec.clientCapabilities(), syncSpec.roots(),
194202
toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers,
195203
loggingConsumers, progressConsumers, samplingHandler, elicitationHandler,
196-
syncSpec.enableCallToolSchemaCaching);
204+
syncSpec.enableCallToolSchemaCaching, syncSpec.connectHook());
197205
}
198206
}
199207

@@ -206,12 +214,14 @@ public static Async fromSync(Sync syncSpec) {
206214
* @param roots the roots.
207215
* @param toolsChangeConsumers the tools change consumers.
208216
* @param resourcesChangeConsumers the resources change consumers.
217+
* @param resourcesUpdateConsumers the resources update consumers.
209218
* @param promptsChangeConsumers the prompts change consumers.
210219
* @param loggingConsumers the logging consumers.
211220
* @param progressConsumers the progress consumers.
212221
* @param samplingHandler the sampling handler.
213222
* @param elicitationHandler the elicitation handler.
214223
* @param enableCallToolSchemaCaching whether to enable call tool schema caching.
224+
* @param connectHook the connection hook.
215225
*/
216226
public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities,
217227
Map<String, McpSchema.Root> roots, List<Consumer<List<McpSchema.Tool>>> toolsChangeConsumers,
@@ -222,7 +232,7 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili
222232
List<Consumer<McpSchema.ProgressNotification>> progressConsumers,
223233
Function<McpSchema.CreateMessageRequest, McpSchema.CreateMessageResult> samplingHandler,
224234
Function<McpSchema.ElicitRequest, McpSchema.ElicitResult> elicitationHandler,
225-
boolean enableCallToolSchemaCaching) {
235+
boolean enableCallToolSchemaCaching, Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook) {
226236

227237
/**
228238
* Create an instance and validate the arguments.
@@ -238,6 +248,7 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili
238248
* @param samplingHandler the sampling handler.
239249
* @param elicitationHandler the elicitation handler.
240250
* @param enableCallToolSchemaCaching whether to enable call tool schema caching.
251+
* @param connectHook the connection hook.
241252
*/
242253
public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities,
243254
Map<String, McpSchema.Root> roots, List<Consumer<List<McpSchema.Tool>>> toolsChangeConsumers,
@@ -248,7 +259,8 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl
248259
List<Consumer<McpSchema.ProgressNotification>> progressConsumers,
249260
Function<McpSchema.CreateMessageRequest, McpSchema.CreateMessageResult> samplingHandler,
250261
Function<McpSchema.ElicitRequest, McpSchema.ElicitResult> elicitationHandler,
251-
boolean enableCallToolSchemaCaching) {
262+
boolean enableCallToolSchemaCaching,
263+
Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook) {
252264

253265
Assert.notNull(clientInfo, "Client info must not be null");
254266
this.clientInfo = clientInfo;
@@ -268,6 +280,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl
268280
this.samplingHandler = samplingHandler;
269281
this.elicitationHandler = elicitationHandler;
270282
this.enableCallToolSchemaCaching = enableCallToolSchemaCaching;
283+
this.connectHook = connectHook;
271284
}
272285

273286
/**
@@ -283,7 +296,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl
283296
Function<McpSchema.ElicitRequest, McpSchema.ElicitResult> elicitationHandler) {
284297
this(clientInfo, clientCapabilities, roots, toolsChangeConsumers, resourcesChangeConsumers,
285298
resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), samplingHandler,
286-
elicitationHandler, false);
299+
elicitationHandler, false, null);
287300
}
288301
}
289302

mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,15 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport,
119119
this.requestHandlers.putAll(requestHandlers);
120120
this.notificationHandlers.putAll(notificationHandlers);
121121

122-
this.transport.connect(mono -> mono.doOnNext(this::handle)).transform(connectHook).subscribe();
122+
this.transport.connect(mono -> mono.doOnNext(this::handle)).transform(connectHook).subscribe(v -> {
123+
}, error -> {
124+
logger.error("MCP session connection error", error);
125+
this.pendingResponses.forEach((id, sink) -> {
126+
logger.warn("Terminating exchange for request {} due to connection error", id);
127+
sink.error(new RuntimeException("MCP session connection error", error));
128+
});
129+
this.pendingResponses.clear();
130+
});
123131
}
124132

125133
private void dismissPendingResponses() {
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package io.modelcontextprotocol.client;
2+
3+
import java.time.Duration;
4+
import java.util.concurrent.atomic.AtomicReference;
5+
6+
import io.modelcontextprotocol.client.transport.ServerParameters;
7+
import io.modelcontextprotocol.client.transport.StdioClientTransport;
8+
import io.modelcontextprotocol.json.McpJsonDefaults;
9+
import io.modelcontextprotocol.json.McpJsonMapper;
10+
import org.junit.jupiter.api.Test;
11+
12+
import static org.assertj.core.api.Assertions.assertThat;
13+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
14+
15+
public class McpClientInitializationTests {
16+
17+
private static final McpJsonMapper JSON_MAPPER = McpJsonDefaults.getMapper();
18+
19+
@Test
20+
void reproduceInitializeErrorShouldNotBeDropped() {
21+
ServerParameters stdioParams = ServerParameters.builder("non-existent-command").build();
22+
StdioClientTransport transport = new StdioClientTransport(stdioParams, JSON_MAPPER);
23+
24+
McpSyncClient client = McpClient.sync(transport).requestTimeout(Duration.ofSeconds(2)).build();
25+
26+
assertThatExceptionOfType(RuntimeException.class).isThrownBy(client::initialize)
27+
.withMessageContaining("Client failed to initialize")
28+
.satisfies(ex -> {
29+
assertThat(ex.getCause().getMessage()).contains("MCP session connection error");
30+
assertThat(ex.getCause().getCause().getMessage()).contains("Failed to start process");
31+
});
32+
}
33+
34+
@Test
35+
void verifyConnectHook() {
36+
ServerParameters stdioParams = ServerParameters.builder("non-existent-command").build();
37+
StdioClientTransport transport = new StdioClientTransport(stdioParams, JSON_MAPPER);
38+
39+
AtomicReference<Throwable> hookError = new AtomicReference<>();
40+
41+
McpSyncClient client = McpClient.sync(transport)
42+
.requestTimeout(Duration.ofSeconds(2))
43+
.connectHook(mono -> mono.doOnError(hookError::set))
44+
.build();
45+
46+
try {
47+
client.initialize();
48+
}
49+
catch (Exception e) {
50+
// ignore
51+
}
52+
53+
assertThat(hookError.get()).isNotNull();
54+
assertThat(hookError.get().getMessage()).contains("Failed to start process");
55+
}
56+
57+
}

0 commit comments

Comments
 (0)