Skip to content

Commit 59f174b

Browse files
committed
Per-conversation MCP sessions
1 parent 4758698 commit 59f174b

5 files changed

Lines changed: 178 additions & 23 deletions

File tree

ui/src/components/ConversationsProvider/ConversationsProvider.tsx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import {
1919
import type { ConversationWithProject, Message } from "@/api/generated/types.gen";
2020
import { useAuth } from "@/auth";
2121
import { useIndexedDB } from "@/hooks/useIndexedDB";
22+
import { useMCPStore } from "@/stores/mcpStore";
2223

2324
import type { ChatMessage, Conversation } from "@/components/chat-types";
2425
import { usePreferences } from "@/preferences/PreferencesProvider";
@@ -579,6 +580,9 @@ export function ConversationsProvider({ children }: ConversationsProviderProps)
579580

580581
const deleteConversation = useCallback(
581582
(id: string) => {
583+
// Clean up any per-conversation MCP sessions
584+
useMCPStore.getState().disconnectConversation(id);
585+
582586
// Find the conversation to get remoteId before deleting
583587
const conv = storedConversations.find((c) => c.id === id);
584588
if (conv?.remoteId && canSync) {

ui/src/pages/chat/ChatPage.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ export default function ChatPage() {
170170
captureRawSSEEvents,
171171
subAgentModel,
172172
projectId: currentConversation?.projectId ?? pendingProject.id ?? undefined,
173+
conversationId,
173174
});
174175

175176
const { moveToProject } = useConversationsContext();

ui/src/pages/chat/useChat.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ interface UseChatOptions {
130130
subAgentModel?: string | null;
131131
/** Project ID for usage attribution (sent as X-Hadrian-Project header) */
132132
projectId?: string;
133+
/** Conversation ID for per-conversation MCP sessions */
134+
conversationId?: string;
133135
}
134136

135137
/**
@@ -255,6 +257,7 @@ export function useChat({
255257
captureRawSSEEvents = false,
256258
subAgentModel,
257259
projectId,
260+
conversationId,
258261
}: UseChatOptions): UseChatReturn {
259262
const { token } = useAuth();
260263
const abortControllersRef = useRef<AbortController[]>([]);
@@ -267,6 +270,8 @@ export function useChat({
267270
// and used as a ref to ensure the latest value is available at fetch time.
268271
const projectIdRef = useRef(projectId);
269272
projectIdRef.current = projectId;
273+
const conversationIdRef = useRef(conversationId);
274+
conversationIdRef.current = conversationId;
270275
const streamingStore = useStreamingStore();
271276
const debugStore = useDebugStore();
272277
const modelResponses = useAllStreams();
@@ -1482,6 +1487,7 @@ export function useChat({
14821487
},
14831488
// Use configured sub-agent model, fall back to current streaming model
14841489
defaultModel: subAgentModel || model,
1490+
conversationId: conversationIdRef.current,
14851491
};
14861492

14871493
const toolResults = await executeToolCalls(result.toolCalls, toolContext);

ui/src/pages/chat/utils/toolExecutors.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ export interface ToolExecutorContext {
6161
onStatusMessage?: (toolCallId: string, message: string) => void;
6262
/** Default model for sub-agent tool when no model is specified in arguments */
6363
defaultModel?: string;
64+
/** Conversation ID for per-conversation MCP sessions */
65+
conversationId?: string;
6466
}
6567

6668
/**
@@ -2423,7 +2425,7 @@ const mcpToolExecutor: ToolExecutor = async (toolCall, context) => {
24232425
}
24242426

24252427
try {
2426-
const result = await callMCPTool(serverId, toolName, args);
2428+
const result = await callMCPTool(serverId, toolName, args, context.conversationId);
24272429

24282430
// Clear status message
24292431
context.onStatusMessage?.(toolCall.id, "");

ui/src/stores/mcpStore.ts

Lines changed: 164 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ interface MCPActions {
7676
ensureConnected: () => Promise<void>;
7777
/** Disconnect all servers */
7878
disconnectAll: () => void;
79+
/** Disconnect all per-conversation MCP sessions for a conversation */
80+
disconnectConversation: (conversationId: string) => void;
7981
}
8082

8183
export type MCPStore = MCPState & MCPActions;
@@ -84,43 +86,130 @@ export type MCPStore = MCPState & MCPActions;
8486
// Client Management (outside Zustand for reference stability)
8587
// =============================================================================
8688

87-
/** Map of server ID to MCPClient instance */
88-
const clients = new Map<string, MCPClient>();
89+
/** Global clients: Map of server ID to MCPClient instance (for discovery/status) */
90+
const globalClients = new Map<string, MCPClient>();
8991

90-
/** Map of server ID to cleanup functions for listeners */
91-
const listenerCleanups = new Map<string, Array<() => void>>();
92+
/** Global listener cleanups: Map of server ID to cleanup functions */
93+
const globalListenerCleanups = new Map<string, Array<() => void>>();
9294

93-
/** Get or create a client for a server */
95+
/** Per-conversation clients: Map of `${serverId}::${conversationId}` to MCPClient */
96+
const conversationClients = new Map<string, MCPClient>();
97+
98+
/** Dedup map to prevent concurrent lazy-connect calls for the same key */
99+
const connectingPromises = new Map<string, Promise<MCPClient>>();
100+
101+
/** Composite key for conversation clients */
102+
function clientKey(serverId: string, conversationId: string): string {
103+
return `${serverId}::${conversationId}`;
104+
}
105+
106+
/** Get or create a global client for a server */
94107
function getClient(server: MCPServerConfig): MCPClient {
95-
let client = clients.get(server.id);
108+
let client = globalClients.get(server.id);
109+
if (!client) {
110+
client = new MCPClient({
111+
url: server.url,
112+
name: server.name,
113+
headers: server.headers,
114+
timeout: server.timeout,
115+
});
116+
globalClients.set(server.id, client);
117+
}
118+
return client;
119+
}
120+
121+
/** Create a new client for a per-conversation session */
122+
function getConversationClient(server: MCPServerConfig, conversationId: string): MCPClient {
123+
const key = clientKey(server.id, conversationId);
124+
let client = conversationClients.get(key);
96125
if (!client) {
97126
client = new MCPClient({
98127
url: server.url,
99128
name: server.name,
100129
headers: server.headers,
101130
timeout: server.timeout,
102131
});
103-
clients.set(server.id, client);
132+
conversationClients.set(key, client);
104133
}
105134
return client;
106135
}
107136

137+
/** Get or connect a per-conversation client, deduplicating concurrent calls */
138+
async function ensureConversationClient(
139+
serverId: string,
140+
conversationId: string
141+
): Promise<MCPClient> {
142+
const key = clientKey(serverId, conversationId);
143+
const existing = conversationClients.get(key);
144+
if (existing?.isConnected()) return existing;
145+
146+
const pending = connectingPromises.get(key);
147+
if (pending) return pending;
148+
149+
const promise = (async () => {
150+
const server = useMCPStore.getState().servers.find((s) => s.id === serverId);
151+
if (!server) throw new Error(`Server not found: ${serverId}`);
152+
const client = getConversationClient(server, conversationId);
153+
await client.connect();
154+
return client;
155+
})();
156+
157+
connectingPromises.set(key, promise);
158+
try {
159+
return await promise;
160+
} finally {
161+
connectingPromises.delete(key);
162+
}
163+
}
164+
165+
/** Disconnect and remove a single per-conversation client */
166+
function removeConversationClient(serverId: string, conversationId: string): void {
167+
const key = clientKey(serverId, conversationId);
168+
const client = conversationClients.get(key);
169+
if (client) {
170+
client.disconnect();
171+
conversationClients.delete(key);
172+
}
173+
}
174+
175+
/** Disconnect all per-conversation clients for a given conversation */
176+
function removeAllClientsForConversation(conversationId: string): void {
177+
const suffix = `::${conversationId}`;
178+
for (const [key, client] of conversationClients) {
179+
if (key.endsWith(suffix)) {
180+
client.disconnect();
181+
conversationClients.delete(key);
182+
}
183+
}
184+
}
185+
186+
/** Disconnect all per-conversation clients for a given server */
187+
function removeAllConversationClientsForServer(serverId: string): void {
188+
const prefix = `${serverId}::`;
189+
for (const [key, client] of conversationClients) {
190+
if (key.startsWith(prefix)) {
191+
client.disconnect();
192+
conversationClients.delete(key);
193+
}
194+
}
195+
}
196+
108197
/** Remove listener subscriptions for a server */
109198
function cleanupListeners(serverId: string): void {
110-
const cleanups = listenerCleanups.get(serverId);
199+
const cleanups = globalListenerCleanups.get(serverId);
111200
if (cleanups) {
112201
cleanups.forEach((fn) => fn());
113-
listenerCleanups.delete(serverId);
202+
globalListenerCleanups.delete(serverId);
114203
}
115204
}
116205

117-
/** Remove and disconnect a client */
206+
/** Remove and disconnect a global client */
118207
function removeClient(serverId: string): void {
119208
cleanupListeners(serverId);
120-
const client = clients.get(serverId);
209+
const client = globalClients.get(serverId);
121210
if (client) {
122211
client.disconnect();
123-
clients.delete(serverId);
212+
globalClients.delete(serverId);
124213
}
125214
}
126215

@@ -172,6 +261,7 @@ export const useMCPStore = create<MCPStore>()(
172261

173262
removeServer: (serverId) => {
174263
removeClient(serverId);
264+
removeAllConversationClientsForServer(serverId);
175265
set((state) => ({
176266
servers: state.servers.filter((s) => s.id !== serverId),
177267
}));
@@ -182,6 +272,7 @@ export const useMCPStore = create<MCPStore>()(
182272
const server = get().servers.find((s) => s.id === serverId);
183273
if (server && (updates.url || updates.headers || updates.timeout)) {
184274
removeClient(serverId);
275+
removeAllConversationClientsForServer(serverId);
185276
}
186277

187278
set((state) => ({
@@ -217,7 +308,7 @@ export const useMCPStore = create<MCPStore>()(
217308
// are automatically re-discovered
218309
cleanups.push(
219310
client.onNotification(async (method) => {
220-
const c = clients.get(serverId);
311+
const c = globalClients.get(serverId);
221312
if (!c?.isConnected()) return;
222313

223314
try {
@@ -232,7 +323,7 @@ export const useMCPStore = create<MCPStore>()(
232323
})
233324
);
234325

235-
listenerCleanups.set(serverId, cleanups);
326+
globalListenerCleanups.set(serverId, cleanups);
236327

237328
// Connect to server
238329
await client.connect();
@@ -288,9 +379,10 @@ export const useMCPStore = create<MCPStore>()(
288379

289380
const newEnabled = !server.enabled;
290381

291-
// If disabling, disconnect
382+
// If disabling, disconnect global + all conversation clients
292383
if (!newEnabled && server.status === "connected") {
293384
get().disconnectServer(serverId);
385+
removeAllConversationClientsForServer(serverId);
294386
}
295387

296388
set((state) => ({
@@ -377,6 +469,11 @@ export const useMCPStore = create<MCPStore>()(
377469
removeClient(server.id);
378470
}
379471
}
472+
// Also disconnect all per-conversation clients
473+
for (const [, client] of conversationClients) {
474+
client.disconnect();
475+
}
476+
conversationClients.clear();
380477
set((state) => ({
381478
servers: state.servers.map((s) => ({
382479
...s,
@@ -390,6 +487,10 @@ export const useMCPStore = create<MCPStore>()(
390487
})),
391488
}));
392489
},
490+
491+
disconnectConversation: (conversationId) => {
492+
removeAllClientsForConversation(conversationId);
493+
},
393494
}),
394495
{
395496
name: "hadrian-mcp-servers",
@@ -488,18 +589,38 @@ export const useMCPErrors = () =>
488589
// Utility Functions
489590
// =============================================================================
490591

491-
/** Get MCPClient instance for a server (for making tool calls) */
492-
export function getMCPClient(serverId: string): MCPClient | undefined {
493-
return clients.get(serverId);
592+
/** Get MCPClient instance for a server (global or per-conversation) */
593+
export function getMCPClient(serverId: string, conversationId?: string): MCPClient | undefined {
594+
if (conversationId) {
595+
return conversationClients.get(clientKey(serverId, conversationId));
596+
}
597+
return globalClients.get(serverId);
494598
}
495599

496-
/** Call a tool on an MCP server, auto-reconnecting on session expiry */
600+
/**
601+
* Call a tool on an MCP server, auto-reconnecting on session expiry.
602+
* If conversationId is provided, uses a per-conversation session (lazily created).
603+
* Otherwise falls back to the global client.
604+
*/
497605
export async function callMCPTool(
606+
serverId: string,
607+
toolName: string,
608+
args?: Record<string, unknown>,
609+
conversationId?: string
610+
) {
611+
if (conversationId) {
612+
return callMCPToolWithConversationClient(serverId, toolName, args, conversationId);
613+
}
614+
return callMCPToolWithGlobalClient(serverId, toolName, args);
615+
}
616+
617+
/** Call a tool using the global client (existing behavior) */
618+
async function callMCPToolWithGlobalClient(
498619
serverId: string,
499620
toolName: string,
500621
args?: Record<string, unknown>
501622
) {
502-
const client = clients.get(serverId);
623+
const client = globalClients.get(serverId);
503624
if (!client) {
504625
throw new Error(`No client for server: ${serverId}`);
505626
}
@@ -515,8 +636,7 @@ export async function callMCPTool(
515636
console.debug("MCP session expired during tool call, reconnecting…");
516637
const store = useMCPStore.getState();
517638
await store.connectServer(serverId);
518-
// Retry with the (possibly new) client
519-
const newClient = clients.get(serverId);
639+
const newClient = globalClients.get(serverId);
520640
if (!newClient?.isConnected()) {
521641
throw new Error(`Reconnection failed for server: ${serverId}`);
522642
}
@@ -525,3 +645,25 @@ export async function callMCPTool(
525645
throw err;
526646
}
527647
}
648+
649+
/** Call a tool using a per-conversation client (lazy init, auto-reconnect) */
650+
async function callMCPToolWithConversationClient(
651+
serverId: string,
652+
toolName: string,
653+
args: Record<string, unknown> | undefined,
654+
conversationId: string
655+
) {
656+
let client = await ensureConversationClient(serverId, conversationId);
657+
658+
try {
659+
return await client.callTool(toolName, args);
660+
} catch (err) {
661+
if (err instanceof Error && err.message.includes("session expired")) {
662+
console.debug("MCP conversation session expired, reconnecting…");
663+
removeConversationClient(serverId, conversationId);
664+
client = await ensureConversationClient(serverId, conversationId);
665+
return client.callTool(toolName, args);
666+
}
667+
throw err;
668+
}
669+
}

0 commit comments

Comments
 (0)