Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 14 additions & 88 deletions src/mcp_github/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,32 +59,23 @@ async def verify_token(self, token: str) -> AccessToken | None:
return AccessToken(
token=token,
client_id="github_token",
expires_at=None, # API keys don't expire
expires_at=None,
scopes=["api:read", "api:write"],
claims={"authenticated": True},
)
return None


class _PermissiveGitHubProvider(GitHubProvider):
"""GitHubProvider that accepts the upstream client_id without prior DCR.

Claude.ai and similar MCP clients skip Dynamic Client Registration and
send the GitHub OAuth App's client_id directly in /authorize requests.
On first use, this subclass auto-registers that client_id in the proxy's
client store so the full OAuth flow can proceed normally.
"""
"""GitHubProvider that auto-registers the upstream client_id when MCP clients skip DCR."""

async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
client = await super().get_client(client_id)
if client is not None:
return client

if client_id == self._upstream_client_id:
logging.info(
"Auto-registering upstream client_id %s (MCP client skipped DCR)",
client_id,
)
logging.info("Auto-registering upstream client_id %s (MCP client skipped DCR)", client_id)
await self.register_client(
OAuthClientInformationFull(
client_id=client_id,
Expand All @@ -99,48 +90,25 @@ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
return None


def _parse_redis_db(path: str) -> int:
"""Parse the database index from a Redis URI path component."""
db_path = path.lstrip("/")
if db_path and not db_path.isdigit():
raise ValueError(f"Invalid Redis database in URI: {db_path!r} (must be a non-negative integer)")
return int(db_path) if db_path else 0


def _build_redis_client(host_port: str) -> AsyncRedis:
"""Build an AsyncRedis client from a host:port string or Redis URI."""
uri = host_port if "://" in host_port else f"redis://{host_port}"
parsed = urlparse(uri)
db_path = parsed.path.lstrip("/")
if db_path and not db_path.isdigit():
raise ValueError(f"Invalid Redis database in URI: {db_path!r} (must be a non-negative integer)")
return AsyncRedis(
host=parsed.hostname or "localhost",
port=parsed.port or 6379,
db=_parse_redis_db(parsed.path),
db=int(db_path) if db_path else 0,
password=parsed.password or REDIS_PASSWORD or None,
ssl=parsed.scheme == "rediss",
decode_responses=True,
)


def build_token_store() -> AsyncKeyValue:
"""
Return a token store for OAuth state.

When REDIS_HOST_PORT is set, returns a RedisStore whose collection names are
prefixed with a 12-char SHA-256 hash of GITHUB_OAUTH_BASE_URL. Two server
instances sharing the same Redis instance will have fully isolated keyspaces
provided their base URLs differ.

When REDIS_HOST_PORT is unset, returns an in-process MemoryStore. No tokens
are written to disk in either case. Sessions are lost on server restart in
MemoryStore mode.

REDIS_HOST_PORT accepts either a bare host:port or a full URI:
redis://[:<password>@]<host>:<port>[/<db>] — plaintext
rediss://[:<password>@]<host>:<port>[/<db>] — TLS
REDIS_PASSWORD is used as a fallback when not embedded in the URI.
The database defaults to 0 when not specified in the URI.

"""
"""Return a token store for OAuth state. MemoryStore by default; RedisStore when REDIS_HOST_PORT is set."""
if REDIS_HOST_PORT:
store: AsyncKeyValue = RedisStore(client=_build_redis_client(REDIS_HOST_PORT))
if GITHUB_OAUTH_BASE_URL:
Expand All @@ -151,48 +119,19 @@ def build_token_store() -> AsyncKeyValue:


def _derive_jwt_signing_key() -> bytes:
"""
Return a stable JWT signing key.

Priority:
1. ``JWT_SIGNING_KEY`` env var (explicit override).
2. Deterministic derivation from ``GITHUB_OAUTH_CLIENT_SECRET``
(automatic — all pods with the same secret share the same key).

When the automatic path is used, rotating the GitHub OAuth App
secret invalidates all stored sessions and forces clients to
re-authenticate.

"""
"""Return a stable JWT signing key from JWT_SIGNING_KEY or GITHUB_OAUTH_CLIENT_SECRET."""
if JWT_SIGNING_KEY:
return derive_jwt_key(
low_entropy_material=JWT_SIGNING_KEY,
salt="fastmcp-jwt-signing-key",
)
return derive_jwt_key(
high_entropy_material=GITHUB_OAUTH_CLIENT_SECRET, # type: ignore[arg-type]
salt="fastmcp-jwt-signing-key",
)
return derive_jwt_key(low_entropy_material=JWT_SIGNING_KEY, salt="fastmcp-jwt-signing-key")
return derive_jwt_key(high_entropy_material=GITHUB_OAUTH_CLIENT_SECRET, salt="fastmcp-jwt-signing-key") # type: ignore[arg-type]


def get_oauth_verifier() -> _PermissiveGitHubProvider:
"""Return a PermissiveGitHubProvider instance for OAuth2 authentication.

Requires GITHUB_OAUTH_CLIENT_ID, GITHUB_OAUTH_CLIENT_SECRET, and
GITHUB_OAUTH_BASE_URL to be set.

JWT_SIGNING_KEY is optional. When omitted, a stable key is derived
automatically from GITHUB_OAUTH_CLIENT_SECRET so all pods generate
the same signing key without requiring an additional env var.

"""
"""Return a PermissiveGitHubProvider instance for OAuth2 authentication."""
if not all((GITHUB_OAUTH_CLIENT_ID, GITHUB_OAUTH_CLIENT_SECRET, GITHUB_OAUTH_BASE_URL)):
raise ValueError(
"GITHUB_OAUTH_CLIENT_ID, GITHUB_OAUTH_CLIENT_SECRET, and "
"GITHUB_OAUTH_BASE_URL must all be set to use OAuth2 auth"
"GITHUB_OAUTH_CLIENT_ID, GITHUB_OAUTH_CLIENT_SECRET, and GITHUB_OAUTH_BASE_URL must all be set"
)

# Validate types after check (pyright doesn't narrow through all() check)
return _PermissiveGitHubProvider(
client_id=GITHUB_OAUTH_CLIENT_ID, # type: ignore[arg-type]
client_secret=GITHUB_OAUTH_CLIENT_SECRET, # type: ignore[arg-type]
Expand All @@ -204,20 +143,7 @@ def get_oauth_verifier() -> _PermissiveGitHubProvider:


def resolve_token(github_token: str | None, oauth_mode: bool) -> str:
"""
Return the token to use for the current request.

In OAuth2 mode, reads the authenticated user's token from FastMCP's
per-request context. Falls back to the static github_token in all other
cases (stdio mode or API-key mode).

Raises
------
RuntimeError
In OAuth2 mode when no access token is available in
the request context and no GITHUB_TOKEN fallback is configured.

"""
"""Return the token for the current request."""
if oauth_mode:
access_token = get_access_token()
if access_token is not None:
Expand Down
57 changes: 6 additions & 51 deletions src/mcp_github/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,14 @@
# * limitations under the License.
# */

"""
Custom exceptions for MCP GitHub integration.

This module defines a hierarchy of exceptions for handling errors from
GitHub API and IP info services in a structured way.
"""
"""Custom exceptions for MCP GitHub integration."""

from __future__ import annotations


class MCPGitHubError(Exception):
"""Base exception for MCP GitHub integration."""

def __init__(self, message: str, code: str = "ERROR"):
"""Initialize MCPGitHubError."""
super().__init__(message)
self.message = message
self.code = code

def __str__(self) -> str:
return f"[{self.code}] {self.message}"


class GitHubAPIError(MCPGitHubError):
"""GitHub API returned an error."""
Expand All @@ -49,67 +35,36 @@ def __init__(
response_body: dict | None = None,
code: str = "GITHUB_API_ERROR",
):
super().__init__(message, code)
prefix = f"[{code}] HTTP {status_code}: " if status_code else f"[{code}] "
super().__init__(f"{prefix}{message}")
self.status_code = status_code
self.response_body = response_body

def __str__(self) -> str:
if self.status_code:
return f"[{self.code}] HTTP {self.status_code}: {self.message}"
return super().__str__()
self.code = code


class GitHubAuthError(GitHubAPIError):
# fmt: off

"""
Authentication failed (401).

Inherits from GitHubAPIError because a 401 response is still an HTTP API
response -- it follows the same status_code + response_body pattern.
"""

# fmt: on

def __init__(
self,
message: str = "Authentication failed. Check your GitHub token.",
response_body: dict | None = None,
self, message: str = "Authentication failed. Check your GitHub token.", response_body: dict | None = None
):
"""Initialize GitHubAuthError."""
super().__init__(message, status_code=401, response_body=response_body, code="AUTH_FAILED")


class GitHubRateLimitError(GitHubAPIError):
"""Rate limit exceeded (403)."""

def __init__(
self,
message: str = "GitHub API rate limit exceeded.",
response_body: dict | None = None,
reset_timestamp: int | None = None,
):
"""Initialize GitHubRateLimitError."""
super().__init__(message, status_code=403, response_body=response_body, code="RATE_LIMITED")
self.reset_timestamp = reset_timestamp


class GitHubNotFoundError(GitHubAPIError):
"""Resource not found (404)."""

def __init__(self, message: str, response_body: dict | None = None):
"""Initialize GitHubNotFoundError."""
super().__init__(message, status_code=404, response_body=response_body, code="NOT_FOUND")


class GitHubValidationError(GitHubAPIError):
"""Validation failed (422)."""

def __init__(self, message: str = "Validation failed.", response_body: dict | None = None):
"""Initialize GitHubValidationError."""
super().__init__(
message,
status_code=422,
response_body=response_body,
code="VALIDATION_ERROR",
)
super().__init__(message, status_code=422, response_body=response_body, code="VALIDATION_ERROR")
Loading