diff --git a/example-docker-compose.yaml b/example-docker-compose.yaml index 0f81e06..1a0f5e9 100644 --- a/example-docker-compose.yaml +++ b/example-docker-compose.yaml @@ -306,6 +306,8 @@ services: env_file: - /root/.env environment: + MCP_OAUTH_ISSUER_URL: ${MCP_OAUTH_ISSUER_URL:-} + MCP_RESOURCE_SERVER_URL: ${MCP_RESOURCE_SERVER_URL:-} BRAINAPI_PLUGINS: ${BRAINAPI_PLUGINS:-} PLUGIN_REGISTRY_URL: ${PLUGIN_REGISTRY_URL:-} PLUGIN_PUBLISHER_ID: ${PLUGIN_PUBLISHER_ID:-} diff --git a/src/services/mcp/app.py b/src/services/mcp/app.py index 01f19e2..1f04f82 100644 --- a/src/services/mcp/app.py +++ b/src/services/mcp/app.py @@ -1,5 +1,6 @@ import logging import os +from contextlib import asynccontextmanager from pathlib import Path import dotenv @@ -11,7 +12,7 @@ _project_root = Path(__file__).resolve().parent.parent.parent.parent dotenv.load_dotenv(_project_root / ".env") -from src.services.mcp.main import auth_token_var, mcp +from src.services.mcp.main import auth_token_var, mcp, oauth_provider PLUGINS_DIR = Path(os.getenv("PLUGINS_DIR", str(_project_root / "plugins"))) @@ -92,10 +93,17 @@ async def __call__(self, scope, receive, send): token = brainpat.decode() else: raw = (headers.get(b"authorization") or b"").decode() + bearer = None if raw.startswith("Bearer: "): - token = raw.removeprefix("Bearer: ").strip() or None + bearer = raw.removeprefix("Bearer: ").strip() or None elif raw.startswith("Bearer "): - token = raw.removeprefix("Bearer ").strip() or None + bearer = raw.removeprefix("Bearer ").strip() or None + if bearer: + if oauth_provider: + pat = oauth_provider.get_pat_for_access_token(bearer) + token = pat if pat else bearer + else: + token = bearer auth_token_var.set(token) await self.app(scope, receive, send) @@ -105,10 +113,21 @@ async def _health(_request): async def _mcp_info(_request): - return JSONResponse( - {"service": "brainapi-mcp", "streamable_http": True, "path": "/mcp"}, - status_code=200, - ) + body = { + "service": "brainapi-mcp", + "streamable_http": True, + "path": "/mcp", + } + if oauth_provider: + body["oauth"] = True + body["oauth_consent_path"] = "/mcp-oauth/consent" + return JSONResponse(body, status_code=200) + + +@asynccontextmanager +async def _lifespan(app): + async with _mcp_app.router.lifespan_context(_mcp_app): + yield _custom_routes = [ @@ -119,5 +138,5 @@ async def _mcp_info(_request): app = Starlette( routes=_custom_routes + list(_mcp_app.routes), middleware=[Middleware(AuthContextMiddleware)], - lifespan=_mcp_app.router.lifespan_context, + lifespan=_lifespan, ) diff --git a/src/services/mcp/main.py b/src/services/mcp/main.py index 724d844..e694251 100644 --- a/src/services/mcp/main.py +++ b/src/services/mcp/main.py @@ -10,10 +10,16 @@ """ import asyncio +import html +import os from contextvars import ContextVar from typing import Any from mcp.server import FastMCP +from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions +from pydantic import AnyHttpUrl, AnyUrl +from starlette.requests import Request +from starlette.responses import HTMLResponse, RedirectResponse from src.core.instances import ( data_adapter, @@ -22,13 +28,135 @@ vector_store_adapter, ) from src.lib.neo4j.client import _neo4j_client +from src.services.mcp.oauth_provider import BrainapiMcpOAuthProvider from src.services.mcp.utils import guard_brainpat from src.utils.vector_search import VectorSearchFacade auth_token_var: ContextVar[str | None] = ContextVar("auth_token", default=None) vector_search = VectorSearchFacade(vector_store_adapter) -mcp = FastMCP("brainapi-mcp", stateless_http=True, host="0.0.0.0") +_oauth_issuer = os.getenv("MCP_OAUTH_ISSUER_URL", "").strip() +_oauth_resource = os.getenv("MCP_RESOURCE_SERVER_URL", "").strip() +if _oauth_issuer and not _oauth_resource: + _oauth_resource = _oauth_issuer.rstrip("/") + "/mcp" + +_oauth_scopes = [ + s for s in os.getenv("MCP_OAUTH_SCOPES", "brainapi").strip().split() if s +] +if not _oauth_scopes: + _oauth_scopes = ["brainapi"] + +_access_ttl = int(os.getenv("MCP_OAUTH_ACCESS_TOKEN_TTL", "3600")) +_refresh_ttl_raw = os.getenv("MCP_OAUTH_REFRESH_TOKEN_TTL", "").strip() +_refresh_ttl = int(_refresh_ttl_raw) if _refresh_ttl_raw else None +_code_ttl = int(os.getenv("MCP_OAUTH_AUTH_CODE_TTL", "600")) + +oauth_provider: BrainapiMcpOAuthProvider | None = None +if _oauth_issuer: + oauth_provider = BrainapiMcpOAuthProvider( + issuer_url=_oauth_issuer, + resource_server_url=_oauth_resource, + valid_scopes=_oauth_scopes, + access_token_ttl_seconds=_access_ttl, + refresh_token_ttl_seconds=_refresh_ttl, + auth_code_ttl_seconds=_code_ttl, + ) + _doc_url = os.getenv("MCP_OAUTH_SERVICE_DOCUMENTATION_URL", "").strip() + mcp = FastMCP( + "brainapi-mcp", + stateless_http=True, + host="0.0.0.0", + auth_server_provider=oauth_provider, + auth=AuthSettings( + issuer_url=AnyHttpUrl(_oauth_issuer), + resource_server_url=AnyHttpUrl(_oauth_resource), + service_documentation_url=AnyHttpUrl(_doc_url) if _doc_url else None, + client_registration_options=ClientRegistrationOptions( + enabled=True, + valid_scopes=_oauth_scopes, + default_scopes=_oauth_scopes, + ), + ), + ) +else: + mcp = FastMCP("brainapi-mcp", stateless_http=True, host="0.0.0.0") + + +if oauth_provider: + + async def _mcp_oauth_consent(request: Request) -> HTMLResponse | RedirectResponse: + if request.method == "GET": + q = request.query_params + client_id = q.get("client_id") or "" + redirect_uri = q.get("redirect_uri") or "" + code_challenge = q.get("code_challenge") or "" + scope = q.get("scope") or " ".join(_oauth_scopes) + resource = q.get("resource") or "" + state = q.get("state") or "" + if not (client_id and redirect_uri and code_challenge): + return HTMLResponse("Missing OAuth parameters", status_code=400) + client = await oauth_provider.get_client(client_id) + if not client: + return HTMLResponse("Unknown client_id", status_code=400) + esc = html.escape + form = f""" +
Enter your BrainPAT (personal access token for your brain). This authorizes the MCP client to act with that token.
+ +""" + return HTMLResponse(form) + + form = await request.form() + client_id = str(form.get("client_id") or "") + redirect_uri = str(form.get("redirect_uri") or "") + code_challenge = str(form.get("code_challenge") or "") + scope_str = str(form.get("scope") or "") + resource = str(form.get("resource") or "") or None + state = str(form.get("state") or "") or None + brainpat = str(form.get("brainpat") or "").strip() + if not (client_id and redirect_uri and code_challenge and brainpat): + return HTMLResponse("Missing fields", status_code=400) + client = await oauth_provider.get_client(client_id) + if not client: + return HTMLResponse("Unknown client", status_code=400) + if guard_brainpat(brainpat) is False: + return HTMLResponse("Invalid BrainPAT", status_code=401) + scopes = scope_str.split() if scope_str.strip() else list(_oauth_scopes) + for s in scopes: + if s not in _oauth_scopes: + return HTMLResponse("Invalid scope", status_code=400) + try: + ru = AnyUrl(redirect_uri) + ru = client.validate_redirect_uri(ru) + except Exception: + return HTMLResponse("Invalid redirect_uri", status_code=400) + code = oauth_provider.issue_auth_code( + client_id=client_id, + redirect_uri=ru, + code_challenge=code_challenge, + scopes=scopes, + resource=resource, + state=state, + brainpat=brainpat, + ) + url = oauth_provider.redirect_after_consent( + redirect_uri=redirect_uri, code=code, state=state + ) + return RedirectResponse(url, status_code=302) + + mcp.custom_route("/mcp-oauth/consent", methods=["GET", "POST"])(_mcp_oauth_consent) @mcp.tool() diff --git a/src/services/mcp/oauth_provider.py b/src/services/mcp/oauth_provider.py new file mode 100644 index 0000000..ccfd221 --- /dev/null +++ b/src/services/mcp/oauth_provider.py @@ -0,0 +1,236 @@ +import secrets +import time +from urllib.parse import urlencode + +from pydantic import AnyUrl + +from mcp.server.auth.provider import ( + AccessToken, + AuthorizationCode, + AuthorizationParams, + AuthorizeError, + OAuthAuthorizationServerProvider, + RefreshToken, + TokenError, + construct_redirect_uri, +) +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + + +class BrainapiMcpOAuthProvider( + OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken] +): + def __init__( + self, + *, + issuer_url: str, + resource_server_url: str, + valid_scopes: list[str], + access_token_ttl_seconds: int, + refresh_token_ttl_seconds: int | None, + auth_code_ttl_seconds: int, + ): + self._issuer_url = issuer_url.rstrip("/") + self._resource_server_url = resource_server_url.rstrip("/") + self._valid_scopes = valid_scopes + self._access_token_ttl = access_token_ttl_seconds + self._refresh_token_ttl = refresh_token_ttl_seconds + self._auth_code_ttl = auth_code_ttl_seconds + self._clients: dict[str, OAuthClientInformationFull] = {} + self._auth_codes: dict[str, AuthorizationCode] = {} + self._code_to_pat: dict[str, str] = {} + self._refresh_tokens: dict[str, RefreshToken] = {} + self._refresh_to_pat: dict[str, str] = {} + self._access_tokens: dict[str, AccessToken] = {} + self._pat_by_access: dict[str, str] = {} + + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: + return self._clients.get(client_id) + + async def register_client(self, client_info: OAuthClientInformationFull) -> None: + cid = client_info.client_id + if not cid: + raise ValueError("client_id required") + self._clients[cid] = client_info + + async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: + if params.resource and params.resource.rstrip("/") != self._resource_server_url: + raise AuthorizeError( + error="invalid_request", + error_description="resource must match this MCP server URL", + ) + scopes = params.scopes + if scopes is None: + scopes = list(self._valid_scopes) + for s in scopes: + if s not in self._valid_scopes: + raise AuthorizeError(error="invalid_scope", error_description=f"unknown scope: {s}") + + consent_base = f"{self._issuer_url}/mcp-oauth/consent" + q = { + "client_id": client.client_id or "", + "redirect_uri": str(params.redirect_uri), + "code_challenge": params.code_challenge, + "scope": " ".join(scopes), + } + if params.resource: + q["resource"] = params.resource + if params.state: + q["state"] = params.state + return f"{consent_base}?{urlencode(q)}" + + async def load_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> AuthorizationCode | None: + return self._auth_codes.get(authorization_code) + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode + ) -> OAuthToken: + self._auth_codes.pop(authorization_code.code, None) + brainpat = self._code_to_pat.pop(authorization_code.code, None) + if not brainpat: + raise TokenError(error="invalid_grant", error_description="missing credentials") + + access = secrets.token_urlsafe(32) + refresh = secrets.token_urlsafe(48) + now = int(time.time()) + exp_access = now + self._access_token_ttl + exp_refresh = now + self._refresh_token_ttl if self._refresh_token_ttl else None + + self._pat_by_access[access] = brainpat + at = AccessToken( + token=access, + client_id=client.client_id or "", + scopes=authorization_code.scopes, + expires_at=exp_access, + resource=self._resource_server_url, + ) + self._access_tokens[access] = at + rt = RefreshToken( + token=refresh, + client_id=client.client_id or "", + scopes=authorization_code.scopes, + expires_at=exp_refresh, + ) + self._refresh_tokens[refresh] = rt + self._refresh_to_pat[refresh] = brainpat + return OAuthToken( + access_token=access, + token_type="Bearer", + expires_in=self._access_token_ttl, + refresh_token=refresh, + scope=" ".join(authorization_code.scopes), + ) + + async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: + rt = self._refresh_tokens.get(refresh_token) + if rt is None or rt.client_id != (client.client_id or ""): + return None + return rt + + async def exchange_refresh_token( + self, + client: OAuthClientInformationFull, + refresh_token: RefreshToken, + scopes: list[str], + ) -> OAuthToken: + brainpat = self._refresh_to_pat.pop(refresh_token.token, None) + self._refresh_tokens.pop(refresh_token.token, None) + if not brainpat: + raise TokenError(error="invalid_grant", error_description="refresh token not found") + + for t, v in list(self._access_tokens.items()): + if v.client_id == (client.client_id or ""): + self._access_tokens.pop(t, None) + self._pat_by_access.pop(t, None) + + access = secrets.token_urlsafe(32) + new_refresh = secrets.token_urlsafe(48) + now = int(time.time()) + exp_access = now + self._access_token_ttl + exp_refresh = now + self._refresh_token_ttl if self._refresh_token_ttl else None + + self._pat_by_access[access] = brainpat + at = AccessToken( + token=access, + client_id=client.client_id or "", + scopes=scopes, + expires_at=exp_access, + resource=self._resource_server_url, + ) + self._access_tokens[access] = at + rt = RefreshToken( + token=new_refresh, + client_id=client.client_id or "", + scopes=scopes, + expires_at=exp_refresh, + ) + self._refresh_tokens[new_refresh] = rt + self._refresh_to_pat[new_refresh] = brainpat + return OAuthToken( + access_token=access, + token_type="Bearer", + expires_in=self._access_token_ttl, + refresh_token=new_refresh, + scope=" ".join(scopes), + ) + + async def load_access_token(self, token: str) -> AccessToken | None: + at = self._access_tokens.get(token) + if at is None: + return None + if at.expires_at and at.expires_at < int(time.time()): + self._access_tokens.pop(token, None) + self._pat_by_access.pop(token, None) + return None + return at + + async def revoke_token(self, token: AccessToken | RefreshToken) -> None: + if isinstance(token, AccessToken): + self._access_tokens.pop(token.token, None) + self._pat_by_access.pop(token.token, None) + else: + self._refresh_tokens.pop(token.token, None) + self._refresh_to_pat.pop(token.token, None) + + def get_pat_for_access_token(self, token: str) -> str | None: + return self._pat_by_access.get(token) + + def issue_auth_code( + self, + *, + client_id: str, + redirect_uri: AnyUrl, + code_challenge: str, + scopes: list[str], + resource: str | None, + state: str | None, + brainpat: str, + ) -> str: + code = secrets.token_urlsafe(48) + ac = AuthorizationCode( + code=code, + scopes=scopes, + expires_at=time.time() + self._auth_code_ttl, + client_id=client_id, + code_challenge=code_challenge, + redirect_uri=redirect_uri, + redirect_uri_provided_explicitly=True, + resource=resource, + ) + self._auth_codes[code] = ac + self._code_to_pat[code] = brainpat + return code + + def redirect_after_consent( + self, + *, + redirect_uri: str, + code: str, + state: str | None, + ) -> str: + params: dict[str, str] = {"code": code} + if state: + params["state"] = state + return construct_redirect_uri(redirect_uri, **params)