diff --git a/ui/src/components/ChatView/ChatView.tsx b/ui/src/components/ChatView/ChatView.tsx index e40a662..3a81e04 100644 --- a/ui/src/components/ChatView/ChatView.tsx +++ b/ui/src/components/ChatView/ChatView.tsx @@ -4,7 +4,7 @@ import { ChatHeader } from "@/components/ChatHeader/ChatHeader"; import { ChatInput } from "@/components/ChatInput/ChatInput"; import { ChatMessageList } from "@/components/ChatMessageList/ChatMessageList"; import { ConversationSettingsModal } from "@/components/ConversationSettingsModal/ConversationSettingsModal"; -import { MCPConfigModal } from "@/components/MCPConfigModal"; +import { MCPConfigModal, type MCPServerPrefill } from "@/components/MCPConfigModal"; import type { ModelInfo } from "@/components/ModelSelector/ModelSelector"; import { useChatUIStore, @@ -30,7 +30,7 @@ import { useTotalUsage, useCurrentConversationForExport, } from "@/stores/conversationStore"; -import { useMemo, useCallback } from "react"; +import { useMemo, useCallback, useState, useEffect } from "react"; export interface ChatFile { id: string; @@ -115,6 +115,25 @@ export function ChatView({ const widescreenMode = useWidescreenMode(); const subAgentModel = useSubAgentModel(); const mcpConfigModalOpen = useMCPConfigModalOpen(); + const [mcpPrefill, setMcpPrefill] = useState(null); + + // Check for ?mcp_server_url= query param to auto-open the MCP config modal + useEffect(() => { + const params = new URLSearchParams(window.location.search); + const serverUrl = params.get("mcp_server_url"); + if (serverUrl) { + const serverName = params.get("mcp_server_name") ?? undefined; + setMcpPrefill({ url: serverUrl, name: serverName }); + setMCPConfigModalOpen(true); + // Clean the URL to prevent re-triggering + const cleanUrl = new URL(window.location.href); + cleanUrl.searchParams.delete("mcp_server_url"); + cleanUrl.searchParams.delete("mcp_server_name"); + window.history.replaceState({}, "", cleanUrl.toString()); + } + // eslint-disable-next-line react-hooks/exhaustive-deps -- only run on mount + }, []); + const { setSelectedInstances, updateInstance } = useConversationStore(); const { settingsModalOpen, @@ -268,7 +287,14 @@ export function ChatView({ /> {/* MCP Config Modal */} - setMCPConfigModalOpen(false)} /> + { + setMCPConfigModalOpen(false); + setMcpPrefill(null); + }} + prefill={mcpPrefill} + /> ); } diff --git a/ui/src/components/MCPConfigModal/MCPConfigModal.tsx b/ui/src/components/MCPConfigModal/MCPConfigModal.tsx index c64c8d0..2d8a4e7 100644 --- a/ui/src/components/MCPConfigModal/MCPConfigModal.tsx +++ b/ui/src/components/MCPConfigModal/MCPConfigModal.tsx @@ -14,15 +14,18 @@ import { zodResolver } from "@hookform/resolvers/zod"; import { z } from "zod"; import { AlertCircle, + AlertTriangle, CheckCircle2, ChevronDown, ChevronRight, Eye, EyeOff, + KeyRound, Loader2, Pencil, Plug, Plus, + ShieldCheck, Trash2, Wifi, Wrench, @@ -43,16 +46,35 @@ import { import { Switch } from "@/components/Switch/Switch"; import { cn } from "@/utils/cn"; import { useMCPStore, useMCPServers } from "@/stores/mcpStore"; -import { MCPClient, type MCPServerState, type MCPConnectionStatus } from "@/services/mcp"; +import { + MCPClient, + type MCPServerState, + type MCPConnectionStatus, + type MCPAuthType, + type MCPOAuthConfig, + startOAuthFlow, + getValidAccessToken, + hasValidTokens, + clearOAuthData, + detectServerAuth, +} from "@/services/mcp"; import type { MCPToolDefinition, JSONSchema } from "@/services/mcp"; // ============================================================================= // Types // ============================================================================= +/** Pre-fill data for adding a new server (e.g., from URL query params) */ +export interface MCPServerPrefill { + url: string; + name?: string; +} + export interface MCPConfigModalProps { open: boolean; onClose: () => void; + /** Pre-fill a new server (e.g., from ?mcp_server_url= query param) */ + prefill?: MCPServerPrefill | null; } // ============================================================================= @@ -62,7 +84,10 @@ export interface MCPConfigModalProps { const serverFormSchema = z.object({ name: z.string().min(1, "Name is required"), url: z.string().url("Must be a valid URL"), + authType: z.enum(["none", "bearer", "oauth"]), bearerToken: z.string(), + oauthClientId: z.string(), + oauthScopes: z.string(), headers: z.string(), timeout: z.number().int().min(1, "Must be at least 1 second"), }); @@ -213,6 +238,27 @@ function ServerCard({ server, onEdit, onDelete }: ServerCardProps) { const [expanded, setExpanded] = useState(false); const { connectServer, disconnectServer, setToolEnabled } = useMCPStore(); const [isToggling, setIsToggling] = useState(false); + const [isAuthorizing, setIsAuthorizing] = useState(false); + const [authError, setAuthError] = useState(); + const oauthAuthorized = server.authType === "oauth" && hasValidTokens(server.url); + + const handleAuthorize = useCallback(async () => { + setIsAuthorizing(true); + setAuthError(undefined); + try { + await startOAuthFlow(server.url, server.oauth); + // Tokens obtained — now connect + try { + await connectServer(server.id); + } catch { + // Connection error stored in server state + } + } catch (err) { + setAuthError(err instanceof Error ? err.message : String(err)); + } finally { + setIsAuthorizing(false); + } + }, [server.url, server.oauth, server.id, connectServer]); // Unified toggle: switch ON = enable + connect, switch OFF = disconnect + disable const handleToggle = useCallback(async () => { @@ -246,6 +292,7 @@ function ServerCard({ server, onEdit, onDelete }: ServerCardProps) { const isConnectingStatus = server.status === "connecting" || isToggling; const hasTools = server.tools.length > 0; + const needsAuthorization = server.authType === "oauth" && !oauthAuthorized; return (
@@ -299,7 +346,7 @@ function ServerCard({ server, onEdit, onDelete }: ServerCardProps) { {server.error}
)} + + {/* OAuth status & authorize button */} + {server.authType === "oauth" && ( +
+ {oauthAuthorized ? ( + + + Authorized + + ) : ( + + + Authorization required to connect + + )} + + {authError && ( +
+ + {authError} +
+ )} +
+ )} {/* Tools list (expandable) */} @@ -370,12 +457,17 @@ interface ServerFormProps { editingServer?: MCPServerState | null; onSubmit: (values: ServerFormValues) => void; onCancel: () => void; + /** Pre-fill data (e.g., from URL query params) */ + prefill?: MCPServerPrefill | null; } type TestStatus = "idle" | "testing" | "success" | "error"; -function ServerForm({ editingServer, onSubmit, onCancel }: ServerFormProps) { +type OAuthStatus = "idle" | "authorizing" | "authorized" | "error"; + +function ServerForm({ editingServer, onSubmit, onCancel, prefill }: ServerFormProps) { const [showToken, setShowToken] = useState(false); + const isNewServer = !editingServer; // Extract bearer token from existing headers, pass the rest as extra headers const existingHeaders = editingServer?.headers ?? {}; @@ -387,12 +479,19 @@ function ServerForm({ editingServer, onSubmit, onCancel }: ServerFormProps) { Object.entries(existingHeaders).filter(([k]) => k.toLowerCase() !== "authorization") ); + // Infer initial auth type from existing config + const initialAuthType: MCPAuthType = + editingServer?.authType ?? (existingBearer ? "bearer" : "none"); + const form = useForm({ resolver: zodResolver(serverFormSchema), defaultValues: { - name: editingServer?.name ?? "", - url: editingServer?.url ?? "", + name: editingServer?.name ?? prefill?.name ?? "", + url: editingServer?.url ?? prefill?.url ?? "", + authType: initialAuthType, bearerToken: existingBearer, + oauthClientId: editingServer?.oauth?.clientId ?? "", + oauthScopes: editingServer?.oauth?.scopes ?? "", headers: Object.keys(extraHeaders).length > 0 ? JSON.stringify(extraHeaders, null, 2) : "", timeout: Math.round((editingServer?.timeout ?? 300000) / 1000), }, @@ -402,9 +501,27 @@ function ServerForm({ editingServer, onSubmit, onCancel }: ServerFormProps) { const [testMessage, setTestMessage] = useState(); const [testLatency, setTestLatency] = useState(); - // Reset test results when URL or headers change + // OAuth state + const [oauthStatus, setOauthStatus] = useState(() => + initialAuthType === "oauth" && editingServer?.url && hasValidTokens(editingServer.url) + ? "authorized" + : "idle" + ); + const [oauthError, setOauthError] = useState(); + + // Auth detection state + type DetectionStatus = "idle" | "detecting" | "detected"; + const [detectionStatus, setDetectionStatus] = useState("idle"); + const [detectionMessage, setDetectionMessage] = useState(""); + // Track whether user manually changed auth type (disables auto-select) + const [userOverrodeAuth, setUserOverrodeAuth] = useState(false); + + // Watched form values const watchedUrl = form.watch("url"); const watchedHeaders = form.watch("headers"); + const watchedAuthType = form.watch("authType") as MCPAuthType; + + // Reset test results when URL or headers change useEffect(() => { if (testStatus !== "idle" && testStatus !== "testing") { setTestStatus("idle"); @@ -414,13 +531,92 @@ function ServerForm({ editingServer, onSubmit, onCancel }: ServerFormProps) { // eslint-disable-next-line react-hooks/exhaustive-deps -- only reset on field changes }, [watchedUrl, watchedHeaders]); + // Reset OAuth status when auth type or URL changes + useEffect(() => { + if (watchedAuthType === "oauth" && watchedUrl) { + setOauthStatus(hasValidTokens(watchedUrl) ? "authorized" : "idle"); + setOauthError(undefined); + } else { + setOauthStatus("idle"); + setOauthError(undefined); + } + }, [watchedAuthType, watchedUrl]); + + // Auto-detect auth requirements when URL changes (new servers only) + useEffect(() => { + if (!isNewServer || !watchedUrl || userOverrodeAuth) { + setDetectionStatus("idle"); + setDetectionMessage(""); + return; + } + + // Validate URL before probing + if (!z.string().url().safeParse(watchedUrl).success) { + setDetectionStatus("idle"); + setDetectionMessage(""); + return; + } + + setDetectionStatus("detecting"); + setDetectionMessage(""); + + let cancelled = false; + const timer = setTimeout(() => { + detectServerAuth(watchedUrl).then((result) => { + if (cancelled) return; + setDetectionStatus("detected"); + setDetectionMessage(result.message); + if (result.authType !== watchedAuthType) { + form.setValue("authType", result.authType); + } + // Pre-fill server name from resource metadata if the field is still empty + if (result.serverName && !form.getValues("name")) { + form.setValue("name", result.serverName); + } + }); + }, 600); + + return () => { + clearTimeout(timer); + cancelled = true; + }; + // eslint-disable-next-line react-hooks/exhaustive-deps -- only re-run on URL change + }, [watchedUrl, isNewServer, userOverrodeAuth]); + + const handleAuthorize = useCallback(async () => { + const valid = await form.trigger("url"); + if (!valid) return; + + const url = form.getValues("url"); + const clientId = form.getValues("oauthClientId") || undefined; + const scopes = form.getValues("oauthScopes") || undefined; + + setOauthStatus("authorizing"); + setOauthError(undefined); + + try { + await startOAuthFlow(url, { clientId, scopes }); + setOauthStatus("authorized"); + } catch (err) { + setOauthStatus("error"); + setOauthError(err instanceof Error ? err.message : String(err)); + } + }, [form]); + + const handleRevoke = useCallback(() => { + const url = form.getValues("url"); + if (url) clearOAuthData(url); + setOauthStatus("idle"); + }, [form]); + const handleTestConnection = useCallback(async () => { // Validate URL field first const valid = await form.trigger("url"); if (!valid) return; const values = form.getValues(); - let headers: Record | undefined; + + // Parse extra headers const extra: Record = {}; if (values.headers) { try { @@ -431,18 +627,31 @@ function ServerForm({ editingServer, onSubmit, onCancel }: ServerFormProps) { return; } } - if (values.bearerToken || Object.keys(extra).length > 0) { - headers = { ...extra }; - if (values.bearerToken) { - headers["Authorization"] = `Bearer ${values.bearerToken}`; - } + + // Build client config based on auth type + const headers: Record = { ...extra }; + let getAccessTokenFn: (() => Promise) | undefined; + + if (values.authType === "bearer" && values.bearerToken) { + headers["Authorization"] = `Bearer ${values.bearerToken}`; + } else if (values.authType === "oauth") { + const oauthCfg: MCPOAuthConfig = { + clientId: values.oauthClientId || undefined, + scopes: values.oauthScopes || undefined, + }; + getAccessTokenFn = () => getValidAccessToken(values.url, oauthCfg); } setTestStatus("testing"); setTestMessage(undefined); setTestLatency(undefined); - const client = new MCPClient({ url: values.url, headers, timeout: 10000 }); + const client = new MCPClient({ + url: values.url, + headers: Object.keys(headers).length > 0 ? headers : undefined, + timeout: 10000, + getAccessToken: getAccessTokenFn, + }); const start = performance.now(); try { @@ -468,16 +677,21 @@ function ServerForm({ editingServer, onSubmit, onCancel }: ServerFormProps) { onSubmit(values); }); + const authTypeOptions = [ + { value: "none" as const, label: "None" }, + { value: "bearer" as const, label: "Bearer Token" }, + { value: "oauth" as const, label: "OAuth (PKCE)" }, + ]; + return (
- - - + {/* Warning banner when pre-filled from a URL param */} + {prefill && ( +
+ + Server URL provided via link. Only add servers you trust. +
+ )} -
- - + + + + {/* Auth detection indicator */} + {detectionStatus === "detecting" && ( +
+ + Checking authentication requirements... +
+ )} + {detectionStatus === "detected" && detectionMessage && ( +
+ + {detectionMessage} +
+ )} + + {/* Auth type selector */} + +
+ {authTypeOptions.map((opt) => ( + + ))}
+ {/* Bearer Token fields */} + {watchedAuthType === "bearer" && ( + +
+ + +
+
+ )} + + {/* OAuth fields */} + {watchedAuthType === "oauth" && ( +
+ {/* Prominent authorize CTA */} + {oauthStatus === "authorized" ? ( +
+ + + Authorized + + +
+ ) : ( +
+
+ +
+
Authorization required
+
+ Click Authorize to sign in and grant access. You won't be able to test or + add the server until this step is completed. +
+
+
+
+ + {oauthStatus === "error" && oauthError && ( +
+ + {oauthError} +
+ )} +
+
+ )} + + {/* Advanced OAuth fields */} +
+ + Advanced options + +
+ + + + + + + +
+
+
+ )} +