Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
import type { ConversationWithProject, Message } from "@/api/generated/types.gen";
import { useAuth } from "@/auth";
import { useIndexedDB } from "@/hooks/useIndexedDB";
import { useMCPStore } from "@/stores/mcpStore";

import type { ChatMessage, Conversation } from "@/components/chat-types";
import { usePreferences } from "@/preferences/PreferencesProvider";
Expand Down Expand Up @@ -579,6 +580,9 @@ export function ConversationsProvider({ children }: ConversationsProviderProps)

const deleteConversation = useCallback(
(id: string) => {
// Clean up any per-conversation MCP sessions
useMCPStore.getState().disconnectConversation(id);

// Find the conversation to get remoteId before deleting
const conv = storedConversations.find((c) => c.id === id);
if (conv?.remoteId && canSync) {
Expand Down
1 change: 1 addition & 0 deletions ui/src/pages/chat/ChatPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ export default function ChatPage() {
captureRawSSEEvents,
subAgentModel,
projectId: currentConversation?.projectId ?? pendingProject.id ?? undefined,
conversationId,
});

const { moveToProject } = useConversationsContext();
Expand Down
6 changes: 6 additions & 0 deletions ui/src/pages/chat/useChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ interface UseChatOptions {
subAgentModel?: string | null;
/** Project ID for usage attribution (sent as X-Hadrian-Project header) */
projectId?: string;
/** Conversation ID for per-conversation MCP sessions */
conversationId?: string;
}

/**
Expand Down Expand Up @@ -255,6 +257,7 @@ export function useChat({
captureRawSSEEvents = false,
subAgentModel,
projectId,
conversationId,
}: UseChatOptions): UseChatReturn {
const { token } = useAuth();
const abortControllersRef = useRef<AbortController[]>([]);
Expand All @@ -267,6 +270,8 @@ export function useChat({
// and used as a ref to ensure the latest value is available at fetch time.
const projectIdRef = useRef(projectId);
projectIdRef.current = projectId;
const conversationIdRef = useRef(conversationId);
conversationIdRef.current = conversationId;
const streamingStore = useStreamingStore();
const debugStore = useDebugStore();
const modelResponses = useAllStreams();
Expand Down Expand Up @@ -1482,6 +1487,7 @@ export function useChat({
},
// Use configured sub-agent model, fall back to current streaming model
defaultModel: subAgentModel || model,
conversationId: conversationIdRef.current,
};

const toolResults = await executeToolCalls(result.toolCalls, toolContext);
Expand Down
4 changes: 3 additions & 1 deletion ui/src/pages/chat/utils/toolExecutors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ export interface ToolExecutorContext {
onStatusMessage?: (toolCallId: string, message: string) => void;
/** Default model for sub-agent tool when no model is specified in arguments */
defaultModel?: string;
/** Conversation ID for per-conversation MCP sessions */
conversationId?: string;
}

/**
Expand Down Expand Up @@ -2423,7 +2425,7 @@ const mcpToolExecutor: ToolExecutor = async (toolCall, context) => {
}

try {
const result = await callMCPTool(serverId, toolName, args);
const result = await callMCPTool(serverId, toolName, args, context.conversationId);

// Clear status message
context.onStatusMessage?.(toolCall.id, "");
Expand Down
199 changes: 177 additions & 22 deletions ui/src/stores/mcpStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ interface MCPActions {
ensureConnected: () => Promise<void>;
/** Disconnect all servers */
disconnectAll: () => void;
/** Disconnect all per-conversation MCP sessions for a conversation */
disconnectConversation: (conversationId: string) => void;
}

export type MCPStore = MCPState & MCPActions;
Expand All @@ -84,43 +86,143 @@ export type MCPStore = MCPState & MCPActions;
// Client Management (outside Zustand for reference stability)
// =============================================================================

/** Map of server ID to MCPClient instance */
const clients = new Map<string, MCPClient>();
/** Global clients: Map of server ID to MCPClient instance (for discovery/status) */
const globalClients = new Map<string, MCPClient>();

/** Map of server ID to cleanup functions for listeners */
const listenerCleanups = new Map<string, Array<() => void>>();
/** Global listener cleanups: Map of server ID to cleanup functions */
const globalListenerCleanups = new Map<string, Array<() => void>>();

/** Get or create a client for a server */
/** Per-conversation clients: Map of `${serverId}::${conversationId}` to MCPClient */
const conversationClients = new Map<string, MCPClient>();

/** Dedup map to prevent concurrent lazy-connect calls for the same key */
const connectingPromises = new Map<string, Promise<MCPClient>>();

/** Composite key for conversation clients */
function clientKey(serverId: string, conversationId: string): string {
return `${serverId}::${conversationId}`;
}

/** Get or create a global client for a server */
function getClient(server: MCPServerConfig): MCPClient {
let client = clients.get(server.id);
let client = globalClients.get(server.id);
if (!client) {
client = new MCPClient({
url: server.url,
name: server.name,
headers: server.headers,
timeout: server.timeout,
});
globalClients.set(server.id, client);
}
return client;
}

/** Get or create a client for a per-conversation session */
function getConversationClient(server: MCPServerConfig, conversationId: string): MCPClient {
const key = clientKey(server.id, conversationId);
let client = conversationClients.get(key);
if (!client) {
client = new MCPClient({
url: server.url,
name: server.name,
headers: server.headers,
timeout: server.timeout,
});
clients.set(server.id, client);
conversationClients.set(key, client);
}
return client;
}

/** Get or connect a per-conversation client, deduplicating concurrent calls */
async function ensureConversationClient(
serverId: string,
conversationId: string
): Promise<MCPClient> {
const key = clientKey(serverId, conversationId);
const existing = conversationClients.get(key);
if (existing?.isConnected()) return existing;

const pending = connectingPromises.get(key);
if (pending) return pending;

const promise = (async () => {
const server = useMCPStore.getState().servers.find((s) => s.id === serverId);
if (!server) throw new Error(`Server not found: ${serverId}`);
const client = getConversationClient(server, conversationId);
await client.connect();
return client;
})();

connectingPromises.set(key, promise);
try {
return await promise;
} finally {
connectingPromises.delete(key);
}
}

/** Disconnect and remove a single per-conversation client */
function removeConversationClient(serverId: string, conversationId: string): void {
const key = clientKey(serverId, conversationId);
const client = conversationClients.get(key);
if (client) {
client.disconnect();
conversationClients.delete(key);
}
connectingPromises.delete(key);
}

/** Disconnect all per-conversation clients for a given conversation */
function removeAllClientsForConversation(conversationId: string): void {
const suffix = `::${conversationId}`;
// Deleting from a Map during iteration is safe in ES6+
for (const [key, client] of conversationClients) {
if (key.endsWith(suffix)) {
client.disconnect();
conversationClients.delete(key);
}
}
Comment thread
greptile-apps[bot] marked this conversation as resolved.
for (const key of connectingPromises.keys()) {
if (key.endsWith(suffix)) {
connectingPromises.delete(key);
}
}
}

/** Disconnect all per-conversation clients for a given server */
function removeAllConversationClientsForServer(serverId: string): void {
const prefix = `${serverId}::`;
// Deleting from a Map during iteration is safe in ES6+
for (const [key, client] of conversationClients) {
if (key.startsWith(prefix)) {
client.disconnect();
conversationClients.delete(key);
}
}
for (const key of connectingPromises.keys()) {
if (key.startsWith(prefix)) {
connectingPromises.delete(key);
}
}
}
Comment thread
greptile-apps[bot] marked this conversation as resolved.

/** Remove listener subscriptions for a server */
function cleanupListeners(serverId: string): void {
const cleanups = listenerCleanups.get(serverId);
const cleanups = globalListenerCleanups.get(serverId);
if (cleanups) {
cleanups.forEach((fn) => fn());
listenerCleanups.delete(serverId);
globalListenerCleanups.delete(serverId);
}
}

/** Remove and disconnect a client */
/** Remove and disconnect a global client */
function removeClient(serverId: string): void {
cleanupListeners(serverId);
const client = clients.get(serverId);
const client = globalClients.get(serverId);
if (client) {
client.disconnect();
clients.delete(serverId);
globalClients.delete(serverId);
}
}

Expand Down Expand Up @@ -172,6 +274,7 @@ export const useMCPStore = create<MCPStore>()(

removeServer: (serverId) => {
removeClient(serverId);
removeAllConversationClientsForServer(serverId);
set((state) => ({
servers: state.servers.filter((s) => s.id !== serverId),
}));
Expand All @@ -182,6 +285,7 @@ export const useMCPStore = create<MCPStore>()(
const server = get().servers.find((s) => s.id === serverId);
if (server && (updates.url || updates.headers || updates.timeout)) {
removeClient(serverId);
removeAllConversationClientsForServer(serverId);
}

set((state) => ({
Expand Down Expand Up @@ -217,7 +321,7 @@ export const useMCPStore = create<MCPStore>()(
// are automatically re-discovered
cleanups.push(
client.onNotification(async (method) => {
const c = clients.get(serverId);
const c = globalClients.get(serverId);
if (!c?.isConnected()) return;

try {
Expand All @@ -232,7 +336,7 @@ export const useMCPStore = create<MCPStore>()(
})
);

listenerCleanups.set(serverId, cleanups);
globalListenerCleanups.set(serverId, cleanups);

// Connect to server
await client.connect();
Expand Down Expand Up @@ -288,9 +392,10 @@ export const useMCPStore = create<MCPStore>()(

const newEnabled = !server.enabled;

// If disabling, disconnect
// If disabling, disconnect global + all conversation clients
if (!newEnabled && server.status === "connected") {
get().disconnectServer(serverId);
removeAllConversationClientsForServer(serverId);
}

set((state) => ({
Expand Down Expand Up @@ -377,6 +482,11 @@ export const useMCPStore = create<MCPStore>()(
removeClient(server.id);
}
}
// Also disconnect all per-conversation clients
for (const [, client] of conversationClients) {
client.disconnect();
}
conversationClients.clear();
Comment thread
greptile-apps[bot] marked this conversation as resolved.
set((state) => ({
servers: state.servers.map((s) => ({
...s,
Expand All @@ -390,6 +500,10 @@ export const useMCPStore = create<MCPStore>()(
})),
}));
},

disconnectConversation: (conversationId) => {
removeAllClientsForConversation(conversationId);
},
}),
{
name: "hadrian-mcp-servers",
Expand Down Expand Up @@ -488,18 +602,38 @@ export const useMCPErrors = () =>
// Utility Functions
// =============================================================================

/** Get MCPClient instance for a server (for making tool calls) */
export function getMCPClient(serverId: string): MCPClient | undefined {
return clients.get(serverId);
/** Get MCPClient instance for a server (global or per-conversation) */
export function getMCPClient(serverId: string, conversationId?: string): MCPClient | undefined {
if (conversationId) {
return conversationClients.get(clientKey(serverId, conversationId));
}
return globalClients.get(serverId);
}

/** Call a tool on an MCP server, auto-reconnecting on session expiry */
/**
* Call a tool on an MCP server, auto-reconnecting on session expiry.
* If conversationId is provided, uses a per-conversation session (lazily created).
* Otherwise falls back to the global client.
*/
export async function callMCPTool(
serverId: string,
toolName: string,
args?: Record<string, unknown>,
conversationId?: string
) {
if (conversationId) {
return callMCPToolWithConversationClient(serverId, toolName, args, conversationId);
}
return callMCPToolWithGlobalClient(serverId, toolName, args);
}

/** Call a tool using the global client (existing behavior) */
async function callMCPToolWithGlobalClient(
serverId: string,
toolName: string,
args?: Record<string, unknown>
) {
const client = clients.get(serverId);
const client = globalClients.get(serverId);
if (!client) {
throw new Error(`No client for server: ${serverId}`);
}
Expand All @@ -515,8 +649,7 @@ export async function callMCPTool(
console.debug("MCP session expired during tool call, reconnecting…");
const store = useMCPStore.getState();
await store.connectServer(serverId);
// Retry with the (possibly new) client
const newClient = clients.get(serverId);
const newClient = globalClients.get(serverId);
if (!newClient?.isConnected()) {
throw new Error(`Reconnection failed for server: ${serverId}`);
}
Expand All @@ -525,3 +658,25 @@ export async function callMCPTool(
throw err;
}
}

/** Call a tool using a per-conversation client (lazy init, auto-reconnect) */
async function callMCPToolWithConversationClient(
serverId: string,
toolName: string,
args: Record<string, unknown> | undefined,
conversationId: string
) {
let client = await ensureConversationClient(serverId, conversationId);

try {
return await client.callTool(toolName, args);
} catch (err) {
if (err instanceof Error && err.message.includes("session expired")) {
console.debug("MCP conversation session expired, reconnecting…");
removeConversationClient(serverId, conversationId);
client = await ensureConversationClient(serverId, conversationId);
return client.callTool(toolName, args);
}
throw err;
}
}
Loading