diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index b296cdfa3f..74024fa3c0 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -1410,6 +1410,50 @@ def validate_password_strength(password: str) -> tuple[bool, str]: ADMIN_CSRF_FORM_FIELD = "csrf_token" +def _resolve_root_path(request: Request) -> str: + """Resolve the application root path from the request scope with fallback. + + Some embedded/proxy deployments do not populate ``scope["root_path"]`` + consistently. This helper checks the ASGI scope first and falls back + to ``settings.app_root_path`` when the scope value is empty. + + Args: + request: Incoming request used to read ASGI ``root_path``. + + Returns: + Normalized root path (leading ``/``, no trailing ``/``), or empty + string when no root path is configured. + """ + root_path = request.scope.get("root_path", "") or "" + if not root_path or not str(root_path).strip(): + root_path = settings.app_root_path or "" + root_path = str(root_path).strip() + if root_path: + root_path = "/" + root_path.lstrip("/") + return root_path.rstrip("/") + + +def _is_safe_local_path(path: str) -> bool: + """Validate that a path is a safe local redirect target (no open redirect). + + Args: + path: The path to validate. + + Returns: + True if the path is a safe relative path starting with ``/``. + """ + if not path or not isinstance(path, str): + return False + if not path.startswith("/"): + return False + # Block protocol-relative URLs (//evil.com), authority injection (@), backslash tricks + if path.startswith("//") or "@" in path or "\\" in path: + return False + parsed = urllib.parse.urlparse(path) + if parsed.scheme or parsed.netloc: + return False + return True + def _admin_cookie_path(request: Request) -> str: """Build admin cookie path honoring ASGI root_path. @@ -2928,6 +2972,8 @@ async def admin_add_server(request: Request, db: Session = Depends(get_db), user visibility=visibility, oauth_enabled=oauth_enabled, oauth_config=oauth_config, + server_type="meta" if form.get("meta_server_enabled") else str(form.get("server_type", "standard")), + hide_underlying_tools=form.get("hide_underlying_tools") == "true" or form.get("hide_underlying_tools") == "on", ) except KeyError as e: # Convert KeyError to ValidationError-like response @@ -3091,6 +3137,8 @@ async def admin_edit_server( owner_email=user_email, oauth_enabled=oauth_enabled, oauth_config=oauth_config, + server_type="meta" if form.get("meta_server_enabled") else str(form.get("server_type", "standard")), + hide_underlying_tools=form.get("hide_underlying_tools") == "true" or form.get("hide_underlying_tools") == "on", ) await server_service.update_server( @@ -4161,6 +4209,21 @@ async def admin_login_page(request: Request) -> Response: response.delete_cookie("jwt_token", path="/") response.delete_cookie("access_token", path="/") + # Preserve ?next= parameter as a short-lived cookie so SSO callback can redirect + # back to the original URL (e.g. /oauth/authorize/{gateway_id}) after login. + next_url = request.query_params.get("next", "") + if next_url and _is_safe_local_path(next_url): + use_secure = (settings.environment == "production") or settings.secure_cookies + response.set_cookie( + key="post_login_next", + value=next_url, + max_age=300, # 5 minutes — enough for SSO round-trip + httponly=True, + secure=use_secure, + samesite=settings.cookie_samesite, + path=settings.app_root_path or "/", + ) + return response diff --git a/mcpgateway/admin_ui/formSubmitHandlers.js b/mcpgateway/admin_ui/formSubmitHandlers.js index 31ee784ec9..30098870ba 100644 --- a/mcpgateway/admin_ui/formSubmitHandlers.js +++ b/mcpgateway/admin_ui/formSubmitHandlers.js @@ -438,6 +438,17 @@ export const handleServerFormSubmit = async function (e) { } } + // Handle Meta-Server configuration + const metaEnabledCheckbox = safeGetElement("server-meta-enabled"); + if (metaEnabledCheckbox && metaEnabledCheckbox.checked) { + formData.set("server_type", "meta"); + const hideToolsCheckbox = safeGetElement("server-hide-underlying-tools"); + formData.set("hide_underlying_tools", hideToolsCheckbox && hideToolsCheckbox.checked ? "true" : "false"); + } else { + formData.set("server_type", "standard"); + formData.delete("hide_underlying_tools"); + } + const response = await fetch(`${window.ROOT_PATH}/admin/servers`, { method: "POST", body: formData, @@ -1008,6 +1019,17 @@ export const handleEditServerFormSubmit = async function (e) { } }); + // Handle Meta-Server configuration + const metaEnabledCheckbox = safeGetElement("edit-server-meta-enabled"); + if (metaEnabledCheckbox && metaEnabledCheckbox.checked) { + formData.set("server_type", "meta"); + const hideToolsCheckbox = safeGetElement("edit-server-hide-underlying-tools"); + formData.set("hide_underlying_tools", hideToolsCheckbox && hideToolsCheckbox.checked ? "true" : "false"); + } else { + formData.set("server_type", "standard"); + formData.delete("hide_underlying_tools"); + } + // Submit via fetch const response = await fetch(form.action, { method: "POST", diff --git a/mcpgateway/admin_ui/servers.js b/mcpgateway/admin_ui/servers.js index da68d440e8..d99354d50c 100644 --- a/mcpgateway/admin_ui/servers.js +++ b/mcpgateway/admin_ui/servers.js @@ -864,6 +864,47 @@ export const editServer = async function (serverId) { if (oauthTokenEndpointField) oauthTokenEndpointField.value = ""; } + // Set Meta-Server configuration fields + const metaEnabledCheckbox = safeGetElement("edit-server-meta-enabled"); + const metaConfigSection = safeGetElement("edit-server-meta-config-section"); + const hideUnderlyingToolsCheckbox = safeGetElement("edit-server-hide-underlying-tools"); + const isMeta = server.serverType === "meta" || server.server_type === "meta"; + + if (metaEnabledCheckbox) { + metaEnabledCheckbox.checked = isMeta; + } + if (metaConfigSection) { + if (isMeta) { + metaConfigSection.classList.remove("hidden"); + } else { + metaConfigSection.classList.add("hidden"); + } + } + if (hideUnderlyingToolsCheckbox) { + const hideTools = server.hideUnderlyingTools !== undefined + ? server.hideUnderlyingTools + : (server.hide_underlying_tools !== undefined ? server.hide_underlying_tools : true); + hideUnderlyingToolsCheckbox.checked = isMeta ? hideTools : true; + } + + // Toggle gateways+tools wrapper and info banner based on meta-server mode + const editGatewaysAndTools = safeGetElement("edit-server-gateways-and-tools"); + const editMetaInfoBanner = safeGetElement("edit-meta-info-banner"); + if (editGatewaysAndTools) { + if (isMeta) { + editGatewaysAndTools.classList.add("hidden"); + } else { + editGatewaysAndTools.classList.remove("hidden"); + } + } + if (editMetaInfoBanner) { + if (isMeta) { + editMetaInfoBanner.classList.remove("hidden"); + } else { + editMetaInfoBanner.classList.add("hidden"); + } + } + // Store server data for modal population window.Admin.currentEditingServer = server; diff --git a/mcpgateway/alembic/versions/5126ced48fd0_add_meta_server_fields.py b/mcpgateway/alembic/versions/5126ced48fd0_add_meta_server_fields.py new file mode 100644 index 0000000000..7c36414d34 --- /dev/null +++ b/mcpgateway/alembic/versions/5126ced48fd0_add_meta_server_fields.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +"""Add meta-server fields to servers table + +Revision ID: 5126ced48fd0 +Revises: 64acf94cb7f2 +Create Date: 2026-02-12 10:00:00.000000 + +""" + +# Third-Party +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "5126ced48fd0" +down_revision = "64acf94cb7f2" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Add server_type, hide_underlying_tools, meta_config, and meta_scope columns to servers.""" + inspector = sa.inspect(op.get_bind()) + + # Skip if table doesn't exist (fresh DB uses db.py models directly) + if "servers" not in inspector.get_table_names(): + return + + columns = [col["name"] for col in inspector.get_columns("servers")] + + if "server_type" not in columns: + op.add_column("servers", sa.Column("server_type", sa.String(20), nullable=False, server_default="standard")) + + if "hide_underlying_tools" not in columns: + op.add_column("servers", sa.Column("hide_underlying_tools", sa.Boolean(), nullable=False, server_default=sa.text("true"))) + + if "meta_config" not in columns: + op.add_column("servers", sa.Column("meta_config", sa.JSON(), nullable=True)) + + if "meta_scope" not in columns: + op.add_column("servers", sa.Column("meta_scope", sa.JSON(), nullable=True)) + + +def downgrade() -> None: + """Remove meta-server fields from servers table.""" + inspector = sa.inspect(op.get_bind()) + + if "servers" not in inspector.get_table_names(): + return + + columns = [col["name"] for col in inspector.get_columns("servers")] + + if "meta_scope" in columns: + op.drop_column("servers", "meta_scope") + if "meta_config" in columns: + op.drop_column("servers", "meta_config") + if "hide_underlying_tools" in columns: + op.drop_column("servers", "hide_underlying_tools") + if "server_type" in columns: + op.drop_column("servers", "server_type") diff --git a/mcpgateway/common/validators.py b/mcpgateway/common/validators.py index a4eadcd388..3e954d7482 100644 --- a/mcpgateway/common/validators.py +++ b/mcpgateway/common/validators.py @@ -51,7 +51,6 @@ from functools import lru_cache from html.parser import HTMLParser import ipaddress -import json import logging from pathlib import Path import re @@ -78,7 +77,9 @@ _HTML_SPECIAL_CHARS_RE: Pattern[str] = re.compile(r'[<>"\']') # / removed per SEP-986 _DANGEROUS_TEMPLATE_TAGS_RE: Pattern[str] = re.compile(r"<(script|iframe|object|embed|link|meta|base|form)\b", re.IGNORECASE) _EVENT_HANDLER_RE: Pattern[str] = re.compile(r"on\w+\s*=", re.IGNORECASE) -_MIME_TYPE_RE: Pattern[str] = re.compile(r'^[a-zA-Z0-9][a-zA-Z0-9!#$&\-\^_+\.]*\/[a-zA-Z0-9][a-zA-Z0-9!#$&\-\^_+\.]*(?:\s*;\s*[a-zA-Z0-9!#$&\-\^_+\.]+=(?:[a-zA-Z0-9!#$&\-\^_+\.]+|"[^"\r\n]*"))*$') +_MIME_TYPE_RE: Pattern[str] = re.compile( + r'^[a-zA-Z0-9][a-zA-Z0-9!#$&\-\^_+\.]*\/[a-zA-Z0-9][a-zA-Z0-9!#$&\-\^_+\.]*(?:\s*;\s*[a-zA-Z0-9!#$&\-\^_+\.]+=(?:[a-zA-Z0-9!#$&\-\^_+\.]+|"[^"\r\n]*"))*$' +) _URI_SCHEME_RE: Pattern[str] = re.compile(r"^[a-zA-Z][a-zA-Z0-9+\-.]*://") _SHELL_DANGEROUS_CHARS_RE: Pattern[str] = re.compile(r"[;&|`$(){}\[\]<>]") _ANSI_ESCAPE_RE: Pattern[str] = re.compile(r"\x1B\[[0-9;]*[A-Za-z]") @@ -1949,60 +1950,3 @@ def validate_core_url(value: str, field_name: str = "URL") -> str: The validated URL string. """ return SecurityValidator.validate_url(value, field_name) - - -# CWE-400: Limits for user-supplied meta_data forwarded to upstream MCP servers. -# Keeps arbitrarily large dicts from amplifying into downstream network/DB load. -# These are now read from config (settings.meta_max_keys, etc.) but kept as -# module-level aliases for backward-compatible imports. -META_MAX_KEYS: int = settings.meta_max_keys -META_MAX_DEPTH: int = settings.meta_max_depth -META_MAX_BYTES: int = settings.meta_max_bytes - - -def validate_meta_data(meta_data: Optional[Dict[str, Any]]) -> None: - """Enforce size, key-count, and depth limits on user-supplied meta_data (CWE-400). - - Args: - meta_data: The metadata dictionary to validate. ``None`` is always accepted. - - Raises: - ValueError: if any limit is exceeded. - """ - max_keys = settings.meta_max_keys - max_depth = settings.meta_max_depth - max_bytes = settings.meta_max_bytes - - if not meta_data: - return - if len(meta_data) > max_keys: - raise ValueError(f"meta_data exceeds maximum key count ({max_keys}): got {len(meta_data)}") - - def _check_depth(obj: Any, depth: int) -> None: - """Recursively enforce nesting depth, traversing both dicts and lists (CWE-400). - - Lists are traversed without incrementing the depth counter so that a - list-of-dicts does not hide an extra level of dict nesting — e.g. - ``{"k": [{"l2": {"l3": "x"}}]}`` is correctly caught as depth 3. - """ - if depth > max_depth: - raise ValueError(f"meta_data exceeds maximum nesting depth ({max_depth})") - if isinstance(obj, dict): - for v in obj.values(): - _check_depth(v, depth + 1) - elif isinstance(obj, list): - for item in obj: - _check_depth(item, depth) - - for v in meta_data.values(): - _check_depth(v, 1) - - try: - # CWE-20: Use strict json.dumps (no default=str) so non-serializable objects - # raise TypeError rather than being silently coerced — keeps the byte limit - # meaningful and matches the strict rejection behaviour used in prompt_service. - size = len(json.dumps(meta_data)) - if size > max_bytes: - raise ValueError(f"meta_data exceeds maximum size ({max_bytes} bytes): got {size}") - except (TypeError, ValueError) as exc: - raise ValueError(f"meta_data is not serializable: {exc}") from exc diff --git a/mcpgateway/config.py b/mcpgateway/config.py index 47f0d76803..aba1e3e9c2 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -410,9 +410,6 @@ class Settings(BaseSettings): allowed_roots: List[str] = Field(default_factory=list, description="Allowed root paths for resource access") max_path_depth: int = Field(default=10, description="Maximum allowed path depth") max_param_length: int = Field(default=10000, description="Maximum parameter length") - meta_max_keys: int = Field(default=16, description="Maximum number of keys in user-supplied meta_data forwarded to upstream MCP servers (CWE-400)") - meta_max_depth: int = Field(default=2, description="Maximum nesting depth for user-supplied meta_data forwarded to upstream MCP servers (CWE-400)") - meta_max_bytes: int = Field(default=4096, description="Maximum JSON-encoded byte size for user-supplied meta_data forwarded to upstream MCP servers (CWE-400)") dangerous_patterns: List[str] = Field( default_factory=lambda: [ r"[;&|`$(){}\[\]<>]", # Shell metacharacters @@ -1725,6 +1722,7 @@ def parse_issuers(cls, v: Any) -> list[str]: "Longer responses are truncated to prevent exposing excessive sensitive data. " "Default: 5000 characters. Range: 1000-100000.", ) + semantic_search_rate_limit: int = 30 # requests per minute for semantic search # Content Security - Size Limits content_max_resource_size: int = Field(default=102400, ge=1024, le=10485760, description="Maximum size in bytes for resource content (default: 100KB)") # 100KB # Minimum 1KB # Maximum 10MB diff --git a/mcpgateway/db.py b/mcpgateway/db.py index 7a1f3471f5..9276c33da7 100644 --- a/mcpgateway/db.py +++ b/mcpgateway/db.py @@ -44,6 +44,7 @@ from sqlalchemy.types import TypeDecorator # First-Party +from mcpgateway.utils.pgvector import HAS_PGVECTOR, Vector from mcpgateway.common.validators import SecurityValidator from mcpgateway.config import settings from mcpgateway.utils.create_slug import slugify @@ -3256,6 +3257,9 @@ class Tool(Base): viewonly=True, ) + # Embeddings relationship + embeddings: Mapped[List["ToolEmbedding"]] = relationship("ToolEmbedding", back_populates="tool", cascade="all, delete-orphan") + # Team scoping fields for resource organization team_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey("email_teams.id", ondelete="SET NULL"), nullable=True) owner_email: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) @@ -3552,6 +3556,124 @@ def metrics_summary(self) -> Dict[str, Any]: ) +class ToolEmbedding(Base): + """ORM model for tool embedding vectors used in semantic search.""" + + __tablename__ = "tool_embeddings" + + id: Mapped[str] = mapped_column( + String(36), + primary_key=True, + default=lambda: str(uuid.uuid4()), + nullable=False, + ) + tool_id: Mapped[str] = mapped_column( + String(36), + ForeignKey("tools.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + + if HAS_PGVECTOR: + embedding: Mapped[list[float]] = mapped_column( + Vector(getattr(settings, "embedding_dim", 1536)), + nullable=False, + ) + else: + embedding: Mapped[list[float]] = mapped_column( + JSON, + nullable=False, + ) + + model_name: Mapped[str] = mapped_column( + String(255), + nullable=False, + default="text-embedding-3-small", + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + nullable=False, + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ) + + tool: Mapped["Tool"] = relationship( + "Tool", + back_populates="embeddings", + passive_deletes=True, + ) + + if HAS_PGVECTOR: + __table_args__ = ( + Index( + "idx_tool_embeddings_hnsw", + "embedding", + postgresql_using="hnsw", + postgresql_with={"m": getattr(settings, "hnsw_m", 16), "ef_construction": getattr(settings, "hnsw_ef_construction", 64)}, + postgresql_ops={"embedding": "vector_cosine_ops"}, + ), + Index("idx_tool_embeddings_toolid_model", "tool_id", "model_name"), + Index("idx_tool_embeddings_toolid_created", "tool_id", "created_at"), + ) + else: + __table_args__ = ( + Index("idx_tool_embeddings_toolid_model", "tool_id", "model_name"), + Index("idx_tool_embeddings_toolid_created", "tool_id", "created_at"), + ) + + def __repr__(self) -> str: + return f"" + + def similar_to( + self, + db: "Session", + limit: int = 10, + threshold: Optional[float] = None, + ) -> "List[tuple[ToolEmbedding, float]]": + """Find other ToolEmbeddings similar to this one. + + Uses pgvector cosine distance on PostgreSQL, numpy fallback on SQLite. + + Args: + db: Active database session. + limit: Maximum number of results. + threshold: Optional minimum similarity (0-1). + + Returns: + List of (ToolEmbedding, similarity_score) tuples, ordered by + descending similarity. Does not include self. + """ + dialect_name = db.get_bind().dialect.name + + if dialect_name == "postgresql" and HAS_PGVECTOR: + distance_expr = ToolEmbedding.embedding.cosine_distance(self.embedding) + query = select(ToolEmbedding, (1 - distance_expr).label("similarity")).filter(ToolEmbedding.id != self.id) + if threshold is not None: + query = query.filter(distance_expr <= (1 - threshold)) + query = query.order_by(distance_expr.asc()).limit(limit) + rows = db.execute(query).all() + return [(te, max(0.0, min(1.0, float(sim)))) for te, sim in rows] + else: + # SQLite/non-pgvector: compute cosine similarity in Python + # First-Party + from mcpgateway.services.vector_search_service import _cosine_similarity_numpy + + all_embeddings = db.query(ToolEmbedding).filter(ToolEmbedding.id != self.id).all() + scored = [] + for te in all_embeddings: + sim = _cosine_similarity_numpy(self.embedding, te.embedding) + if threshold is not None and sim < threshold: + continue + scored.append((te, sim)) + scored.sort(key=lambda x: x[1], reverse=True) + return scored[:limit] + + class Resource(Base): """ ORM model for a registered Resource. @@ -4533,6 +4655,16 @@ def metrics_summary(self) -> Dict[str, Any]: oauth_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) oauth_config: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + # Meta-server fields + # server_type: 'standard' (default) or 'meta' (exposes meta-tools instead of real tools) + server_type: Mapped[str] = mapped_column(String(20), nullable=False, default="standard") + # When True, underlying tools are hidden from tool listing endpoints + hide_underlying_tools: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + # JSON configuration for meta-server behavior (MetaConfig schema) + meta_config: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + # JSON scope rules for filtering which tools are visible (MetaToolScope schema) + meta_scope: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + # Relationship for loading team names (only active teams) # Uses default lazy loading - team name is only loaded when accessed # For list/admin views, use explicit joinedload(DbServer.email_team) for single-query loading diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 30596a4740..756620aac7 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -120,6 +120,7 @@ stop_plugin_invalidation_listener, ) from mcpgateway.plugins.framework.constants import PLUGIN_VIOLATION_CODE_MAPPING, PluginViolationCode, VALID_HTTP_STATUS_CODES +from mcpgateway.routers.meta_router import router as meta_router from mcpgateway.routers.server_well_known import router as server_well_known_router from mcpgateway.routers.well_known import router as well_known_router from mcpgateway.schemas import ( @@ -11523,6 +11524,7 @@ async def cleanup_import_statuses(max_age_hours: int = 24, user=Depends(get_curr app.include_router(metrics_router) app.include_router(tag_router) app.include_router(export_import_router) +app.include_router(meta_router) # Tool plugin bindings router try: diff --git a/mcpgateway/meta_server/__init__.py b/mcpgateway/meta_server/__init__.py new file mode 100644 index 0000000000..e17f4d923d --- /dev/null +++ b/mcpgateway/meta_server/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/meta_server/__init__.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Meta-Server package for MCP Gateway. + +This package implements the Virtual Meta-Server feature, which exposes a fixed set +of meta-tools (search_tools, list_tools, describe_tool, execute_tool, +get_tool_categories, get_similar_tools) instead of the underlying real tools. + +The meta-server provides: +- Unified tool discovery across federated servers +- Scope-based filtering configuration +- Configurable meta-tool behavior via MetaConfig +""" diff --git a/mcpgateway/meta_server/schemas.py b/mcpgateway/meta_server/schemas.py new file mode 100644 index 0000000000..456aedeeba --- /dev/null +++ b/mcpgateway/meta_server/schemas.py @@ -0,0 +1,881 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/meta_server/schemas.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Meta-Server Schema Definitions. + +This module defines Pydantic models for the Meta-Server feature including: +- Scope configuration for filtering tools across servers +- Meta-server configuration options +- Request/response contracts for all six meta-tools +- Server type enumeration supporting 'meta' type + +These are contract-only definitions. Business logic is NOT implemented here. + +Examples: + >>> from mcpgateway.meta_server.schemas import MetaToolScope, MetaConfig + >>> scope = MetaToolScope(include_tags=["production"], exclude_servers=["legacy-server"]) + >>> scope.include_tags + ['production'] + >>> config = MetaConfig(enable_semantic_search=True, default_search_limit=25) + >>> config.default_search_limit + 25 +""" + +# Standard +from enum import Enum +from typing import Any, Dict, List, Optional + +# Third-Party +from pydantic import BaseModel, ConfigDict, Field, field_validator + +# First-Party +from mcpgateway.utils.base_models import BaseModelWithConfigDict + +# Server Type Enum + + +class ServerType(str, Enum): + """Enumeration of supported virtual server types. + + Attributes: + STANDARD: A standard virtual server that directly exposes associated tools. + META: A meta-server that exposes meta-tools for tool discovery and execution + instead of exposing underlying tools directly. + + Examples: + >>> ServerType.STANDARD.value + 'standard' + >>> ServerType.META.value + 'meta' + >>> ServerType("meta") == ServerType.META + True + """ + + STANDARD = "standard" + META = "meta" + + +# Scope Configuration + + +class MetaToolScope(BaseModelWithConfigDict): + """Scope configuration for filtering which tools are visible through a meta-server. + + This model defines the filtering rules that determine which underlying tools + are accessible via meta-tool operations. Multiple filter fields combine with + AND semantics (all conditions must match). + + Attributes: + include_tags: Only include tools with at least one of these tags. + exclude_tags: Exclude tools that have any of these tags. + include_servers: Only include tools from these server IDs. + exclude_servers: Exclude tools from these server IDs. + include_visibility: Only include tools with these visibility levels. + include_teams: Only include tools belonging to these team IDs. + name_patterns: Only include tools whose names match one of these glob patterns. + + Examples: + >>> scope = MetaToolScope( + ... include_tags=["production", "stable"], + ... exclude_tags=["deprecated"], + ... include_servers=["server-1", "server-2"], + ... ) + >>> scope.include_tags + ['production', 'stable'] + >>> scope.exclude_tags + ['deprecated'] + >>> empty_scope = MetaToolScope() + >>> empty_scope.include_tags + [] + """ + + include_tags: List[str] = Field(default_factory=list, description="Only include tools with at least one of these tags") + exclude_tags: List[str] = Field(default_factory=list, description="Exclude tools that have any of these tags") + include_servers: List[str] = Field(default_factory=list, description="Only include tools from these server IDs") + exclude_servers: List[str] = Field(default_factory=list, description="Exclude tools from these server IDs") + include_visibility: List[str] = Field(default_factory=list, description="Only include tools with these visibility levels (private, team, public)") + include_teams: List[str] = Field(default_factory=list, description="Only include tools belonging to these team IDs") + name_patterns: List[str] = Field(default_factory=list, description="Only include tools whose names match one of these glob patterns") + + @field_validator("include_visibility") + @classmethod + def validate_visibility_values(cls, v: List[str]) -> List[str]: + """Validate that visibility values are valid. + + Args: + v: List of visibility values to validate. + + Returns: + Validated list of visibility values. + + Raises: + ValueError: If any visibility value is invalid. + + Examples: + >>> MetaToolScope.validate_visibility_values(["public", "team"]) + ['public', 'team'] + """ + valid_values = {"private", "team", "public"} + for value in v: + if value not in valid_values: + raise ValueError(f"Invalid visibility value '{value}'. Must be one of: {valid_values}") + return v + + +# Meta Configuration + + +class MetaConfig(BaseModelWithConfigDict): + """Configuration options for meta-server behavior. + + Controls which meta-tool features are enabled and sets operational limits. + + Attributes: + enable_semantic_search: Whether semantic search is available via search_tools. + enable_categories: Whether tool categorization is available via get_tool_categories. + enable_similar_tools: Whether similar tool discovery is available via get_similar_tools. + default_search_limit: Default maximum number of results returned by search operations. + max_search_limit: Hard upper limit for search results regardless of request parameters. + include_metrics_in_search: Whether to include execution metrics in search results. + + Examples: + >>> config = MetaConfig() + >>> config.enable_semantic_search + False + >>> config.default_search_limit + 50 + >>> config.max_search_limit + 200 + >>> custom = MetaConfig(default_search_limit=10, max_search_limit=100) + >>> custom.default_search_limit + 10 + """ + + enable_semantic_search: bool = Field(False, description="Whether semantic search is available via search_tools") + enable_categories: bool = Field(False, description="Whether tool categorization is available via get_tool_categories") + enable_similar_tools: bool = Field(False, description="Whether similar tool discovery is available via get_similar_tools") + default_search_limit: int = Field(50, ge=1, le=1000, description="Default maximum number of results returned by search operations") + max_search_limit: int = Field(200, ge=1, le=10000, description="Hard upper limit for search results regardless of request parameters") + include_metrics_in_search: bool = Field(False, description="Whether to include execution metrics in search results") + + @field_validator("max_search_limit") + @classmethod + def validate_max_gte_default(cls, v: int, info: Any) -> int: + """Ensure max_search_limit is >= default_search_limit. + + Args: + v: The max_search_limit value. + info: Validation info containing other field values. + + Returns: + Validated max_search_limit value. + + Raises: + ValueError: If max_search_limit is less than default_search_limit. + + Examples: + >>> MetaConfig(default_search_limit=50, max_search_limit=200) # Valid + MetaConfig(enable_semantic_search=False, enable_categories=False, enable_similar_tools=False, default_search_limit=50, max_search_limit=200, include_metrics_in_search=False) + """ + default_limit = info.data.get("default_search_limit", 50) + if v < default_limit: + raise ValueError(f"max_search_limit ({v}) must be >= default_search_limit ({default_limit})") + return v + + +# Meta-Tool Request/Response Schemas (Contracts Only) + + +class SearchToolsRequest(BaseModelWithConfigDict): + """Request schema for the search_tools meta-tool. + + Attributes: + query: Search query string for finding tools. + limit: Maximum number of results to return. + offset: Number of results to skip for pagination. + tags: Optional tag filter to narrow results. + include_metrics: Whether to include execution metrics in results. + + Examples: + >>> req = SearchToolsRequest(query="database") + >>> req.query + 'database' + >>> req.limit + 50 + """ + + query: str = Field(..., min_length=1, max_length=500, description="Search query string for finding tools") + limit: int = Field(50, ge=1, le=1000, description="Maximum number of results to return") + offset: int = Field(0, ge=0, description="Number of results to skip for pagination") + tags: List[str] = Field(default_factory=list, description="Optional tag filter to narrow results") + include_metrics: bool = Field(False, description="Whether to include execution metrics in results") + + +class ToolSummary(BaseModelWithConfigDict): + """Summary representation of a tool in meta-tool responses. + + Attributes: + name: Tool name identifier. + description: Human-readable description of the tool. + server_id: ID of the server hosting this tool. + server_name: Name of the server hosting this tool. + tags: Tags associated with the tool. + input_schema: JSON Schema for the tool's input parameters. + metrics: Optional execution metrics for the tool. + + Examples: + >>> summary = ToolSummary(name="query_db", description="Run a DB query", server_id="s1", server_name="DB Server") + >>> summary.name + 'query_db' + """ + + name: str = Field(..., description="Tool name identifier") + description: Optional[str] = Field(None, description="Human-readable description of the tool") + server_id: Optional[str] = Field(None, description="ID of the server hosting this tool") + server_name: Optional[str] = Field(None, description="Name of the server hosting this tool") + tags: List[str] = Field(default_factory=list, description="Tags associated with the tool") + input_schema: Optional[Dict[str, Any]] = Field(None, description="JSON Schema for the tool's input parameters") + metrics: Optional[Dict[str, Any]] = Field(None, description="Optional execution metrics for the tool") + + +class SearchToolsResponse(BaseModelWithConfigDict): + """Response schema for the search_tools meta-tool. + + Attributes: + tools: List of matching tool summaries. + total_count: Total number of matching tools (before pagination). + query: The original query string. + has_more: Whether more results are available. + + Examples: + >>> resp = SearchToolsResponse(tools=[], total_count=0, query="test", has_more=False) + >>> resp.total_count + 0 + """ + + tools: List[ToolSummary] = Field(default_factory=list, description="List of matching tool summaries") + total_count: int = Field(0, ge=0, description="Total number of matching tools (before pagination)") + query: str = Field(..., description="The original query string") + has_more: bool = Field(False, description="Whether more results are available") + + +class ListToolsRequest(BaseModelWithConfigDict): + """Request schema for the list_tools meta-tool. + + Attributes: + limit: Maximum number of tools to return. + offset: Number of tools to skip for pagination. + tags: Optional tag filter. + server_id: Optional server ID filter. + include_metrics: Whether to include execution metrics. + sort_by: Field to sort by (name, created_at, execution_count). + sort_order: Sort order (asc, desc). + include_schema: Whether to include full input/output schemas. + + Examples: + >>> req = ListToolsRequest() + >>> req.limit + 50 + >>> req.offset + 0 + >>> req.sort_by + 'created_at' + """ + + limit: int = Field(50, ge=1, le=1000, description="Maximum number of tools to return") + offset: int = Field(0, ge=0, description="Number of tools to skip for pagination") + tags: List[str] = Field(default_factory=list, description="Optional tag filter") + server_id: Optional[str] = Field(None, description="Optional server ID filter") + include_metrics: bool = Field(False, description="Whether to include execution metrics") + sort_by: str = Field("created_at", description="Field to sort by (name, created_at, execution_count)") + sort_order: str = Field("desc", description="Sort order (asc, desc)") + include_schema: bool = Field(False, description="Whether to include full input/output schemas") + + @field_validator("sort_by") + @classmethod + def validate_sort_by(cls, v: str) -> str: + """Validate sort_by field. + + Args: + v: Sort by value. + + Returns: + Validated sort_by value. + + Raises: + ValueError: If sort_by value is invalid. + """ + valid_values = {"name", "created_at", "execution_count"} + if v not in valid_values: + raise ValueError(f"Invalid sort_by value '{v}'. Must be one of: {valid_values}") + return v + + @field_validator("sort_order") + @classmethod + def validate_sort_order(cls, v: str) -> str: + """Validate sort_order field. + + Args: + v: Sort order value. + + Returns: + Validated sort_order value. + + Raises: + ValueError: If sort_order value is invalid. + """ + valid_values = {"asc", "desc"} + if v not in valid_values: + raise ValueError(f"Invalid sort_order value '{v}'. Must be one of: {valid_values}") + return v + + +class ListToolsResponse(BaseModelWithConfigDict): + """Response schema for the list_tools meta-tool. + + Attributes: + tools: List of tool summaries. + total_count: Total number of tools matching the filter. + has_more: Whether more results are available. + + Examples: + >>> resp = ListToolsResponse(tools=[], total_count=0, has_more=False) + >>> resp.total_count + 0 + """ + + tools: List[ToolSummary] = Field(default_factory=list, description="List of tool summaries") + total_count: int = Field(0, ge=0, description="Total number of tools matching the filter") + has_more: bool = Field(False, description="Whether more results are available") + + +class DescribeToolRequest(BaseModelWithConfigDict): + """Request schema for the describe_tool meta-tool. + + Attributes: + tool_name: The name of the tool to describe. + include_metrics: Whether to include execution metrics. + + Examples: + >>> req = DescribeToolRequest(tool_name="query_db") + >>> req.tool_name + 'query_db' + """ + + tool_name: str = Field(..., min_length=1, max_length=255, description="The name of the tool to describe") + include_metrics: bool = Field(False, description="Whether to include execution metrics") + + +class DescribeToolResponse(BaseModelWithConfigDict): + """Response schema for the describe_tool meta-tool. + + Attributes: + name: Tool name identifier. + description: Human-readable description. + input_schema: JSON Schema for the tool's input. + output_schema: JSON Schema for the tool's output. + server_id: ID of the hosting server. + server_name: Name of the hosting server. + tags: Tags associated with the tool. + metrics: Optional execution metrics. + annotations: Optional tool annotations/metadata. + + Examples: + >>> resp = DescribeToolResponse(name="query_db", description="Run a DB query") + >>> resp.name + 'query_db' + """ + + name: str = Field(..., description="Tool name identifier") + description: Optional[str] = Field(None, description="Human-readable description") + input_schema: Optional[Dict[str, Any]] = Field(None, description="JSON Schema for the tool's input") + output_schema: Optional[Dict[str, Any]] = Field(None, description="JSON Schema for the tool's output") + server_id: Optional[str] = Field(None, description="ID of the hosting server") + server_name: Optional[str] = Field(None, description="Name of the hosting server") + tags: List[str] = Field(default_factory=list, description="Tags associated with the tool") + metrics: Optional[Dict[str, Any]] = Field(None, description="Optional execution metrics") + annotations: Optional[Dict[str, Any]] = Field(None, description="Optional tool annotations/metadata") + + +class ExecuteToolRequest(BaseModelWithConfigDict): + """Request schema for the execute_tool meta-tool. + + Attributes: + tool_name: The name of the tool to execute. + arguments: Arguments to pass to the tool. + + Examples: + >>> req = ExecuteToolRequest(tool_name="query_db", arguments={"sql": "SELECT 1"}) + >>> req.tool_name + 'query_db' + >>> req.arguments + {'sql': 'SELECT 1'} + """ + + tool_name: str = Field(..., min_length=1, max_length=255, description="The name of the tool to execute") + arguments: Dict[str, Any] = Field(default_factory=dict, description="Arguments to pass to the tool") + + +class ExecuteToolResponse(BaseModelWithConfigDict): + """Response schema for the execute_tool meta-tool. + + Attributes: + tool_name: Name of the tool that was executed. + success: Whether the execution was successful. + result: The execution result data. + error: Error message if execution failed. + execution_time_ms: Execution time in milliseconds. + + Examples: + >>> resp = ExecuteToolResponse(tool_name="query_db", success=True, result={"rows": []}) + >>> resp.success + True + """ + + tool_name: str = Field(..., description="Name of the tool that was executed") + success: bool = Field(..., description="Whether the execution was successful") + result: Optional[Any] = Field(None, description="The execution result data") + error: Optional[str] = Field(None, description="Error message if execution failed") + execution_time_ms: Optional[float] = Field(None, ge=0, description="Execution time in milliseconds") + + +class GetToolCategoriesRequest(BaseModelWithConfigDict): + """Request schema for the get_tool_categories meta-tool. + + Attributes: + include_counts: Whether to include tool counts per category. + + Examples: + >>> req = GetToolCategoriesRequest() + >>> req.include_counts + True + """ + + include_counts: bool = Field(True, description="Whether to include tool counts per category") + + +class ToolCategory(BaseModelWithConfigDict): + """Representation of a tool category. + + Attributes: + name: Category name. + description: Category description. + tool_count: Number of tools in this category. + + Examples: + >>> cat = ToolCategory(name="database", description="Database tools", tool_count=5) + >>> cat.name + 'database' + """ + + name: str = Field(..., description="Category name") + description: Optional[str] = Field(None, description="Category description") + tool_count: int = Field(0, ge=0, description="Number of tools in this category") + + +class GetToolCategoriesResponse(BaseModelWithConfigDict): + """Response schema for the get_tool_categories meta-tool. + + Attributes: + categories: List of tool categories. + total_categories: Total number of categories. + + Examples: + >>> resp = GetToolCategoriesResponse(categories=[], total_categories=0) + >>> resp.total_categories + 0 + """ + + categories: List[ToolCategory] = Field(default_factory=list, description="List of tool categories") + total_categories: int = Field(0, ge=0, description="Total number of categories") + + +class GetSimilarToolsRequest(BaseModelWithConfigDict): + """Request schema for the get_similar_tools meta-tool. + + Attributes: + tool_name: The name of the reference tool. + limit: Maximum number of similar tools to return. + + Examples: + >>> req = GetSimilarToolsRequest(tool_name="query_db", limit=5) + >>> req.tool_name + 'query_db' + """ + + tool_name: str = Field(..., min_length=1, max_length=255, description="The name of the reference tool") + limit: int = Field(10, ge=1, le=100, description="Maximum number of similar tools to return") + + +class GetSimilarToolsResponse(BaseModelWithConfigDict): + """Response schema for the get_similar_tools meta-tool. + + Attributes: + reference_tool: Name of the reference tool. + similar_tools: List of similar tool summaries with similarity scores. + total_found: Total number of similar tools found. + + Examples: + >>> resp = GetSimilarToolsResponse(reference_tool="query_db", similar_tools=[], total_found=0) + >>> resp.reference_tool + 'query_db' + """ + + reference_tool: str = Field(..., description="Name of the reference tool") + similar_tools: List[ToolSummary] = Field(default_factory=list, description="List of similar tool summaries") + total_found: int = Field(0, ge=0, description="Total number of similar tools found") + + +class ListResourcesRequest(BaseModelWithConfigDict): + """Request schema for the list_resources meta-tool. + + Attributes: + limit: Maximum number of resources to return. + offset: Number of resources to skip for pagination. + tags: Optional tag filter to narrow results. + mime_type: Optional MIME type filter (e.g. 'text/markdown'). + + Examples: + >>> req = ListResourcesRequest() + >>> req.limit + 50 + """ + + limit: int = Field(50, ge=1, le=1000, description="Maximum number of resources to return") + offset: int = Field(0, ge=0, description="Number of resources to skip for pagination") + tags: List[str] = Field(default_factory=list, description="Optional tag filter to narrow results") + mime_type: Optional[str] = Field(None, description="Optional MIME type filter (e.g. 'text/markdown')") + + +class ResourceSummary(BaseModelWithConfigDict): + """Summary representation of a resource in meta-tool responses. + + Attributes: + uri: Resource URI identifier. + name: Human-readable resource name. + description: Description of the resource. + mime_type: MIME type of the resource content. + size: Size of the resource content in bytes. + tags: Tags associated with the resource. + + Examples: + >>> summary = ResourceSummary(uri="resource://example", name="Example") + >>> summary.uri + 'resource://example' + """ + + uri: str = Field(..., description="Resource URI identifier") + name: str = Field(..., description="Human-readable resource name") + description: Optional[str] = Field(None, description="Description of the resource") + mime_type: Optional[str] = Field(None, description="MIME type of the resource content") + size: Optional[int] = Field(None, description="Size of the resource content in bytes") + tags: List[str] = Field(default_factory=list, description="Tags associated with the resource") + + +class ListResourcesResponse(BaseModelWithConfigDict): + """Response schema for the list_resources meta-tool. + + Attributes: + resources: List of resource summaries. + total_count: Total number of resources matching the filter. + has_more: Whether more results are available. + + Examples: + >>> resp = ListResourcesResponse(resources=[], total_count=0, has_more=False) + >>> resp.total_count + 0 + """ + + resources: List[ResourceSummary] = Field(default_factory=list, description="List of resource summaries") + total_count: int = Field(0, ge=0, description="Total number of resources matching the filter") + has_more: bool = Field(False, description="Whether more results are available") + + +class ReadResourceRequest(BaseModelWithConfigDict): + """Request schema for the read_resource meta-tool. + + Attributes: + uri: The URI of the resource to read. + + Examples: + >>> req = ReadResourceRequest(uri="resource://example/guide") + >>> req.uri + 'resource://example/guide' + """ + + uri: str = Field(..., min_length=1, max_length=767, description="The URI of the resource to read") + + +class ReadResourceResponse(BaseModelWithConfigDict): + """Response schema for the read_resource meta-tool. + + Attributes: + uri: Resource URI. + name: Resource name. + mime_type: MIME type of the content. + text: Text content of the resource (if text-based). + size: Size of the content in bytes. + + Examples: + >>> resp = ReadResourceResponse(uri="resource://ex", name="ex", text="Hello") + >>> resp.text + 'Hello' + """ + + uri: str = Field(..., description="Resource URI") + name: str = Field(..., description="Resource name") + mime_type: Optional[str] = Field(None, description="MIME type of the content") + text: Optional[str] = Field(None, description="Text content of the resource") + size: Optional[int] = Field(None, description="Size of the content in bytes") + + +class ListPromptsRequest(BaseModelWithConfigDict): + """Request schema for the list_prompts meta-tool. + + Attributes: + limit: Maximum number of prompts to return. + offset: Number of prompts to skip for pagination. + tags: Optional tag filter to narrow results. + + Examples: + >>> req = ListPromptsRequest() + >>> req.limit + 50 + """ + + limit: int = Field(50, ge=1, le=1000, description="Maximum number of prompts to return") + offset: int = Field(0, ge=0, description="Number of prompts to skip for pagination") + tags: List[str] = Field(default_factory=list, description="Optional tag filter to narrow results") + + +class PromptSummary(BaseModelWithConfigDict): + """Summary representation of a prompt in meta-tool responses. + + Attributes: + name: Prompt name identifier. + description: Human-readable description of the prompt. + tags: Tags associated with the prompt. + argument_schema: JSON Schema for the prompt's arguments. + + Examples: + >>> summary = PromptSummary(name="summarize", description="Summarize text") + >>> summary.name + 'summarize' + """ + + name: str = Field(..., description="Prompt name identifier") + description: Optional[str] = Field(None, description="Human-readable description of the prompt") + tags: List[str] = Field(default_factory=list, description="Tags associated with the prompt") + argument_schema: Optional[Dict[str, Any]] = Field(None, description="JSON Schema for the prompt's arguments") + + +class ListPromptsResponse(BaseModelWithConfigDict): + """Response schema for the list_prompts meta-tool. + + Attributes: + prompts: List of prompt summaries. + total_count: Total number of prompts matching the filter. + has_more: Whether more results are available. + + Examples: + >>> resp = ListPromptsResponse(prompts=[], total_count=0, has_more=False) + >>> resp.total_count + 0 + """ + + prompts: List[PromptSummary] = Field(default_factory=list, description="List of prompt summaries") + total_count: int = Field(0, ge=0, description="Total number of prompts matching the filter") + has_more: bool = Field(False, description="Whether more results are available") + + +class GetPromptRequest(BaseModelWithConfigDict): + """Request schema for the get_prompt meta-tool. + + Attributes: + name: The name of the prompt to retrieve. + arguments: Optional arguments to render the prompt template. + + Examples: + >>> req = GetPromptRequest(name="summarize") + >>> req.name + 'summarize' + """ + + name: str = Field(..., min_length=1, max_length=255, description="The name of the prompt to retrieve") + arguments: Dict[str, str] = Field(default_factory=dict, description="Optional arguments to render the prompt template") + + +class GetPromptResponse(BaseModelWithConfigDict): + """Response schema for the get_prompt meta-tool. + + Attributes: + name: Prompt name. + description: Prompt description. + template: The raw prompt template. + rendered: The rendered prompt with arguments applied (if arguments were provided). + argument_schema: JSON Schema for the prompt's arguments. + tags: Tags associated with the prompt. + + Examples: + >>> resp = GetPromptResponse(name="summarize", template="Summarize: {text}") + >>> resp.name + 'summarize' + """ + + name: str = Field(..., description="Prompt name") + description: Optional[str] = Field(None, description="Prompt description") + template: str = Field(..., description="The raw prompt template") + rendered: Optional[str] = Field(None, description="The rendered prompt with arguments applied") + argument_schema: Optional[Dict[str, Any]] = Field(None, description="JSON Schema for the prompt's arguments") + tags: List[str] = Field(default_factory=list, description="Tags associated with the prompt") + + +class AuthorizeGatewayRequest(BaseModelWithConfigDict): + """Request schema for the authorize_gateway meta-tool. + + Attributes: + gateway_name: Name or ID of the gateway to authorize. + + Examples: + >>> req = AuthorizeGatewayRequest(gateway_name="github-enterprise") + >>> req.gateway_name + 'github-enterprise' + """ + + gateway_name: str = Field(..., description="Name or ID of the gateway to authorize") + + +class AuthorizeGatewayResponse(BaseModelWithConfigDict): + """Response schema for the authorize_gateway meta-tool. + + Attributes: + gateway_id: ID of the gateway. + gateway_name: Name of the gateway. + status: Authorization status (authorized, authorization_required, not_found, error). + authorize_url: URL to open in browser if authorization is required. + message: Human-readable status message. + + Examples: + >>> resp = AuthorizeGatewayResponse(gateway_id="abc", gateway_name="gh", status="authorized", message="ok") + >>> resp.status + 'authorized' + """ + + gateway_id: str = Field(..., description="ID of the gateway") + gateway_name: str = Field(..., description="Name of the gateway") + status: str = Field(..., description="Authorization status: authorized, authorization_required, not_found, error") + authorize_url: Optional[str] = Field(None, description="URL to open in browser for OAuth authorization") + message: str = Field(..., description="Human-readable status message") + + +class AuthorizeAllGatewaysRequest(BaseModelWithConfigDict): + """Request schema for the authorize_all_gateways meta-tool. + + No parameters required — the tool checks all OAuth gateways + the current user has access to. + + Examples: + >>> req = AuthorizeAllGatewaysRequest() + >>> isinstance(req, AuthorizeAllGatewaysRequest) + True + """ + + pass + + +class GatewayAuthStatus(BaseModelWithConfigDict): + """Authorization status for a single gateway. + + Attributes: + gateway_id: ID of the gateway. + gateway_name: Name of the gateway. + status: Token status (authorized or authorization_required). + + Examples: + >>> gs = GatewayAuthStatus(gateway_id="abc", gateway_name="M365", status="authorized") + >>> gs.status + 'authorized' + """ + + gateway_id: str = Field(..., description="ID of the gateway") + gateway_name: str = Field(..., description="Name of the gateway") + status: str = Field(..., description="Token status: authorized or authorization_required") + + +class AuthorizeAllGatewaysResponse(BaseModelWithConfigDict): + """Response schema for the authorize_all_gateways meta-tool. + + Attributes: + status: Overall status (all_authorized, authorization_required, error). + authorize_url: Single URL to authorize all pending gateways at once. + gateways: Per-gateway authorization status. + message: Human-readable summary. + + Examples: + >>> resp = AuthorizeAllGatewaysResponse(status="all_authorized", gateways=[], message="ok") + >>> resp.status + 'all_authorized' + """ + + status: str = Field(..., description="Overall status: all_authorized, authorization_required, error") + authorize_url: Optional[str] = Field(None, description="URL to open in browser to authorize all pending gateways") + gateways: List[GatewayAuthStatus] = Field(default_factory=list, description="Per-gateway authorization status") + message: str = Field(..., description="Human-readable summary") + + +# Meta-Tool Definition Constants + +#: Registry of meta-tool names and their input schemas. +#: Used by the meta-server to register stubs and validate tool calls. +META_TOOL_DEFINITIONS: Dict[str, Dict[str, Any]] = { + "search_tools": { + "description": "Search for tools across all servers in scope using text or semantic matching.", + "input_schema": SearchToolsRequest.model_json_schema(), + }, + "list_tools": { + "description": "List all tools available in scope with optional filtering by tags or server.", + "input_schema": ListToolsRequest.model_json_schema(), + }, + "describe_tool": { + "description": "Get detailed information about a specific tool including its schema and metadata.", + "input_schema": DescribeToolRequest.model_json_schema(), + }, + "execute_tool": { + "description": "Execute a tool by name with the provided arguments, routing to the correct server.", + "input_schema": ExecuteToolRequest.model_json_schema(), + }, + "get_tool_categories": { + "description": "Get a list of tool categories derived from tags and server groupings.", + "input_schema": GetToolCategoriesRequest.model_json_schema(), + }, + "get_similar_tools": { + "description": "Find tools that are similar to a given tool based on description and schema similarity.", + "input_schema": GetSimilarToolsRequest.model_json_schema(), + }, + "authorize_gateway": { + "description": "Check OAuth authorization status for a gateway and provide an authorization URL if needed. Use this when a tool call fails with 'User authentication required for OAuth-protected gateway'.", + "input_schema": AuthorizeGatewayRequest.model_json_schema(), + }, + "authorize_all_gateways": { + "description": "Check OAuth authorization status for ALL gateways at once and provide a single URL to authorize all pending gateways. Use this proactively at the start of a session to ensure all tools are available, or when multiple gateways need authorization.", + "input_schema": AuthorizeAllGatewaysRequest.model_json_schema(), + }, + "list_resources": { + "description": "List all MCP resources (documents, guides, knowledge bases) available in scope with optional filtering by tags or MIME type.", + "input_schema": ListResourcesRequest.model_json_schema(), + }, + "read_resource": { + "description": "Read the content of an MCP resource by its URI. Returns the full text content of the resource.", + "input_schema": ReadResourceRequest.model_json_schema(), + }, + "list_prompts": { + "description": "List all MCP prompt templates available in scope with optional filtering by tags.", + "input_schema": ListPromptsRequest.model_json_schema(), + }, + "get_prompt": { + "description": "Get a prompt template by name, optionally rendering it with provided arguments.", + "input_schema": GetPromptRequest.model_json_schema(), + }, +} diff --git a/mcpgateway/meta_server/service.py b/mcpgateway/meta_server/service.py new file mode 100644 index 0000000000..f528115d94 --- /dev/null +++ b/mcpgateway/meta_server/service.py @@ -0,0 +1,1951 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/meta_server/service.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Meta-Server Service. + +This module provides the service layer for Virtual Meta-Servers. It handles: +- Registration of meta-tools for meta-type servers +- Extension points for tool listing interception (hide underlying tools) +- search_tools: semantic + keyword hybrid search with scope filtering +- Placeholder (stub) responses for meta-tools not yet implemented + +Implemented meta-tools: +- search_tools: natural language search + filters + ranking +- get_similar_tools: "more like this tool" vector similarity search + +Stub meta-tools (not yet implemented): +- list_tools, describe_tool, execute_tool, get_tool_categories + +Examples: + >>> from mcpgateway.meta_server.service import MetaServerService + >>> service = MetaServerService() + >>> tools = service.get_meta_tool_definitions() + >>> len(tools) + 7 + >>> tools[0]["name"] + 'search_tools' +""" + +# Standard +import fnmatch +import logging +import re +import time +from typing import Any, Dict, List, Optional, Set + +# First-Party +from mcpgateway.db import fresh_db_session, Gateway, get_db, Tool +from mcpgateway.meta_server.schemas import ( + AuthorizeAllGatewaysResponse, + AuthorizeGatewayResponse, + DescribeToolResponse, + ExecuteToolResponse, + GatewayAuthStatus, + GetPromptResponse, + GetSimilarToolsResponse, + GetToolCategoriesResponse, + ListPromptsResponse, + ListResourcesResponse, + ListToolsResponse, + META_TOOL_DEFINITIONS, + MetaConfig, + MetaToolScope, + PromptSummary, + ReadResourceResponse, + ResourceSummary, + SearchToolsResponse, + ServerType, + ToolSummary, +) +from mcpgateway.schemas import ToolSearchResult +from mcpgateway.services.semantic_search_service import get_semantic_search_service +from mcpgateway.services.vector_search_service import VectorSearchService + +logger = logging.getLogger(__name__) + + +class MetaServerService: + """Service for managing meta-server tool registration and dispatch. + + This service provides: + - Meta-tool definitions that replace the underlying tool listing + - Stub handlers for each meta-tool that return placeholder responses + - Extension points for future business logic integration + + The service is stateless and can be used as a singleton. + + Examples: + >>> service = MetaServerService() + >>> defs = service.get_meta_tool_definitions() + >>> isinstance(defs, list) + True + >>> len(defs) == 12 + True + """ + + def get_meta_tool_definitions(self) -> List[Dict[str, Any]]: + """Return the list of meta-tool definitions for MCP tool listing. + + Each definition contains the tool name, description, and input schema + in a format compatible with the MCP SDK's types.Tool structure. + + Returns: + List of meta-tool definition dicts with keys: name, description, inputSchema. + """ + return [{"name": name, "description": defn["description"], "inputSchema": defn["input_schema"]} for name, defn in META_TOOL_DEFINITIONS.items()] + + def is_meta_server(self, server_type: Optional[str]) -> bool: + """Check if the given server type is a meta-server. + + Args: + server_type: The server type string to check. + + Returns: + True if the server type is 'meta', False otherwise. + + Examples: + >>> service = MetaServerService() + >>> service.is_meta_server("meta") + True + >>> service.is_meta_server("standard") + False + >>> service.is_meta_server(None) + False + """ + return server_type == ServerType.META.value + + def should_hide_underlying_tools(self, server_type: Optional[str], hide_underlying_tools: bool = True) -> bool: + """Determine if underlying tools should be hidden for this server. + + Extension point for future filtering logic. Currently returns True + when the server is a meta-server and hide_underlying_tools is enabled. + + Args: + server_type: The server type string. + hide_underlying_tools: The hide_underlying_tools flag value. + + Returns: + True if underlying tools should be hidden. + + Examples: + >>> service = MetaServerService() + >>> service.should_hide_underlying_tools("meta", True) + True + >>> service.should_hide_underlying_tools("meta", False) + False + >>> service.should_hide_underlying_tools("standard", True) + False + """ + return self.is_meta_server(server_type) and hide_underlying_tools + + def is_meta_tool(self, tool_name: str) -> bool: + """Check if a tool name is a registered meta-tool. + + Args: + tool_name: The tool name to check. + + Returns: + True if the tool name matches a meta-tool. + + Examples: + >>> service = MetaServerService() + >>> service.is_meta_tool("search_tools") + True + >>> service.is_meta_tool("some_real_tool") + False + """ + return tool_name in META_TOOL_DEFINITIONS + + async def handle_meta_tool_call( + self, + tool_name: str, + arguments: Dict[str, Any], + user_email: Optional[str] = None, + token_teams: Optional[List[str]] = None, + request_headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + """Dispatch a meta-tool call to the appropriate stub handler. + + This is the main entry point for meta-tool invocations. Each meta-tool + returns a placeholder response indicating that the actual business logic + is not yet implemented. + + MCP clients send arguments using camelCase keys (from the JSON schema), + but handlers expect snake_case keys. This method normalizes keys before + dispatching to ensure both conventions work. + + Args: + tool_name: Name of the meta-tool to invoke. + arguments: Arguments for the tool call (may use camelCase or snake_case keys). + user_email: Email of the authenticated user (for OAuth token retrieval). + token_teams: Team IDs from JWT token. + request_headers: Headers from the original request. + + Returns: + Dict containing the stub response. + + Raises: + ValueError: If the tool_name is not a recognized meta-tool. + + Examples: + >>> import asyncio + >>> service = MetaServerService() + >>> result = asyncio.run(service.handle_meta_tool_call("search_tools", {"query": "test"})) + >>> result["query"] + 'test' + """ + handlers = { + "search_tools": self._search_tools, + "list_tools": self._list_tools, + "describe_tool": self._stub_describe_tool, + "execute_tool": self._stub_execute_tool, + "get_tool_categories": self._get_tool_categories, + "get_similar_tools": self._get_similar_tools, + "authorize_gateway": self._authorize_gateway, + "authorize_all_gateways": self._authorize_all_gateways, + "list_resources": self._list_resources, + "read_resource": self._read_resource, + "list_prompts": self._list_prompts, + "get_prompt": self._get_prompt, + } + + handler = handlers.get(tool_name) + if handler is None: + raise ValueError(f"Unknown meta-tool: {tool_name}") + + # Normalize camelCase keys to snake_case so handlers can use consistent key names. + # MCP JSON schemas expose camelCase (e.g. "toolName") but handlers use snake_case + # (e.g. "tool_name"). Accept both conventions by normalizing before dispatch. + normalized_args = self._normalize_arguments(arguments) + + logger.info(f"Handling meta-tool call: {tool_name}") + return await handler( + normalized_args, + user_email=user_email, + token_teams=token_teams, + request_headers=request_headers, + ) + + @staticmethod + def _normalize_arguments(arguments: Dict[str, Any]) -> Dict[str, Any]: + """Normalize argument keys from camelCase to snake_case. + + MCP JSON schemas use camelCase aliases (e.g. ``toolName``), but + internal handlers expect snake_case (e.g. ``tool_name``). This + method converts all top-level keys so that both conventions are + supported transparently. Keys that are already snake_case pass + through unchanged. + + Args: + arguments: Raw arguments dict from the MCP request. + + Returns: + A new dict with all top-level keys converted to snake_case. + + Examples: + >>> MetaServerService._normalize_arguments({"toolName": "x", "limit": 5}) + {'tool_name': 'x', 'limit': 5} + """ + _camel_re = re.compile(r"(?<=[a-z0-9])([A-Z])") + + def _to_snake(key: str) -> str: + return _camel_re.sub(r"_\1", key).lower() + + return {_to_snake(k): v for k, v in arguments.items()} + + @staticmethod + def _extract_user_context(kwargs: Dict[str, Any]) -> tuple: + """Extract access-control parameters from handler kwargs. + + ``handle_meta_tool_call`` passes ``user_email``, ``token_teams``, + and ``request_headers`` through to every handler. This helper + provides a single extraction point so handlers don't repeat the + pattern. + + Returns: + Tuple of (user_email, token_teams, request_headers). + """ + return ( + kwargs.get("user_email"), + kwargs.get("token_teams"), + kwargs.get("request_headers"), + ) + + @staticmethod + def _resolve_effective_email( + user_email: Optional[str], + request_headers: Optional[Dict[str, str]], + ) -> Optional[str]: + """Resolve effective user email from explicit param or JWT in Authorization header. + + Prefers the explicit ``user_email`` parameter. Falls back to extracting + the email claim from the Bearer JWT in the Authorization header. + The JWT signature is NOT verified here — upstream middleware is responsible + for authentication. This is only used for user identity resolution. + """ + if user_email: + return user_email.strip().lower() if isinstance(user_email, str) else None + if not request_headers: + return None + auth_header = request_headers.get("authorization", "") + if not auth_header.startswith("Bearer "): + return None + try: + import jwt as pyjwt # pylint: disable=import-outside-toplevel + token = auth_header[7:] + payload = pyjwt.decode(token, options={"verify_signature": False}) + email = payload.get("email") or payload.get("sub") + return email.strip().lower() if email else None + except Exception: + return None + + # ------------------------------------------------------------------ + # Implemented handlers + # ------------------------------------------------------------------ + + async def _search_tools(self, arguments: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + """Search for tools using hybrid semantic + keyword search with scope filtering. + + Performs a hybrid search: + 1. Semantic search via embedding service + vector search + 2. Keyword fallback via basic name/description matching + 3. Merges and deduplicates results + 4. Normalizes ranking scores into a stable 0-1 range + 5. Applies scope filtering (last gate) + 6. Returns paginated response + + Args: + arguments: Search parameters dict with keys: + - query (str): Natural language search query (required) + - limit (int): Max results to return (default 50) + - offset (int): Pagination offset (default 0) + - tags (List[str]): Optional tag filter + - include_metrics (bool): Whether to include execution metrics + + Returns: + SearchToolsResponse as dict with ranked, scoped results. + """ + # -- Parse request params -- + query = arguments.get("query", "") + limit = arguments.get("limit", 50) + offset = arguments.get("offset", 0) + tags = arguments.get("tags", []) + include_metrics = arguments.get("include_metrics", False) + + # -- Extract user context for access control -- + user_email, token_teams, _ = self._extract_user_context(kwargs) + + # -- Step 1: Semantic search -- + semantic_results = [] + try: + semantic_service = get_semantic_search_service() + semantic_results = await semantic_service.search_tools(query=query, limit=limit) + except Exception as e: + logger.error(f"Semantic search failed: {e}") + # Proceed with empty semantic results as fallback + + # -- Step 2: Keyword fallback search -- + keyword_results = [] + try: + from mcpgateway.services.tool_service import ToolService as _KwToolService # pylint: disable=import-outside-toplevel + _kw_ts = _KwToolService() + + db_gen = get_db() + db = next(db_gen) + try: + # Use ToolService.list_tools for consistent access control. + # Fetch ALL tools (limit=0) so keyword matching covers the + # full catalog; pagination is applied after scoring. + kw_result = await _kw_ts.list_tools( + db=db, + include_inactive=False, + limit=0, + user_email=user_email, + token_teams=token_teams, + ) + kw_tools_list, _ = kw_result if isinstance(kw_result, tuple) else (kw_result, None) + + query_lower = query.lower() + # Tokenize query: split on whitespace, hyphens, underscores + import re as _re # pylint: disable=import-outside-toplevel + tokens = [t for t in _re.split(r'[\s\-_]+', query_lower) if len(t) >= 2] + if not tokens: + tokens = [query_lower] + + for tool in kw_tools_list: + tool_name = getattr(tool, "name", "") + tool_desc = getattr(tool, "description", "") or "" + name_lower = tool_name.lower() + desc_lower = tool_desc.lower() + # Also tokenize tool name for token-level matching + name_tokens = set(_re.split(r'[\s\-_/.]+', name_lower)) + + # Count how many query tokens match (name or description) + name_hits = 0 + desc_hits = 0 + for token in tokens: + # Exact token match in name tokens or substring in full name + if token in name_tokens or token in name_lower: + name_hits += 1 + elif desc_lower and token in desc_lower: + desc_hits += 1 + + total_hits = name_hits + desc_hits + if total_hits == 0: + continue + + # Score: ratio of matched tokens, with name hits weighted higher + hit_ratio = total_hits / len(tokens) + name_ratio = name_hits / len(tokens) + + if name_lower == query_lower: + score = 1.0 + elif hit_ratio == 1.0 and name_ratio >= 0.5: + # All tokens match, majority in name + score = 0.95 + elif hit_ratio == 1.0: + # All tokens match but mostly in description + score = 0.85 + elif name_ratio >= 0.5: + # At least half the tokens match in name + score = 0.6 + (hit_ratio * 0.2) + else: + # Partial match + score = 0.3 + (hit_ratio * 0.3) + + keyword_results.append( + ToolSearchResult( + tool_name=tool_name, + description=tool_desc, + server_id=getattr(tool, "gateway_id", None), + server_name=None, + similarity_score=round(score, 3), + ) + ) + finally: + try: + next(db_gen) + except StopIteration: + pass + except Exception as e: + logger.warning(f"Keyword search failed: {e}") + + # -- Step 3: Merge and deduplicate results -- + # Combine semantic + keyword results, dedupe by tool_name, + # keeping the higher score when duplicates are found. + merged: Dict[str, ToolSearchResult] = {} + for result in semantic_results + keyword_results: + existing = merged.get(result.tool_name) + if existing is None or result.similarity_score > existing.similarity_score: + merged[result.tool_name] = result + + # -- Step 4: Normalize ranking scores and sort -- + # Scores from both sources are already in 0-1 range. + # Sort descending by score. + ranked_results = sorted(merged.values(), key=lambda r: r.similarity_score, reverse=True) + + # -- Step 5: Apply scope filtering (must be last gate) -- + # Enrich results with tool metadata from DB for scope fields + # that aren't available on ToolSearchResult (tags, visibility, team_id). + filtered_results = self._apply_scope_filtering(ranked_results, arguments.get("scope")) + + # -- Step 6: Apply tag filter from request args -- + if tags: + filtered_results = [r for r in filtered_results if r.tool_name in self._get_tools_matching_tags(tags)] + + # -- Step 7: Paginate -- + total_count = len(filtered_results) + paginated = filtered_results[offset : offset + limit] + has_more = total_count > offset + limit + + # -- Step 8: Map results to ToolSummary objects -- + tool_summaries = self._map_to_tool_summaries(paginated, include_metrics) + + return SearchToolsResponse( + tools=tool_summaries, + total_count=total_count, + query=query, + has_more=has_more, + ).model_dump(by_alias=True) + + async def _get_similar_tools(self, arguments: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + """Find tools similar to a given reference tool using vector similarity. + + Performs a "more like this" search: + 1. Resolves the reference tool by name from the database + 2. Retrieves the tool's stored embedding vector + 3. Queries the vector search service for nearest neighbors + 4. Filters out the reference tool itself from results + 5. Applies scope filtering (last gate) + 6. Returns similarity scores with optional reason strings + + Args: + arguments: Similarity query parameters dict with keys: + - tool_name (str): Name of the reference tool (required) + - limit (int): Max similar tools to return (default 10) + + Returns: + GetSimilarToolsResponse as dict with similar tools and scores. + """ + # -- Parse request params -- + tool_name = arguments.get("tool_name", "") + limit = arguments.get("limit", 10) + + # -- Extract user context for access control -- + user_email, token_teams, _ = self._extract_user_context(kwargs) + + if not tool_name: + return GetSimilarToolsResponse( + reference_tool=tool_name, + similar_tools=[], + total_found=0, + ).model_dump(by_alias=True) + + # -- Step 1: Resolve reference tool from the database -- + reference_tool = None + try: + db_gen = get_db() + db = next(db_gen) + try: + reference_tool = ( + db.query(Tool) + .filter( + Tool._computed_name == tool_name, + Tool.enabled.is_(True), + ) + .first() + ) + finally: + try: + next(db_gen) + except StopIteration: + pass + except Exception as e: + logger.warning(f"Failed to look up reference tool '{tool_name}': {e}") + + if reference_tool is None: + logger.info(f"Reference tool '{tool_name}' not found, returning empty results") + return GetSimilarToolsResponse( + reference_tool=tool_name, + similar_tools=[], + total_found=0, + ).model_dump(by_alias=True) + + # -- Step 2: Retrieve the tool's stored embedding -- + embedding_vector = None + try: + db_gen = get_db() + db = next(db_gen) + try: + vector_service = VectorSearchService(db=db) + tool_embedding = vector_service.get_tool_embedding(db, reference_tool.id) + if tool_embedding is not None: + embedding_vector = tool_embedding.embedding + finally: + try: + next(db_gen) + except StopIteration: + pass + except Exception as e: + logger.warning(f"Failed to retrieve embedding for tool '{tool_name}': {e}") + + if embedding_vector is None: + logger.info(f"No embedding found for tool '{tool_name}', falling back to keyword similarity") + # -- Keyword-based similarity fallback -- + # Use the reference tool's name tokens and description keywords + # to find tools with overlapping vocabulary. + similar_results = await self._keyword_similar_tools( + reference_tool, tool_name, limit, user_email, token_teams, + ) + + # Apply scope filtering + filtered_results = self._apply_scope_filtering(similar_results, arguments.get("scope")) + tool_summaries = self._map_to_tool_summaries(filtered_results) + + return GetSimilarToolsResponse( + reference_tool=tool_name, + similar_tools=tool_summaries, + total_found=len(tool_summaries), + ).model_dump(by_alias=True) + + # -- Step 3: Query vector search for nearest neighbors -- + similar_results: List[ToolSearchResult] = [] + try: + db_gen = get_db() + db = next(db_gen) + try: + vector_service = VectorSearchService(db=db) + # Request extra results so we still have enough after filtering out self + similar_results = await vector_service.search_similar_tools( + embedding=embedding_vector, + limit=limit + 1, + db=db, + ) + finally: + try: + next(db_gen) + except StopIteration: + pass + except Exception as e: + logger.warning(f"Vector search for similar tools failed: {e}") + + # -- Step 4: Filter out the reference tool itself -- + similar_results = [r for r in similar_results if r.tool_name != tool_name][:limit] + + # -- Step 4.5: Apply access-control filtering -- + # Build set of tool names the user can access, then discard the rest. + if user_email is not None or token_teams is not None: + try: + from mcpgateway.services.tool_service import ToolService as _AcToolService # pylint: disable=import-outside-toplevel + + _ac_ts = _AcToolService() + db_gen = get_db() + db = next(db_gen) + try: + ac_result = await _ac_ts.list_tools( + db=db, + include_inactive=False, + limit=0, + user_email=user_email, + token_teams=token_teams, + ) + ac_tools_list, _ = ac_result if isinstance(ac_result, tuple) else (ac_result, None) + accessible_names = {getattr(t, "name", "") for t in ac_tools_list} + similar_results = [r for r in similar_results if r.tool_name in accessible_names] + finally: + try: + next(db_gen) + except StopIteration: + pass + except Exception as e: + logger.warning(f"Access control filtering failed for similar tools: {e}") + + # -- Step 5: Apply scope filtering -- + filtered_results = self._apply_scope_filtering(similar_results, arguments.get("scope")) + + # -- Step 6: Map results to ToolSummary objects -- + tool_summaries = self._map_to_tool_summaries(filtered_results) + + return GetSimilarToolsResponse( + reference_tool=tool_name, + similar_tools=tool_summaries, + total_found=len(tool_summaries), + ).model_dump(by_alias=True) + + # ------------------------------------------------------------------ + # Helper methods for search and scope filtering + # ------------------------------------------------------------------ + + async def _keyword_similar_tools( + self, + reference_tool: Any, + tool_name: str, + limit: int, + user_email: Optional[str], + token_teams: Optional[List[str]], + ) -> List[ToolSearchResult]: + """Find similar tools using keyword overlap when embeddings are unavailable. + + Tokenizes the reference tool's name and description, then scores all + other accessible tools by token overlap. Tools from the same gateway + get a small boost since they belong to the same MCP server. + + Args: + reference_tool: The DB Tool object used as reference. + tool_name: Computed name of the reference tool. + limit: Max results to return. + user_email: User email for access control. + token_teams: Token team IDs for access control. + + Returns: + List of ToolSearchResult sorted by similarity score (descending). + """ + ref_desc = (getattr(reference_tool, "description", "") or getattr(reference_tool, "original_description", "") or "").lower() + ref_name = tool_name.lower() + ref_gateway_id = getattr(reference_tool, "gateway_id", None) + + # Build token set from name + description + _re = re # module-level import already available + ref_tokens = set(_re.split(r'[\s\-_/.]+', ref_name)) + ref_tokens |= {w for w in _re.split(r'[\s\-_/.,:;()]+', ref_desc) if len(w) >= 3} + ref_tokens.discard("") + + if not ref_tokens: + return [] + + # Fetch all accessible tools + from mcpgateway.services.tool_service import ToolService as _SimToolService # pylint: disable=import-outside-toplevel + + _sim_ts = _SimToolService() + try: + db_gen = get_db() + db = next(db_gen) + try: + result = await _sim_ts.list_tools( + db=db, include_inactive=False, limit=0, + user_email=user_email, token_teams=token_teams, + ) + all_tools, _ = result if isinstance(result, tuple) else (result, None) + finally: + try: + next(db_gen) + except StopIteration: + pass + except Exception as e: + logger.warning(f"Keyword similar tools: failed to list tools: {e}") + return [] + + scored: List[ToolSearchResult] = [] + for tool in all_tools: + t_name = getattr(tool, "name", "") + if t_name == tool_name: + continue # skip self + + t_name_lower = t_name.lower() + t_desc = (getattr(tool, "description", "") or "").lower() + + t_tokens = set(_re.split(r'[\s\-_/.]+', t_name_lower)) + t_tokens |= {w for w in _re.split(r'[\s\-_/.,:;()]+', t_desc) if len(w) >= 3} + t_tokens.discard("") + + if not t_tokens: + continue + + overlap = ref_tokens & t_tokens + if not overlap: + continue + + # Jaccard-like score + score = len(overlap) / len(ref_tokens | t_tokens) + + # Boost tools from the same gateway (same MCP server) + t_gateway = getattr(tool, "gateway_id", None) + if ref_gateway_id and t_gateway == ref_gateway_id: + score = min(score + 0.1, 1.0) + + scored.append( + ToolSearchResult( + tool_name=t_name, + description=getattr(tool, "description", "") or "", + server_id=t_gateway, + server_name=None, + similarity_score=round(score, 3), + ) + ) + + # Sort descending by score, take top N + scored.sort(key=lambda r: r.similarity_score, reverse=True) + return scored[:limit] + + def _apply_scope_filtering( + self, + results: List[ToolSearchResult], + scope_dict: Optional[Dict[str, Any]] = None, + ) -> List[ToolSearchResult]: + """Apply MetaToolScope filtering rules to search results. + + Scope filters combine with AND semantics — a tool must pass ALL + active filters to be included. This is the last gate before + pagination and should always be applied. + + Args: + results: Ranked search results to filter. + scope_dict: Optional scope configuration dict (MetaToolScope fields). + + Returns: + Filtered list of ToolSearchResult objects. + """ + if not scope_dict or not results: + return results + + scope = MetaToolScope(**scope_dict) + + # Batch-fetch tool metadata for fields not on ToolSearchResult + tool_names = [r.tool_name for r in results] + metadata = self._get_tool_metadata(tool_names) + + filtered: List[ToolSearchResult] = [] + for result in results: + meta = metadata.get(result.tool_name) + if meta is None: + # Tool not found in DB — exclude from scoped results + continue + + tool_tags = meta.get("tags", []) + tool_visibility = meta.get("visibility", "public") + tool_team_id = meta.get("team_id") + tool_server_id = result.server_id + tool_name = result.tool_name + + # include_tags: tool must have at least one matching tag + if scope.include_tags and not any(t in scope.include_tags for t in tool_tags): + continue + + # exclude_tags: tool must NOT have any excluded tag + if scope.exclude_tags and any(t in scope.exclude_tags for t in tool_tags): + continue + + # include_servers: tool must be from one of these servers + if scope.include_servers and tool_server_id not in scope.include_servers: + continue + + # exclude_servers: tool must NOT be from excluded servers + if scope.exclude_servers and tool_server_id in scope.exclude_servers: + continue + + # include_visibility: tool must have one of these visibility levels + if scope.include_visibility and tool_visibility not in scope.include_visibility: + continue + + # include_teams: tool must belong to one of these teams + if scope.include_teams and tool_team_id not in scope.include_teams: + continue + + # name_patterns: tool name must match at least one glob pattern + if scope.name_patterns and not any(fnmatch.fnmatch(tool_name, pat) for pat in scope.name_patterns): + continue + + filtered.append(result) + + return filtered + + def _get_tool_metadata(self, tool_names: List[str]) -> Dict[str, Dict[str, Any]]: + """Batch-fetch tool metadata from the database for scope filtering. + + Retrieves tags, visibility, and team_id for the given tool names. + When a tool has no tags of its own, it inherits tags from its + parent Gateway (MCP server). + + Args: + tool_names: List of tool names to look up. + + Returns: + Dict mapping tool_name -> {tags, visibility, team_id, input_schema}. + """ + if not tool_names: + return {} + + metadata: Dict[str, Dict[str, Any]] = {} + try: + from sqlalchemy.orm import joinedload as _jl # pylint: disable=import-outside-toplevel + + db_gen = get_db() + db = next(db_gen) + try: + tools = ( + db.query(Tool) + .options(_jl(Tool.gateway)) + .filter(Tool._computed_name.in_(tool_names)) + .all() + ) + for tool in tools: + # Extract tag strings from tag objects + tags_list = tool.tags or [] + if tags_list and isinstance(tags_list[0], dict): + tags_list = [tag.get("id") or tag.get("label") for tag in tags_list if isinstance(tag, dict)] + + # Inherit tags from parent gateway when the tool has none + if not tags_list and tool.gateway_id and tool.gateway: + gw_tags = tool.gateway.tags or [] + if gw_tags and isinstance(gw_tags[0], dict): + tags_list = [t.get("id") or t.get("label") for t in gw_tags if isinstance(t, dict)] + elif gw_tags: + tags_list = list(gw_tags) + + metadata[tool.name] = { + "tags": tags_list, + "visibility": tool.visibility or "public", + "team_id": tool.team_id, + "input_schema": tool.input_schema, + } + finally: + try: + next(db_gen) + except StopIteration: + pass + except Exception as e: + logger.warning(f"Failed to fetch tool metadata for scope filtering: {e}") + + return metadata + + def _get_tools_matching_tags(self, tags: List[str]) -> Set[str]: + """Get the set of tool names that have at least one of the given tags. + + Args: + tags: Tag values to match against. + + Returns: + Set of tool names that match. + """ + matching: Set[str] = set() + try: + db_gen = get_db() + db = next(db_gen) + try: + tools = db.query(Tool).filter(Tool.enabled.is_(True)).all() + for tool in tools: + tool_tags = tool.tags or [] + if any(t in tags for t in tool_tags): + matching.add(tool.name) + finally: + try: + next(db_gen) + except StopIteration: + pass + except Exception as e: + logger.warning(f"Failed to fetch tools for tag filtering: {e}") + + return matching + + def _map_to_tool_summaries( + self, + results: List[ToolSearchResult], + include_metrics: bool = False, + ) -> List[ToolSummary]: + """Map ToolSearchResult objects to ToolSummary objects. + + Enriches results with additional metadata (tags, input_schema) + from the database. + + Args: + results: Search results to map. + include_metrics: Whether to include execution metrics. + + Returns: + List of ToolSummary objects. + """ + if not results: + return [] + + tool_names = [r.tool_name for r in results] + metadata = self._get_tool_metadata(tool_names) + + summaries: List[ToolSummary] = [] + for result in results: + meta = metadata.get(result.tool_name, {}) + summaries.append( + ToolSummary( + name=result.tool_name, + description=result.description, + server_id=result.server_id, + server_name=result.server_name, + tags=meta.get("tags", []), + input_schema=meta.get("input_schema"), + metrics=None, # TODO: populate from ToolMetric if include_metrics is True + ) + ) + + return summaries + + # ------------------------------------------------------------------ + # Implemented handlers + # ------------------------------------------------------------------ + + async def _list_tools(self, arguments: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + """List tools with pagination, sorting, and scope filtering. + + Performs paginated tool listing: + 1. Queries tools from database using ToolService + 2. Applies scope filtering (last gate) + 3. Supports sorting by name, created_at, or execution_count + 4. Returns paginated response with metadata + + Args: + arguments: List parameters dict with keys: + - limit (int): Max results to return (default 50) + - offset (int): Pagination offset (default 0) + - tags (List[str]): Optional tag filter + - server_id (str): Optional server ID filter + - include_metrics (bool): Whether to include execution metrics + - sort_by (str): Field to sort by (name, created_at, execution_count) + - sort_order (str): Sort order (asc, desc) + - include_schema (bool): Whether to include input/output schemas + + Returns: + ListToolsResponse as dict with tools, total_count, and pagination metadata. + """ + # -- Parse request params -- + limit = arguments.get("limit", 50) + offset = arguments.get("offset", 0) + tags = arguments.get("tags", []) + server_id = arguments.get("server_id") + include_metrics = arguments.get("include_metrics", False) + sort_by = arguments.get("sort_by", "created_at") + sort_order = arguments.get("sort_order", "desc") + include_schema = arguments.get("include_schema", False) + + # -- Extract user context for access control -- + user_email, token_teams, _ = self._extract_user_context(kwargs) + + # -- Step 1: Query tools from database using ToolService -- + # First-Party + from mcpgateway.services.tool_service import ToolService + + tool_service = ToolService() + all_tools = [] + + try: + db_gen = get_db() + db = next(db_gen) + try: + # Query with offset+limit+1 to determine has_more + query_limit = limit + offset + 1 + + # Call ToolService.list_tools with appropriate parameters + result = await tool_service.list_tools( + db=db, + include_inactive=False, + tags=tags if tags else None, + gateway_id=server_id, + limit=query_limit, + user_email=user_email, + token_teams=token_teams, + ) + + # Extract tools from result (could be tuple or dict) + if isinstance(result, tuple): + all_tools, _ = result + elif isinstance(result, dict): + all_tools = result.get("data", []) + else: + all_tools = result + + finally: + try: + next(db_gen) + except StopIteration: + pass + except Exception as e: + logger.error(f"Failed to query tools from database: {e}") + # Return empty result on error + return ListToolsResponse( + tools=[], + total_count=0, + has_more=False, + ).model_dump(by_alias=True) + + # -- Step 2: Convert to tool search results for scope filtering -- + # (Scope filtering expects ToolSearchResult objects) + # First-Party + from mcpgateway.schemas import ToolSearchResult + + search_results = [] + # Keep a created_at lookup for sorting + _created_at_map: Dict[str, Any] = {} + for tool in all_tools: + # ToolService.list_tools returns ToolRead objects (Pydantic) + # which have gateway_id but not gateway relationship + server_id_val = getattr(tool, "gateway_id", None) + tool_name_val = tool.name + + search_results.append( + ToolSearchResult( + tool_name=tool_name_val, + description=getattr(tool, "description", "") or "", + server_id=server_id_val, + server_name=None, + similarity_score=1.0, # Not relevant for listing + ) + ) + _created_at_map[tool_name_val] = getattr(tool, "created_at", None) + + # -- Step 3: Apply scope filtering (must be last gate) -- + filtered_results = self._apply_scope_filtering(search_results, arguments.get("scope")) + + # -- Step 4: Sort results -- + _reverse = sort_order == "desc" + if sort_by == "name": + filtered_results.sort(key=lambda r: r.tool_name.lower(), reverse=_reverse) + elif sort_by == "created_at": + # Sort by created_at from the original tool objects + _epoch = None # sentinel for tools missing created_at + filtered_results.sort( + key=lambda r: _created_at_map.get(r.tool_name) or _epoch, + reverse=_reverse, + ) + # else: keep original DB order (default) + + # -- Step 5: Paginate -- + total_count = len(filtered_results) + paginated = filtered_results[offset : offset + limit] + has_more = total_count > offset + limit + + # -- Step 5: Map results to ToolSummary objects -- + tool_summaries = self._map_to_tool_summaries(paginated, include_metrics) + + return ListToolsResponse( + tools=tool_summaries, + total_count=total_count, + has_more=has_more, + ).model_dump(by_alias=True) + + # ------------------------------------------------------------------ + # Stub handlers — return placeholder responses + # Business logic will be implemented by other teams. + # ------------------------------------------------------------------ + + async def _stub_describe_tool(self, arguments: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + """Delegate to MetaToolService for describe_tool implementation. + + Args: + arguments: Describe parameters. + + Returns: + DescribeToolResponse as dict from MetaToolService. + """ + # First-Party + from mcpgateway.services.meta_tool_service import MetaToolService + + user_email, token_teams, _ = self._extract_user_context(kwargs) + + try: + db_gen = get_db() + db = next(db_gen) + try: + service = MetaToolService(db) + result = await service.describe_tool( + tool_name=arguments.get("tool_name", ""), + include_metrics=arguments.get("include_metrics", False), + scope=arguments.get("scope"), + user_email=user_email, + token_teams=token_teams, + ) + return result.model_dump(by_alias=True) + finally: + try: + next(db_gen) + except StopIteration: + pass + except Exception as e: + logger.error(f"Error delegating describe_tool to MetaToolService: {e}") + tool_name = arguments.get("tool_name", "unknown") + return DescribeToolResponse( + name=tool_name, + description=f"Error describing tool {tool_name}: {str(e)}", + ).model_dump(by_alias=True) + + async def _stub_execute_tool( + self, + arguments: Dict[str, Any], + user_email: Optional[str] = None, + token_teams: Optional[List[str]] = None, + request_headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + """Delegate to MetaToolService for execute_tool implementation. + + Args: + arguments: Execution parameters. + user_email: Email of the authenticated user (for OAuth token retrieval). + token_teams: Team IDs from JWT token. + request_headers: Headers from the original request. + + Returns: + ExecuteToolResponse as dict from MetaToolService. + """ + # First-Party + from mcpgateway.services.meta_tool_service import MetaToolService + + try: + db_gen = get_db() + db = next(db_gen) + try: + service = MetaToolService(db) + tool_name = arguments.get("tool_name", "") + tool_arguments = arguments.get("arguments", {}) + + # Tolerate flat argument layout: some clients (e.g. Copilot Studio) + # send tool arguments at the same level as tool_name instead of + # nesting them inside "arguments". Detect this by collecting any + # keys that are not part of the execute_tool schema itself. + if not tool_arguments: + _meta_keys = {"tool_name", "arguments", "scope"} + extra = {k: v for k, v in arguments.items() if k not in _meta_keys} + if extra: + tool_arguments = extra + logger.info( + "execute_tool: restructured flat arguments into nested format " + f"for tool '{tool_name}': {list(extra.keys())}" + ) + + result = await service.execute_tool( + tool_name=tool_name, + arguments=tool_arguments, + scope=arguments.get("scope"), + user_email=user_email, + token_teams=token_teams, + request_headers=request_headers, + ) + return result.model_dump(by_alias=True) + finally: + try: + next(db_gen) + except StopIteration: + pass + except Exception as e: + logger.error(f"Error delegating execute_tool to MetaToolService: {e}") + tool_name = arguments.get("tool_name", "unknown") + return ExecuteToolResponse( + tool_name=tool_name, + success=False, + result=None, + error=f"Error executing tool {tool_name}: {str(e)}", + ).model_dump(by_alias=True) + + async def _get_tool_categories(self, arguments: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + """Get aggregated tool categories with counts. + + Builds categories from tool tags, inheriting tags from parent + gateways when a tool has no tags of its own. + + Args: + arguments: Category query parameters (include_counts). + + Returns: + GetToolCategoriesResponse as dict with categories and counts. + """ + # First-Party + from collections import Counter + + from sqlalchemy.orm import joinedload as _jl # pylint: disable=import-outside-toplevel + + from mcpgateway.meta_server.schemas import ToolCategory + + include_counts = arguments.get("include_counts", True) + + try: + db_gen = get_db() + db = next(db_gen) + try: + tools = ( + db.query(Tool) + .options(_jl(Tool.gateway)) + .filter(Tool.enabled.is_(True)) + .all() + ) + + tag_counter: Counter = Counter() + for tool in tools: + # Resolve tags — same inheritance logic as _get_tool_metadata + tags_list = tool.tags or [] + if tags_list and isinstance(tags_list[0], dict): + tags_list = [ + tag.get("id") or tag.get("label") + for tag in tags_list + if isinstance(tag, dict) + ] + + # Inherit from parent gateway when the tool has no own tags + if not tags_list and tool.gateway_id and tool.gateway: + gw_tags = tool.gateway.tags or [] + if gw_tags and isinstance(gw_tags[0], dict): + tags_list = [ + t.get("id") or t.get("label") + for t in gw_tags + if isinstance(t, dict) + ] + elif gw_tags: + tags_list = list(gw_tags) + + for tag in tags_list: + if tag: + tag_counter[tag] += 1 + + # Build sorted categories + categories_list = [ + ToolCategory( + name=tag_name, + description=None, + tool_count=count if include_counts else 0, + ) + for tag_name, count in sorted(tag_counter.items()) + ] + + return GetToolCategoriesResponse( + categories=categories_list, + total_categories=len(categories_list), + ).model_dump(by_alias=True) + finally: + try: + next(db_gen) + except StopIteration: + pass + except Exception as e: + logger.error(f"Error getting tool categories: {e}") + return GetToolCategoriesResponse( + categories=[], + total_categories=0, + ).model_dump(by_alias=True) + + async def _authorize_gateway( + self, + arguments: Dict[str, Any], + user_email: Optional[str] = None, + token_teams: Optional[List[str]] = None, + request_headers: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + """Check OAuth authorization status for a gateway and return an authorize URL if needed. + + Args: + arguments: Must contain gateway_name (name or ID of the gateway). + user_email: Email of the authenticated user. + token_teams: Team IDs from JWT token. + request_headers: Headers from the original request. + + Returns: + AuthorizeGatewayResponse as dict with status and optional authorize_url. + """ + # First-Party + from mcpgateway.config import get_settings + from mcpgateway.db import Gateway + from mcpgateway.services.token_storage_service import TokenStorageService + + # Resolve user_email: prefer explicit param, fall back to JWT in request_headers + effective_email = self._resolve_effective_email(user_email, request_headers) + + gateway_name = arguments.get("gateway_name", "") + if not gateway_name: + return AuthorizeGatewayResponse( + gateway_id="", + gateway_name="", + status="error", + message="gateway_name is required", + ).model_dump(by_alias=True) + + try: + db_gen = get_db() + db = next(db_gen) + try: + from sqlalchemy import or_, select # pylint: disable=import-outside-toplevel + + # Find gateway by name or ID + gateway = db.execute( + select(Gateway).where( + or_(Gateway.name == gateway_name, Gateway.id == gateway_name) + ) + ).scalar_one_or_none() + + if not gateway: + return AuthorizeGatewayResponse( + gateway_id="", + gateway_name=gateway_name, + status="not_found", + message=f"Gateway '{gateway_name}' not found", + ).model_dump(by_alias=True) + + gateway_id = gateway.id + + # Check if gateway has OAuth config + if not gateway.oauth_config: + return AuthorizeGatewayResponse( + gateway_id=gateway_id, + gateway_name=gateway.name, + status="authorized", + message="Gateway does not require OAuth authorization", + ).model_dump(by_alias=True) + + # Check if user already has a valid token (attempt refresh if expired) + if effective_email: + token_service = TokenStorageService(db) + # get_user_token attempts automatic refresh via refresh_token + valid_token = await token_service.get_user_token(gateway_id, effective_email) + if valid_token: + token_info = await token_service.get_token_info(gateway_id, effective_email) + expires_at = token_info.get('expires_at', 'unknown') if token_info else 'unknown' + return AuthorizeGatewayResponse( + gateway_id=gateway_id, + gateway_name=gateway.name, + status="authorized", + message=f"You already have a valid OAuth token for '{gateway.name}' (expires {expires_at})", + ).model_dump(by_alias=True) + + # Build the authorize URL + settings = get_settings() + app_domain = str(settings.app_domain or "").rstrip("/") + root_path = str(settings.app_root_path or "").strip("/") + base = f"{app_domain}/{root_path}" if root_path else app_domain + authorize_url = f"{base}/oauth/authorize/{gateway_id}" + + return AuthorizeGatewayResponse( + gateway_id=gateway_id, + gateway_name=gateway.name, + status="authorization_required", + authorize_url=authorize_url, + message=f"OAuth authorization required for '{gateway.name}'. [Click here to authorize]({authorize_url})", + ).model_dump(by_alias=True) + + finally: + try: + next(db_gen) + except StopIteration: + pass + + except Exception as e: + logger.error(f"Error in authorize_gateway: {e}") + return AuthorizeGatewayResponse( + gateway_id="", + gateway_name=gateway_name, + status="error", + message=f"Error checking gateway authorization: {str(e)}", + ).model_dump(by_alias=True) + + async def _authorize_all_gateways( + self, + arguments: Dict[str, Any], + user_email: Optional[str] = None, + token_teams: Optional[List[str]] = None, + request_headers: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + """Check OAuth authorization status for all gateways and return a single authorize-all URL. + + Args: + arguments: No required arguments. + user_email: Email of the authenticated user. + token_teams: Team IDs from JWT token. + request_headers: Headers from the original request. + + Returns: + AuthorizeAllGatewaysResponse as dict with status and optional authorize_url. + """ + # First-Party + from mcpgateway.config import get_settings + from mcpgateway.db import Gateway + from mcpgateway.services.token_storage_service import TokenStorageService + + # Resolve user_email from JWT if not provided + effective_email = self._resolve_effective_email(user_email, request_headers) + + try: + db_gen = get_db() + db = next(db_gen) + try: + from sqlalchemy import select # pylint: disable=import-outside-toplevel + + # Find all active OAuth gateways with authorization_code flow + gateways = db.execute( + select(Gateway).where( + Gateway.auth_type == "oauth", + Gateway.enabled.is_(True), + ) + ).scalars().all() + + token_service = TokenStorageService(db) + gateway_statuses = [] + pending_count = 0 + + for gw in gateways: + if not gw.oauth_config or gw.oauth_config.get("grant_type") != "authorization_code": + continue + + gw_status = "authorization_required" + if effective_email: + # get_user_token attempts automatic refresh via refresh_token + valid_token = await token_service.get_user_token(gw.id, effective_email) + if valid_token: + gw_status = "authorized" + + if gw_status == "authorization_required": + pending_count += 1 + + gateway_statuses.append(GatewayAuthStatus( + gateway_id=gw.id, + gateway_name=gw.name, + status=gw_status, + )) + + if not gateway_statuses: + return AuthorizeAllGatewaysResponse( + status="all_authorized", + gateways=[], + message="No OAuth gateways found.", + ).model_dump(by_alias=True) + + if pending_count == 0: + names = ", ".join(gs.gateway_name for gs in gateway_statuses) + return AuthorizeAllGatewaysResponse( + status="all_authorized", + gateways=[gs.model_dump(by_alias=True) for gs in gateway_statuses], + message=f"All {len(gateway_statuses)} OAuth gateways are authorized: {names}", + ).model_dump(by_alias=True) + + # Build authorize-all URL + settings = get_settings() + app_domain = str(settings.app_domain or "").rstrip("/") + root_path = str(settings.app_root_path or "").strip("/") + base = f"{app_domain}/{root_path}" if root_path else app_domain + authorize_url = f"{base}/oauth/authorize-all" + + pending_names = ", ".join( + gs.gateway_name for gs in gateway_statuses + if gs.status == "authorization_required" + ) + + return AuthorizeAllGatewaysResponse( + status="authorization_required", + authorize_url=authorize_url, + gateways=[gs.model_dump(by_alias=True) for gs in gateway_statuses], + message=f"{pending_count} gateway(s) need authorization: {pending_names}. [Click here to authorize all at once]({authorize_url})", + ).model_dump(by_alias=True) + + finally: + try: + next(db_gen) + except StopIteration: + pass + + except Exception as e: + logger.error(f"Error in authorize_all_gateways: {e}") + return AuthorizeAllGatewaysResponse( + status="error", + gateways=[], + message=f"Error checking gateway authorization: {str(e)}", + ).model_dump(by_alias=True) + + # ------------------------------------------------------------------ + # Resource and Prompt handlers + # ------------------------------------------------------------------ + + @staticmethod + def _normalize_tags(raw_tags: Any) -> List[str]: + """Normalize tags from DB format to plain strings. + + Tags may be stored as dicts {'id': ..., 'label': ...} or plain strings. + """ + if not raw_tags: + return [] + result: List[str] = [] + for tag in raw_tags: + if isinstance(tag, dict): + result.append(tag.get("id") or tag.get("label") or str(tag)) + else: + result.append(str(tag)) + return result + + async def _list_resources(self, arguments: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + """List MCP resources with pagination and optional filtering. + + Args: + arguments: List parameters dict with keys: + - limit (int): Max results to return (default 50) + - offset (int): Pagination offset (default 0) + - tags (List[str]): Optional tag filter + - mime_type (str): Optional MIME type filter + + Returns: + ListResourcesResponse as dict. + """ + from mcpgateway.db import Resource # pylint: disable=import-outside-toplevel + from mcpgateway.services.resource_service import ResourceService as _RsService # pylint: disable=import-outside-toplevel + + limit = arguments.get("limit", 50) + offset = arguments.get("offset", 0) + tags = arguments.get("tags", []) + mime_type = arguments.get("mime_type") + + # Extract user context for access control + user_email, token_teams, _ = self._extract_user_context(kwargs) + + try: + db_gen = get_db() + db = next(db_gen) + try: + query = db.query(Resource).filter(Resource.enabled.is_(True)) + + # Apply access control + _rs = _RsService() + query = await _rs._apply_access_control(query, db, user_email, token_teams) + + if mime_type: + query = query.filter(Resource.mime_type == mime_type) + + all_resources = query.order_by(Resource.created_at.desc()).all() + + # Apply tag filtering in Python (tags stored as JSON) + if tags: + all_resources = [ + r for r in all_resources + if r.tags and any(t in self._normalize_tags(r.tags) for t in tags) + ] + + total_count = len(all_resources) + paginated = all_resources[offset: offset + limit] + has_more = total_count > offset + limit + + summaries = [ + ResourceSummary( + uri=r.uri, + name=r.name, + description=r.description, + mime_type=r.mime_type, + size=r.size, + tags=self._normalize_tags(r.tags), + ) + for r in paginated + ] + + return ListResourcesResponse( + resources=summaries, + total_count=total_count, + has_more=has_more, + ).model_dump(by_alias=True) + finally: + try: + next(db_gen) + except StopIteration: + pass + except Exception as e: + logger.error(f"Error listing resources: {e}") + return ListResourcesResponse( + resources=[], + total_count=0, + has_more=False, + ).model_dump(by_alias=True) + + async def _read_resource(self, arguments: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + """Read the content of an MCP resource by URI. + + Args: + arguments: Must contain uri (str) of the resource to read. + + Returns: + ReadResourceResponse as dict with content. + """ + from mcpgateway.db import Resource # pylint: disable=import-outside-toplevel + from mcpgateway.services.observability_service import ObservabilityService, current_trace_id # pylint: disable=import-outside-toplevel + from mcpgateway.services.resource_service import ResourceService as _RsService # pylint: disable=import-outside-toplevel + + uri = arguments.get("uri", "") + if not uri: + return ReadResourceResponse( + uri="", + name="", + text="Error: uri is required", + ).model_dump(by_alias=True) + + # Extract user context for access control + user_email, token_teams, _ = self._extract_user_context(kwargs) + + start_time = time.monotonic() + success = False + error_message = None + trace_id = current_trace_id.get() + db_span_id = None + observability_service = ObservabilityService() if trace_id else None + + # Start observability span + if trace_id and observability_service: + try: + with fresh_db_session() as span_db: + db_span_id = observability_service.start_span( + db=span_db, + trace_id=trace_id, + name="resource.read", + attributes={ + "resource.uri": uri, + "user": kwargs.get("user_email", "anonymous"), + }, + commit=False, + ) + logger.debug(f"✓ Created resource.read span: {db_span_id} for resource: {uri}") + except Exception as e: + logger.warning(f"Failed to start observability span for resource read: {e}") + db_span_id = None + + try: + db_gen = get_db() + db = next(db_gen) + try: + resource = ( + db.query(Resource) + .filter(Resource.uri == uri, Resource.enabled.is_(True)) + .first() + ) + + if resource is None: + error_message = f"Resource not found: {uri}" + return ReadResourceResponse( + uri=uri, + name="", + text=error_message, + ).model_dump(by_alias=True) + + # Check access control + _rs = _RsService() + if not await _rs._check_resource_access(db, resource, user_email, token_teams): + error_message = f"Resource not found: {uri}" + return ReadResourceResponse( + uri=uri, + name="", + text=error_message, + ).model_dump(by_alias=True) + + text_content = resource.text_content + if text_content is None and resource.binary_content is not None: + text_content = "(binary content — not displayable as text)" + + success = True + return ReadResourceResponse( + uri=resource.uri, + name=resource.name, + mime_type=resource.mime_type, + text=text_content, + size=resource.size, + ).model_dump(by_alias=True) + finally: + try: + next(db_gen) + except StopIteration: + pass + except Exception as e: + error_message = str(e) + logger.error(f"Error reading resource '{uri}': {e}") + return ReadResourceResponse( + uri=uri, + name="", + text=f"Error reading resource: {str(e)}", + ).model_dump(by_alias=True) + finally: + # End observability span + if db_span_id and observability_service: + try: + with fresh_db_session() as span_db: + observability_service.end_span( + db=span_db, + span_id=db_span_id, + status="ok" if success else "error", + status_message=error_message, + attributes={ + "duration_ms": (time.monotonic() - start_time) * 1000, + }, + commit=False, + ) + logger.debug(f"✓ Ended resource.read span: {db_span_id}") + except Exception as e: + logger.warning(f"Failed to end observability span for resource read: {e}") + + async def _list_prompts(self, arguments: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + """List MCP prompts with pagination and optional filtering. + + Args: + arguments: List parameters dict with keys: + - limit (int): Max results to return (default 50) + - offset (int): Pagination offset (default 0) + - tags (List[str]): Optional tag filter + + Returns: + ListPromptsResponse as dict. + """ + from mcpgateway.db import Prompt # pylint: disable=import-outside-toplevel + from mcpgateway.services.prompt_service import PromptService as _PsService # pylint: disable=import-outside-toplevel + + limit = arguments.get("limit", 50) + offset = arguments.get("offset", 0) + tags = arguments.get("tags", []) + + # Extract user context for access control + user_email, token_teams, _ = self._extract_user_context(kwargs) + + try: + db_gen = get_db() + db = next(db_gen) + try: + query = db.query(Prompt).filter(Prompt.enabled.is_(True)) + + # Apply access control + _ps = _PsService() + query = await _ps._apply_access_control(query, db, user_email, token_teams) + + all_prompts = query.order_by(Prompt.created_at.desc()).all() + + # Apply tag filtering in Python (tags stored as JSON) + if tags: + all_prompts = [ + p for p in all_prompts + if p.tags and any(t in self._normalize_tags(p.tags) for t in tags) + ] + + total_count = len(all_prompts) + paginated = all_prompts[offset: offset + limit] + has_more = total_count > offset + limit + + summaries = [ + PromptSummary( + name=p.name, + description=p.description, + tags=self._normalize_tags(p.tags), + argument_schema=p.argument_schema, + ) + for p in paginated + ] + + return ListPromptsResponse( + prompts=summaries, + total_count=total_count, + has_more=has_more, + ).model_dump(by_alias=True) + finally: + try: + next(db_gen) + except StopIteration: + pass + except Exception as e: + logger.error(f"Error listing prompts: {e}") + return ListPromptsResponse( + prompts=[], + total_count=0, + has_more=False, + ).model_dump(by_alias=True) + + async def _get_prompt(self, arguments: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + """Get a prompt template by name with optional rendering. + + Args: + arguments: Must contain name (str). Optional arguments (dict) for rendering. + + Returns: + GetPromptResponse as dict with template and optionally rendered content. + """ + from mcpgateway.db import Prompt # pylint: disable=import-outside-toplevel + from mcpgateway.services.observability_service import ObservabilityService, current_trace_id # pylint: disable=import-outside-toplevel + from mcpgateway.services.prompt_service import PromptService as _PsService # pylint: disable=import-outside-toplevel + + name = arguments.get("name", "") + prompt_args = arguments.get("arguments", {}) + + if not name: + return GetPromptResponse( + name="", + template="", + description="Error: name is required", + ).model_dump(by_alias=True) + + # Extract user context for access control + user_email, token_teams, _ = self._extract_user_context(kwargs) + + start_time = time.monotonic() + success = False + error_message = None + trace_id = current_trace_id.get() + db_span_id = None + observability_service = ObservabilityService() if trace_id else None + + # Start observability span + if trace_id and observability_service: + try: + with fresh_db_session() as span_db: + db_span_id = observability_service.start_span( + db=span_db, + trace_id=trace_id, + name="prompt.render", + attributes={ + "prompt.id": name, + "arguments_count": len(prompt_args) if prompt_args else 0, + "user": kwargs.get("user_email", "anonymous"), + }, + commit=False, + ) + logger.debug(f"✓ Created prompt.render span: {db_span_id} for prompt: {name}") + except Exception as e: + logger.warning(f"Failed to start observability span for prompt render: {e}") + db_span_id = None + + try: + db_gen = get_db() + db = next(db_gen) + try: + prompt = ( + db.query(Prompt) + .filter(Prompt.name == name, Prompt.enabled.is_(True)) + .first() + ) + + if prompt is None: + error_message = f"Prompt not found: {name}" + return GetPromptResponse( + name=name, + template="", + description=error_message, + ).model_dump(by_alias=True) + + # Check access control + _ps = _PsService() + if not await _ps._check_prompt_access(db, prompt, user_email, token_teams): + error_message = f"Prompt not found: {name}" + return GetPromptResponse( + name=name, + template="", + description=error_message, + ).model_dump(by_alias=True) + + rendered = None + if prompt_args: + try: + prompt.validate_arguments(prompt_args) + rendered = prompt.template.format(**prompt_args) + except (ValueError, KeyError) as e: + rendered = f"Error rendering prompt: {str(e)}" + + success = True + return GetPromptResponse( + name=prompt.name, + description=prompt.description, + template=prompt.template, + rendered=rendered, + argument_schema=prompt.argument_schema, + tags=self._normalize_tags(prompt.tags), + ).model_dump(by_alias=True) + finally: + try: + next(db_gen) + except StopIteration: + pass + except Exception as e: + error_message = str(e) + logger.error(f"Error getting prompt '{name}': {e}") + return GetPromptResponse( + name=name, + template="", + description=f"Error getting prompt: {str(e)}", + ).model_dump(by_alias=True) + finally: + # End observability span + if db_span_id and observability_service: + try: + with fresh_db_session() as span_db: + observability_service.end_span( + db=span_db, + span_id=db_span_id, + status="ok" if success else "error", + status_message=error_message, + attributes={ + "duration_ms": (time.monotonic() - start_time) * 1000, + }, + commit=False, + ) + logger.debug(f"✓ Ended prompt.render span: {db_span_id}") + except Exception as e: + logger.warning(f"Failed to end observability span for prompt render: {e}") + + +# Module-level singleton +_meta_server_service: Optional[MetaServerService] = None + + +def get_meta_server_service() -> MetaServerService: + """Get or create the MetaServerService singleton. + + Returns: + MetaServerService instance. + + Examples: + >>> service = get_meta_server_service() + >>> isinstance(service, MetaServerService) + True + """ + global _meta_server_service # pylint: disable=global-statement + if _meta_server_service is None: + _meta_server_service = MetaServerService() + return _meta_server_service diff --git a/mcpgateway/middleware/rbac.py b/mcpgateway/middleware/rbac.py index 5e412f29f0..cca02f517b 100644 --- a/mcpgateway/middleware/rbac.py +++ b/mcpgateway/middleware/rbac.py @@ -17,6 +17,7 @@ from functools import wraps import logging from typing import Any, Callable, Generator, List, Optional +from urllib.parse import quote import uuid import warnings @@ -39,6 +40,26 @@ set_trace_user_email, set_trace_user_is_admin, ) + + +def _login_url_with_next(request: Request) -> str: + """Build login redirect URL, preserving the original request path as ?next= parameter. + + This ensures that after SSO login, the user is redirected back to the page + they were trying to access (e.g. /oauth/authorize/{gateway_id}). + + Args: + request: The incoming request whose path should be preserved. + + Returns: + Login URL with optional ?next= parameter. + """ + login_url = f"{settings.app_root_path}/admin/login" + request_path = request.scope.get("path", "/") + # Don't add ?next= for the login page itself (avoid redirect loops) or root paths + if request_path and request_path != "/" and not request_path.rstrip("/").endswith("/admin/login"): + login_url = f"{login_url}?next={quote(request_path, safe='')}" + return login_url from mcpgateway.utils.verify_credentials import is_proxy_auth_trust_active logger = logging.getLogger(__name__) @@ -268,7 +289,7 @@ def _set_trace_context_for_identity(*, email: Optional[str], is_admin: bool, aut raise HTTPException( status_code=status.HTTP_302_FOUND, detail="Authentication required", - headers={"Location": f"{settings.app_root_path}/admin/login"}, + headers={"Location": _login_url_with_next(request)}, ) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -302,7 +323,7 @@ def _set_trace_context_for_identity(*, email: Optional[str], is_admin: bool, aut raise HTTPException( status_code=status.HTTP_302_FOUND, detail="Authentication required", - headers={"Location": f"{settings.app_root_path}/admin/login"}, + headers={"Location": _login_url_with_next(request)}, ) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -365,7 +386,7 @@ def _set_trace_context_for_identity(*, email: Optional[str], is_admin: bool, aut if not token: # For browser requests (HTML Accept header or HTMX), redirect to login if is_browser_request: - raise HTTPException(status_code=status.HTTP_302_FOUND, detail="Authentication required", headers={"Location": f"{settings.app_root_path}/admin/login"}) + raise HTTPException(status_code=status.HTTP_302_FOUND, detail="Authentication required", headers={"Location": _login_url_with_next(request)}) # AUTH_REQUIRED=false no longer implies admin access. # Preserve explicit unsafe override for local-only compatibility. @@ -448,7 +469,7 @@ def _set_trace_context_for_identity(*, email: Optional[str], is_admin: bool, aut accept_header = request.headers.get("accept", "") is_htmx = request.headers.get("hx-request") == "true" if "text/html" in accept_header or is_htmx: - raise HTTPException(status_code=status.HTTP_302_FOUND, detail="Authentication required", headers={"Location": f"{settings.app_root_path}/admin/login"}) + raise HTTPException(status_code=status.HTTP_302_FOUND, detail="Authentication required", headers={"Location": _login_url_with_next(request)}) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials") diff --git a/mcpgateway/observability.py b/mcpgateway/observability.py index cd7f3f7d38..9a9095d627 100644 --- a/mcpgateway/observability.py +++ b/mcpgateway/observability.py @@ -113,7 +113,6 @@ class _ConsoleSpanExporterStub: # pragma: no cover - test patch replaces this logging.getLogger(__name__).debug("Skipping OpenTelemetry shim setup: %s", exc) # First-Party -from mcpgateway import __version__ # noqa: E402 # pylint: disable=wrong-import-position from mcpgateway.config import get_settings # noqa: E402 # pylint: disable=wrong-import-position from mcpgateway.utils.correlation_id import get_correlation_id # noqa: E402 # pylint: disable=wrong-import-position from mcpgateway.utils.log_sanitizer import sanitize_for_log # noqa: E402 # pylint: disable=wrong-import-position @@ -844,7 +843,7 @@ def init_telemetry() -> Optional[Any]: # Create resource attributes resource_attributes: Dict[str, Any] = { "service.name": cfg.otel_service_name, - "service.version": __version__, + "service.version": "1.0.0-RC-2", "deployment.environment": _get_deployment_environment(), } @@ -989,7 +988,7 @@ def on_end(self, span): # Get tracer # Obtain a tracer if trace API available; otherwise create a no-op tracer if trace is not None and hasattr(trace, "get_tracer"): - _TRACER = cast(Any, trace).get_tracer("mcp-gateway", __version__, schema_url="https://opentelemetry.io/schemas/1.11.0") + _TRACER = cast(Any, trace).get_tracer("mcp-gateway", "1.0.0-RC-2", schema_url="https://opentelemetry.io/schemas/1.11.0") else: class _NoopTracer: diff --git a/mcpgateway/routers/meta_router.py b/mcpgateway/routers/meta_router.py new file mode 100644 index 0000000000..06c256dc79 --- /dev/null +++ b/mcpgateway/routers/meta_router.py @@ -0,0 +1,450 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/routers/meta_router.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Meta-Tool Router. +This module provides FastAPI routes for meta-tools (describe_tool, execute_tool). + +Examples: + >>> from fastapi import FastAPI + >>> from mcpgateway.routers.meta_router import router + >>> app = FastAPI() + >>> app.include_router(router, prefix="/meta", tags=["Meta Tools"]) + >>> isinstance(router, APIRouter) + True +""" + +# Standard +import json +import time +from typing import Any, Dict, Optional + +# Third-Party +from fastapi import APIRouter, Depends, Header, HTTPException, Request, status +from sqlalchemy.orm import Session + +# First-Party +from mcpgateway.db import get_db +from mcpgateway.meta_server.schemas import ( + DescribeToolRequest, + DescribeToolResponse, + ExecuteToolRequest, + ExecuteToolResponse, + GetPromptRequest, + GetPromptResponse, + GetSimilarToolsRequest, + GetSimilarToolsResponse, + GetToolCategoriesRequest, + GetToolCategoriesResponse, + ListPromptsRequest, + ListPromptsResponse, + ListResourcesRequest, + ListResourcesResponse, + ListToolsRequest, + ListToolsResponse, + ReadResourceRequest, + ReadResourceResponse, + SearchToolsRequest, + SearchToolsResponse, +) +from mcpgateway.meta_server.service import get_meta_server_service +from mcpgateway.middleware.rbac import get_current_user_with_permissions +from mcpgateway.services.logging_service import LoggingService +from mcpgateway.services.meta_tool_service import MetaToolService + +# Initialize logging +logging_service = LoggingService() +logger = logging_service.get_logger(__name__) + +# Create router +router = APIRouter(prefix="/meta", tags=["Meta Tools"]) + + +def _apply_scope_header(arguments: Dict[str, Any], x_scope: Optional[str]) -> Dict[str, Any]: + """Apply X-Scope header to arguments dict, parsing JSON if provided.""" + if x_scope: + try: + arguments["scope"] = json.loads(x_scope) + except (json.JSONDecodeError, ValueError): + logger.warning(f"Invalid X-Scope header: {x_scope}") + return arguments + + +@router.post("/describe_tool", response_model=DescribeToolResponse) +async def describe_tool( + req: DescribeToolRequest, + request: Request, + current_user_ctx: dict = Depends(get_current_user_with_permissions), + db: Session = Depends(get_db), + x_scope: Optional[str] = Header(None, alias="X-Scope"), +) -> DescribeToolResponse: + """Get detailed information about a specific tool including schema and metadata. + + Args: + req: Describe tool request + request: FastAPI request object + current_user_ctx: Current user context with permissions + db: Database session + x_scope: Optional scope header for filtering + + Returns: + DescribeToolResponse: Tool details + + Raises: + HTTPException: If tool is not found or access is denied + """ + try: + service = MetaToolService(db) + user_email = current_user_ctx.get("email") + token_teams = current_user_ctx.get("teams") + is_admin = current_user_ctx.get("is_admin", False) + + response = await service.describe_tool( + tool_name=req.tool_name, + include_metrics=req.include_metrics, + user_email=user_email, + token_teams=token_teams, + is_admin=is_admin, + scope=x_scope, + ) + return response + except ValueError as e: + logger.warning(f"Tool not found or access denied: {req.tool_name} - {e}") + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) + except Exception as e: + logger.error(f"Error describing tool {req.tool_name}: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") + + +@router.post("/execute_tool", response_model=ExecuteToolResponse) +async def execute_tool( + req: ExecuteToolRequest, + request: Request, + current_user_ctx: dict = Depends(get_current_user_with_permissions), + db: Session = Depends(get_db), + x_scope: Optional[str] = Header(None, alias="X-Scope"), +) -> ExecuteToolResponse: + """Execute a tool by name with the provided arguments. + + Validates input against the tool's JSON schema and routes execution to + the correct backend server. + + Args: + req: Execute tool request + request: FastAPI request object + current_user_ctx: Current user context with permissions + db: Database session + x_scope: Optional scope header for filtering + + Returns: + ExecuteToolResponse: Execution result with metadata + + Raises: + HTTPException: If tool is not found, validation fails, or execution fails + """ + start_time = time.time() + + try: + service = MetaToolService(db) + user_email = current_user_ctx.get("email") + token_teams = current_user_ctx.get("teams") + is_admin = current_user_ctx.get("is_admin", False) + + # Extract headers for forwarding + request_headers = dict(request.headers) + + response = await service.execute_tool( + tool_name=req.tool_name, + arguments=req.arguments, + user_email=user_email, + token_teams=token_teams, + is_admin=is_admin, + scope=x_scope, + request_headers=request_headers, + ) + + # Add execution time + execution_time_ms = int((time.time() - start_time) * 1000) + response.execution_time_ms = execution_time_ms + + return response + except ValueError as e: + logger.warning(f"Validation error for tool {req.tool_name}: {e}") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + except PermissionError as e: + logger.warning(f"Access denied for tool {req.tool_name}: {e}") + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(e)) + except Exception as e: + logger.error(f"Error executing tool {req.tool_name}: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") + + +@router.post("/search_tools", response_model=SearchToolsResponse) +async def search_tools( + req: SearchToolsRequest, + request: Request, + current_user_ctx: dict = Depends(get_current_user_with_permissions), + db: Session = Depends(get_db), + x_scope: Optional[str] = Header(None, alias="X-Scope"), +) -> SearchToolsResponse: + """Search for tools using hybrid semantic and keyword search. + + Performs semantic search via embeddings + keyword fallback with scope filtering. + + Args: + req: Search request parameters + request: FastAPI request object + current_user_ctx: Current user context with permissions + db: Database session + x_scope: Optional scope header for filtering + + Returns: + SearchToolsResponse: Ranked search results + + Raises: + HTTPException: If search fails + """ + try: + meta_service = get_meta_server_service() + + # Build arguments dict with scope from header if provided + arguments = _apply_scope_header(req.model_dump(), x_scope) + + # Call the service handler + result = await meta_service._search_tools(arguments) + return SearchToolsResponse(**result) + except Exception as e: + logger.error(f"Error searching tools: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") + + +@router.post("/list_tools", response_model=ListToolsResponse) +async def list_tools( + req: ListToolsRequest, + request: Request, + current_user_ctx: dict = Depends(get_current_user_with_permissions), + db: Session = Depends(get_db), + x_scope: Optional[str] = Header(None, alias="X-Scope"), +) -> ListToolsResponse: + """List tools with pagination, sorting, and filtering. + + Args: + req: List request parameters + request: FastAPI request object + current_user_ctx: Current user context with permissions + db: Database session + x_scope: Optional scope header for filtering + + Returns: + ListToolsResponse: Paginated tool list + + Raises: + HTTPException: If listing fails + """ + try: + meta_service = get_meta_server_service() + + # Build arguments dict with scope from header if provided + arguments = _apply_scope_header(req.model_dump(), x_scope) + + # Call the service handler + result = await meta_service._list_tools(arguments) + return ListToolsResponse(**result) + except Exception as e: + logger.error(f"Error listing tools: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") + + +@router.post("/get_similar_tools", response_model=GetSimilarToolsResponse) +async def get_similar_tools( + req: GetSimilarToolsRequest, + request: Request, + current_user_ctx: dict = Depends(get_current_user_with_permissions), + db: Session = Depends(get_db), + x_scope: Optional[str] = Header(None, alias="X-Scope"), +) -> GetSimilarToolsResponse: + """Find tools similar to a reference tool using vector similarity. + + Args: + req: Similarity search request + request: FastAPI request object + current_user_ctx: Current user context with permissions + db: Database session + x_scope: Optional scope header for filtering + + Returns: + GetSimilarToolsResponse: Similar tools with scores + + Raises: + HTTPException: If similarity search fails + """ + try: + meta_service = get_meta_server_service() + + # Build arguments dict with scope from header if provided + arguments = _apply_scope_header(req.model_dump(), x_scope) + + # Call the service handler + result = await meta_service._get_similar_tools(arguments) + return GetSimilarToolsResponse(**result) + except Exception as e: + logger.error(f"Error finding similar tools: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") + + +@router.post("/get_tool_categories", response_model=GetToolCategoriesResponse) +async def get_tool_categories( + req: GetToolCategoriesRequest, + request: Request, + current_user_ctx: dict = Depends(get_current_user_with_permissions), + db: Session = Depends(get_db), +) -> GetToolCategoriesResponse: + """Get aggregated tool categories with counts. + + Args: + req: Category request parameters + request: FastAPI request object + current_user_ctx: Current user context with permissions + db: Database session + + Returns: + GetToolCategoriesResponse: Categories with tool counts + + Raises: + HTTPException: If category aggregation fails + """ + try: + meta_service = get_meta_server_service() + + # Call the service handler + arguments = req.model_dump() + result = await meta_service._get_tool_categories(arguments) + return GetToolCategoriesResponse(**result) + except Exception as e: + logger.error(f"Error getting tool categories: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") + + +@router.post("/list_resources", response_model=ListResourcesResponse) +async def list_resources( + req: ListResourcesRequest, + request: Request, + current_user_ctx: dict = Depends(get_current_user_with_permissions), + db: Session = Depends(get_db), +) -> ListResourcesResponse: + """List MCP resources with optional filtering by tags or MIME type. + + Args: + req: List resources request parameters + request: FastAPI request object + current_user_ctx: Current user context with permissions + db: Database session + + Returns: + ListResourcesResponse: Paginated resource list + + Raises: + HTTPException: If listing fails + """ + try: + meta_service = get_meta_server_service() + arguments = req.model_dump() + result = await meta_service._list_resources(arguments) + return ListResourcesResponse(**result) + except Exception as e: + logger.error(f"Error listing resources: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") + + +@router.post("/read_resource", response_model=ReadResourceResponse) +async def read_resource( + req: ReadResourceRequest, + request: Request, + current_user_ctx: dict = Depends(get_current_user_with_permissions), + db: Session = Depends(get_db), +) -> ReadResourceResponse: + """Read the content of an MCP resource by URI. + + Args: + req: Read resource request with URI + request: FastAPI request object + current_user_ctx: Current user context with permissions + db: Database session + + Returns: + ReadResourceResponse: Resource content + + Raises: + HTTPException: If reading fails + """ + try: + meta_service = get_meta_server_service() + arguments = req.model_dump() + result = await meta_service._read_resource(arguments) + return ReadResourceResponse(**result) + except Exception as e: + logger.error(f"Error reading resource: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") + + +@router.post("/list_prompts", response_model=ListPromptsResponse) +async def list_prompts( + req: ListPromptsRequest, + request: Request, + current_user_ctx: dict = Depends(get_current_user_with_permissions), + db: Session = Depends(get_db), +) -> ListPromptsResponse: + """List MCP prompt templates with optional filtering by tags. + + Args: + req: List prompts request parameters + request: FastAPI request object + current_user_ctx: Current user context with permissions + db: Database session + + Returns: + ListPromptsResponse: Paginated prompt list + + Raises: + HTTPException: If listing fails + """ + try: + meta_service = get_meta_server_service() + arguments = req.model_dump() + result = await meta_service._list_prompts(arguments) + return ListPromptsResponse(**result) + except Exception as e: + logger.error(f"Error listing prompts: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") + + +@router.post("/get_prompt", response_model=GetPromptResponse) +async def get_prompt( + req: GetPromptRequest, + request: Request, + current_user_ctx: dict = Depends(get_current_user_with_permissions), + db: Session = Depends(get_db), +) -> GetPromptResponse: + """Get a prompt template by name, optionally rendering it with arguments. + + Args: + req: Get prompt request with name and optional arguments + request: FastAPI request object + current_user_ctx: Current user context with permissions + db: Database session + + Returns: + GetPromptResponse: Prompt template and optionally rendered content + + Raises: + HTTPException: If prompt retrieval fails + """ + try: + meta_service = get_meta_server_service() + arguments = req.model_dump() + result = await meta_service._get_prompt(arguments) + return GetPromptResponse(**result) + except Exception as e: + logger.error(f"Error getting prompt: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") diff --git a/mcpgateway/routers/oauth_router.py b/mcpgateway/routers/oauth_router.py index 95323b152e..5a47745ebe 100644 --- a/mcpgateway/routers/oauth_router.py +++ b/mcpgateway/routers/oauth_router.py @@ -567,6 +567,36 @@ def _invalid_state_response() -> HTMLResponse: logger.info(f"Completed OAuth flow for gateway {SecurityValidator.sanitize_log_message(gateway_id)}, user {SecurityValidator.sanitize_log_message(str(result.get('user_id')))}") + # Check for chained authorization flow (from /oauth/authorize-all) + oauth_chain = request.cookies.get("oauth_chain") if request else None + if oauth_chain: + chain_ids = [gid.strip() for gid in oauth_chain.split(",") if gid.strip()] + if chain_ids: + next_gw_id = chain_ids[0] + remaining = chain_ids[1:] + response = RedirectResponse(url=f"{safe_root_path}/oauth/authorize/{next_gw_id}", status_code=302) + use_secure = (settings.environment == "production") or settings.secure_cookies + if remaining: + response.set_cookie( + key="oauth_chain", + value=",".join(remaining), + max_age=600, + httponly=True, + secure=use_secure, + samesite=settings.cookie_samesite, + path=settings.app_root_path or "/", + ) + else: + response.delete_cookie( + "oauth_chain", + path=settings.app_root_path or "/", + secure=use_secure, + httponly=True, + samesite=settings.cookie_samesite, + ) + logger.info(f"OAuth chain: authorized {SecurityValidator.sanitize_log_message(gateway_id)}, continuing to next gateway") + return response + # Return success page with option to return to admin return HTMLResponse(content=f""" @@ -727,6 +757,97 @@ def _invalid_state_response() -> HTMLResponse: ) +@oauth_router.get("/authorize-all", response_model=None) +async def authorize_all_gateways( + request: Request, + current_user: EmailUserResponse = Depends(get_current_user_with_permissions), + db: Session = Depends(get_db), +): + """Authorize all OAuth gateways the user has access to in a single flow. + + This endpoint chains OAuth authorization flows for all gateways that + require authorization_code grant and where the user does not yet have + a valid token. After authorizing the first gateway, the callback handler + continues to the next one automatically via the ``oauth_chain`` cookie. + + Args: + request: The FastAPI request object. + current_user: The authenticated user. + db: The database session. + + Returns: + RedirectResponse to the first pending gateway, or an HTML page if + all gateways are already authorized. + """ + requester_email = _extract_user_email(current_user) + if not requester_email: + raise HTTPException(status_code=401, detail="User authentication required") + + root_path = request.scope.get("root_path", "") if request else "" + safe_root_path = escape(str(root_path), quote=True) + + # Find all OAuth gateways with authorization_code flow + gateways = db.execute( + select(Gateway).where( + Gateway.auth_type == "oauth", + Gateway.enabled.is_(True), + ) + ).scalars().all() + + token_service = TokenStorageService(db) + pending_gateway_ids = [] + already_authorized = [] + + for gw in gateways: + if not gw.oauth_config or gw.oauth_config.get("grant_type") != "authorization_code": + continue + # Check access — skip gateways the user can't reach + try: + await _enforce_gateway_access(gw.id, gw, current_user, db, request=request) + except HTTPException: + continue + # Check existing token + token_info = await token_service.get_token_info(gw.id, requester_email) + if token_info and not token_info.get("is_expired", True): + already_authorized.append(gw.name) + else: + pending_gateway_ids.append(gw.id) + + if not pending_gateway_ids: + gw_list = "".join(f"
  • {escape(n)}
  • " for n in already_authorized) if already_authorized else "
  • No OAuth gateways found
  • " + return HTMLResponse( + content=f"""All Gateways Authorized + +

    ✅ All Gateways Already Authorized

    + +

    You can close this tab and return to your AI agent.

    + Admin Panel + """, + ) + + # Start chain: redirect to first gateway, store remaining in cookie + first_gw_id = pending_gateway_ids[0] + remaining = pending_gateway_ids[1:] + + response = RedirectResponse(url=f"{root_path}/oauth/authorize/{first_gw_id}", status_code=302) + + if remaining: + use_secure = (settings.environment == "production") or settings.secure_cookies + response.set_cookie( + key="oauth_chain", + value=",".join(remaining), + max_age=600, # 10 minutes for the full chain + httponly=True, + secure=use_secure, + samesite=settings.cookie_samesite, + path=settings.app_root_path or "/", + ) + + return response + + @oauth_router.get("/status/{gateway_id}") async def get_oauth_status( gateway_id: str, diff --git a/mcpgateway/routers/sso.py b/mcpgateway/routers/sso.py index f4d93566bb..81d00cad53 100644 --- a/mcpgateway/routers/sso.py +++ b/mcpgateway/routers/sso.py @@ -150,6 +150,27 @@ def _normalize_origin(scheme: str, host: str, port: int | None) -> str: return f"{scheme}://{host}:{port}" +def _is_safe_local_path(path: str) -> bool: + """Validate that a path is a safe local redirect target (no open redirect). + + Args: + path: The path to validate. + + Returns: + True if the path is a safe relative path starting with ``/``. + """ + if not path or not isinstance(path, str): + return False + if not path.startswith("/"): + return False + if path.startswith("//") or "@" in path or "\\" in path: + return False + parsed = urlparse(path) + if parsed.scheme or parsed.netloc: + return False + return True + + def _validate_redirect_uri(redirect_uri: str, request: Request | None = None) -> bool: """Validate redirect_uri to prevent open redirect attacks. @@ -396,8 +417,24 @@ async def handle_sso_callback( if not access_token: return RedirectResponse(url=f"{root_path}/admin/login?error=user_creation_failed", status_code=302) - # Create redirect response - redirect_response = RedirectResponse(url=f"{root_path}/admin", status_code=302) + # Create redirect response — check for post-login destination cookie + post_login_next = request.cookies.get("post_login_next") if request else None + if post_login_next and _is_safe_local_path(post_login_next): + redirect_url = f"{root_path}{post_login_next}" + else: + redirect_url = f"{root_path}/admin" + redirect_response = RedirectResponse(url=redirect_url, status_code=302) + + # Clear the post_login_next cookie regardless + if post_login_next: + use_secure = (settings.environment == "production") or settings.secure_cookies + redirect_response.delete_cookie( + "post_login_next", + path=settings.app_root_path or "/", + secure=use_secure, + httponly=True, + samesite=settings.cookie_samesite, + ) # Set secure HTTP-only cookie using the same method as email auth # First-Party diff --git a/mcpgateway/schemas.py b/mcpgateway/schemas.py index 821463b26a..6571c2bbd5 100644 --- a/mcpgateway/schemas.py +++ b/mcpgateway/schemas.py @@ -379,14 +379,6 @@ class MetricsResponse(BaseModelWithConfigDict): @model_serializer(mode="wrap") def _exclude_none_a2a(self, handler): - """Omit the A2A metrics field when that feature is disabled. - - Args: - handler: Pydantic serializer callback for the wrapped model. - - Returns: - Dict[str, Any]: Serialized metrics payload without empty A2A fields. - """ result = handler(self) if self.a2a_agents is None: result.pop("a2aAgents", None) @@ -4071,6 +4063,19 @@ def validate_id(cls, v: Optional[str]) -> Optional[str]: oauth_enabled: bool = Field(False, description="Enable OAuth 2.0 for MCP client authentication") oauth_config: Optional[Dict[str, Any]] = Field(None, description="OAuth 2.0 configuration (authorization_server, scopes_supported, etc.)") + # Meta-server configuration + server_type: str = Field("standard", description="Server type: 'standard' or 'meta'. Meta servers expose meta-tools instead of real tools.") + hide_underlying_tools: bool = Field(True, description="When True and server_type is 'meta', underlying tools are hidden from tool listing endpoints") + meta_config: Optional[Dict[str, Any]] = Field(None, description="Meta-server configuration (MetaConfig schema). Only applicable when server_type is 'meta'.") + meta_scope: Optional[Dict[str, Any]] = Field(None, description="Scope rules for filtering tools visible to the meta-server (MetaToolScope schema).") + + @field_validator("server_type") + @classmethod + def validate_server_type(cls, v: str) -> str: + if v not in ("standard", "meta"): + raise ValueError("server_type must be one of: standard, meta") + return v + @field_validator("name") @classmethod def validate_name(cls, v: str) -> str: @@ -4232,6 +4237,19 @@ def validate_id(cls, v: Optional[str]) -> Optional[str]: associated_prompts: Optional[List[str]] = Field(None, description="Comma-separated prompt IDs") associated_a2a_agents: Optional[List[str]] = Field(None, description="Comma-separated A2A agent IDs") + # Meta-server configuration (optional update fields) + server_type: Optional[str] = Field(None, description="Server type: 'standard' or 'meta'") + hide_underlying_tools: Optional[bool] = Field(None, description="When True and server_type is 'meta', underlying tools are hidden") + meta_config: Optional[Dict[str, Any]] = Field(None, description="Meta-server configuration (MetaConfig schema)") + meta_scope: Optional[Dict[str, Any]] = Field(None, description="Scope rules for filtering tools visible to the meta-server") + + @field_validator("server_type") + @classmethod + def validate_server_type(cls, v: Optional[str]) -> Optional[str]: + if v is not None and v not in ("standard", "meta"): + raise ValueError("server_type must be one of: standard, meta") + return v + @field_validator("name") @classmethod def validate_name(cls, v: str) -> str: @@ -4364,6 +4382,12 @@ class ServerRead(BaseModelWithConfigDict): oauth_enabled: bool = Field(False, description="Whether OAuth 2.0 is enabled for MCP client authentication") oauth_config: Optional[Dict[str, Any]] = Field(None, description="OAuth 2.0 configuration (authorization_server, scopes_supported, etc.)") + # Meta-server configuration + server_type: str = Field("standard", description="Server type: 'standard' or 'meta'") + hide_underlying_tools: bool = Field(True, description="When True and server_type is 'meta', underlying tools are hidden") + meta_config: Optional[Dict[str, Any]] = Field(None, description="Meta-server configuration (MetaConfig schema)") + meta_scope: Optional[Dict[str, Any]] = Field(None, description="Scope rules for filtering tools visible to the meta-server") + _normalize_visibility = field_validator("visibility", mode="before")(classmethod(lambda cls, v: _coerce_visibility(v))) @model_validator(mode="before") @@ -8212,8 +8236,6 @@ class PerformanceHistoryResponse(BaseModel): # --------------------------------------------------------------------------- -# Tool Plugin Binding Schemas -# --------------------------------------------------------------------------- class PluginBindingMode(str, Enum): @@ -8349,3 +8371,22 @@ class ToolPluginBindingListResponse(BaseModelWithConfigDict): bindings: List[ToolPluginBindingResponse] = Field(default_factory=list, description="List of tool plugin bindings") total: int = Field(0, description="Total number of bindings returned") + + +class ToolSearchResult(BaseModelWithConfigDict): + """Response schema for a single tool search result with relevance score.""" + + tool_name: str = Field(..., description="Tool name") + description: Optional[str] = Field(None, description="Tool description") + similarity_score: float = Field(..., ge=0.0, le=1.0, description="Similarity score (0-1)") + server_id: Optional[str] = Field(None, description="Server ID the tool belongs to") + server_name: Optional[str] = Field(None, description="Server name the tool belongs to") + tags: List[str] = Field(default_factory=list, description="Tool tags") + + +class SemanticSearchResponse(BaseModelWithConfigDict): + """Response schema for semantic tool search.""" + + results: List[ToolSearchResult] = Field(..., description="Ranked list of matching tools") + query: str = Field(..., description="Original search query") + total_results: int = Field(0, description="Number of results returned") diff --git a/mcpgateway/services/embedding_service.py b/mcpgateway/services/embedding_service.py new file mode 100644 index 0000000000..3f0c2d71ef --- /dev/null +++ b/mcpgateway/services/embedding_service.py @@ -0,0 +1,9 @@ +"""Embedding service stub. + +Full implementation requires an embedding model (e.g. OpenAI text-embedding-3-small) +and pgvector. This stub is a no-op so the application starts without those dependencies. +""" + + +async def index_tool_fire_and_forget(tool_id: str) -> None: + """Index a tool's embedding in the background. No-op stub.""" diff --git a/mcpgateway/services/meta_tool_service.py b/mcpgateway/services/meta_tool_service.py new file mode 100644 index 0000000000..54241f5f13 --- /dev/null +++ b/mcpgateway/services/meta_tool_service.py @@ -0,0 +1,335 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/services/meta_tool_service.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Meta-Tool Service Implementation. +This module implements the business logic for meta-tools (describe_tool, execute_tool). +""" + +# Standard +import time +from typing import Any, Dict, List, Optional +import uuid + +# Third-Party +import jsonschema +import orjson +from sqlalchemy import select +from sqlalchemy.orm import joinedload, Session + +# First-Party +from mcpgateway.db import Server as DbServer +from mcpgateway.db import Tool as DbTool +from mcpgateway.db import ToolMetric +from mcpgateway.meta_server.schemas import ( + DescribeToolResponse, + ExecuteToolResponse, +) +from mcpgateway.services.logging_service import LoggingService +from mcpgateway.services.tool_service import ToolService + +# Initialize logging +logging_service = LoggingService() +logger = logging_service.get_logger(__name__) + + +class MetaToolService: + """Service for meta-tool operations.""" + + def __init__(self, db: Session): + """Initialize the MetaToolService. + + Args: + db: Database session + """ + self.db = db + self.tool_service = ToolService() + + async def describe_tool( + self, + tool_name: str, + include_metrics: bool = False, + user_email: Optional[str] = None, + token_teams: Optional[List[str]] = None, + is_admin: bool = False, + scope: Optional[str] = None, + ) -> DescribeToolResponse: + """Get detailed information about a specific tool. + + This implements the describe_tool meta-tool functionality with: + - Tool resolution by name + - Schema and metadata fetching + - Optional metrics fetching + - Scope verification + + Args: + tool_name: Name of the tool to describe + include_metrics: Whether to include execution metrics + user_email: Email of requesting user + token_teams: Team IDs from JWT token + is_admin: Whether user is an admin + scope: Optional scope filter + + Returns: + DescribeToolResponse with tool details + + Raises: + ValueError: If tool not found or access denied + """ + # Resolve tool by name with scope verification + tool = await self._resolve_tool(tool_name, user_email, token_teams, is_admin, scope) + + if not tool: + raise ValueError(f"Tool not found: {tool_name}") + + # Fetch server information — prefer gateway (MCP backend) over M2M servers + server_id = None + server_name = None + if tool.gateway_id: + server_id = tool.gateway_id + try: + if tool.gateway: + server_name = tool.gateway.name + except Exception: + pass + elif tool.servers: + server = tool.servers[0] + server_id = server.id + server_name = server.name + + # Fetch metrics if requested + metrics = None + if include_metrics: + metrics = await self._fetch_tool_metrics(tool.id) + + # Extract tag strings from database format + # Tags may be stored as [{'id': 'tag', 'label': 'tag'}, ...] or ['tag', ...] + tags_list = tool.tags or [] + if tags_list and isinstance(tags_list[0], dict): + tags_list = [tag.get("id") or tag.get("label") for tag in tags_list if isinstance(tag, dict)] + + # Inherit tags from gateway if the tool has no tags of its own + if not tags_list and tool.gateway_id: + try: + if tool.gateway and tool.gateway.tags: + gw_tags = tool.gateway.tags + if gw_tags and isinstance(gw_tags[0], dict): + tags_list = [t.get("id") or t.get("label") for t in gw_tags if isinstance(t, dict)] + else: + tags_list = list(gw_tags) + except Exception: + pass + + # Build response + response = DescribeToolResponse( + name=tool.name, + description=tool.description or tool.original_description, + input_schema=tool.input_schema, + output_schema=tool.output_schema, + server_id=server_id, + server_name=server_name, + tags=tags_list, + metrics=metrics, + annotations=tool.annotations, + ) + + return response + + async def execute_tool( + self, + tool_name: str, + arguments: Dict[str, Any], + user_email: Optional[str] = None, + token_teams: Optional[List[str]] = None, + is_admin: bool = False, + scope: Optional[str] = None, + request_headers: Optional[Dict[str, str]] = None, + ) -> ExecuteToolResponse: + """Execute a tool with argument validation and routing. + + This implements the execute_tool meta-tool functionality with: + - Tool resolution + - Argument validation against JSON schema + - Routing to backend server + - Safe header forwarding + - Execution metadata + + Args: + tool_name: Name of the tool to execute + arguments: Arguments to pass to the tool + user_email: Email of requesting user + token_teams: Team IDs from JWT token + is_admin: Whether user is an admin + scope: Optional scope filter + request_headers: Headers from the original request + + Returns: + ExecuteToolResponse with execution result and metadata + + Raises: + ValueError: If tool not found, validation fails, or execution fails + PermissionError: If access is denied + """ + start_time = time.time() + + # Resolve tool with scope verification + tool = await self._resolve_tool(tool_name, user_email, token_teams, is_admin, scope) + + if not tool: + raise ValueError(f"Tool not found: {tool_name}") + + # Validate arguments against input schema + if tool.input_schema: + try: + jsonschema.validate(instance=arguments, schema=tool.input_schema) + except jsonschema.ValidationError as e: + raise ValueError(f"Argument validation failed: {e.message}") + + # Execute tool via ToolService + try: + # Generate request ID for tracking + request_id = str(uuid.uuid4()) + + # Prepare metadata + meta_data = { + "request_id": request_id, + "meta_tool": "execute_tool", + } + + # Forward request to ToolService for execution + tool_result = await self.tool_service.invoke_tool( + db=self.db, + name=tool_name, + arguments=arguments, + request_headers=request_headers, + app_user_email=user_email, + user_email=user_email, + token_teams=token_teams, + meta_data=meta_data, + ) + + # Extract result content + result_data = None + if tool_result.content: + if isinstance(tool_result.content, list) and len(tool_result.content) > 0: + first_content = tool_result.content[0] + if hasattr(first_content, "text"): + result_data = first_content.text + elif hasattr(first_content, "model_dump"): + result_data = orjson.dumps(first_content.model_dump(by_alias=True, mode="json")).decode() + else: + result_data = str(first_content) + else: + result_data = str(tool_result.content) + + execution_time_ms = int((time.time() - start_time) * 1000) + + return ExecuteToolResponse( + tool_name=tool_name, + success=not getattr(tool_result, "isError", getattr(tool_result, "is_error", False)), + result=result_data, + error=None, + execution_time_ms=execution_time_ms, + ) + + except Exception as e: + execution_time_ms = int((time.time() - start_time) * 1000) + logger.error(f"Tool execution failed for {tool_name}: {e}") + return ExecuteToolResponse( + tool_name=tool_name, + success=False, + result=None, + error=str(e), + execution_time_ms=execution_time_ms, + ) + + async def _resolve_tool( + self, + tool_name: str, + user_email: Optional[str], + token_teams: Optional[List[str]], + is_admin: bool, + scope: Optional[str], + ) -> Optional[DbTool]: + """Resolve a tool by name with scope verification. + + Args: + tool_name: Name of the tool + user_email: Email of requesting user + token_teams: Team IDs from JWT token + is_admin: Whether user is an admin + scope: Optional scope filter + + Returns: + Tool object or None if not found/accessible + """ + # Build query with eager loading of relationships + query = select(DbTool).options(joinedload(DbTool.servers), joinedload(DbTool.gateway)).where(DbTool.name == tool_name, DbTool.enabled == True) + + # Apply scope filtering if provided + # Scope filtering logic: + # - If scope is provided, filter by visibility or team + # - Admin bypass if is_admin=True + if scope and not is_admin: + # Scope can be: public, team:, private + if scope == "public": + query = query.where(DbTool.visibility == "public") + elif scope.startswith("team:"): + team_id = scope.replace("team:", "") + query = query.where(DbTool.team_id == team_id) + elif scope == "private": + query = query.where(DbTool.owner_email == user_email) + + # Apply team-based filtering if not admin + if not is_admin and token_teams is not None: + # If token_teams is empty list, only public tools + # If token_teams has values, include team tools + public tools + if len(token_teams) == 0: + query = query.where(DbTool.visibility == "public") + else: + # Third-Party + from sqlalchemy import or_ + + query = query.where(or_(DbTool.visibility == "public", DbTool.team_id.in_(token_teams))) + + result = self.db.execute(query) + tool = result.scalars().first() + + return tool + + async def _fetch_tool_metrics(self, tool_id: str) -> Optional[Dict[str, Any]]: + """Fetch execution metrics for a tool. + + Args: + tool_id: Tool ID + + Returns: + Dictionary with metrics or None + """ + try: + # Query ToolMetric for aggregated metrics + query = select(ToolMetric).where(ToolMetric.tool_id == tool_id) + result = self.db.execute(query) + metrics_records = result.scalars().all() + + if not metrics_records: + return None + + # Aggregate metrics + execution_count = len(metrics_records) + successful = sum(1 for m in metrics_records if m.success) + failed = execution_count - successful + total_time = sum(m.response_time for m in metrics_records if m.response_time) + avg_time = total_time / execution_count if execution_count > 0 else 0 + + return { + "execution_count": execution_count, + "successful_executions": successful, + "failed_executions": failed, + "success_rate": successful / execution_count if execution_count > 0 else 0, + "avg_response_time_ms": avg_time, + } + except Exception as e: + logger.warning(f"Failed to fetch metrics for tool {tool_id}: {e}") + return None diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index 50b3e2448d..7fa5fa7e88 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -30,7 +30,6 @@ from mcp import ClientSession, types from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client -from mcp.types import GetPromptRequest, GetPromptRequestParams import orjson from pydantic import ValidationError from sqlalchemy import and_, delete, desc, not_, or_, select @@ -39,7 +38,6 @@ # First-Party from mcpgateway.common.models import Message, PromptResult, Role, TextContent -from mcpgateway.common.validators import validate_meta_data as _validate_meta_data from mcpgateway.config import settings from mcpgateway.db import EmailTeam from mcpgateway.db import EmailTeamMember as DbEmailTeamMember @@ -143,55 +141,6 @@ def _get_registry_cache(): metrics_buffer = get_metrics_buffer_service() -def _build_get_prompt_request(name: str, arguments: Optional[Dict[str, str]], meta_data: Dict[str, Any]) -> "types.ClientRequest": - """Build a GetPrompt ClientRequest that carries _meta (CWE-20, CWE-284). - - Using ``by_alias=True`` ensures the Pydantic alias ``_meta`` is the only - key written into the dict so the subsequent ``model_validate`` call - resolves it correctly regardless of ``populate_by_name`` settings. - - ``send_request`` is used instead of ``session.get_prompt()`` because the - MCP SDK helper does not expose a ``_meta`` parameter; this wrapper must be - updated if the SDK later adds that capability. - - Args: - name: The prompt name. - arguments: Optional prompt arguments. - meta_data: Validated metadata dict to inject as ``_meta``. - - Returns: - A :class:`types.ClientRequest` ready to be passed to ``session.send_request``. - """ - _gp_dict = GetPromptRequestParams(name=name, arguments=arguments).model_dump(by_alias=True) - _gp_dict["_meta"] = meta_data - return types.ClientRequest(GetPromptRequest(params=GetPromptRequestParams.model_validate(_gp_dict))) - - -async def _get_prompt_with_meta(session: "ClientSession", name: str, arguments: Optional[Dict[str, str]], meta_data: Optional[Dict[str, Any]]) -> Any: - """Dispatch a get_prompt call, injecting ``_meta`` when meta_data is provided. - - Eliminates the repeated ``if meta_data: send_request … else: get_prompt`` - pattern across every transport/pool branch in this module. - - Args: - session: An active MCP :class:`ClientSession`. - name: The prompt name. - arguments: Optional prompt-rendering arguments. - meta_data: Optional validated metadata dict. When ``None`` the standard - SDK helper is used; when non-empty the low-level ``send_request`` - path is taken to carry ``_meta``. - - Returns: - The raw MCP result object (caller extracts ``.messages``). - """ - if meta_data: - return await session.send_request( - _build_get_prompt_request(name, arguments, meta_data), - types.GetPromptResult, - ) - return await session.get_prompt(name, arguments=arguments) - - class PromptError(Exception): """Base class for prompt-related errors.""" @@ -365,13 +314,14 @@ def _should_fetch_gateway_prompt(prompt: DbPrompt) -> bool: """ return bool(getattr(prompt, "gateway_id", None)) and not bool(getattr(prompt, "template", "")) - async def _fetch_gateway_prompt_result(self, prompt: DbPrompt, arguments: Optional[Dict[str, str]], meta_data: Optional[Dict[str, Any]] = None) -> PromptResult: + async def _fetch_gateway_prompt_result(self, prompt: DbPrompt, arguments: Optional[Dict[str, str]], meta_data: Optional[Dict[str, Any]] = None, user_identity: Optional[str] = None) -> PromptResult: """Fetch a rendered prompt from the upstream MCP gateway. Args: prompt: Gateway-backed prompt record from the catalog. arguments: Optional prompt-rendering arguments. meta_data: Optional metadata dict forwarded as ``_meta`` in the upstream MCP request. + user_identity: Effective requester email for session-pool isolation. Returns: Prompt result normalized into ContextForge models. @@ -403,8 +353,6 @@ async def _fetch_gateway_prompt_result(self, prompt: DbPrompt, arguments: Option transport = str(getattr(gateway, "transport", "streamable_http") or "streamable_http").lower() registry_transport_type = TransportType.SSE if transport == "sse" else TransportType.STREAMABLE_HTTP prompt_arguments = arguments or None - # CWE-400: Validate meta_data limits before forwarding to upstream - _validate_meta_data(meta_data) try: # #4205: Use the upstream session registry when a downstream Mcp-Session-Id @@ -423,6 +371,7 @@ async def _fetch_gateway_prompt_result(self, prompt: DbPrompt, arguments: Option url=gateway_url, headers=headers, transport_type=registry_transport_type, + user_identity=pool_user_identity, ) as upstream: remote_result = await _get_prompt_with_meta(upstream.session, remote_name, prompt_arguments, meta_data) return PromptResult( @@ -437,12 +386,12 @@ async def _fetch_gateway_prompt_result(self, prompt: DbPrompt, arguments: Option async with sse_client(url=gateway_url, headers=headers, timeout=settings.health_check_timeout) as streams: async with ClientSession(*streams) as session: await session.initialize() - remote_result = await _get_prompt_with_meta(session, remote_name, prompt_arguments, meta_data) + remote_result = await session.get_prompt(remote_name, arguments=prompt_arguments) else: async with streamablehttp_client(url=gateway_url, headers=headers, timeout=settings.health_check_timeout) as (read_stream, write_stream, _get_session_id): async with ClientSession(read_stream, write_stream) as session: await session.initialize() - remote_result = await _get_prompt_with_meta(session, remote_name, prompt_arguments, meta_data) + remote_result = await session.get_prompt(remote_name, arguments=prompt_arguments) return PromptResult( messages=[ @@ -1906,7 +1855,7 @@ async def get_prompt( None = unrestricted admin, [] = public-only, [...] = team-scoped. plugin_context_table: Optional plugin context table from previous hooks for cross-hook state sharing. plugin_global_context: Optional global context from middleware for consistency across hooks. - _meta_data: Optional metadata forwarded as _meta to the upstream MCP gateway during prompt retrieval. + _meta_data: Optional metadata for prompt retrieval (not used currently). Returns: Prompt result with rendered messages @@ -2067,7 +2016,7 @@ async def get_prompt( if self._should_fetch_gateway_prompt(prompt): # Release the read transaction before any remote network I/O. db.commit() - result = await self._fetch_gateway_prompt_result(prompt, arguments, meta_data=_meta_data) + result = await self._fetch_gateway_prompt_result(prompt, arguments, meta_data=_meta_data, user_identity=user) elif not arguments: result = PromptResult( messages=[ diff --git a/mcpgateway/services/resource_service.py b/mcpgateway/services/resource_service.py index c2e4bcf742..4ea233ba91 100644 --- a/mcpgateway/services/resource_service.py +++ b/mcpgateway/services/resource_service.py @@ -35,10 +35,9 @@ # Third-Party import httpx -from mcp import ClientSession, types +from mcp import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client -from mcp.types import ReadResourceRequest, ReadResourceRequestParams import parse from pydantic import ValidationError from sqlalchemy import and_, delete, desc, not_, or_, select @@ -48,7 +47,6 @@ # First-Party from mcpgateway.common.models import ResourceContent, ResourceContents, ResourceTemplate, TextContent from mcpgateway.common.validators import SecurityValidator -from mcpgateway.common.validators import validate_meta_data as _validate_meta_data from mcpgateway.config import settings from mcpgateway.db import EmailTeam from mcpgateway.db import EmailTeamMember as DbEmailTeamMember @@ -116,53 +114,6 @@ def _get_registry_cache(): metrics_buffer = get_metrics_buffer_service() -def _build_read_resource_request(uri: Any, meta_data: Dict[str, Any]) -> "types.ClientRequest": - """Build a ReadResource ClientRequest that carries _meta (CWE-20, CWE-284). - - Using ``by_alias=True`` ensures the Pydantic alias ``_meta`` is the only - key written into the dict so the subsequent ``model_validate`` call - resolves it correctly regardless of ``populate_by_name`` settings. - - ``send_request`` is used instead of ``session.read_resource()`` because the - MCP SDK helper does not expose a ``_meta`` parameter; this wrapper must be - updated if the SDK later adds that capability. - - Args: - uri: The resource URI. - meta_data: Validated metadata dict to inject as ``_meta``. - - Returns: - A :class:`types.ClientRequest` ready to be passed to ``session.send_request``. - """ - _rp_dict = ReadResourceRequestParams(uri=uri).model_dump(by_alias=True) - _rp_dict["_meta"] = meta_data - return types.ClientRequest(ReadResourceRequest(params=ReadResourceRequestParams.model_validate(_rp_dict))) - - -async def _read_resource_with_meta(session: "ClientSession", uri: Any, meta_data: Optional[Dict[str, Any]]) -> Any: - """Dispatch a read_resource call, injecting ``_meta`` when meta_data is provided. - - Eliminates the repeated ``if meta_data: send_request … else: read_resource`` - pattern across every transport/pool branch in this module. - - Args: - session: An active MCP :class:`ClientSession`. - uri: The resource URI to read. - meta_data: Optional validated metadata dict. When ``None`` the standard - SDK helper is used; when non-empty the low-level ``send_request`` - path is taken to carry ``_meta``. - - Returns: - The raw MCP result object (caller extracts ``.contents``). - """ - if meta_data: - return await session.send_request( - _build_read_resource_request(uri, meta_data), - types.ReadResourceResult, - ) - return await session.read_resource(uri=uri) - - class ResourceError(Exception): """Base class for resource-related errors.""" @@ -1622,7 +1573,7 @@ async def invoke_resource( # pylint: disable=unused-argument resource_uri: str, resource_template_uri: Optional[str] = None, user_identity: Optional[Union[str, Dict[str, Any]]] = None, - meta_data: Optional[Dict[str, Any]] = None, # Forwarded as _meta in upstream MCP requests + meta_data: Optional[Dict[str, Any]] = None, # Reserved for future MCP SDK support resource_obj: Optional[Any] = None, gateway_obj: Optional[Any] = None, server_id: Optional[str] = None, @@ -1736,10 +1687,6 @@ async def invoke_resource( # pylint: disable=unused-argument 'using template: /template' """ - # CWE-400: Validate meta_data limits before any further processing; invoke_resource is - # a separate entry point that must enforce the same guards as read_resource. - _validate_meta_data(meta_data) - uri = None if resource_uri and resource_template_uri: uri = resource_template_uri @@ -1976,8 +1923,8 @@ async def connect_to_sse_session(server_url: str, uri: str, authentication: Opti ``None`` instead of raising. Note: - When meta_data is provided, the request is built using send_request - with _meta injected into ReadResourceRequestParams. + MCP SDK 1.25.0 read_resource() does not support meta parameter. + When the SDK adds support, meta_data can be added back here. Args: server_url (str): @@ -2024,6 +1971,7 @@ async def connect_to_sse_session(server_url: str, uri: str, authentication: Opti headers=authentication, transport_type=TransportType.SSE, httpx_client_factory=_get_httpx_client_factory, + user_identity=pool_user_identity, ) as upstream: resource_response = await _read_resource_with_meta(upstream.session, uri, meta_data) return getattr(getattr(resource_response, "contents")[0], "text") @@ -2035,7 +1983,8 @@ async def connect_to_sse_session(server_url: str, uri: str, authentication: Opti ): async with ClientSession(read_stream, write_stream) as session: _ = await session.initialize() - resource_response = await _read_resource_with_meta(session, uri, meta_data) + # Note: MCP SDK 1.25.0 read_resource() does not support meta parameter + resource_response = await session.read_resource(uri=uri) return getattr(getattr(resource_response, "contents")[0], "text") except Exception as e: # Sanitize error message to prevent URL secrets from leaking in logs @@ -2057,8 +2006,8 @@ async def connect_to_streamablehttp_server(server_url: str, uri: str, authentica of propagating the exception. Note: - When meta_data is provided, the request is built using send_request - with _meta injected into ReadResourceRequestParams. + MCP SDK 1.25.0 read_resource() does not support meta parameter. + When the SDK adds support, meta_data can be added back here. Args: server_url (str): @@ -2102,6 +2051,7 @@ async def connect_to_streamablehttp_server(server_url: str, uri: str, authentica headers=authentication, transport_type=TransportType.STREAMABLE_HTTP, httpx_client_factory=_get_httpx_client_factory, + user_identity=pool_user_identity, ) as upstream: resource_response = await _read_resource_with_meta(upstream.session, uri, meta_data) return getattr(getattr(resource_response, "contents")[0], "text") @@ -2114,7 +2064,8 @@ async def connect_to_streamablehttp_server(server_url: str, uri: str, authentica ): async with ClientSession(read_stream, write_stream) as session: _ = await session.initialize() - resource_response = await _read_resource_with_meta(session, uri, meta_data) + # Note: MCP SDK 1.25.0 read_resource() does not support meta parameter + resource_response = await session.read_resource(uri=uri) return getattr(getattr(resource_response, "contents")[0], "text") except Exception as e: # Sanitize error message to prevent URL secrets from leaking in logs @@ -2128,8 +2079,10 @@ async def connect_to_streamablehttp_server(server_url: str, uri: str, authentica resource_text = "" if (gateway_transport).lower() == "sse": + # Note: meta_data not passed - MCP SDK 1.25.0 read_resource() doesn't support it resource_text = await connect_to_sse_session(server_url=gateway_url, authentication=headers, uri=uri) else: + # Note: meta_data not passed - MCP SDK 1.25.0 read_resource() doesn't support it resource_text = await connect_to_streamablehttp_server(server_url=gateway_url, authentication=headers, uri=uri) if span and resource_text is not None and is_output_capture_enabled("invoke.resource"): set_span_attribute(span, "langfuse.observation.output", serialize_trace_payload({"content": resource_text})) @@ -2241,8 +2194,6 @@ async def read_resource( resource_db = None server_scoped = False resource_db_gateway = None # Only set when eager-loaded via Q2's joinedload - # CWE-400: Validate meta_data limits before any further processing - _validate_meta_data(meta_data) content = None uri = resource_uri or "unknown" if resource_id: @@ -2421,7 +2372,8 @@ async def read_resource( async with ClientSession(read_stream, write_stream) as session: await session.initialize() - result = await _read_resource_with_meta(session, uri, meta_data) + # Note: MCP SDK read_resource() only accepts uri; _meta is not supported + result = await session.read_resource(uri=uri) # Convert MCP result to MCP-compliant content models # result.contents is a list of TextResourceContents or BlobResourceContents diff --git a/mcpgateway/services/semantic_search_service.py b/mcpgateway/services/semantic_search_service.py new file mode 100644 index 0000000000..851569e2d3 --- /dev/null +++ b/mcpgateway/services/semantic_search_service.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +"""Semantic search service stub. + +Placeholder for the full semantic search implementation (issue #2229). +Returns empty results, allowing fallback to keyword search. +""" + +import logging +from typing import Any, List, Optional + +logger = logging.getLogger(__name__) + + +class SemanticSearchService: + """Stub semantic search service that returns empty results.""" + + async def search_tools( + self, + query: str, + db: Any = None, + limit: int = 10, + threshold: float = 0.7, + ) -> List[Any]: + """Return empty results — semantic search not yet implemented.""" + logger.debug("Semantic search not available, returning empty results for query: %s", query[:100]) + return [] + + +_instance: Optional[SemanticSearchService] = None + + +def get_semantic_search_service() -> SemanticSearchService: + """Get or create the singleton semantic search service.""" + global _instance + if _instance is None: + _instance = SemanticSearchService() + return _instance diff --git a/mcpgateway/services/server_service.py b/mcpgateway/services/server_service.py index 311fa92e1d..92e5c6af52 100644 --- a/mcpgateway/services/server_service.py +++ b/mcpgateway/services/server_service.py @@ -431,6 +431,11 @@ def convert_server_to_read(self, server: DbServer, include_metrics: bool = False # OAuth 2.0 configuration for RFC 9728 Protected Resource Metadata "oauth_enabled": getattr(server, "oauth_enabled", False), "oauth_config": getattr(server, "oauth_config", None), + # Meta-server fields + "server_type": getattr(server, "server_type", "standard") or "standard", + "hide_underlying_tools": getattr(server, "hide_underlying_tools", True), + "meta_config": getattr(server, "meta_config", None), + "meta_scope": getattr(server, "meta_scope", None), } # Compute aggregated metrics only if requested (avoids N+1 queries in list operations) @@ -598,6 +603,11 @@ async def register_server( # OAuth 2.0 configuration for RFC 9728 Protected Resource Metadata oauth_enabled=getattr(server_in, "oauth_enabled", False) or False, oauth_config=oauth_config, + # Meta-server fields + server_type=getattr(server_in, "server_type", "standard") or "standard", + hide_underlying_tools=getattr(server_in, "hide_underlying_tools", True), + meta_config=getattr(server_in, "meta_config", None), + meta_scope=getattr(server_in, "meta_scope", None), # Metadata fields created_by=created_by, created_from_ip=created_from_ip, @@ -1326,6 +1336,16 @@ async def update_server( elif server_update.oauth_config is not None: server.oauth_config = await protect_oauth_config_for_storage(server_update.oauth_config, existing_oauth_config=server.oauth_config) + # Update meta-server fields if provided + if getattr(server_update, "server_type", None) is not None: + server.server_type = server_update.server_type + if getattr(server_update, "hide_underlying_tools", None) is not None: + server.hide_underlying_tools = server_update.hide_underlying_tools + if getattr(server_update, "meta_config", None) is not None: + server.meta_config = server_update.meta_config + if getattr(server_update, "meta_scope", None) is not None: + server.meta_scope = server_update.meta_scope + # Update metadata fields server.updated_at = datetime.now(timezone.utc) if modified_by: diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 89a7369e8f..8a5651ad01 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -3928,9 +3928,10 @@ async def prepare_rust_mcp_tool_execution( with fresh_db_session() as token_db: token_storage = TokenStorageService(token_db) - if not app_user_email: + effective_email = app_user_email or user_email + if not effective_email: raise ToolInvocationError(f"User authentication required for OAuth-protected gateway '{gateway_name}'. Please ensure you are authenticated.") - access_token = await token_storage.get_user_token(gateway_id_str, app_user_email) + access_token = await token_storage.get_user_token(gateway_id_str, effective_email) if access_token: headers = {"Authorization": f"Bearer {access_token}"} @@ -5064,10 +5065,11 @@ async def invoke_tool( token_storage = TokenStorageService(token_db) # Get user-specific OAuth token - if not app_user_email: + effective_email = app_user_email or user_email + if not effective_email: raise ToolInvocationError(f"User authentication required for OAuth-protected gateway '{gateway_name}'. Please ensure you are authenticated.") - access_token = await token_storage.get_user_token(gateway_id_str, app_user_email) + access_token = await token_storage.get_user_token(gateway_id_str, effective_email) if access_token: headers = {"Authorization": f"Bearer {access_token}"} diff --git a/mcpgateway/services/vector_search_service.py b/mcpgateway/services/vector_search_service.py new file mode 100644 index 0000000000..a4add0fb2c --- /dev/null +++ b/mcpgateway/services/vector_search_service.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +"""Vector search service stub. + +Placeholder for the full vector search implementation (issue #2229). +Provides embedding retrieval and similarity search over tool embeddings. +""" + +import logging +import math +from typing import Any, List, Optional + +from sqlalchemy.orm import Session + +logger = logging.getLogger(__name__) + + +def _cosine_similarity_numpy(vec_a: List[float], vec_b: List[float]) -> float: + """Compute cosine similarity between two vectors without numpy.""" + if not vec_a or not vec_b or len(vec_a) != len(vec_b): + return 0.0 + dot = sum(a * b for a, b in zip(vec_a, vec_b)) + norm_a = math.sqrt(sum(a * a for a in vec_a)) + norm_b = math.sqrt(sum(b * b for b in vec_b)) + if norm_a == 0 or norm_b == 0: + return 0.0 + return dot / (norm_a * norm_b) + + +class VectorSearchService: + """Stub vector search service for tool embeddings.""" + + def __init__(self, db: Optional[Session] = None): + self.db = db + + def get_tool_embedding(self, db: Session, tool_id: str) -> Any: + """Retrieve the stored embedding for a tool. + + Returns None when no embedding is found (stub always returns None). + """ + try: + from mcpgateway.db import ToolEmbedding + + result = db.query(ToolEmbedding).filter(ToolEmbedding.tool_id == tool_id).first() + return result + except Exception as e: + logger.debug("Failed to get tool embedding for %s: %s", tool_id, e) + return None + + async def search_similar_tools( + self, + embedding: List[float], + limit: int = 10, + db: Optional[Session] = None, + ) -> List[Any]: + """Search for tools similar to the given embedding vector. + + Returns empty list when no embeddings are available. + """ + session = db or self.db + if session is None: + return [] + + try: + from mcpgateway.db import ToolEmbedding + + all_embeddings = session.query(ToolEmbedding).all() + if not all_embeddings: + return [] + + scored = [] + for te in all_embeddings: + sim = _cosine_similarity_numpy(embedding, te.embedding) + scored.append((te, sim)) + scored.sort(key=lambda x: x[1], reverse=True) + return scored[:limit] + except Exception as e: + logger.debug("Vector similarity search failed: %s", e) + return [] diff --git a/mcpgateway/templates/admin.html b/mcpgateway/templates/admin.html index 08fdc66641..a2eb02e8a1 100644 --- a/mcpgateway/templates/admin.html +++ b/mcpgateway/templates/admin.html @@ -2884,8 +2884,68 @@

    class="mt-1 px-3 py-2 block w-full rounded-md border border-gray-300 dark:border-gray-700 shadow-sm focus:border-indigo-500 focus:ring-indigo-500 dark:bg-gray-900 dark:placeholder-gray-300 dark:text-gray-300" /> + +
    +
    + + +
    +

    + When enabled, this server exposes meta-tools (search, list, describe, execute) instead of individual underlying tools. +

    + + + +
    + + + + +
    -
    +
    +
    @@ -3036,10 +3097,13 @@

    class="mt-2 min-h-[1.25rem] text-sm font-semibold text-yellow-600" aria-live="polite" >

    -
    - + + +
    -
    +
    -
    +
    -
    - + + servers.

    + +
    + + +

    Team that owns this server.

    +
    +
    +

    + MCP OAuth Proxy (DCR Bypass) +

    +

    + Pre-register OAuth client credentials to bypass Dynamic Client Registration (RFC 7591). + Required for IdPs like Microsoft Entra ID that do not support DCR. +

    +
    + + +

    + OAuth client ID registered with the Identity Provider +

    +
    +
    + + +

    + OAuth client secret (stored encrypted). Leave blank if using public client (PKCE only). +

    +
    +
    @@ -3367,6 +3482,7 @@

    Add Server + {% endif %} @@ -4102,6 +4218,12 @@

    tools.

    + +
    + + +

    Team that owns this tool.

    +
    + +
    + + +

    Team that owns this resource.

    +
    + +
    + + +

    Team that owns this prompt.

    +
    + +
    + + +

    Team that owns this gateway.

    +
    + +
    + + +

    + How client credentials are sent to the token endpoint (RFC 6749 Section 2.3) +

    +
    @@ -6001,6 +6160,40 @@

    class="mt-1 px-3 py-2 block w-full rounded-md border border-gray-300 dark:border-gray-700 shadow-sm focus:border-indigo-500 focus:ring-indigo-500 dark:bg-gray-900 dark:placeholder-gray-300 dark:text-gray-300" /> +
    + + + Glob patterns for tools to include (comma-separated, e.g., + "manage-ticket*, manage-task*"). Only matching tools will be + imported. Leave empty to include all. + + +
    +
    + + + Glob patterns for tools to exclude (comma-separated, e.g., + "manage-project*"). Matching tools will be skipped. Leave + empty to exclude none. + + +
    @@ -7515,6 +7708,12 @@

    + + +
    + + +

    Team that owns this agent.

    @@ -7617,6 +7816,7 @@

    + {% endif %} @@ -8936,6 +9136,13 @@

    filter tools.

    + +
    + + +

    Team that owns this tool.

    +
    + +
    + + +

    Team that owns this resource.

    +
    + +
    + + +

    Team that owns this prompt.

    +
    + + +
    automatically normalized.

    + +
    + + +

    Team that owns this gateway.

    +
    @@ -10343,6 +10620,25 @@

    read:user")

    + +
    + + +

    + How client credentials are sent to the token endpoint (RFC 6749 Section 2.3) +

    +
    @@ -10390,6 +10686,46 @@

    class="mt-1 px-3 py-2 block w-full rounded-md border border-gray-300 shadow-sm focus:border-indigo-500 focus:ring-indigo-500 dark:bg-gray-900 dark:placeholder-gray-300 dark:text-gray-300" /> +
    + + + Glob patterns for tools to include (comma-separated, + e.g., "manage-ticket*, manage-task*"). Only matching + tools will be imported. Leave empty to include all. + + +
    +
    + + + Glob patterns for tools to exclude (comma-separated, + e.g., "manage-project*"). Matching tools will be + skipped. Leave empty to exclude none. + + +
    + + +
    + + +

    Team that owns this agent.

    @@ -11021,6 +11364,7 @@

    + @@ -11157,6 +11501,13 @@

    filter servers.

    + +
    + + +

    Team that owns this server.

    +
    @@ -11233,8 +11584,69 @@

    {% endif %} + + +
    +
    + + +
    +

    + When enabled, this server exposes meta-tools (search, list, describe, execute) instead of individual underlying tools. +

    + + + +
    + + + + +
    -
    -
    + +
    +
    @@ -11386,6 +11800,7 @@

    >

    +
    +
    @@ -11598,6 +12014,49 @@

    Leave blank to use standard discovery from authorization server

    +
    +

    + MCP OAuth Proxy (DCR Bypass) +

    +

    + Pre-register OAuth client credentials to bypass Dynamic Client Registration (RFC 7591). + Required for IdPs like Microsoft Entra ID that do not support DCR. +

    +
    + + +

    + OAuth client ID registered with the Identity Provider +

    +
    +
    + + +

    + OAuth client secret (stored encrypted). Leave blank to keep existing secret. +

    +
    +
    @@ -14131,6 +14590,7 @@

    hx-target="#create-team-error" hx-swap="innerHTML" data-team-validation="true" + onsubmit="Admin.syncOidcGroupIds('create-oidc-groups-table', 'create-oidc-group-ids')" >
    @@ -14242,6 +14702,62 @@

    {% endif %}

    + + +
    +

    OIDC Group Sync

    +
    + + +
    + +
    @@ -14431,8 +14947,8 @@

    required class="mt-1 block w-full px-3 py-2 border border-gray-300 dark:border-gray-600 rounded-md shadow-sm focus:outline-none focus:ring-indigo-500 focus:border-indigo-500 dark:bg-gray-700 dark:text-white" > - - + +

    @@ -14525,8 +15041,8 @@

    required class="mt-1 block w-full px-3 py-2 border border-gray-300 dark:border-gray-600 rounded-md shadow-sm focus:outline-none focus:ring-indigo-500 focus:border-indigo-500 dark:bg-gray-700 dark:text-white" > - - + + diff --git a/mcpgateway/transports/streamablehttp_transport.py b/mcpgateway/transports/streamablehttp_transport.py index dff7425b01..ffe5dd9868 100644 --- a/mcpgateway/transports/streamablehttp_transport.py +++ b/mcpgateway/transports/streamablehttp_transport.py @@ -64,9 +64,9 @@ # First-Party from mcpgateway.cache.global_config_cache import global_config_cache from mcpgateway.common.models import LogLevel -from mcpgateway.common.validators import validate_meta_data as _validate_meta_data from mcpgateway.config import settings from mcpgateway.db import SessionLocal +from mcpgateway.meta_server.service import get_meta_server_service from mcpgateway.middleware.rbac import _ACCESS_DENIED_MSG from mcpgateway.observability import create_span from mcpgateway.plugins.framework.models import UserContext @@ -290,6 +290,9 @@ def _resolve_authorization_servers(oauth_config: Dict[str, Any]) -> List[str]: return [url] return [] +# Meta-server context: stores server_type for the current request +server_type_var: contextvars.ContextVar[str] = contextvars.ContextVar("server_type", default="standard") +hide_underlying_tools_var: contextvars.ContextVar[bool] = contextvars.ContextVar("hide_underlying_tools", default=True) _shared_session_registry: Optional[Any] = None _rust_event_store_client: Optional[httpx.AsyncClient] = None @@ -1295,8 +1298,14 @@ async def _proxy_list_tools_to_gateway(gateway: Any, request_headers: dict, user async with ClientSession(read_stream, write_stream) as session: await session.initialize() + # Prepare params with _meta if provided + params = None + if meta: + params = PaginatedRequestParams(_meta=meta) + logger.debug("Forwarding _meta to remote gateway: %s", meta) + # List tools with _meta forwarded - result = await session.list_tools(params=_build_paginated_params(meta)) + result = await session.list_tools(params=params) return result.tools except Exception as e: @@ -1343,16 +1352,21 @@ async def _proxy_list_resources_to_gateway(gateway: Any, request_headers: dict, logger.info("Proxying resources/list to gateway %s at %s", gateway.id, gateway.url) if meta: - # CWE-532: log only key names, never values which may carry PII/tokens - logger.debug("Forwarding _meta to remote gateway (keys: %s)", sorted(meta.keys()) if isinstance(meta, dict) else type(meta).__name__) + logger.debug("Forwarding _meta to remote gateway: %s", meta) # Use MCP SDK to connect and list resources async with streamablehttp_client(url=gateway.url, headers=headers, timeout=settings.mcpgateway_direct_proxy_timeout) as (read_stream, write_stream, _get_session_id): async with ClientSession(read_stream, write_stream) as session: await session.initialize() + # Prepare params with _meta if provided + params = None + if meta: + params = PaginatedRequestParams(_meta=meta) + logger.debug("Forwarding _meta to remote gateway: %s", meta) + # List resources with _meta forwarded - result = await session.list_resources(params=_build_paginated_params(meta)) + result = await session.list_resources(params=params) logger.info("Received %s resources from gateway %s", len(result.resources), gateway.id) return result.resources @@ -1409,8 +1423,7 @@ async def _proxy_read_resource_to_gateway(gateway: Any, resource_uri: str, user_ logger.info("Proxying resources/read for %s to gateway %s at %s", resource_uri, gateway.id, gateway.url) if meta: - # CWE-532: log only key names, never values which may carry PII/tokens - logger.debug("Forwarding _meta to remote gateway (keys: %s)", sorted(meta.keys()) if isinstance(meta, dict) else type(meta).__name__) + logger.debug("Forwarding _meta to remote gateway: %s", meta) # Use MCP SDK to connect and read resource async with streamablehttp_client(url=gateway.url, headers=headers, timeout=settings.mcpgateway_direct_proxy_timeout) as (read_stream, write_stream, _get_session_id): @@ -1420,10 +1433,8 @@ async def _proxy_read_resource_to_gateway(gateway: Any, resource_uri: str, user_ # Prepare request params with _meta if provided if meta: # Create params and inject _meta - # by_alias=True ensures the alias "_meta" key is written so - # model_validate resolves it correctly (fixes CWE-20 silent drop) request_params = ReadResourceRequestParams(uri=resource_uri) - request_params_dict = request_params.model_dump(by_alias=True) + request_params_dict = request_params.model_dump() request_params_dict["_meta"] = meta # Send request with _meta @@ -1535,6 +1546,9 @@ async def call_tool(name: str, arguments: dict) -> Union[ token_teams = user_context.get("teams") if user_context else None is_admin = user_context.get("is_admin", False) if user_context else False + # Preserve actual email for OAuth token lookup before admin bypass nulls it + actual_user_email = user_email + # Admin bypass - only when token has NO team restrictions (token_teams is None) # If token has explicit team scope (even empty [] for public-only), respect it if is_admin and token_teams is None: @@ -1567,6 +1581,20 @@ async def call_tool(name: str, arguments: dict) -> Union[ if not has_execute_permission: raise PermissionError(_ACCESS_DENIED_MSG) + # Check if this is a meta-tool call on a meta-server + current_server_type = server_type_var.get() + meta_service = get_meta_server_service() + if meta_service.is_meta_server(current_server_type) and meta_service.is_meta_tool(name): + # Dispatch to meta-tool stub handler + # Use actual_user_email (not RBAC-filtered user_email) so OAuth token lookup works + result_data = await meta_service.handle_meta_tool_call( + name, arguments, + user_email=actual_user_email, + token_teams=token_teams, + request_headers=request_headers, + ) + return [types.TextContent(type="text", text=orjson.dumps(result_data).decode())] + # Check if we're in direct_proxy mode by looking for X-Context-Forge-Gateway-Id header gateway_id_from_header = extract_gateway_id_from_headers(request_headers) @@ -2078,6 +2106,15 @@ async def list_tools() -> List[types.Tool]: if not settings.mcp_require_auth: await _check_server_oauth_enforcement(server_id, user_context) + # Check if this is a meta-server that should expose meta-tools instead + current_server_type = server_type_var.get() + current_hide_underlying = hide_underlying_tools_var.get() + meta_service = get_meta_server_service() + if meta_service.should_hide_underlying_tools(current_server_type, current_hide_underlying): + # Return meta-tools instead of underlying real tools + meta_tool_defs = meta_service.get_meta_tool_definitions() + return [types.Tool(name=td["name"], description=td["description"], inputSchema=td["inputSchema"]) for td in meta_tool_defs] + if server_id: try: async with get_db() as db: @@ -2525,20 +2562,23 @@ async def read_resource(resource_uri: str) -> Union[str, bytes]: return "" # Direct proxy mode: forward request to remote MCP server - # SECURITY: CWE-532 protection - Log only meta_data key names, NEVER values - # Metadata may contain PII, authentication tokens, or sensitive context that - # MUST NOT be written to logs. This is a critical security control. - logger.debug( - "Using direct_proxy mode for resources/read %s, server %s, gateway %s (from %s header), forwarding _meta keys: %s", - resource_uri, - server_id, - gateway.id, - GATEWAY_ID_HEADER, - sorted(meta_data.keys()) if meta_data else None, - ) - # CWE-400: validate _meta limits before network I/O (bypassed in direct-proxy branch) - _validate_meta_data(meta_data) - contents = await _proxy_read_resource_to_gateway(gateway, str(resource_uri), user_context, meta_data) + # Get _meta from request context if available + meta = None + try: + request_ctx = mcp_app.request_context + meta = request_ctx.meta + logger.info( + "Using direct_proxy mode for resources/read %s, server %s, gateway %s (from %s header), forwarding _meta: %s", + resource_uri, + server_id, + gateway.id, + GATEWAY_ID_HEADER, + meta, + ) + except (LookupError, AttributeError) as e: + logger.debug("No request context available for _meta extraction: %s", e) + + contents = await _proxy_read_resource_to_gateway(gateway, str(resource_uri), user_context, meta) if contents: # Return first content (text or blob) first_content = contents[0] @@ -4267,6 +4307,21 @@ async def handle_streamable_http( # noqa: PLR0911,PLR0912,PLR0915 — pylint: d server_id_var.set(validated) + # Load server metadata for meta-server tool hiding + if validated: + try: + from mcpgateway.db import Server as DbServer # pylint: disable=import-outside-toplevel + db = SessionLocal() + try: + srv = db.query(DbServer).filter(DbServer.id == validated).first() + if srv: + server_type_var.set(getattr(srv, "server_type", "standard") or "standard") + hide_underlying_tools_var.set(getattr(srv, "hide_underlying_tools", True)) + finally: + db.close() + except Exception as e: + logger.debug("Failed to load server metadata for meta-server: %s", e) + # For session affinity: wrap send to capture session ID from response headers # This allows us to register ownership for new sessions created by the SDK captured_session_id: Optional[str] = None diff --git a/mcpgateway/utils/pgvector.py b/mcpgateway/utils/pgvector.py new file mode 100644 index 0000000000..95bea652df --- /dev/null +++ b/mcpgateway/utils/pgvector.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +"""pgvector compatibility shim. + +Provides HAS_PGVECTOR flag and Vector type for optional pgvector support. +When pgvector is not installed, falls back to JSON column storage. +""" + +import logging + +logger = logging.getLogger(__name__) + +try: + from pgvector.sqlalchemy import Vector # type: ignore[import-untyped] + + HAS_PGVECTOR = True + logger.debug("pgvector extension available") +except ImportError: + HAS_PGVECTOR = False + Vector = None # type: ignore[assignment,misc] + logger.debug("pgvector not available, using JSON fallback for embeddings") diff --git a/tests/unit/mcpgateway/services/test_meta_tool_service.py b/tests/unit/mcpgateway/services/test_meta_tool_service.py new file mode 100644 index 0000000000..4785490490 --- /dev/null +++ b/tests/unit/mcpgateway/services/test_meta_tool_service.py @@ -0,0 +1,180 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/services/test_meta_tool_service.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Unit tests for the Meta-Tool Service with mocked database layer. +""" + +# Standard +from unittest.mock import AsyncMock, MagicMock, patch +import uuid + +# Third-Party +import pytest + +# First-Party +from mcpgateway.meta_server.schemas import DescribeToolResponse, ExecuteToolResponse +from mcpgateway.services.meta_tool_service import MetaToolService + + +class TestDescribeTool: + """Tests for describe_tool functionality.""" + + @pytest.mark.asyncio + async def test_describe_tool_success(self, test_db): + """Test successful tool description.""" + service = MetaToolService(test_db) + + # Create mock server + mock_server = MagicMock() + mock_server.id = "server-123" + mock_server.name = "test-server" + + # Create mock tool + mock_tool = MagicMock() + mock_tool.id = str(uuid.uuid4()) + mock_tool.name = "test_tool" + mock_tool.description = "Test tool description" + mock_tool.input_schema = {"type": "object", "properties": {"arg1": {"type": "string"}}} + mock_tool.output_schema = {"type": "object"} + mock_tool.tags = ["test", "sample"] + mock_tool.annotations = {"example": "data"} + mock_tool.servers = [mock_server] + + # Mock the _resolve_tool method + with patch.object(service, '_resolve_tool', new_callable=AsyncMock) as mock_resolve: + mock_resolve.return_value = mock_tool + + response = await service.describe_tool( + tool_name="test_tool", + include_metrics=False, + user_email="test@example.com", + token_teams=[], + is_admin=False, + scope=None, + ) + + assert isinstance(response, DescribeToolResponse) + assert response.name == "test_tool" + assert response.description == "Test tool description" + assert response.server_name == "test-server" + assert "test" in response.tags + + @pytest.mark.asyncio + async def test_describe_tool_not_found(self, test_db): + """Test describe_tool with non-existent tool.""" + service = MetaToolService(test_db) + + with patch.object(service, '_resolve_tool', new_callable=AsyncMock, return_value=None): + with pytest.raises(ValueError, match="Tool not found"): + await service.describe_tool( + tool_name="nonexistent_tool", + include_metrics=False, + user_email="test@example.com", + token_teams=[], + is_admin=False, + scope=None, + ) + + +class TestExecuteTool: + """Tests for execute_tool functionality.""" + + @pytest.mark.asyncio + async def test_execute_tool_validation_error_returns_400(self, test_db): + """Test execute_tool returns validation error for invalid arguments.""" + service = MetaToolService(test_db) + + # Create mock tool with strict schema + mock_tool = MagicMock() + mock_tool.id = str(uuid.uuid4()) + mock_tool.name = "strict_tool" + mock_tool.input_schema = { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + } + + with patch.object(service, '_resolve_tool', new_callable=AsyncMock, return_value=mock_tool): + # Missing required argument should raise ValueError + with pytest.raises(ValueError, match="Argument validation failed"): + await service.execute_tool( + tool_name="strict_tool", + arguments={}, # Missing 'name' + user_email="test@example.com", + token_teams=[], + is_admin=False, + scope=None, + ) + + @pytest.mark.asyncio + async def test_execute_tool_backend_error_surfaces_cleanly(self, test_db): + """Test backend errors are surfaced cleanly in response.""" + service = MetaToolService(test_db) + + # Create mock tool + mock_tool = MagicMock() + mock_tool.id = str(uuid.uuid4()) + mock_tool.name = "failing_tool" + mock_tool.input_schema = {} + + with patch.object(service, '_resolve_tool', new_callable=AsyncMock, return_value=mock_tool): + # Mock tool_service.invoke_tool to raise an exception + with patch.object(service.tool_service, 'invoke_tool', new_callable=AsyncMock) as mock_invoke: + mock_invoke.side_effect = Exception("Backend connection failed") + + response = await service.execute_tool( + tool_name="failing_tool", + arguments={}, + user_email="test@example.com", + token_teams=[], + is_admin=False, + scope=None, + ) + + assert response.success is False + assert response.error == "Backend connection failed" + assert response.execution_time_ms is not None + + @pytest.mark.asyncio + async def test_execute_tool_metadata_present(self, test_db): + """Test execution metadata is present in response.""" + service = MetaToolService(test_db) + + # Create mock tool + mock_tool = MagicMock() + mock_tool.id = str(uuid.uuid4()) + mock_tool.name = "meta_tool" + mock_tool.input_schema = {} + + # Create mock result + mock_result = MagicMock() + mock_result.isError = False + mock_content = MagicMock() + mock_content.text = "success" + mock_result.content = [mock_content] + + with patch.object(service, '_resolve_tool', new_callable=AsyncMock, return_value=mock_tool): + with patch.object(service.tool_service, 'invoke_tool', new_callable=AsyncMock, return_value=mock_result) as mock_invoke: + response = await service.execute_tool( + tool_name="meta_tool", + arguments={}, + user_email="test@example.com", + token_teams=[], + is_admin=False, + scope=None, + ) + + # Verify metadata + assert response.tool_name == "meta_tool" + assert response.execution_time_ms is not None + assert isinstance(response.execution_time_ms, (int, float)) + assert response.execution_time_ms >= 0 + + # Verify invoke_tool was called with proper metadata + mock_invoke.assert_called_once() + call_kwargs = mock_invoke.call_args.kwargs + assert "meta_data" in call_kwargs + assert call_kwargs["meta_data"]["meta_tool"] == "execute_tool" + assert "request_id" in call_kwargs["meta_data"] diff --git a/tests/unit/mcpgateway/test_meta_server.py b/tests/unit/mcpgateway/test_meta_server.py new file mode 100644 index 0000000000..824cbff87e --- /dev/null +++ b/tests/unit/mcpgateway/test_meta_server.py @@ -0,0 +1,2259 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/test_meta_server.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Unit tests for the Meta-Server feature. + +Tests cover: +- Meta-server schema validation (ServerType, MetaToolScope, MetaConfig) +- Meta-tool request/response schema contracts +- Meta-server creation with server_type='meta' +- Config validation (limits, ranges) +- Meta-tools appearing when server_type == 'meta' +- Underlying tools hidden when hide_underlying_tools is enabled +- MetaServerService stub handlers +- search_tools: hybrid semantic + keyword search, merge, ranking, scope, pagination +- get_similar_tools: vector similarity with self-filtering and scope +- _apply_scope_filtering: all 7 scope fields with AND semantics +- Helper methods: _get_tool_metadata, _get_tools_matching_tags, _map_to_tool_summaries +""" + +# Standard +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +# Third-Party +import pytest +from pydantic import ValidationError + +# First-Party +from mcpgateway.meta_server.schemas import ( + DescribeToolRequest, + DescribeToolResponse, + ExecuteToolRequest, + ExecuteToolResponse, + GetPromptRequest, + GetPromptResponse, + GetSimilarToolsRequest, + GetSimilarToolsResponse, + GetToolCategoriesRequest, + GetToolCategoriesResponse, + ListPromptsRequest, + ListPromptsResponse, + ListResourcesRequest, + ListResourcesResponse, + ListToolsRequest, + ListToolsResponse, + META_TOOL_DEFINITIONS, + MetaConfig, + MetaToolScope, + ReadResourceRequest, + ReadResourceResponse, + SearchToolsRequest, + SearchToolsResponse, + ServerType, + ToolSummary, +) +from mcpgateway.meta_server.service import MetaServerService, get_meta_server_service +from mcpgateway.schemas import ToolSearchResult + +# ServerType Enum Tests + +class TestServerType: + """Tests for the ServerType enum.""" + + def test_standard_value(self): + """Test standard server type value.""" + assert ServerType.STANDARD.value == "standard" + + def test_meta_value(self): + """Test meta server type value.""" + assert ServerType.META.value == "meta" + + def test_from_string_meta(self): + """Test creating ServerType from string 'meta'.""" + assert ServerType("meta") == ServerType.META + + def test_from_string_standard(self): + """Test creating ServerType from string 'standard'.""" + assert ServerType("standard") == ServerType.STANDARD + + def test_invalid_type_raises(self): + """Test that invalid server type raises ValueError.""" + with pytest.raises(ValueError): + ServerType("invalid") + +# MetaToolScope Tests + +class TestMetaToolScope: + """Tests for the MetaToolScope configuration model.""" + + def test_default_scope(self): + """Test that default scope has empty lists.""" + scope = MetaToolScope() + assert scope.include_tags == [] + assert scope.exclude_tags == [] + assert scope.include_servers == [] + assert scope.exclude_servers == [] + assert scope.include_visibility == [] + assert scope.include_teams == [] + assert scope.name_patterns == [] + + def test_scope_with_tags(self): + """Test scope with tag filters.""" + scope = MetaToolScope(include_tags=["prod", "stable"], exclude_tags=["deprecated"]) + assert scope.include_tags == ["prod", "stable"] + assert scope.exclude_tags == ["deprecated"] + + def test_scope_with_servers(self): + """Test scope with server filters.""" + scope = MetaToolScope(include_servers=["s1", "s2"], exclude_servers=["s3"]) + assert scope.include_servers == ["s1", "s2"] + assert scope.exclude_servers == ["s3"] + + def test_scope_with_visibility(self): + """Test scope with valid visibility values.""" + scope = MetaToolScope(include_visibility=["public", "team"]) + assert scope.include_visibility == ["public", "team"] + + def test_scope_invalid_visibility_raises(self): + """Test that invalid visibility value raises ValidationError.""" + with pytest.raises(ValidationError): + MetaToolScope(include_visibility=["invalid_level"]) + + def test_scope_serialization(self): + """Test scope serializes correctly with camelCase aliases.""" + scope = MetaToolScope(include_tags=["test"], name_patterns=["db_*"]) + data = scope.model_dump(by_alias=True) + assert "includeTags" in data + assert "namePatterns" in data + assert data["includeTags"] == ["test"] + + def test_scope_with_teams(self): + """Test scope with team filters.""" + scope = MetaToolScope(include_teams=["team-1", "team-2"]) + assert scope.include_teams == ["team-1", "team-2"] + + def test_scope_name_patterns(self): + """Test scope with name patterns.""" + scope = MetaToolScope(name_patterns=["db_*", "*_tool"]) + assert scope.name_patterns == ["db_*", "*_tool"] + + +# MetaConfig Tests + +class TestMetaConfig: + """Tests for the MetaConfig configuration model.""" + + def test_default_config(self): + """Test default config values.""" + config = MetaConfig() + assert config.enable_semantic_search is False + assert config.enable_categories is False + assert config.enable_similar_tools is False + assert config.default_search_limit == 50 + assert config.max_search_limit == 200 + assert config.include_metrics_in_search is False + + def test_custom_config(self): + """Test custom config values.""" + config = MetaConfig( + enable_semantic_search=True, + enable_categories=True, + enable_similar_tools=True, + default_search_limit=25, + max_search_limit=500, + include_metrics_in_search=True, + ) + assert config.enable_semantic_search is True + assert config.default_search_limit == 25 + assert config.max_search_limit == 500 + + def test_config_search_limit_range(self): + """Test that default_search_limit respects range constraints.""" + with pytest.raises(ValidationError): + MetaConfig(default_search_limit=0) # Must be >= 1 + + def test_config_max_search_limit_range(self): + """Test that max_search_limit respects range constraints.""" + with pytest.raises(ValidationError): + MetaConfig(max_search_limit=0) # Must be >= 1 + + def test_config_max_less_than_default_raises(self): + """Test that max_search_limit < default_search_limit raises ValidationError.""" + with pytest.raises(ValidationError): + MetaConfig(default_search_limit=100, max_search_limit=50) + + def test_config_serialization(self): + """Test config serializes correctly with camelCase aliases.""" + config = MetaConfig(enable_semantic_search=True) + data = config.model_dump(by_alias=True) + assert "enableSemanticSearch" in data + assert data["enableSemanticSearch"] is True + + def test_config_max_equals_default(self): + """Test that max_search_limit == default_search_limit is valid.""" + config = MetaConfig(default_search_limit=100, max_search_limit=100) + assert config.max_search_limit == 100 + + +# Meta-Tool Request/Response Schema Tests + +class TestSearchToolsSchemas: + """Tests for search_tools request/response schemas.""" + + def test_request_minimal(self): + """Test minimal search request.""" + req = SearchToolsRequest(query="database") + assert req.query == "database" + assert req.limit == 50 + assert req.offset == 0 + + def test_request_with_all_fields(self): + """Test search request with all fields.""" + req = SearchToolsRequest(query="test", limit=10, offset=5, tags=["db"], include_metrics=True) + assert req.limit == 10 + assert req.tags == ["db"] + + def test_request_empty_query_raises(self): + """Test that empty query raises ValidationError.""" + with pytest.raises(ValidationError): + SearchToolsRequest(query="") + + def test_response_empty(self): + """Test empty search response.""" + resp = SearchToolsResponse(tools=[], total_count=0, query="test", has_more=False) + assert resp.total_count == 0 + assert resp.has_more is False + + +class TestListToolsSchemas: + """Tests for list_tools request/response schemas.""" + + def test_request_defaults(self): + """Test list request defaults.""" + req = ListToolsRequest() + assert req.limit == 50 + assert req.offset == 0 + + def test_response_with_tools(self): + """Test list response with tool summaries.""" + tool = ToolSummary(name="my_tool", description="A test tool", server_id="s1", server_name="Server 1") + resp = ListToolsResponse(tools=[tool], total_count=1, has_more=False) + assert len(resp.tools) == 1 + assert resp.tools[0].name == "my_tool" + + +class TestDescribeToolSchemas: + """Tests for describe_tool request/response schemas.""" + + def test_request(self): + """Test describe request.""" + req = DescribeToolRequest(tool_name="query_db") + assert req.tool_name == "query_db" + + def test_request_empty_name_raises(self): + """Test that empty tool_name raises ValidationError.""" + with pytest.raises(ValidationError): + DescribeToolRequest(tool_name="") + + def test_response(self): + """Test describe response.""" + resp = DescribeToolResponse(name="query_db", description="Run SQL queries") + assert resp.name == "query_db" + assert resp.input_schema is None + + +class TestExecuteToolSchemas: + """Tests for execute_tool request/response schemas.""" + + def test_request(self): + """Test execute request.""" + req = ExecuteToolRequest(tool_name="query_db", arguments={"sql": "SELECT 1"}) + assert req.tool_name == "query_db" + assert req.arguments["sql"] == "SELECT 1" + + def test_response_success(self): + """Test successful execute response.""" + resp = ExecuteToolResponse(tool_name="query_db", success=True, result={"rows": []}) + assert resp.success is True + assert resp.error is None + + def test_response_failure(self): + """Test failed execute response.""" + resp = ExecuteToolResponse(tool_name="query_db", success=False, error="Connection failed") + assert resp.success is False + assert resp.error == "Connection failed" + + +class TestGetToolCategoriesSchemas: + """Tests for get_tool_categories request/response schemas.""" + + def test_request_defaults(self): + """Test categories request defaults.""" + req = GetToolCategoriesRequest() + assert req.include_counts is True + + def test_response_empty(self): + """Test empty categories response.""" + resp = GetToolCategoriesResponse(categories=[], total_categories=0) + assert resp.total_categories == 0 + + +class TestGetSimilarToolsSchemas: + """Tests for get_similar_tools request/response schemas.""" + + def test_request(self): + """Test similar tools request.""" + req = GetSimilarToolsRequest(tool_name="query_db", limit=5) + assert req.tool_name == "query_db" + assert req.limit == 5 + + def test_response_empty(self): + """Test empty similar tools response.""" + resp = GetSimilarToolsResponse(reference_tool="query_db", similar_tools=[], total_found=0) + assert resp.reference_tool == "query_db" + assert resp.total_found == 0 + + +# META_TOOL_DEFINITIONS Tests + +class TestMetaToolDefinitions: + """Tests for the META_TOOL_DEFINITIONS registry.""" + + def test_all_six_tools_defined(self): + """Test that all six meta-tools are defined.""" + expected = {"search_tools", "list_tools", "describe_tool", "execute_tool", "get_tool_categories", "get_similar_tools"} + assert set(META_TOOL_DEFINITIONS.keys()) == expected + + def test_each_has_description(self): + """Test that each meta-tool has a description.""" + for name, defn in META_TOOL_DEFINITIONS.items(): + assert "description" in defn, f"{name} missing description" + assert isinstance(defn["description"], str) + + def test_each_has_input_schema(self): + """Test that each meta-tool has an input_schema.""" + for name, defn in META_TOOL_DEFINITIONS.items(): + assert "input_schema" in defn, f"{name} missing input_schema" + assert isinstance(defn["input_schema"], dict) + + +# MetaServerService Tests + +class TestMetaServerService: + """Tests for the MetaServerService.""" + + def test_get_meta_tool_definitions(self): + """Test that meta-tool definitions are returned correctly.""" + service = MetaServerService() + defs = service.get_meta_tool_definitions() + assert len(defs) == 12 + names = {d["name"] for d in defs} + assert "search_tools" in names + assert "execute_tool" in names + assert "list_resources" in names + assert "authorize_gateway" in names + + def test_is_meta_server(self): + """Test is_meta_server check.""" + service = MetaServerService() + assert service.is_meta_server("meta") is True + assert service.is_meta_server("standard") is False + assert service.is_meta_server(None) is False + + def test_should_hide_underlying_tools(self): + """Test should_hide_underlying_tools logic.""" + service = MetaServerService() + assert service.should_hide_underlying_tools("meta", True) is True + assert service.should_hide_underlying_tools("meta", False) is False + assert service.should_hide_underlying_tools("standard", True) is False + assert service.should_hide_underlying_tools(None, True) is False + + def test_is_meta_tool(self): + """Test is_meta_tool check.""" + service = MetaServerService() + assert service.is_meta_tool("search_tools") is True + assert service.is_meta_tool("list_tools") is True + assert service.is_meta_tool("some_random_tool") is False + + def test_stub_search_tools(self): + """Test search_tools returns empty results when both search sources return nothing.""" + service = MetaServerService() + mock_semantic = AsyncMock() + mock_semantic.search_tools = AsyncMock(return_value=[]) + + def mock_get_db(): + db = MagicMock() + db.query.return_value.filter.return_value.limit.return_value.all.return_value = [] + yield db + + with ( + patch("mcpgateway.meta_server.service.get_semantic_search_service", return_value=mock_semantic), + patch("mcpgateway.meta_server.service.get_db", mock_get_db), + ): + result = asyncio.run(service.handle_meta_tool_call("search_tools", {"query": "database"})) + assert result["query"] == "database" + assert result["tools"] == [] + assert result["totalCount"] == 0 + + def test_list_tools_returns_empty_when_no_tools(self): + """Test list_tools returns empty results when no tools exist.""" + service = MetaServerService() + + def mock_get_db(): + db = MagicMock() + yield db + + # Mock ToolService.list_tools to return empty list + from mcpgateway.services.tool_service import ToolService + + with ( + patch("mcpgateway.meta_server.service.get_db", mock_get_db), + patch.object(ToolService, "list_tools", new_callable=AsyncMock, return_value=([], None)), + ): + result = asyncio.run(service.handle_meta_tool_call("list_tools", {})) + + assert result["tools"] == [] + assert result["totalCount"] == 0 + assert result["hasMore"] is False + + def test_stub_describe_tool(self): + """Test describe_tool stub returns placeholder response.""" + service = MetaServerService() + + def mock_get_db(): + db = MagicMock() + yield db + + mock_response = DescribeToolResponse(name="my_tool", description="Stub description for my_tool") + + with ( + patch("mcpgateway.meta_server.service.get_db", mock_get_db), + patch("mcpgateway.services.meta_tool_service.MetaToolService.describe_tool", new_callable=AsyncMock, return_value=mock_response), + ): + result = asyncio.run(service.handle_meta_tool_call("describe_tool", {"tool_name": "my_tool"})) + assert result["name"] == "my_tool" + assert "Stub description" in result["description"] + + def test_stub_execute_tool(self): + """Test execute_tool stub returns not-implemented response.""" + service = MetaServerService() + + def mock_get_db(): + db = MagicMock() + yield db + + mock_response = ExecuteToolResponse(tool_name="my_tool", success=False, error="This action is not yet implemented") + + with ( + patch("mcpgateway.meta_server.service.get_db", mock_get_db), + patch("mcpgateway.services.meta_tool_service.MetaToolService.execute_tool", new_callable=AsyncMock, return_value=mock_response), + ): + result = asyncio.run(service.handle_meta_tool_call("execute_tool", {"tool_name": "my_tool"})) + assert result["toolName"] == "my_tool" + assert result["success"] is False + assert "not yet implemented" in result["error"] + + def test_stub_get_tool_categories(self): + """Test get_tool_categories stub returns placeholder response.""" + service = MetaServerService() + result = asyncio.run(service.handle_meta_tool_call("get_tool_categories", {})) + assert result["categories"] == [] + assert result["totalCategories"] == 0 + + def test_stub_get_similar_tools(self): + """Test get_similar_tools returns empty when tool not found in DB.""" + service = MetaServerService() + + def mock_get_db(): + db = MagicMock() + db.query.return_value.filter.return_value.first.return_value = None + yield db + + with patch("mcpgateway.meta_server.service.get_db", mock_get_db): + result = asyncio.run(service.handle_meta_tool_call("get_similar_tools", {"tool_name": "my_tool"})) + assert result["referenceTool"] == "my_tool" + assert result["similarTools"] == [] + + def test_unknown_meta_tool_raises(self): + """Test that unknown meta-tool name raises ValueError.""" + service = MetaServerService() + with pytest.raises(ValueError, match="Unknown meta-tool"): + asyncio.run(service.handle_meta_tool_call("nonexistent_tool", {})) + + def test_singleton_service(self): + """Test that get_meta_server_service returns a singleton.""" + s1 = get_meta_server_service() + s2 = get_meta_server_service() + assert s1 is s2 + + +# Server Schema Integration Tests (ServerCreate with server_type) + +class TestServerCreateMetaType: + """Tests for ServerCreate schema with meta server type support.""" + + def test_default_server_type(self): + """Test that default server_type is 'standard'.""" + from mcpgateway.schemas import ServerCreate + + server = ServerCreate(name="Test Server") + assert server.server_type == "standard" + + def test_meta_server_type(self): + """Test creating a server with type 'meta'.""" + from mcpgateway.schemas import ServerCreate + + server = ServerCreate(name="Meta Server", server_type="meta") + assert server.server_type == "meta" + + def test_invalid_server_type_raises(self): + """Test that invalid server_type raises ValidationError.""" + from mcpgateway.schemas import ServerCreate + + with pytest.raises(ValidationError): + ServerCreate(name="Bad Server", server_type="invalid") + + def test_hide_underlying_tools_default(self): + """Test that hide_underlying_tools defaults to True.""" + from mcpgateway.schemas import ServerCreate + + server = ServerCreate(name="Test Server") + assert server.hide_underlying_tools is True + + def test_meta_config_field(self): + """Test that meta_config can be set.""" + from mcpgateway.schemas import ServerCreate + + config = {"enable_semantic_search": True, "default_search_limit": 25} + server = ServerCreate(name="Meta Server", server_type="meta", meta_config=config) + assert server.meta_config == config + + def test_meta_scope_field(self): + """Test that meta_scope can be set.""" + from mcpgateway.schemas import ServerCreate + + scope = {"include_tags": ["production"], "exclude_servers": ["legacy"]} + server = ServerCreate(name="Meta Server", server_type="meta", meta_scope=scope) + assert server.meta_scope == scope + + +class TestServerUpdateMetaType: + """Tests for ServerUpdate schema with meta server type support.""" + + def test_update_server_type(self): + """Test updating server_type.""" + from mcpgateway.schemas import ServerUpdate + + update = ServerUpdate(server_type="meta") + assert update.server_type == "meta" + + def test_update_invalid_server_type_raises(self): + """Test that invalid server_type raises ValidationError on update.""" + from mcpgateway.schemas import ServerUpdate + + with pytest.raises(ValidationError): + ServerUpdate(server_type="bad_type") + + def test_update_meta_config(self): + """Test updating meta_config.""" + from mcpgateway.schemas import ServerUpdate + + update = ServerUpdate(meta_config={"enable_categories": True}) + assert update.meta_config == {"enable_categories": True} + + +class TestServerReadMetaFields: + """Tests for ServerRead schema meta-server fields.""" + + def test_read_defaults(self): + """Test that ServerRead has correct meta field defaults.""" + from datetime import datetime, timezone + + from mcpgateway.schemas import ServerRead + + now = datetime.now(timezone.utc) + read = ServerRead( + id="test-id", + name="Test Server", + description=None, + icon=None, + created_at=now, + updated_at=now, + enabled=True, + ) + assert read.server_type == "standard" + assert read.hide_underlying_tools is True + assert read.meta_config is None + assert read.meta_scope is None + + def test_read_meta_server(self): + """Test ServerRead with meta server fields populated.""" + from datetime import datetime, timezone + + from mcpgateway.schemas import ServerRead + + now = datetime.now(timezone.utc) + read = ServerRead( + id="test-id", + name="Meta Server", + description="A meta server", + icon=None, + created_at=now, + updated_at=now, + enabled=True, + server_type="meta", + hide_underlying_tools=True, + meta_config={"enable_semantic_search": True}, + meta_scope={"include_tags": ["production"]}, + ) + assert read.server_type == "meta" + assert read.hide_underlying_tools is True + assert read.meta_config["enable_semantic_search"] is True + assert read.meta_scope["include_tags"] == ["production"] + + +# DB Model Integration Tests + +class TestServerDBModelMetaFields: + """Tests for Server DB model meta-server fields.""" + + def test_server_db_has_meta_fields(self, test_db): + """Test that Server DB model has meta-server columns.""" + from mcpgateway.db import Server as DbServer + + server = DbServer( + name="Meta Test Server", + server_type="meta", + hide_underlying_tools=True, + meta_config={"enable_categories": True}, + meta_scope={"include_tags": ["test"]}, + ) + test_db.add(server) + test_db.commit() + test_db.refresh(server) + + assert server.server_type == "meta" + assert server.hide_underlying_tools is True + assert server.meta_config == {"enable_categories": True} + assert server.meta_scope == {"include_tags": ["test"]} + + def test_server_db_default_type_standard(self, test_db): + """Test that Server DB model defaults to server_type='standard'.""" + from mcpgateway.db import Server as DbServer + + server = DbServer(name="Standard Server") + test_db.add(server) + test_db.commit() + test_db.refresh(server) + + assert server.server_type == "standard" + assert server.hide_underlying_tools is True # Default True + assert server.meta_config is None + assert server.meta_scope is None + + +# --------------------------------------------------------------------------- +# Helpers for mocking DB and services used by search/similar +# --------------------------------------------------------------------------- + +def _make_tool_search_result(name, description="desc", server_id="s1", server_name="Server1", score=0.8): + """Shorthand factory for ToolSearchResult.""" + return ToolSearchResult( + tool_name=name, + description=description, + server_id=server_id, + server_name=server_name, + similarity_score=score, + ) + + +def _make_mock_tool(name, description="desc", gateway_id="s1", tags=None, visibility="public", team_id=None, enabled=True, input_schema=None): + """Create a mock Tool ORM object.""" + tool = MagicMock() + tool.name = name + tool._computed_name = name + tool.description = description + tool.gateway_id = gateway_id + tool.gateway = SimpleNamespace(name="Server1") + tool.tags = tags or [] + tool.visibility = visibility + tool.team_id = team_id + tool.enabled = enabled + tool.input_schema = input_schema + tool.id = f"id-{name}" + return tool + + +def _mock_get_db_with_tools(tools): + """Return a mock get_db generator that supports query().filter().* patterns. + + The mock DB handles several query patterns used across the service: + - .filter(...).limit(...).all() → returns tools (keyword search) + - .filter(...).all() → returns tools (metadata / tag queries) + - .filter(...).first() → returns first tool or None (tool lookup) + """ + def mock_get_db(): + db = MagicMock() + query = db.query.return_value + + # Chain .filter() calls (supports multiple chained filters) + filter_mock = MagicMock() + query.filter.return_value = filter_mock + filter_mock.filter.return_value = filter_mock # support chained .filter().filter() + + # .limit().all() for keyword search + filter_mock.limit.return_value.all.return_value = tools + # .all() for metadata / tag queries + filter_mock.all.return_value = tools + # .first() for single-tool lookup + filter_mock.first.return_value = tools[0] if tools else None + + yield db + + return mock_get_db + + +# --------------------------------------------------------------------------- +# search_tools comprehensive tests +# --------------------------------------------------------------------------- + +class TestSearchToolsImplementation: + """Comprehensive tests for the _search_tools implementation.""" + + def test_search_tools_semantic_results_returned(self): + """Test that semantic search results are included in response.""" + service = MetaServerService() + semantic_results = [ + _make_tool_search_result("tool_a", score=0.9), + _make_tool_search_result("tool_b", score=0.7), + ] + mock_semantic = AsyncMock() + mock_semantic.search_tools = AsyncMock(return_value=semantic_results) + + mock_tools = [_make_mock_tool("tool_a"), _make_mock_tool("tool_b")] + + with ( + patch("mcpgateway.meta_server.service.get_semantic_search_service", return_value=mock_semantic), + patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(mock_tools)), + ): + result = asyncio.run(service.handle_meta_tool_call("search_tools", {"query": "test"})) + + assert result["query"] == "test" + assert result["totalCount"] == 2 + assert len(result["tools"]) == 2 + + def test_search_tools_keyword_fallback_when_semantic_fails(self): + """Test keyword search works when semantic search raises an exception.""" + service = MetaServerService() + mock_semantic = AsyncMock() + mock_semantic.search_tools = AsyncMock(side_effect=RuntimeError("Embedding service down")) + + mock_tools = [_make_mock_tool("db_query", description="Query a database")] + + with ( + patch("mcpgateway.meta_server.service.get_semantic_search_service", return_value=mock_semantic), + patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(mock_tools)), + ): + result = asyncio.run(service.handle_meta_tool_call("search_tools", {"query": "db_query"})) + + # Keyword fallback should still produce results + assert result["totalCount"] >= 1 + tool_names = [t["name"] for t in result["tools"]] + assert "db_query" in tool_names + + def test_search_tools_both_fail_returns_empty(self): + """Test that when both semantic and keyword search fail, empty results returned.""" + service = MetaServerService() + mock_semantic = AsyncMock() + mock_semantic.search_tools = AsyncMock(side_effect=RuntimeError("fail")) + + def broken_get_db(): + raise RuntimeError("DB down") + yield # noqa: unreachable - needed to make it a generator + + with ( + patch("mcpgateway.meta_server.service.get_semantic_search_service", return_value=mock_semantic), + patch("mcpgateway.meta_server.service.get_db", broken_get_db), + ): + result = asyncio.run(service.handle_meta_tool_call("search_tools", {"query": "anything"})) + + assert result["tools"] == [] + assert result["totalCount"] == 0 + + def test_search_tools_merge_dedup_keeps_higher_score(self): + """Test that duplicates are merged keeping the higher score.""" + service = MetaServerService() + semantic_results = [_make_tool_search_result("shared_tool", score=0.9)] + mock_semantic = AsyncMock() + mock_semantic.search_tools = AsyncMock(return_value=semantic_results) + + # Keyword search also finds "shared_tool" with a lower score + mock_tools = [_make_mock_tool("shared_tool")] + + with ( + patch("mcpgateway.meta_server.service.get_semantic_search_service", return_value=mock_semantic), + patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(mock_tools)), + ): + result = asyncio.run(service.handle_meta_tool_call("search_tools", {"query": "shared_tool"})) + + # Should have one result, not two + assert result["totalCount"] == 1 + assert len(result["tools"]) == 1 + assert result["tools"][0]["name"] == "shared_tool" + + def test_search_tools_ranking_descending_by_score(self): + """Test that results are sorted descending by similarity score.""" + service = MetaServerService() + semantic_results = [ + _make_tool_search_result("low_score", score=0.3), + _make_tool_search_result("high_score", score=0.95), + _make_tool_search_result("mid_score", score=0.6), + ] + mock_semantic = AsyncMock() + mock_semantic.search_tools = AsyncMock(return_value=semantic_results) + + mock_tools = [ + _make_mock_tool("low_score"), + _make_mock_tool("high_score"), + _make_mock_tool("mid_score"), + ] + + with ( + patch("mcpgateway.meta_server.service.get_semantic_search_service", return_value=mock_semantic), + patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(mock_tools)), + ): + result = asyncio.run(service.handle_meta_tool_call("search_tools", {"query": "test"})) + + names = [t["name"] for t in result["tools"]] + assert names == ["high_score", "mid_score", "low_score"] + + def test_search_tools_pagination_offset_and_limit(self): + """Test pagination with offset and limit.""" + service = MetaServerService() + # Create 5 results + semantic_results = [ + _make_tool_search_result(f"tool_{i}", score=1.0 - i * 0.1) + for i in range(5) + ] + mock_semantic = AsyncMock() + mock_semantic.search_tools = AsyncMock(return_value=semantic_results) + + mock_tools = [_make_mock_tool(f"tool_{i}") for i in range(5)] + + with ( + patch("mcpgateway.meta_server.service.get_semantic_search_service", return_value=mock_semantic), + patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(mock_tools)), + ): + result = asyncio.run(service.handle_meta_tool_call("search_tools", { + "query": "test", "limit": 2, "offset": 1, + })) + + assert result["totalCount"] == 5 + assert len(result["tools"]) == 2 + assert result["hasMore"] is True + + def test_search_tools_pagination_no_more_results(self): + """Test has_more is False when all results fit.""" + service = MetaServerService() + semantic_results = [_make_tool_search_result("tool_a", score=0.8)] + mock_semantic = AsyncMock() + mock_semantic.search_tools = AsyncMock(return_value=semantic_results) + + mock_tools = [_make_mock_tool("tool_a")] + + with ( + patch("mcpgateway.meta_server.service.get_semantic_search_service", return_value=mock_semantic), + patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(mock_tools)), + ): + result = asyncio.run(service.handle_meta_tool_call("search_tools", { + "query": "test", "limit": 50, "offset": 0, + })) + + assert result["hasMore"] is False + assert result["totalCount"] == 1 + + def test_search_tools_pagination_offset_beyond_results(self): + """Test offset beyond total results returns empty tools list.""" + service = MetaServerService() + semantic_results = [_make_tool_search_result("tool_a", score=0.8)] + mock_semantic = AsyncMock() + mock_semantic.search_tools = AsyncMock(return_value=semantic_results) + + mock_tools = [_make_mock_tool("tool_a")] + + with ( + patch("mcpgateway.meta_server.service.get_semantic_search_service", return_value=mock_semantic), + patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(mock_tools)), + ): + result = asyncio.run(service.handle_meta_tool_call("search_tools", { + "query": "test", "limit": 10, "offset": 100, + })) + + assert result["tools"] == [] + assert result["totalCount"] == 1 + assert result["hasMore"] is False + + def test_search_tools_tag_filter(self): + """Test tag filtering narrows results to tools with matching tags.""" + service = MetaServerService() + semantic_results = [ + _make_tool_search_result("tagged_tool", score=0.9), + _make_tool_search_result("untagged_tool", score=0.8), + ] + mock_semantic = AsyncMock() + mock_semantic.search_tools = AsyncMock(return_value=semantic_results) + + mock_tools = [ + _make_mock_tool("tagged_tool", tags=["database"]), + _make_mock_tool("untagged_tool", tags=[]), + ] + + with ( + patch("mcpgateway.meta_server.service.get_semantic_search_service", return_value=mock_semantic), + patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(mock_tools)), + ): + result = asyncio.run(service.handle_meta_tool_call("search_tools", { + "query": "test", "tags": ["database"], + })) + + tool_names = [t["name"] for t in result["tools"]] + assert "tagged_tool" in tool_names + assert "untagged_tool" not in tool_names + + def test_search_tools_scope_filtering_applied(self): + """Test that scope filtering is applied to search results.""" + service = MetaServerService() + semantic_results = [ + _make_tool_search_result("public_tool", server_id="s1", score=0.9), + _make_tool_search_result("private_tool", server_id="s2", score=0.8), + ] + mock_semantic = AsyncMock() + mock_semantic.search_tools = AsyncMock(return_value=semantic_results) + + mock_tools = [ + _make_mock_tool("public_tool", visibility="public"), + _make_mock_tool("private_tool", visibility="private"), + ] + + with ( + patch("mcpgateway.meta_server.service.get_semantic_search_service", return_value=mock_semantic), + patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(mock_tools)), + ): + result = asyncio.run(service.handle_meta_tool_call("search_tools", { + "query": "test", + "scope": {"include_visibility": ["public"]}, + })) + + tool_names = [t["name"] for t in result["tools"]] + assert "public_tool" in tool_names + assert "private_tool" not in tool_names + + def test_search_tools_keyword_exact_match_scores_highest(self): + """Test keyword search gives 1.0 score for exact name match.""" + service = MetaServerService() + mock_semantic = AsyncMock() + mock_semantic.search_tools = AsyncMock(return_value=[]) + + exact_tool = _make_mock_tool("db_query") + partial_tool = _make_mock_tool("db_query_extended") + + with ( + patch("mcpgateway.meta_server.service.get_semantic_search_service", return_value=mock_semantic), + patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools([exact_tool, partial_tool])), + ): + result = asyncio.run(service.handle_meta_tool_call("search_tools", {"query": "db_query"})) + + # Exact match should be first (score 1.0 > 0.7) + if len(result["tools"]) >= 2: + assert result["tools"][0]["name"] == "db_query" + + def test_search_tools_include_metrics_parameter_passed(self): + """Test that include_metrics is forwarded correctly.""" + service = MetaServerService() + semantic_results = [_make_tool_search_result("tool_a", score=0.9)] + mock_semantic = AsyncMock() + mock_semantic.search_tools = AsyncMock(return_value=semantic_results) + + mock_tools = [_make_mock_tool("tool_a")] + + with ( + patch("mcpgateway.meta_server.service.get_semantic_search_service", return_value=mock_semantic), + patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(mock_tools)), + ): + result = asyncio.run(service.handle_meta_tool_call("search_tools", { + "query": "test", "include_metrics": True, + })) + + # Metrics are currently None (TODO), but the call should succeed + assert len(result["tools"]) == 1 + + def test_search_tools_response_is_camel_case(self): + """Test response uses camelCase aliases for serialization.""" + service = MetaServerService() + mock_semantic = AsyncMock() + mock_semantic.search_tools = AsyncMock(return_value=[]) + + with ( + patch("mcpgateway.meta_server.service.get_semantic_search_service", return_value=mock_semantic), + patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools([])), + ): + result = asyncio.run(service.handle_meta_tool_call("search_tools", {"query": "x"})) + + assert "totalCount" in result + assert "hasMore" in result + assert "query" in result + assert "tools" in result + + +# --------------------------------------------------------------------------- +# get_similar_tools comprehensive tests +# --------------------------------------------------------------------------- + +class TestGetSimilarToolsImplementation: + """Comprehensive tests for the _get_similar_tools implementation.""" + + def test_similar_tools_empty_tool_name_returns_empty(self): + """Test that empty tool_name returns empty results immediately.""" + service = MetaServerService() + result = asyncio.run(service.handle_meta_tool_call("get_similar_tools", {"tool_name": ""})) + assert result["referenceTool"] == "" + assert result["similarTools"] == [] + assert result["totalFound"] == 0 + + def test_similar_tools_tool_not_found(self): + """Test that a non-existent reference tool returns empty results.""" + service = MetaServerService() + + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools([])): + result = asyncio.run(service.handle_meta_tool_call("get_similar_tools", {"tool_name": "nonexistent"})) + + assert result["referenceTool"] == "nonexistent" + assert result["similarTools"] == [] + assert result["totalFound"] == 0 + + def test_similar_tools_no_embedding_returns_empty(self): + """Test that tool without embedding returns empty results.""" + service = MetaServerService() + ref_tool = _make_mock_tool("my_tool") + + call_count = [0] + + def mock_get_db(): + call_count[0] += 1 + db = MagicMock() + query = db.query.return_value + filter_mock = MagicMock() + query.filter.return_value = filter_mock + filter_mock.filter.return_value = filter_mock + + if call_count[0] == 1: + # First call: resolve reference tool + filter_mock.first.return_value = ref_tool + else: + # Second call: get embedding — return None + pass + yield db + + mock_vector_service = MagicMock() + mock_vector_service.get_tool_embedding.return_value = None + + with ( + patch("mcpgateway.meta_server.service.get_db", mock_get_db), + patch("mcpgateway.meta_server.service.VectorSearchService", return_value=mock_vector_service), + ): + result = asyncio.run(service.handle_meta_tool_call("get_similar_tools", {"tool_name": "my_tool"})) + + assert result["similarTools"] == [] + assert result["totalFound"] == 0 + + def test_similar_tools_filters_out_reference_tool(self): + """Test that the reference tool itself is excluded from similar results.""" + service = MetaServerService() + ref_tool = _make_mock_tool("my_tool") + + similar_results = [ + _make_tool_search_result("my_tool", score=1.0), # self — should be filtered + _make_tool_search_result("similar_tool_a", score=0.9), + _make_tool_search_result("similar_tool_b", score=0.8), + ] + + call_count = [0] + + def mock_get_db(): + call_count[0] += 1 + db = MagicMock() + query = db.query.return_value + filter_mock = MagicMock() + query.filter.return_value = filter_mock + filter_mock.filter.return_value = filter_mock + filter_mock.first.return_value = ref_tool + filter_mock.all.return_value = [ + _make_mock_tool("similar_tool_a"), + _make_mock_tool("similar_tool_b"), + ] + yield db + + mock_embedding = MagicMock() + mock_embedding.embedding = [0.1] * 128 + + mock_vector_service = MagicMock() + mock_vector_service.get_tool_embedding.return_value = mock_embedding + mock_vector_service.search_similar_tools = AsyncMock(return_value=similar_results) + + with ( + patch("mcpgateway.meta_server.service.get_db", mock_get_db), + patch("mcpgateway.meta_server.service.VectorSearchService", return_value=mock_vector_service), + ): + result = asyncio.run(service.handle_meta_tool_call("get_similar_tools", {"tool_name": "my_tool"})) + + tool_names = [t["name"] for t in result["similarTools"]] + assert "my_tool" not in tool_names + assert "similar_tool_a" in tool_names + assert "similar_tool_b" in tool_names + + def test_similar_tools_respects_limit(self): + """Test that limit parameter is respected.""" + service = MetaServerService() + ref_tool = _make_mock_tool("my_tool") + + similar_results = [ + _make_tool_search_result(f"similar_{i}", score=0.9 - i * 0.1) + for i in range(5) + ] + + call_count = [0] + + def mock_get_db(): + call_count[0] += 1 + db = MagicMock() + query = db.query.return_value + filter_mock = MagicMock() + query.filter.return_value = filter_mock + filter_mock.filter.return_value = filter_mock + filter_mock.first.return_value = ref_tool + filter_mock.all.return_value = [_make_mock_tool(f"similar_{i}") for i in range(2)] + yield db + + mock_embedding = MagicMock() + mock_embedding.embedding = [0.1] * 128 + + mock_vector_service = MagicMock() + mock_vector_service.get_tool_embedding.return_value = mock_embedding + mock_vector_service.search_similar_tools = AsyncMock(return_value=similar_results) + + with ( + patch("mcpgateway.meta_server.service.get_db", mock_get_db), + patch("mcpgateway.meta_server.service.VectorSearchService", return_value=mock_vector_service), + ): + result = asyncio.run(service.handle_meta_tool_call("get_similar_tools", { + "tool_name": "my_tool", "limit": 2, + })) + + # Limit should cap the results (after self-filtering) + assert len(result["similarTools"]) <= 2 + + def test_similar_tools_scope_filtering_applied(self): + """Test that scope filtering is applied to similar tools results.""" + service = MetaServerService() + ref_tool = _make_mock_tool("my_tool") + + similar_results = [ + _make_tool_search_result("public_similar", server_id="s1", score=0.9), + _make_tool_search_result("private_similar", server_id="s2", score=0.8), + ] + + call_count = [0] + + def mock_get_db(): + call_count[0] += 1 + db = MagicMock() + query = db.query.return_value + filter_mock = MagicMock() + query.filter.return_value = filter_mock + filter_mock.filter.return_value = filter_mock + filter_mock.first.return_value = ref_tool + filter_mock.all.return_value = [ + _make_mock_tool("public_similar", visibility="public"), + _make_mock_tool("private_similar", visibility="private"), + ] + yield db + + mock_embedding = MagicMock() + mock_embedding.embedding = [0.1] * 128 + + mock_vector_service = MagicMock() + mock_vector_service.get_tool_embedding.return_value = mock_embedding + mock_vector_service.search_similar_tools = AsyncMock(return_value=similar_results) + + with ( + patch("mcpgateway.meta_server.service.get_db", mock_get_db), + patch("mcpgateway.meta_server.service.VectorSearchService", return_value=mock_vector_service), + ): + result = asyncio.run(service.handle_meta_tool_call("get_similar_tools", { + "tool_name": "my_tool", + "scope": {"include_visibility": ["public"]}, + })) + + tool_names = [t["name"] for t in result["similarTools"]] + assert "public_similar" in tool_names + assert "private_similar" not in tool_names + + def test_similar_tools_db_error_returns_empty(self): + """Test that DB error during tool lookup returns empty results gracefully.""" + service = MetaServerService() + + def broken_get_db(): + raise RuntimeError("DB connection failed") + yield # noqa: unreachable + + with patch("mcpgateway.meta_server.service.get_db", broken_get_db): + result = asyncio.run(service.handle_meta_tool_call("get_similar_tools", {"tool_name": "my_tool"})) + + assert result["similarTools"] == [] + assert result["totalFound"] == 0 + + def test_similar_tools_response_is_camel_case(self): + """Test response uses camelCase aliases.""" + service = MetaServerService() + + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools([])): + result = asyncio.run(service.handle_meta_tool_call("get_similar_tools", {"tool_name": "x"})) + + assert "referenceTool" in result + assert "similarTools" in result + assert "totalFound" in result + + +# --------------------------------------------------------------------------- +# _apply_scope_filtering tests (all 7 scope fields + AND semantics) +# --------------------------------------------------------------------------- + +class TestApplyScopeFiltering: + """Tests for _apply_scope_filtering with all MetaToolScope fields.""" + + def setup_method(self): + """Create a service and standard test results.""" + self.service = MetaServerService() + self.results = [ + _make_tool_search_result("tool_a", server_id="server_1", score=0.9), + _make_tool_search_result("tool_b", server_id="server_2", score=0.8), + _make_tool_search_result("tool_c", server_id="server_1", score=0.7), + ] + self.mock_tools = [ + _make_mock_tool("tool_a", tags=["database", "production"], visibility="public", team_id="team1"), + _make_mock_tool("tool_b", tags=["deprecated"], visibility="private", team_id="team2"), + _make_mock_tool("tool_c", tags=["database"], visibility="team", team_id="team1"), + ] + + def test_no_scope_passes_all(self): + """Test that None scope passes all results through.""" + result = self.service._apply_scope_filtering(self.results, None) + assert len(result) == 3 + + def test_empty_scope_passes_all(self): + """Test that empty scope dict passes all results through.""" + result = self.service._apply_scope_filtering(self.results, {}) + assert len(result) == 3 + + def test_include_tags_filter(self): + """Test include_tags: tool must have at least one matching tag.""" + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(self.mock_tools)): + result = self.service._apply_scope_filtering(self.results, {"include_tags": ["production"]}) + names = [r.tool_name for r in result] + assert "tool_a" in names # has "production" + assert "tool_b" not in names # has "deprecated" only + assert "tool_c" not in names # has "database" only + + def test_exclude_tags_filter(self): + """Test exclude_tags: tool must NOT have any excluded tag.""" + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(self.mock_tools)): + result = self.service._apply_scope_filtering(self.results, {"exclude_tags": ["deprecated"]}) + names = [r.tool_name for r in result] + assert "tool_a" in names + assert "tool_b" not in names # has "deprecated" + assert "tool_c" in names + + def test_include_servers_filter(self): + """Test include_servers: tool must be from one of these servers.""" + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(self.mock_tools)): + result = self.service._apply_scope_filtering(self.results, {"include_servers": ["server_1"]}) + names = [r.tool_name for r in result] + assert "tool_a" in names # server_1 + assert "tool_b" not in names # server_2 + assert "tool_c" in names # server_1 + + def test_exclude_servers_filter(self): + """Test exclude_servers: tool must NOT be from excluded servers.""" + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(self.mock_tools)): + result = self.service._apply_scope_filtering(self.results, {"exclude_servers": ["server_2"]}) + names = [r.tool_name for r in result] + assert "tool_a" in names + assert "tool_b" not in names # server_2 + assert "tool_c" in names + + def test_include_visibility_filter(self): + """Test include_visibility: tool must have one of these visibility levels.""" + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(self.mock_tools)): + result = self.service._apply_scope_filtering(self.results, {"include_visibility": ["public"]}) + names = [r.tool_name for r in result] + assert "tool_a" in names # public + assert "tool_b" not in names # private + assert "tool_c" not in names # team + + def test_include_teams_filter(self): + """Test include_teams: tool must belong to one of these teams.""" + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(self.mock_tools)): + result = self.service._apply_scope_filtering(self.results, {"include_teams": ["team1"]}) + names = [r.tool_name for r in result] + assert "tool_a" in names # team1 + assert "tool_b" not in names # team2 + assert "tool_c" in names # team1 + + def test_name_patterns_filter(self): + """Test name_patterns: tool name must match at least one glob pattern.""" + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(self.mock_tools)): + result = self.service._apply_scope_filtering(self.results, {"name_patterns": ["tool_a"]}) + names = [r.tool_name for r in result] + assert names == ["tool_a"] + + def test_name_patterns_wildcard(self): + """Test name_patterns with glob wildcards.""" + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(self.mock_tools)): + result = self.service._apply_scope_filtering(self.results, {"name_patterns": ["tool_*"]}) + assert len(result) == 3 # All match tool_* + + def test_combined_and_semantics(self): + """Test that multiple scope fields combine with AND semantics.""" + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(self.mock_tools)): + result = self.service._apply_scope_filtering(self.results, { + "include_tags": ["database"], + "include_visibility": ["public"], + "include_teams": ["team1"], + }) + names = [r.tool_name for r in result] + # Only tool_a has database tag AND public visibility AND team1 + assert names == ["tool_a"] + + def test_scope_excludes_tool_not_in_db(self): + """Test that tools not found in DB are excluded from scoped results.""" + # Only return tool_a from DB — tool_b, tool_c should be excluded + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools([self.mock_tools[0]])): + result = self.service._apply_scope_filtering(self.results, {"include_tags": ["database"]}) + names = [r.tool_name for r in result] + assert "tool_a" in names + assert "tool_b" not in names + assert "tool_c" not in names + + def test_scope_empty_results_input(self): + """Test scope filtering with empty results list.""" + result = self.service._apply_scope_filtering([], {"include_tags": ["database"]}) + assert result == [] + + def test_scope_all_fields_combined_strict(self): + """Test that strict AND across all 7 fields filters aggressively.""" + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(self.mock_tools)): + result = self.service._apply_scope_filtering(self.results, { + "include_tags": ["database"], + "exclude_tags": ["deprecated"], + "include_servers": ["server_1"], + "exclude_servers": ["server_3"], # doesn't affect any + "include_visibility": ["public", "team"], + "include_teams": ["team1"], + "name_patterns": ["tool_*"], + }) + names = [r.tool_name for r in result] + # tool_a: database=✓, not deprecated=✓, server_1=✓, public=✓, team1=✓, tool_*=✓ → ✓ + # tool_b: deprecated=✗ (excluded by exclude_tags) + # tool_c: database=✓, not deprecated=✓, server_1=✓, team=✓, team1=✓, tool_*=✓ → ✓ + assert "tool_a" in names + assert "tool_c" in names + assert "tool_b" not in names + + +# --------------------------------------------------------------------------- +# _get_tool_metadata tests +# --------------------------------------------------------------------------- + +class TestGetToolMetadata: + """Tests for _get_tool_metadata helper.""" + + def test_returns_metadata_for_found_tools(self): + """Test that metadata is returned for tools found in DB.""" + service = MetaServerService() + mock_tools = [ + _make_mock_tool("tool_a", tags=["db"], visibility="public", team_id="t1", input_schema={"type": "object"}), + ] + + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(mock_tools)): + result = service._get_tool_metadata(["tool_a"]) + + assert "tool_a" in result + assert result["tool_a"]["tags"] == ["db"] + assert result["tool_a"]["visibility"] == "public" + assert result["tool_a"]["team_id"] == "t1" + assert result["tool_a"]["input_schema"] == {"type": "object"} + + def test_empty_input_returns_empty(self): + """Test that empty tool names list returns empty dict.""" + service = MetaServerService() + result = service._get_tool_metadata([]) + assert result == {} + + def test_db_error_returns_empty(self): + """Test that DB error returns empty dict gracefully.""" + service = MetaServerService() + + def broken_get_db(): + raise RuntimeError("DB down") + yield # noqa: unreachable + + with patch("mcpgateway.meta_server.service.get_db", broken_get_db): + result = service._get_tool_metadata(["tool_a"]) + + assert result == {} + + def test_missing_tool_not_in_result(self): + """Test that tools not in DB are not in result dict.""" + service = MetaServerService() + mock_tools = [_make_mock_tool("tool_a")] + + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(mock_tools)): + result = service._get_tool_metadata(["tool_a", "tool_missing"]) + + assert "tool_a" in result + assert "tool_missing" not in result + + def test_null_tags_default_to_empty_list(self): + """Test that tools with None tags default to empty list.""" + service = MetaServerService() + mock_tools = [_make_mock_tool("tool_a", tags=None)] + + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(mock_tools)): + result = service._get_tool_metadata(["tool_a"]) + + assert result["tool_a"]["tags"] == [] + + +# --------------------------------------------------------------------------- +# _get_tools_matching_tags tests +# --------------------------------------------------------------------------- + +class TestGetToolsMatchingTags: + """Tests for _get_tools_matching_tags helper.""" + + def test_returns_matching_tool_names(self): + """Test that tools with matching tags are returned.""" + service = MetaServerService() + mock_tools = [ + _make_mock_tool("tool_a", tags=["database", "prod"]), + _make_mock_tool("tool_b", tags=["messaging"]), + _make_mock_tool("tool_c", tags=["database"]), + ] + + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(mock_tools)): + result = service._get_tools_matching_tags(["database"]) + + assert "tool_a" in result + assert "tool_c" in result + assert "tool_b" not in result + + def test_no_matching_tags_returns_empty(self): + """Test that no matching tags returns empty set.""" + service = MetaServerService() + mock_tools = [_make_mock_tool("tool_a", tags=["other"])] + + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(mock_tools)): + result = service._get_tools_matching_tags(["nonexistent"]) + + assert len(result) == 0 + + def test_db_error_returns_empty_set(self): + """Test that DB error returns empty set gracefully.""" + service = MetaServerService() + + def broken_get_db(): + raise RuntimeError("DB down") + yield # noqa: unreachable + + with patch("mcpgateway.meta_server.service.get_db", broken_get_db): + result = service._get_tools_matching_tags(["database"]) + + assert result == set() + + +# --------------------------------------------------------------------------- +# _map_to_tool_summaries tests +# --------------------------------------------------------------------------- + +class TestMapToToolSummaries: + """Tests for _map_to_tool_summaries helper.""" + + def test_maps_results_to_summaries(self): + """Test that ToolSearchResult objects are mapped to ToolSummary objects.""" + service = MetaServerService() + results = [ + _make_tool_search_result("tool_a", description="Tool A desc", server_id="s1", server_name="Server1"), + ] + mock_tools = [ + _make_mock_tool("tool_a", tags=["db"], input_schema={"type": "object"}), + ] + + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(mock_tools)): + summaries = service._map_to_tool_summaries(results) + + assert len(summaries) == 1 + assert summaries[0].name == "tool_a" + assert summaries[0].description == "Tool A desc" + assert summaries[0].server_id == "s1" + assert summaries[0].server_name == "Server1" + assert summaries[0].tags == ["db"] + assert summaries[0].input_schema == {"type": "object"} + + def test_empty_results_returns_empty(self): + """Test that empty results list returns empty summaries list.""" + service = MetaServerService() + summaries = service._map_to_tool_summaries([]) + assert summaries == [] + + def test_tool_not_in_db_gets_default_metadata(self): + """Test that tools not found in DB get default empty metadata.""" + service = MetaServerService() + results = [_make_tool_search_result("missing_tool")] + + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools([])): + summaries = service._map_to_tool_summaries(results) + + assert len(summaries) == 1 + assert summaries[0].name == "missing_tool" + assert summaries[0].tags == [] + assert summaries[0].input_schema is None + + def test_multiple_results_mapped_in_order(self): + """Test that multiple results preserve order.""" + service = MetaServerService() + results = [ + _make_tool_search_result("tool_a", score=0.9), + _make_tool_search_result("tool_b", score=0.8), + _make_tool_search_result("tool_c", score=0.7), + ] + mock_tools = [ + _make_mock_tool("tool_a"), + _make_mock_tool("tool_b"), + _make_mock_tool("tool_c"), + ] + + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(mock_tools)): + summaries = service._map_to_tool_summaries(results) + + assert [s.name for s in summaries] == ["tool_a", "tool_b", "tool_c"] + + def test_metrics_is_none_by_default(self): + """Test that metrics is None (TODO pending ToolMetric implementation).""" + service = MetaServerService() + results = [_make_tool_search_result("tool_a")] + mock_tools = [_make_mock_tool("tool_a")] + + with patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(mock_tools)): + summaries = service._map_to_tool_summaries(results, include_metrics=True) + + assert summaries[0].metrics is None + +# --------------------------------------------------------------------------- +# list_tools comprehensive tests +# --------------------------------------------------------------------------- + + +class TestListToolsImplementation: + """Comprehensive tests for the _list_tools implementation.""" + + def test_list_tools_returns_results(self): + """Test that list_tools returns tools from ToolService.""" + service = MetaServerService() + + # Create mock ToolRead objects + mock_tool_a = MagicMock() + mock_tool_a.name = "tool_a" + mock_tool_a.description = "Tool A description" + mock_tool_a.gateway = SimpleNamespace(id="server_1", name="Server 1") + mock_tool_a.tags = ["database"] + mock_tool_a.input_schema = {"type": "object"} + + mock_tool_b = MagicMock() + mock_tool_b.name = "tool_b" + mock_tool_b.description = "Tool B description" + mock_tool_b.gateway = SimpleNamespace(id="server_1", name="Server 1") + mock_tool_b.tags = ["api"] + mock_tool_b.input_schema = {"type": "object"} + + tools = [mock_tool_a, mock_tool_b] + + def mock_get_db(): + db = MagicMock() + yield db + + mock_db_tools = [ + _make_mock_tool("tool_a", tags=["database"], input_schema={"type": "object"}), + _make_mock_tool("tool_b", tags=["api"], input_schema={"type": "object"}), + ] + + from mcpgateway.services.tool_service import ToolService + + with ( + patch("mcpgateway.meta_server.service.get_db", mock_get_db), + patch.object(ToolService, "list_tools", new_callable=AsyncMock, return_value=(tools, None)), + patch.object(service, "_get_tool_metadata", return_value={ + "tool_a": {"tags": ["database"], "input_schema": {"type": "object"}, "visibility": "public", "team_id": None}, + "tool_b": {"tags": ["api"], "input_schema": {"type": "object"}, "visibility": "public", "team_id": None}, + }), + ): + result = asyncio.run(service.handle_meta_tool_call("list_tools", {})) + + assert result["totalCount"] == 2 + assert len(result["tools"]) == 2 + tool_names = [t["name"] for t in result["tools"]] + assert "tool_a" in tool_names + assert "tool_b" in tool_names + + def test_list_tools_with_pagination(self): + """Test list_tools respects limit and offset.""" + service = MetaServerService() + + # Create 5 mock tools + tools = [] + for i in range(5): + tool = MagicMock() + tool.name = f"tool_{i}" + tool.description = f"Tool {i}" + tool.gateway = SimpleNamespace(id="server_1", name="Server 1") + tool.tags = [] + tool.input_schema = {} + tools.append(tool) + + def mock_get_db(): + db = MagicMock() + yield db + + mock_db_tools = [_make_mock_tool(f"tool_{i}") for i in range(5)] + metadata = {f"tool_{i}": {"tags": [], "input_schema": {}, "visibility": "public", "team_id": None} for i in range(5)} + + from mcpgateway.services.tool_service import ToolService + + with ( + patch("mcpgateway.meta_server.service.get_db", mock_get_db), + patch.object(ToolService, "list_tools", new_callable=AsyncMock, return_value=(tools, None)), + patch.object(service, "_get_tool_metadata", return_value=metadata), + ): + result = asyncio.run(service.handle_meta_tool_call("list_tools", {"limit": 2, "offset": 1})) + + assert result["totalCount"] == 5 + assert len(result["tools"]) == 2 + assert result["hasMore"] is True + + def test_list_tools_with_tag_filter(self): + """Test list_tools respects tag filter.""" + service = MetaServerService() + + # Create tools with different tags + tool_a = MagicMock() + tool_a.name = "db_tool" + tool_a.description = "Database tool" + tool_a.gateway = SimpleNamespace(id="s1", name="Server 1") + tool_a.tags = ["database"] + tool_a.input_schema = {} + + tool_b = MagicMock() + tool_b.name = "api_tool" + tool_b.description = "API tool" + tool_b.gateway = SimpleNamespace(id="s1", name="Server 1") + tool_b.tags = ["api"] + tool_b.input_schema = {} + + tools = [tool_a, tool_b] + + def mock_get_db(): + db = MagicMock() + yield db + + metadata = { + "db_tool": {"tags": ["database"], "input_schema": {}, "visibility": "public", "team_id": None}, + "api_tool": {"tags": ["api"], "input_schema": {}, "visibility": "public", "team_id": None}, + } + + from mcpgateway.services.tool_service import ToolService + + with ( + patch("mcpgateway.meta_server.service.get_db", mock_get_db), + patch.object(ToolService, "list_tools", new_callable=AsyncMock, return_value=(tools, None)), + patch.object(service, "_get_tool_metadata", return_value=metadata), + ): + result = asyncio.run(service.handle_meta_tool_call("list_tools", {"tags": ["database"]})) + + # ToolService.list_tools should be called with tags filter + # The implementation passes tags to the service + assert "tools" in result + + def test_list_tools_with_server_filter(self): + """Test list_tools respects server_id filter.""" + service = MetaServerService() + + tool = MagicMock() + tool.name = "tool_a" + tool.description = "Tool A" + tool.gateway = SimpleNamespace(id="server_1", name="Server 1") + tool.tags = [] + tool.input_schema = {} + + def mock_get_db(): + db = MagicMock() + yield db + + metadata = { + "tool_a": {"tags": [], "input_schema": {}, "visibility": "public", "team_id": None}, + } + + from mcpgateway.services.tool_service import ToolService + + with ( + patch("mcpgateway.meta_server.service.get_db", mock_get_db), + patch.object(ToolService, "list_tools", new_callable=AsyncMock, return_value=([tool], None)), + patch.object(service, "_get_tool_metadata", return_value=metadata), + ): + result = asyncio.run(service.handle_meta_tool_call("list_tools", {"server_id": "server_1"})) + + assert len(result["tools"]) == 1 + + def test_list_tools_with_sorting(self): + """Test list_tools respects sort_by and sort_order.""" + service = MetaServerService() + + # Create mock tools (already sorted by ToolService) + tools = [] + for i, name in enumerate(["alpha", "beta", "gamma"]): + tool = MagicMock() + tool.name = name + tool.description = f"Tool {name}" + tool.gateway = SimpleNamespace(id="s1", name="Server 1") + tool.tags = [] + tool.input_schema = {} + tools.append(tool) + + def mock_get_db(): + db = MagicMock() + yield db + + metadata = { + name: {"tags": [], "input_schema": {}, "visibility": "public", "team_id": None} + for name in ["alpha", "beta", "gamma"] + } + + from mcpgateway.services.tool_service import ToolService + + with ( + patch("mcpgateway.meta_server.service.get_db", mock_get_db), + patch.object(ToolService, "list_tools", new_callable=AsyncMock, return_value=(tools, None)) as mock_list, + patch.object(service, "_get_tool_metadata", return_value=metadata), + ): + result = asyncio.run(service.handle_meta_tool_call("list_tools", { + "sort_by": "name", + "sort_order": "asc", + })) + + # Verify ToolService.list_tools was called with correct sort params + mock_list.assert_called_once() + call_kwargs = mock_list.call_args.kwargs + assert call_kwargs["sort_by"] == "name" + assert call_kwargs["sort_order"] == "asc" + + def test_list_tools_scope_filtering_applied(self): + """Test that scope filtering is applied to list results.""" + service = MetaServerService() + + # Create tools with different visibility + public_tool = MagicMock() + public_tool.name = "public_tool" + public_tool.description = "Public tool" + public_tool.gateway = SimpleNamespace(id="s1", name="Server 1") + public_tool.tags = [] + public_tool.input_schema = {} + + private_tool = MagicMock() + private_tool.name = "private_tool" + private_tool.description = "Private tool" + private_tool.gateway = SimpleNamespace(id="s1", name="Server 1") + private_tool.tags = [] + private_tool.input_schema = {} + + tools = [public_tool, private_tool] + + def mock_get_db(): + db = MagicMock() + yield db + + mock_db_tools = [ + _make_mock_tool("public_tool", visibility="public"), + _make_mock_tool("private_tool", visibility="private"), + ] + + metadata = { + "public_tool": {"tags": [], "input_schema": {}, "visibility": "public", "team_id": None}, + "private_tool": {"tags": [], "input_schema": {}, "visibility": "private", "team_id": None}, + } + + from mcpgateway.services.tool_service import ToolService + + with ( + patch("mcpgateway.meta_server.service.get_db", _mock_get_db_with_tools(mock_db_tools)), + patch.object(ToolService, "list_tools", new_callable=AsyncMock, return_value=(tools, None)), + ): + result = asyncio.run(service.handle_meta_tool_call("list_tools", { + "scope": {"include_visibility": ["public"]}, + })) + + tool_names = [t["name"] for t in result["tools"]] + assert "public_tool" in tool_names + assert "private_tool" not in tool_names + + def test_list_tools_db_error_returns_empty(self): + """Test list_tools returns empty result gracefully on DB error.""" + service = MetaServerService() + + def broken_get_db(): + raise RuntimeError("DB connection failed") + yield # noqa: unreachable + + with patch("mcpgateway.meta_server.service.get_db", broken_get_db): + result = asyncio.run(service.handle_meta_tool_call("list_tools", {})) + + assert result["tools"] == [] + assert result["totalCount"] == 0 + assert result["hasMore"] is False + + def test_list_tools_offset_beyond_total_returns_empty(self): + """Test offset beyond total count returns empty tools list.""" + service = MetaServerService() + + tool = MagicMock() + tool.name = "tool_a" + tool.description = "Tool A" + tool.gateway = SimpleNamespace(id="s1", name="Server 1") + tool.tags = [] + tool.input_schema = {} + + def mock_get_db(): + db = MagicMock() + yield db + + metadata = { + "tool_a": {"tags": [], "input_schema": {}, "visibility": "public", "team_id": None}, + } + + from mcpgateway.services.tool_service import ToolService + + with ( + patch("mcpgateway.meta_server.service.get_db", mock_get_db), + patch.object(ToolService, "list_tools", new_callable=AsyncMock, return_value=([tool], None)), + patch.object(service, "_get_tool_metadata", return_value=metadata), + ): + result = asyncio.run(service.handle_meta_tool_call("list_tools", {"offset": 100})) + + assert result["tools"] == [] + assert result["totalCount"] == 1 + assert result["hasMore"] is False + + def test_list_tools_include_schema_parameter(self): + """Test include_schema parameter is passed through.""" + service = MetaServerService() + + tool = MagicMock() + tool.name = "tool_a" + tool.description = "Tool A" + tool.gateway = SimpleNamespace(id="s1", name="Server 1") + tool.tags = [] + tool.input_schema = {"type": "object", "properties": {"arg": {"type": "string"}}} + + def mock_get_db(): + db = MagicMock() + yield db + + metadata = { + "tool_a": {"tags": [], "input_schema": {"type": "object"}, "visibility": "public", "team_id": None}, + } + + from mcpgateway.services.tool_service import ToolService + + with ( + patch("mcpgateway.meta_server.service.get_db", mock_get_db), + patch.object(ToolService, "list_tools", new_callable=AsyncMock, return_value=([tool], None)) as mock_list, + patch.object(service, "_get_tool_metadata", return_value=metadata), + ): + result = asyncio.run(service.handle_meta_tool_call("list_tools", {"include_schema": True})) + + # Verify ToolService.list_tools was called with include_schema=True + mock_list.assert_called_once() + assert mock_list.call_args.kwargs["include_schema"] is True + + def test_list_tools_response_is_camel_case(self): + """Test response uses camelCase aliases for serialization.""" + service = MetaServerService() + + def mock_get_db(): + db = MagicMock() + yield db + + from mcpgateway.services.tool_service import ToolService + + with ( + patch("mcpgateway.meta_server.service.get_db", mock_get_db), + patch.object(ToolService, "list_tools", new_callable=AsyncMock, return_value=([], None)), + ): + result = asyncio.run(service.handle_meta_tool_call("list_tools", {})) + + assert "totalCount" in result + assert "hasMore" in result + + +# ------------------------------------------------------------------ +# Tests for list_resources / read_resource / list_prompts / get_prompt +# ------------------------------------------------------------------ + + +def _make_mock_resource(uri="resource://test", name="test-resource", description="A test resource", + mime_type="text/markdown", text_content="# Hello", tags=None, enabled=True, size=7): + """Create a mock Resource object.""" + r = MagicMock() + r.uri = uri + r.name = name + r.description = description + r.mime_type = mime_type + r.text_content = text_content + r.binary_content = None + r.tags = tags or [] + r.enabled = enabled + r.size = size + r.created_at = None + return r + + +def _make_mock_prompt(name="test-prompt", description="A test prompt", template="Hello {name}", + argument_schema=None, tags=None, enabled=True): + """Create a mock Prompt object.""" + p = MagicMock() + p.name = name + p.description = description + p.template = template + p.argument_schema = argument_schema or {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]} + p.tags = tags or [] + p.enabled = enabled + p.created_at = None + return p + + +class TestListResourcesMetaTool: + """Tests for list_resources meta-tool.""" + + def test_list_resources_returns_results(self): + """Test that list_resources returns resources from DB.""" + service = MetaServerService() + mock_resources = [ + _make_mock_resource(uri="resource://a", name="res-a", tags=["guide"]), + _make_mock_resource(uri="resource://b", name="res-b", tags=["docs"]), + ] + + mock_db = MagicMock() + mock_query = MagicMock() + mock_query.filter.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.all.return_value = mock_resources + mock_db.query.return_value = mock_query + + def mock_get_db(): + yield mock_db + + with patch("mcpgateway.meta_server.service.get_db", mock_get_db): + result = asyncio.run(service.handle_meta_tool_call("list_resources", {})) + + assert result["totalCount"] == 2 + assert len(result["resources"]) == 2 + assert result["resources"][0]["uri"] == "resource://a" + assert result["resources"][1]["uri"] == "resource://b" + + def test_list_resources_with_tag_filter(self): + """Test that tag filtering works.""" + service = MetaServerService() + mock_resources = [ + _make_mock_resource(uri="resource://a", name="res-a", tags=["guide"]), + _make_mock_resource(uri="resource://b", name="res-b", tags=["docs"]), + ] + + mock_db = MagicMock() + mock_query = MagicMock() + mock_query.filter.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.all.return_value = mock_resources + mock_db.query.return_value = mock_query + + def mock_get_db(): + yield mock_db + + with patch("mcpgateway.meta_server.service.get_db", mock_get_db): + result = asyncio.run(service.handle_meta_tool_call("list_resources", {"tags": ["guide"]})) + + assert result["totalCount"] == 1 + assert result["resources"][0]["name"] == "res-a" + + def test_list_resources_empty(self): + """Test list_resources with no results.""" + service = MetaServerService() + + mock_db = MagicMock() + mock_query = MagicMock() + mock_query.filter.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.all.return_value = [] + mock_db.query.return_value = mock_query + + def mock_get_db(): + yield mock_db + + with patch("mcpgateway.meta_server.service.get_db", mock_get_db): + result = asyncio.run(service.handle_meta_tool_call("list_resources", {})) + + assert result["totalCount"] == 0 + assert result["resources"] == [] + assert result["hasMore"] is False + + def test_list_resources_pagination(self): + """Test list_resources with offset and limit.""" + service = MetaServerService() + mock_resources = [ + _make_mock_resource(uri=f"resource://{i}", name=f"res-{i}") + for i in range(5) + ] + + mock_db = MagicMock() + mock_query = MagicMock() + mock_query.filter.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.all.return_value = mock_resources + mock_db.query.return_value = mock_query + + def mock_get_db(): + yield mock_db + + with patch("mcpgateway.meta_server.service.get_db", mock_get_db): + result = asyncio.run(service.handle_meta_tool_call("list_resources", {"limit": 2, "offset": 0})) + + assert result["totalCount"] == 5 + assert len(result["resources"]) == 2 + assert result["hasMore"] is True + + +class TestReadResourceMetaTool: + """Tests for read_resource meta-tool.""" + + def test_read_resource_returns_content(self): + """Test that read_resource returns text content.""" + service = MetaServerService() + mock_resource = _make_mock_resource( + uri="resource://test/guide", + name="guide", + text_content="# Guide\nThis is the content.", + ) + + mock_db = MagicMock() + mock_query = MagicMock() + mock_query.filter.return_value = mock_query + mock_query.first.return_value = mock_resource + mock_db.query.return_value = mock_query + + def mock_get_db(): + yield mock_db + + with patch("mcpgateway.meta_server.service.get_db", mock_get_db): + result = asyncio.run(service.handle_meta_tool_call("read_resource", {"uri": "resource://test/guide"})) + + assert result["uri"] == "resource://test/guide" + assert result["name"] == "guide" + assert "Guide" in result["text"] + + def test_read_resource_not_found(self): + """Test read_resource with unknown URI.""" + service = MetaServerService() + + mock_db = MagicMock() + mock_query = MagicMock() + mock_query.filter.return_value = mock_query + mock_query.first.return_value = None + mock_db.query.return_value = mock_query + + def mock_get_db(): + yield mock_db + + with patch("mcpgateway.meta_server.service.get_db", mock_get_db): + result = asyncio.run(service.handle_meta_tool_call("read_resource", {"uri": "resource://not/found"})) + + assert result["uri"] == "resource://not/found" + assert "not found" in result["text"].lower() + + def test_read_resource_empty_uri(self): + """Test read_resource with empty URI returns error.""" + service = MetaServerService() + result = asyncio.run(service.handle_meta_tool_call("read_resource", {"uri": ""})) + assert "required" in result["text"].lower() + + +class TestListPromptsMetaTool: + """Tests for list_prompts meta-tool.""" + + def test_list_prompts_returns_results(self): + """Test that list_prompts returns prompts from DB.""" + service = MetaServerService() + mock_prompts = [ + _make_mock_prompt(name="summarize", description="Summarize text", tags=["utility"]), + _make_mock_prompt(name="translate", description="Translate text", tags=["language"]), + ] + + mock_db = MagicMock() + mock_query = MagicMock() + mock_query.filter.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.all.return_value = mock_prompts + mock_db.query.return_value = mock_query + + def mock_get_db(): + yield mock_db + + with patch("mcpgateway.meta_server.service.get_db", mock_get_db): + result = asyncio.run(service.handle_meta_tool_call("list_prompts", {})) + + assert result["totalCount"] == 2 + assert len(result["prompts"]) == 2 + assert result["prompts"][0]["name"] == "summarize" + + def test_list_prompts_empty(self): + """Test list_prompts with no results.""" + service = MetaServerService() + + mock_db = MagicMock() + mock_query = MagicMock() + mock_query.filter.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.all.return_value = [] + mock_db.query.return_value = mock_query + + def mock_get_db(): + yield mock_db + + with patch("mcpgateway.meta_server.service.get_db", mock_get_db): + result = asyncio.run(service.handle_meta_tool_call("list_prompts", {})) + + assert result["totalCount"] == 0 + assert result["prompts"] == [] + + +class TestGetPromptMetaTool: + """Tests for get_prompt meta-tool.""" + + def test_get_prompt_returns_template(self): + """Test that get_prompt returns prompt template.""" + service = MetaServerService() + mock_prompt = _make_mock_prompt(name="greet", template="Hello {name}!") + + mock_db = MagicMock() + mock_query = MagicMock() + mock_query.filter.return_value = mock_query + mock_query.first.return_value = mock_prompt + mock_db.query.return_value = mock_query + + def mock_get_db(): + yield mock_db + + with patch("mcpgateway.meta_server.service.get_db", mock_get_db): + result = asyncio.run(service.handle_meta_tool_call("get_prompt", {"name": "greet"})) + + assert result["name"] == "greet" + assert result["template"] == "Hello {name}!" + assert result["rendered"] is None + + def test_get_prompt_with_rendering(self): + """Test that get_prompt renders template with arguments.""" + service = MetaServerService() + mock_prompt = _make_mock_prompt(name="greet", template="Hello {name}!") + mock_prompt.validate_arguments = MagicMock() + + mock_db = MagicMock() + mock_query = MagicMock() + mock_query.filter.return_value = mock_query + mock_query.first.return_value = mock_prompt + mock_db.query.return_value = mock_query + + def mock_get_db(): + yield mock_db + + with patch("mcpgateway.meta_server.service.get_db", mock_get_db): + result = asyncio.run(service.handle_meta_tool_call("get_prompt", { + "name": "greet", + "arguments": {"name": "World"}, + })) + + assert result["name"] == "greet" + assert result["rendered"] == "Hello World!" + + def test_get_prompt_not_found(self): + """Test get_prompt with unknown name.""" + service = MetaServerService() + + mock_db = MagicMock() + mock_query = MagicMock() + mock_query.filter.return_value = mock_query + mock_query.first.return_value = None + mock_db.query.return_value = mock_query + + def mock_get_db(): + yield mock_db + + with patch("mcpgateway.meta_server.service.get_db", mock_get_db): + result = asyncio.run(service.handle_meta_tool_call("get_prompt", {"name": "nonexistent"})) + + assert result["name"] == "nonexistent" + assert "not found" in result["description"].lower() + + def test_get_prompt_empty_name(self): + """Test get_prompt with empty name returns error.""" + service = MetaServerService() + result = asyncio.run(service.handle_meta_tool_call("get_prompt", {"name": ""})) + assert "required" in result["description"].lower() + + +class TestMetaToolDefinitionsIncludeNewTools: + """Test that META_TOOL_DEFINITIONS includes the 4 new meta-tools.""" + + def test_list_resources_in_definitions(self): + assert "list_resources" in META_TOOL_DEFINITIONS + + def test_read_resource_in_definitions(self): + assert "read_resource" in META_TOOL_DEFINITIONS + + def test_list_prompts_in_definitions(self): + assert "list_prompts" in META_TOOL_DEFINITIONS + + def test_get_prompt_in_definitions(self): + assert "get_prompt" in META_TOOL_DEFINITIONS + + def test_total_meta_tools_is_11(self): + assert len(META_TOOL_DEFINITIONS) == 11 + + def test_service_returns_11_definitions(self): + service = MetaServerService() + defs = service.get_meta_tool_definitions() + assert len(defs) == 11 + names = {d["name"] for d in defs} + assert "list_resources" in names + assert "read_resource" in names + assert "list_prompts" in names + assert "get_prompt" in names + assert "tools" in result \ No newline at end of file