Skip to content

Commit 4a6e2fe

Browse files
temporal-spring-ai: discover MCP clients by type, not by bean name
McpPlugin.getMcpClients() now calls ApplicationContext.getBeansOfType( McpSyncClient.class) instead of looking up the hard-coded bean name "mcpSyncClients". The unchecked cast is removed; discovered bean names are logged for easier debugging. Tests: McpPluginTest exercises three cases via Mockito — - two McpSyncClient beans discovered by type both reach McpClientActivityImpl; - zero beans leaves the worker queued for deferred registration, and afterSingletonsInstantiated handles "still no beans" cleanly; - beans that appear between initializeWorker and afterSingletonsInstantiated get picked up on the deferred attempt. build.gradle: spring-ai-mcp added as a testImplementation so tests can mock McpSyncClient directly. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 7107046 commit 4a6e2fe

3 files changed

Lines changed: 113 additions & 10 deletions

File tree

temporal-spring-ai/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ dependencies {
4646
testImplementation "org.mockito:mockito-core:${mockitoVersion}"
4747
testImplementation 'org.springframework.boot:spring-boot-starter-test'
4848
testImplementation 'org.springframework.ai:spring-ai-rag'
49+
// Needed only so McpPluginTest can mock/reference McpSyncClient directly.
50+
testImplementation 'org.springframework.ai:spring-ai-mcp'
4951

5052
testRuntimeOnly group: 'ch.qos.logback', name: 'logback-classic', version: "${logbackVersion}"
5153
testRuntimeOnly "org.junit.platform:junit-platform-launcher"

temporal-spring-ai/src/main/java/io/temporal/springai/plugin/McpPlugin.java

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import io.temporal.worker.Worker;
77
import java.util.ArrayList;
88
import java.util.List;
9+
import java.util.Map;
910
import javax.annotation.Nonnull;
1011
import org.slf4j.Logger;
1112
import org.slf4j.LoggerFactory;
@@ -39,22 +40,22 @@ public void setApplicationContext(ApplicationContext applicationContext) throws
3940
this.applicationContext = applicationContext;
4041
}
4142

42-
@SuppressWarnings("unchecked")
4343
private List<McpSyncClient> getMcpClients() {
4444
if (!mcpClients.isEmpty()) {
4545
return mcpClients;
4646
}
47+
if (applicationContext == null) {
48+
return mcpClients;
49+
}
4750

48-
if (applicationContext != null && applicationContext.containsBean("mcpSyncClients")) {
49-
try {
50-
Object bean = applicationContext.getBean("mcpSyncClients");
51-
if (bean instanceof List<?> clientList && !clientList.isEmpty()) {
52-
mcpClients = (List<McpSyncClient>) clientList;
53-
log.info("Found {} MCP client(s) in ApplicationContext", mcpClients.size());
54-
}
55-
} catch (Exception e) {
56-
log.debug("Failed to get mcpSyncClients bean: {}", e.getMessage());
51+
try {
52+
Map<String, McpSyncClient> beans = applicationContext.getBeansOfType(McpSyncClient.class);
53+
if (!beans.isEmpty()) {
54+
mcpClients = List.copyOf(beans.values());
55+
log.info("Discovered {} MCP client bean(s): {}", beans.size(), beans.keySet());
5756
}
57+
} catch (Exception e) {
58+
log.debug("Failed to look up McpSyncClient beans: {}", e.getMessage());
5859
}
5960

6061
return mcpClients;
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package io.temporal.springai.plugin;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.mockito.ArgumentMatchers.any;
5+
import static org.mockito.Mockito.atLeastOnce;
6+
import static org.mockito.Mockito.mock;
7+
import static org.mockito.Mockito.verify;
8+
import static org.mockito.Mockito.verifyNoInteractions;
9+
import static org.mockito.Mockito.when;
10+
11+
import io.modelcontextprotocol.client.McpSyncClient;
12+
import io.modelcontextprotocol.spec.McpSchema;
13+
import io.temporal.springai.mcp.McpClientActivityImpl;
14+
import io.temporal.worker.Worker;
15+
import java.util.LinkedHashMap;
16+
import java.util.Map;
17+
import org.junit.jupiter.api.Test;
18+
import org.mockito.ArgumentCaptor;
19+
import org.springframework.context.ApplicationContext;
20+
21+
class McpPluginTest {
22+
23+
@Test
24+
void discoversMcpClientBeansByType() {
25+
McpSyncClient clientA = mockClientNamed("alpha");
26+
McpSyncClient clientB = mockClientNamed("beta");
27+
28+
// Spring's getBeansOfType keeps insertion order via LinkedHashMap; use that for determinism.
29+
Map<String, McpSyncClient> beans = new LinkedHashMap<>();
30+
beans.put("mcpClientAlpha", clientA);
31+
beans.put("mcpClientBeta", clientB);
32+
33+
ApplicationContext ctx = mock(ApplicationContext.class);
34+
when(ctx.getBeansOfType(McpSyncClient.class)).thenReturn(beans);
35+
36+
McpPlugin plugin = new McpPlugin();
37+
plugin.setApplicationContext(ctx);
38+
39+
Worker worker = mock(Worker.class);
40+
plugin.initializeWorker("mcp-tq", worker);
41+
42+
ArgumentCaptor<Object> captor = ArgumentCaptor.forClass(Object.class);
43+
verify(worker, atLeastOnce()).registerActivitiesImplementations(captor.capture());
44+
Object registered = captor.getValue();
45+
assertEquals(McpClientActivityImpl.class, registered.getClass());
46+
47+
// Duplicate-name protection in McpClientActivityImpl still fires if two clients share a
48+
// clientInfo().name(); here they differ ("alpha" vs "beta") so construction succeeds.
49+
}
50+
51+
@Test
52+
void noMcpBeans_defersWorker_thenClearsAfterSingletonsInstantiated() {
53+
ApplicationContext ctx = mock(ApplicationContext.class);
54+
when(ctx.getBeansOfType(McpSyncClient.class)).thenReturn(Map.of());
55+
56+
McpPlugin plugin = new McpPlugin();
57+
plugin.setApplicationContext(ctx);
58+
59+
Worker worker = mock(Worker.class);
60+
plugin.initializeWorker("mcp-tq", worker);
61+
62+
// No beans → nothing registered yet, worker queued for deferred attempt.
63+
verifyNoInteractions(worker);
64+
65+
plugin.afterSingletonsInstantiated();
66+
67+
// Still no beans — the deferred attempt also finds nothing and doesn't crash.
68+
verify(worker, org.mockito.Mockito.never()).registerActivitiesImplementations((Object[]) any());
69+
}
70+
71+
@Test
72+
void beansAppearLate_registeredViaAfterSingletonsInstantiated() {
73+
ApplicationContext ctx = mock(ApplicationContext.class);
74+
// First lookup returns empty (Spring AI MCP bean hasn't been created yet when
75+
// initializeWorker runs).
76+
when(ctx.getBeansOfType(McpSyncClient.class))
77+
.thenReturn(Map.of())
78+
.thenReturn(Map.of("mcpClient", mockClientNamed("late")));
79+
80+
McpPlugin plugin = new McpPlugin();
81+
plugin.setApplicationContext(ctx);
82+
83+
Worker worker = mock(Worker.class);
84+
plugin.initializeWorker("mcp-tq", worker);
85+
verifyNoInteractions(worker);
86+
87+
plugin.afterSingletonsInstantiated();
88+
89+
ArgumentCaptor<Object> captor = ArgumentCaptor.forClass(Object.class);
90+
verify(worker, atLeastOnce()).registerActivitiesImplementations(captor.capture());
91+
assertEquals(McpClientActivityImpl.class, captor.getValue().getClass());
92+
}
93+
94+
private static McpSyncClient mockClientNamed(String name) {
95+
McpSyncClient client = mock(McpSyncClient.class);
96+
McpSchema.Implementation info = new McpSchema.Implementation(name, "1.0.0");
97+
when(client.getClientInfo()).thenReturn(info);
98+
return client;
99+
}
100+
}

0 commit comments

Comments
 (0)