diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 086bc7f668be..fc9df494ab96 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -54,6 +54,10 @@ def get_default_instructions(branding: str = "Apache Superset") -> str: - generate_dashboard: Create a dashboard from chart IDs - add_chart_to_existing_dashboard: Add a chart to an existing dashboard +Database Connections: +- list_databases: List database connections with advanced filters (1-based pagination) +- get_database_info: Get detailed database connection info by ID (backend, capabilities) + Dataset Management: - list_datasets: List datasets with advanced filters (1-based pagination) - get_dataset_info: Get detailed dataset information by ID (includes columns/metrics) @@ -114,12 +118,14 @@ def get_default_instructions(branding: str = "Apache Superset") -> str: 3. generate_explore_link(dataset_id, config) -> preview interactively 4. generate_chart(dataset_id, config, save_chart=True) -> save permanently -To find your own charts/dashboards: +To find your own charts/dashboards/databases: 1. get_instance_info -> get current_user.id 2. list_charts(filters=[{{"col": "created_by_fk", "opr": "eq", "value": current_user.id}}]) 3. Or: list_dashboards(filters=[{{"col": "created_by_fk", "opr": "eq", "value": current_user.id}}]) +4. Or: list_databases(filters=[{{"col": "created_by_fk", + "opr": "eq", "value": current_user.id}}]) To explore data with SQL: 1. list_datasets -> find a dataset and note its database_id @@ -168,6 +174,8 @@ def get_default_instructions(branding: str = "Apache Superset") -> str: filters=[{{"col": "created_by_fk", "opr": "eq", "value": }}] - My dashboards: filters=[{{"col": "created_by_fk", "opr": "eq", "value": }}] +- My databases: + filters=[{{"col": "created_by_fk", "opr": "eq", "value": }}] To modify an existing chart (add filters, change metrics, change dimensions, etc.): 1. get_chart_info(chart_id) -> examine current configuration @@ -432,6 +440,10 @@ def create_mcp_app( get_dashboard_info, list_dashboards, ) +from superset.mcp_service.database.tool import ( # noqa: F401, E402 + get_database_info, + list_databases, +) from superset.mcp_service.dataset.tool import ( # noqa: F401, E402 get_dataset_info, list_datasets, diff --git a/superset/mcp_service/common/schema_discovery.py b/superset/mcp_service/common/schema_discovery.py index dda1a0d5412c..a648f45cc7aa 100644 --- a/superset/mcp_service/common/schema_discovery.py +++ b/superset/mcp_service/common/schema_discovery.py @@ -29,6 +29,8 @@ from pydantic import BaseModel, Field from sqlalchemy.inspection import inspect +from superset.mcp_service.constants import ModelType + class ColumnMetadata(BaseModel): """Metadata for a selectable column.""" @@ -52,7 +54,7 @@ class ModelSchemaInfo(BaseModel): - Default values for each """ - model_type: Literal["chart", "dataset", "dashboard"] = Field( + model_type: ModelType = Field( ..., description="The model type this schema describes" ) select_columns: list[ColumnMetadata] = Field( @@ -82,9 +84,7 @@ class ModelSchemaInfo(BaseModel): class GetSchemaRequest(BaseModel): """Request schema for unified get_schema tool.""" - model_type: Literal["chart", "dataset", "dashboard"] = Field( - ..., description="Model type to get schema for" - ) + model_type: ModelType = Field(..., description="Model type to get schema for") class GetSchemaResponse(BaseModel): @@ -180,6 +180,7 @@ def get_columns_from_model( model_cls: Type[Any], default_columns: list[str], extra_columns: dict[str, ColumnMetadata] | None = None, + exclude_columns: set[str] | None = None, ) -> list[ColumnMetadata]: """ Dynamically extract column metadata from a SQLAlchemy model. @@ -188,6 +189,7 @@ def get_columns_from_model( model_cls: The SQLAlchemy model class to inspect default_columns: List of column names that should be marked as defaults extra_columns: Additional columns not on the model (e.g., computed fields) + exclude_columns: Column names to omit (e.g., sensitive fields) Returns: List of ColumnMetadata objects for all columns @@ -197,6 +199,8 @@ def get_columns_from_model( for col in mapper.columns: col_name = col.key + if exclude_columns and col_name in exclude_columns: + continue col_type = _get_sqlalchemy_type_name(col.type) # Get description from column doc, comment, or fallback mapping description = ( @@ -452,6 +456,68 @@ def get_columns_from_model( } +# Database configuration +DATABASE_DEFAULT_COLUMNS = [ + "id", + "database_name", + "backend", + "expose_in_sqllab", + "changed_on", + "changed_on_humanized", +] +DATABASE_SORTABLE_COLUMNS = [ + "id", + "database_name", + "changed_on", + "created_on", +] +DATABASE_SEARCH_COLUMNS = ["database_name"] +DATABASE_EXTRA_COLUMNS: dict[str, ColumnMetadata] = { + "backend": ColumnMetadata( + name="backend", + description="Database backend type (e.g., postgresql, mysql)", + type="str", + is_default=True, + ), + "changed_by": ColumnMetadata( + name="changed_by", + description="Last modifier username", + type="str", + is_default=False, + ), + "changed_by_name": ColumnMetadata( + name="changed_by_name", + description="Last modifier display name", + type="str", + is_default=False, + ), + "changed_on_humanized": ColumnMetadata( + name="changed_on_humanized", + description="Humanized modification time", + type="str", + is_default=True, + ), + "created_by": ColumnMetadata( + name="created_by", + description="Creator username", + type="str", + is_default=False, + ), + "created_by_name": ColumnMetadata( + name="created_by_name", + description="Creator display name", + type="str", + is_default=False, + ), + "created_on_humanized": ColumnMetadata( + name="created_on_humanized", + description="Humanized creation time", + type="str", + is_default=False, + ), +} + + def get_chart_columns() -> list[ColumnMetadata]: """Get column metadata for Chart model dynamically.""" from superset.models.slice import Slice @@ -477,6 +543,27 @@ def get_dashboard_columns() -> list[ColumnMetadata]: ) +# Sensitive columns that should not be exposed via schema discovery +DATABASE_EXCLUDE_COLUMNS = { + "sqlalchemy_uri", + "password", + "encrypted_extra", + "server_cert", +} + + +def get_database_columns() -> list[ColumnMetadata]: + """Get column metadata for Database model dynamically.""" + from superset.models.core import Database + + return get_columns_from_model( + Database, + DATABASE_DEFAULT_COLUMNS, + DATABASE_EXTRA_COLUMNS, + exclude_columns=DATABASE_EXCLUDE_COLUMNS, + ) + + def get_all_column_names(columns: list[ColumnMetadata]) -> list[str]: """Extract all column names from column metadata list.""" return [col.name for col in columns] @@ -487,3 +574,4 @@ def get_all_column_names(columns: list[ColumnMetadata]) -> list[str]: CHART_ALL_COLUMNS: list[str] = [] DATASET_ALL_COLUMNS: list[str] = [] DASHBOARD_ALL_COLUMNS: list[str] = [] +DATABASE_ALL_COLUMNS: list[str] = [] diff --git a/superset/mcp_service/constants.py b/superset/mcp_service/constants.py index a23a7949e948..6315ef8755f9 100644 --- a/superset/mcp_service/constants.py +++ b/superset/mcp_service/constants.py @@ -16,6 +16,11 @@ # under the License. """Constants for the MCP service.""" +from typing import Literal + +# Supported model types for schema discovery and MCP tools +ModelType = Literal["chart", "dataset", "dashboard", "database"] + # Pagination defaults DEFAULT_PAGE_SIZE = 10 # Default number of items per page MAX_PAGE_SIZE = 100 # Maximum allowed page_size to prevent oversized responses diff --git a/superset/mcp_service/database/__init__.py b/superset/mcp_service/database/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/superset/mcp_service/database/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/mcp_service/database/schemas.py b/superset/mcp_service/database/schemas.py new file mode 100644 index 000000000000..d93ea630b309 --- /dev/null +++ b/superset/mcp_service/database/schemas.py @@ -0,0 +1,364 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Pydantic schemas for database-related responses +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Annotated, Any, Dict, List, Literal + +import humanize +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_serializer, + model_validator, + PositiveInt, +) + +from superset.daos.base import ColumnOperator, ColumnOperatorEnum +from superset.mcp_service.common.cache_schemas import MetadataCacheControl +from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE +from superset.mcp_service.system.schemas import PaginationInfo +from superset.mcp_service.utils.schema_utils import ( + parse_json_or_list, + parse_json_or_model_list, +) +from superset.utils import json + + +class DatabaseFilter(ColumnOperator): + """ + Filter object for database listing. + col: The column to filter on. Must be one of the allowed filter fields. + opr: The operator to use. Must be one of the supported operators. + value: The value to filter by (type depends on col and opr). + """ + + col: Literal[ + "database_name", + "expose_in_sqllab", + "allow_file_upload", + "created_by_fk", + "changed_by_fk", + ] = Field( + ..., + description="Column to filter on. Use get_schema(model_type='database') for " + "available filter columns. Use created_by_fk with the user " + "ID from get_instance_info's current_user to find " + "databases created by a specific user.", + ) + opr: ColumnOperatorEnum = Field( + ..., + description="Operator to use. Use get_schema(model_type='database') for " + "available operators.", + ) + value: str | int | float | bool | List[str | int | float | bool] = Field( + ..., description="Value to filter by (type depends on col and opr)" + ) + + +class DatabaseInfo(BaseModel): + id: int | None = Field(None, description="Database ID") + uuid: str | None = Field(None, description="Database UUID") + database_name: str | None = Field(None, description="Database connection name") + backend: str | None = Field(None, description="Database backend (e.g., postgresql)") + expose_in_sqllab: bool | None = Field( + None, description="Whether exposed in SQL Lab" + ) + allow_ctas: bool | None = Field( + None, description="Whether CREATE TABLE AS is allowed" + ) + allow_cvas: bool | None = Field( + None, description="Whether CREATE VIEW AS is allowed" + ) + allow_dml: bool | None = Field( + None, description="Whether DML statements are allowed" + ) + allow_file_upload: bool | None = Field( + None, description="Whether file upload is allowed" + ) + allow_run_async: bool | None = Field( + None, description="Whether async query execution is allowed" + ) + cache_timeout: int | None = Field( + None, description="Cache timeout override in seconds" + ) + configuration_method: str | None = Field( + None, description="Configuration method (sqlalchemy_form or dynamic_form)" + ) + force_ctas_schema: str | None = Field( + None, description="Schema to force for CTAS queries" + ) + impersonate_user: bool | None = Field( + None, description="Whether to impersonate the logged-in user" + ) + is_managed_externally: bool | None = Field( + None, description="Whether managed by an external system" + ) + external_url: str | None = Field( + None, description="URL of the external management system" + ) + extra: Dict[str, Any | None] | None = Field(None, description="Extra configuration") + changed_by: str | None = Field(None, description="Last modifier (username)") + changed_on: str | datetime | None = Field( + None, description="Last modification timestamp" + ) + changed_on_humanized: str | None = Field( + None, description="Humanized modification time" + ) + created_by: str | None = Field(None, description="Database creator (username)") + created_on: str | datetime | None = Field(None, description="Creation timestamp") + created_on_humanized: str | None = Field( + None, description="Humanized creation time" + ) + model_config = ConfigDict( + from_attributes=True, + ser_json_timedelta="iso8601", + populate_by_name=True, + ) + + @model_serializer(mode="wrap", when_used="json") + def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any]: + """Filter fields based on serialization context. + + If context contains 'select_columns', only include those fields. + Otherwise, include all fields (default behavior). + """ + data = serializer(self) + + if info.context and isinstance(info.context, dict): + select_columns = info.context.get("select_columns") + if select_columns: + requested_fields = set(select_columns) + return {k: v for k, v in data.items() if k in requested_fields} + + return data + + +class DatabaseList(BaseModel): + databases: List[DatabaseInfo] + count: int + total_count: int + page: int + page_size: int + total_pages: int + has_previous: bool + has_next: bool + columns_requested: List[str] = Field( + default_factory=list, + description="Requested columns for the response", + ) + columns_loaded: List[str] = Field( + default_factory=list, + description="Columns that were actually loaded for each database", + ) + columns_available: List[str] = Field( + default_factory=list, + description="All columns available for selection via select_columns parameter", + ) + sortable_columns: List[str] = Field( + default_factory=list, + description="Columns that can be used with order_column parameter", + ) + filters_applied: List[DatabaseFilter] = Field( + default_factory=list, + description="List of advanced filter dicts applied to the query.", + ) + pagination: PaginationInfo | None = None + timestamp: datetime | None = None + model_config = ConfigDict(ser_json_timedelta="iso8601") + + +class ListDatabasesRequest(MetadataCacheControl): + """Request schema for list_databases with clear, unambiguous types.""" + + filters: Annotated[ + List[DatabaseFilter], + Field( + default_factory=list, + description="List of filter objects (column, operator, value). Each " + "filter is an object with 'col', 'opr', and 'value' " + "properties. Cannot be used together with 'search'.", + ), + ] + select_columns: Annotated[ + List[str], + Field( + default_factory=list, + description="List of columns to select. Defaults to common columns if not " + "specified.", + ), + ] + search: Annotated[ + str | None, + Field( + default=None, + description="Text search string to match against database fields. Cannot " + "be used together with 'filters'.", + ), + ] + order_column: Annotated[ + str | None, Field(default=None, description="Column to order results by") + ] + order_direction: Annotated[ + Literal["asc", "desc"], + Field( + default="desc", description="Direction to order results ('asc' or 'desc')" + ), + ] + page: Annotated[ + PositiveInt, + Field(default=1, description="Page number for pagination (1-based)"), + ] + page_size: Annotated[ + int, + Field( + default=DEFAULT_PAGE_SIZE, + gt=0, + le=MAX_PAGE_SIZE, + description=f"Number of items per page (max {MAX_PAGE_SIZE})", + ), + ] + + @field_validator("filters", mode="before") + @classmethod + def parse_filters(cls, v: Any) -> List[DatabaseFilter]: + """Accept both JSON string and list of objects.""" + return parse_json_or_model_list(v, DatabaseFilter, "filters") + + @field_validator("select_columns", mode="before") + @classmethod + def parse_columns(cls, v: Any) -> List[str]: + """Accept JSON array, list, or comma-separated string.""" + return parse_json_or_list(v, "select_columns") + + @model_validator(mode="after") + def validate_search_and_filters(self) -> "ListDatabasesRequest": + """Prevent using both search and filters simultaneously to avoid query + conflicts.""" + if self.search and self.filters: + raise ValueError( + "Cannot use both 'search' and 'filters' parameters simultaneously. " + "Use either 'search' for text-based searching across multiple fields, " + "or 'filters' for precise column-based filtering, but not both." + ) + return self + + +class DatabaseError(BaseModel): + error: str = Field(..., description="Error message") + error_type: str = Field(..., description="Type of error") + timestamp: str | datetime | None = Field(None, description="Error timestamp") + model_config = ConfigDict(ser_json_timedelta="iso8601") + + @classmethod + def create(cls, error: str, error_type: str) -> "DatabaseError": + """Create a standardized DatabaseError with timestamp.""" + from datetime import datetime, timezone + + return cls( + error=error, error_type=error_type, timestamp=datetime.now(timezone.utc) + ) + + +class GetDatabaseInfoRequest(MetadataCacheControl): + """Request schema for get_database_info with support for ID or UUID.""" + + identifier: Annotated[ + int | str, + Field(description="Database identifier - can be numeric ID or UUID string"), + ] + + +def _parse_json_field(obj: Any, field_name: str) -> Dict[str, Any] | None: + """Parse a field that may be stored as a JSON string into a dict.""" + value = getattr(obj, field_name, None) + if isinstance(value, str): + try: + parsed = json.loads(value) + if isinstance(parsed, dict): + return parsed + except (ValueError, TypeError): + pass + return None + return value + + +def _humanize_timestamp(dt: datetime | None) -> str | None: + """Convert a datetime to a humanized string like '2 hours ago'.""" + if dt is None: + return None + now = datetime.now(dt.tzinfo) if dt.tzinfo else datetime.now() + return humanize.naturaltime(now - dt) + + +def _get_backend(database: Any) -> str | None: + """Safely get backend from a Database object or row proxy. + + backend is a @property that decrypts sqlalchemy_uri, which fails on + row proxies returned by column-only DAO list queries. Fall back to None + when the property raises. + """ + try: + return database.backend + except (AttributeError, TypeError): + return None + + +def serialize_database_object(database: Any) -> DatabaseInfo | None: + if not database: + return None + + return DatabaseInfo( + id=getattr(database, "id", None), + uuid=str(getattr(database, "uuid", "")) + if getattr(database, "uuid", None) + else None, + database_name=getattr(database, "database_name", None), + backend=_get_backend(database), + expose_in_sqllab=getattr(database, "expose_in_sqllab", None), + allow_ctas=getattr(database, "allow_ctas", None), + allow_cvas=getattr(database, "allow_cvas", None), + allow_dml=getattr(database, "allow_dml", None), + allow_file_upload=getattr(database, "allow_file_upload", None), + allow_run_async=getattr(database, "allow_run_async", None), + cache_timeout=getattr(database, "cache_timeout", None), + configuration_method=getattr(database, "configuration_method", None), + force_ctas_schema=getattr(database, "force_ctas_schema", None), + impersonate_user=getattr(database, "impersonate_user", None), + is_managed_externally=getattr(database, "is_managed_externally", None), + external_url=getattr(database, "external_url", None), + extra=_parse_json_field(database, "extra"), + changed_by=getattr(database, "changed_by_name", None) + or ( + str(database.changed_by) if getattr(database, "changed_by", None) else None + ), + changed_on=getattr(database, "changed_on", None), + changed_on_humanized=_humanize_timestamp(getattr(database, "changed_on", None)), + created_by=getattr(database, "created_by_name", None) + or ( + str(database.created_by) if getattr(database, "created_by", None) else None + ), + created_on=getattr(database, "created_on", None), + created_on_humanized=_humanize_timestamp(getattr(database, "created_on", None)), + ) diff --git a/superset/mcp_service/database/tool/__init__.py b/superset/mcp_service/database/tool/__init__.py new file mode 100644 index 000000000000..bb87106908e8 --- /dev/null +++ b/superset/mcp_service/database/tool/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .get_database_info import get_database_info +from .list_databases import list_databases + +__all__ = [ + "list_databases", + "get_database_info", +] diff --git a/superset/mcp_service/database/tool/get_database_info.py b/superset/mcp_service/database/tool/get_database_info.py new file mode 100644 index 000000000000..56aef1d2bb41 --- /dev/null +++ b/superset/mcp_service/database/tool/get_database_info.py @@ -0,0 +1,137 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Get database info FastMCP tool + +This module contains the FastMCP tool for getting detailed information +about a specific database connection. +""" + +import logging +from datetime import datetime, timezone + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import event_logger +from superset.mcp_service.database.schemas import ( + DatabaseError, + DatabaseInfo, + GetDatabaseInfoRequest, + serialize_database_object, +) +from superset.mcp_service.mcp_core import ModelGetInfoCore + +logger = logging.getLogger(__name__) + + +@tool( + tags=["discovery"], + class_permission_name="Database", + annotations=ToolAnnotations( + title="Get database info", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def get_database_info( + request: GetDatabaseInfoRequest, ctx: Context +) -> DatabaseInfo | DatabaseError: + """Get database connection metadata by ID or UUID. + + Returns database configuration including backend type and capabilities + (allow_ctas, allow_dml, expose_in_sqllab, etc.). + + IMPORTANT FOR LLM CLIENTS: + - Use numeric ID (e.g., 123) or UUID string (e.g., "a1b2c3d4-...") + - To find a database ID, use the list_databases tool first + + Example usage: + ```json + { + "identifier": 1 + } + ``` + + Or with UUID: + ```json + { + "identifier": "a1b2c3d4-5678-90ab-cdef-1234567890ab" + } + ``` + """ + await ctx.info( + "Retrieving database information: identifier=%s" % (request.identifier,) + ) + await ctx.debug( + "Metadata cache settings: use_cache=%s refresh_metadata=%s force_refresh=%s" + % ( + request.use_cache, + request.refresh_metadata, + request.force_refresh, + ) + ) + + try: + from superset.daos.database import DatabaseDAO + + with event_logger.log_context(action="mcp.get_database_info.lookup"): + get_tool = ModelGetInfoCore( + dao_class=DatabaseDAO, + output_schema=DatabaseInfo, + error_schema=DatabaseError, + serializer=serialize_database_object, + supports_slug=False, + logger=logger, + ) + + result = get_tool.run_tool(request.identifier) + + if isinstance(result, DatabaseInfo): + await ctx.info( + "Database information retrieved successfully: " + "database_id=%s, database_name=%s, backend=%s" + % ( + result.id, + result.database_name, + result.backend, + ) + ) + else: + await ctx.warning( + "Database retrieval failed: error_type=%s, error=%s" + % (result.error_type, result.error) + ) + + return result + + except Exception as e: + await ctx.error( + "Database information retrieval failed: identifier=%s, error=%s, " + "error_type=%s" + % ( + request.identifier, + str(e), + type(e).__name__, + ) + ) + return DatabaseError( + error=f"Failed to get database info: {str(e)}", + error_type="InternalError", + timestamp=datetime.now(timezone.utc), + ) diff --git a/superset/mcp_service/database/tool/list_databases.py b/superset/mcp_service/database/tool/list_databases.py new file mode 100644 index 000000000000..8db4e0fbb353 --- /dev/null +++ b/superset/mcp_service/database/tool/list_databases.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +List databases FastMCP tool (Advanced with metadata cache control) + +This module contains the FastMCP tool for listing databases using +advanced filtering with clear, unambiguous request schema and metadata cache control. +""" + +import logging +from typing import TYPE_CHECKING + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +if TYPE_CHECKING: + from superset.models.core import Database + +from superset.extensions import event_logger +from superset.mcp_service.database.schemas import ( + DatabaseFilter, + DatabaseInfo, + DatabaseList, + ListDatabasesRequest, + serialize_database_object, +) +from superset.mcp_service.mcp_core import ModelListCore + +logger = logging.getLogger(__name__) + + +@tool( + tags=["core"], + class_permission_name="Database", + annotations=ToolAnnotations( + title="List databases", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def list_databases(request: ListDatabasesRequest, ctx: Context) -> DatabaseList: + """List database connections with filtering and search. + + Returns database metadata including name, backend type, and permissions. + + Sortable columns for order_column: id, database_name, changed_on, + created_on + """ + await ctx.info( + "Listing databases: page=%s, page_size=%s, search=%s" + % ( + request.page, + request.page_size, + request.search, + ) + ) + await ctx.debug( + "Database listing parameters: filters=%s, order_column=%s, " + "order_direction=%s, select_columns=%s" + % ( + request.filters, + request.order_column, + request.order_direction, + request.select_columns, + ) + ) + await ctx.debug( + "Metadata cache settings: use_cache=%s, refresh_metadata=%s, force_refresh=%s" + % ( + request.use_cache, + request.refresh_metadata, + request.force_refresh, + ) + ) + + try: + from superset.daos.database import DatabaseDAO + from superset.mcp_service.common.schema_discovery import ( + DATABASE_DEFAULT_COLUMNS, + DATABASE_SORTABLE_COLUMNS, + get_all_column_names, + get_database_columns, + ) + + # Get all column names dynamically from the model + all_columns = get_all_column_names(get_database_columns()) + + def _serialize_database( + obj: "Database | None", cols: list[str] | None + ) -> DatabaseInfo | None: + """Serialize database (filtering via model_serializer).""" + return serialize_database_object(obj) + + # Create tool with standard serialization + list_tool = ModelListCore( + dao_class=DatabaseDAO, + output_schema=DatabaseInfo, + item_serializer=_serialize_database, + filter_type=DatabaseFilter, + default_columns=DATABASE_DEFAULT_COLUMNS, + search_columns=["database_name"], + list_field_name="databases", + output_list_schema=DatabaseList, + all_columns=all_columns, + sortable_columns=DATABASE_SORTABLE_COLUMNS, + logger=logger, + ) + + with event_logger.log_context(action="mcp.list_databases.query"): + result = list_tool.run_tool( + filters=request.filters, + search=request.search, + select_columns=request.select_columns, + order_column=request.order_column, + order_direction=request.order_direction, + page=max(request.page - 1, 0), + page_size=request.page_size, + ) + + await ctx.info( + "Databases listed successfully: count=%s, total_count=%s, total_pages=%s" + % ( + len(result.databases) if hasattr(result, "databases") else 0, + getattr(result, "total_count", None), + getattr(result, "total_pages", None), + ) + ) + + # Apply field filtering via serialization context + columns_to_filter = result.columns_requested + await ctx.debug( + "Applying field filtering via serialization context: columns=%s" + % (columns_to_filter,) + ) + with event_logger.log_context(action="mcp.list_databases.serialization"): + return result.model_dump( + mode="json", + context={"select_columns": columns_to_filter}, + ) + + except Exception as e: + await ctx.error( + "Database listing failed: page=%s, page_size=%s, error=%s, error_type=%s" + % ( + request.page, + request.page_size, + str(e), + type(e).__name__, + ) + ) + raise diff --git a/superset/mcp_service/mcp_core.py b/superset/mcp_service/mcp_core.py index 051aa20809a7..0609ff27ce77 100644 --- a/superset/mcp_service/mcp_core.py +++ b/superset/mcp_service/mcp_core.py @@ -25,6 +25,7 @@ from pydantic import BaseModel from superset.daos.base import BaseDAO +from superset.mcp_service.constants import ModelType from superset.mcp_service.utils import _is_uuid # Type variables for generic model tools @@ -521,7 +522,7 @@ class ModelGetSchemaCore(BaseCore, Generic[S]): def __init__( self, - model_type: Literal["chart", "dataset", "dashboard"], + model_type: ModelType, dao_class: Type[BaseDAO[Any]], output_schema: Type[S], select_columns: List[Any], @@ -530,13 +531,14 @@ def __init__( search_columns: List[str], default_sort: str = "changed_on", default_sort_direction: Literal["asc", "desc"] = "desc", + exclude_filter_columns: set[str] | None = None, logger: logging.Logger | None = None, ) -> None: """ Initialize the schema discovery core. Args: - model_type: The type of model (chart, dataset, dashboard) + model_type: The type of model (chart, dataset, dashboard, database) dao_class: The DAO class to query for filter columns output_schema: Pydantic schema for the response (e.g., ModelSchemaInfo) select_columns: Column metadata (List[ColumnMetadata] or similar) @@ -545,6 +547,8 @@ def __init__( search_columns: Column names used for text search default_sort: Default sort column default_sort_direction: Default sort direction + exclude_filter_columns: Column names to omit from filter discovery + (e.g., sensitive fields like passwords or connection URIs) logger: Optional logger instance """ super().__init__(logger) @@ -557,6 +561,7 @@ def __init__( self.search_columns = search_columns self.default_sort = default_sort self.default_sort_direction = default_sort_direction + self.exclude_filter_columns = exclude_filter_columns or set() def _get_filter_columns(self) -> Dict[str, List[str]]: """Get filterable columns and operators from the DAO.""" @@ -567,16 +572,25 @@ def _get_filter_columns(self) -> Dict[str, List[str]]: return {} # Convert to dict safely - handle both dict and dict-like objects if isinstance(filterable, dict): - return dict(filterable) - # Try to convert mapping-like objects - try: - return dict(filterable) - except (TypeError, ValueError): - self._log_warning( - f"Unexpected filter columns type for {self.model_type}: " - f"{type(filterable)}" - ) - return {} + result = dict(filterable) + else: + # Try to convert mapping-like objects + try: + result = dict(filterable) + except (TypeError, ValueError): + self._log_warning( + f"Unexpected filter columns type for {self.model_type}: " + f"{type(filterable)}" + ) + return {} + # Remove excluded columns (e.g., sensitive fields) + if self.exclude_filter_columns: + result = { + k: v + for k, v in result.items() + if k not in self.exclude_filter_columns + } + return result except Exception as e: self._log_warning( f"Failed to get filter columns for {self.model_type}: {e}" diff --git a/superset/mcp_service/system/tool/get_schema.py b/superset/mcp_service/system/tool/get_schema.py index 57b8909c3014..a0bc8213a4ff 100644 --- a/superset/mcp_service/system/tool/get_schema.py +++ b/superset/mcp_service/system/tool/get_schema.py @@ -24,7 +24,7 @@ """ import logging -from typing import Callable, Literal +from typing import Callable from fastmcp import Context from superset_core.mcp.decorators import tool, ToolAnnotations @@ -37,16 +37,21 @@ DASHBOARD_DEFAULT_COLUMNS, DASHBOARD_SEARCH_COLUMNS, DASHBOARD_SORTABLE_COLUMNS, + DATABASE_DEFAULT_COLUMNS, + DATABASE_SEARCH_COLUMNS, + DATABASE_SORTABLE_COLUMNS, DATASET_DEFAULT_COLUMNS, DATASET_SEARCH_COLUMNS, DATASET_SORTABLE_COLUMNS, get_chart_columns, get_dashboard_columns, + get_database_columns, get_dataset_columns, GetSchemaRequest, GetSchemaResponse, ModelSchemaInfo, ) +from superset.mcp_service.constants import ModelType from superset.mcp_service.mcp_core import ModelGetSchemaCore logger = logging.getLogger(__name__) @@ -109,14 +114,36 @@ def _get_dashboard_schema_core() -> ModelGetSchemaCore[ModelSchemaInfo]: ) +def _get_database_schema_core() -> ModelGetSchemaCore[ModelSchemaInfo]: + """Create database schema core with dynamically extracted columns.""" + # Lazy import to avoid circular dependency at module load time + from superset.daos.database import DatabaseDAO + from superset.mcp_service.common.schema_discovery import DATABASE_EXCLUDE_COLUMNS + + return ModelGetSchemaCore( + model_type="database", + dao_class=DatabaseDAO, + output_schema=ModelSchemaInfo, + select_columns=get_database_columns(), + sortable_columns=DATABASE_SORTABLE_COLUMNS, + default_columns=DATABASE_DEFAULT_COLUMNS, + search_columns=DATABASE_SEARCH_COLUMNS, + default_sort="changed_on", + default_sort_direction="desc", + exclude_filter_columns=DATABASE_EXCLUDE_COLUMNS, + logger=logger, + ) + + # Map model types to their core factory functions _SCHEMA_CORE_FACTORIES: dict[ - Literal["chart", "dataset", "dashboard"], + ModelType, Callable[[], ModelGetSchemaCore[ModelSchemaInfo]], ] = { "chart": _get_chart_schema_core, "dataset": _get_dataset_schema_core, "dashboard": _get_dashboard_schema_core, + "database": _get_database_schema_core, } @@ -143,7 +170,7 @@ async def get_schema(request: GetSchemaRequest, ctx: Context) -> GetSchemaRespon Column metadata is extracted dynamically from SQLAlchemy models. Args: - model_type: One of "chart", "dataset", or "dashboard" + model_type: One of "chart", "dataset", "dashboard", or "database" Returns: Comprehensive schema information for the requested model type diff --git a/tests/unit_tests/mcp_service/database/__init__.py b/tests/unit_tests/mcp_service/database/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/mcp_service/database/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/database/tool/__init__.py b/tests/unit_tests/mcp_service/database/tool/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/mcp_service/database/tool/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/database/tool/test_database_tools.py b/tests/unit_tests/mcp_service/database/tool/test_database_tools.py new file mode 100644 index 000000000000..78ae81c9962e --- /dev/null +++ b/tests/unit_tests/mcp_service/database/tool/test_database_tools.py @@ -0,0 +1,206 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import logging +from unittest.mock import MagicMock, patch + +import pytest +from fastmcp import Client +from fastmcp.exceptions import ToolError + +from superset.mcp_service.app import mcp +from superset.mcp_service.database.schemas import ListDatabasesRequest +from superset.utils import json + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +def create_mock_database( + database_id: int = 1, + database_name: str = "examples", + backend: str = "postgresql", + expose_in_sqllab: bool = True, + allow_ctas: bool = False, + allow_cvas: bool = False, + allow_dml: bool = False, + allow_file_upload: bool = False, + allow_run_async: bool = False, +) -> MagicMock: + """Factory function to create mock database objects with sensible defaults.""" + database = MagicMock() + database.id = database_id + database.database_name = database_name + database.backend = backend + database.verbose_name = None + database.expose_in_sqllab = expose_in_sqllab + database.allow_ctas = allow_ctas + database.allow_cvas = allow_cvas + database.allow_dml = allow_dml + database.allow_file_upload = allow_file_upload + database.allow_run_async = allow_run_async + database.cache_timeout = None + database.configuration_method = "sqlalchemy_form" + database.force_ctas_schema = None + database.impersonate_user = False + database.is_managed_externally = False + database.external_url = None + database.extra = '{"metadata_params": {}, "engine_params": {}}' + database.uuid = f"test-database-uuid-{database_id}" + database.changed_by_name = "admin" + database.changed_by = None + database.changed_on = None + database.created_by_name = "admin" + database.created_by = None + database.created_on = None + database.owners = [] + return database + + +@pytest.fixture +def mcp_server(): + return mcp + + +@pytest.fixture(autouse=True) +def mock_auth(): + """Mock authentication for all tests.""" + from unittest.mock import Mock, patch + + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +@patch("superset.daos.database.DatabaseDAO.list") +@pytest.mark.asyncio +async def test_list_databases_basic(mock_list, mcp_server): + """Test basic database listing functionality.""" + database = create_mock_database() + database._mapping = { + "id": database.id, + "database_name": database.database_name, + "backend": database.backend, + "expose_in_sqllab": database.expose_in_sqllab, + } + mock_list.return_value = ([database], 1) + async with Client(mcp_server) as client: + request = ListDatabasesRequest(page=1, page_size=10) + result = await client.call_tool( + "list_databases", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["databases"] is not None + assert len(data["databases"]) == 1 + assert data["databases"][0]["id"] == 1 + assert data["databases"][0]["database_name"] == "examples" + + +@patch("superset.daos.database.DatabaseDAO.list") +@pytest.mark.asyncio +async def test_list_databases_with_search(mock_list, mcp_server): + """Test database listing with search functionality.""" + database = create_mock_database(database_name="production_db") + database._mapping = { + "id": database.id, + "database_name": database.database_name, + } + mock_list.return_value = ([database], 1) + async with Client(mcp_server) as client: + request = ListDatabasesRequest(page=1, page_size=10, search="production") + result = await client.call_tool( + "list_databases", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["databases"] is not None + assert len(data["databases"]) == 1 + assert data["databases"][0]["database_name"] == "production_db" + + +@patch("superset.daos.database.DatabaseDAO.list") +@pytest.mark.asyncio +async def test_list_databases_with_filters(mock_list, mcp_server): + """Test database listing with filters.""" + database = create_mock_database(expose_in_sqllab=True) + database._mapping = { + "id": database.id, + "database_name": database.database_name, + "expose_in_sqllab": database.expose_in_sqllab, + } + mock_list.return_value = ([database], 1) + async with Client(mcp_server) as client: + request = ListDatabasesRequest( + page=1, + page_size=10, + filters=[ + {"col": "expose_in_sqllab", "opr": "eq", "value": True}, + ], + ) + result = await client.call_tool( + "list_databases", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["databases"] is not None + assert len(data["databases"]) == 1 + + +@patch("superset.daos.database.DatabaseDAO.list") +@pytest.mark.asyncio +async def test_list_databases_api_error(mock_list, mcp_server): + """Test error handling when DAO raises an exception.""" + mock_list.side_effect = ToolError("Database error") + async with Client(mcp_server) as client: + request = ListDatabasesRequest(page=1, page_size=10) + with pytest.raises(ToolError) as excinfo: # noqa: PT012 + await client.call_tool("list_databases", {"request": request.model_dump()}) + assert "Database error" in str(excinfo.value) + + +@patch("superset.daos.database.DatabaseDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_database_info_basic(mock_find, mcp_server): + """Test basic get database info functionality.""" + database = create_mock_database() + mock_find.return_value = database + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_database_info", {"request": {"identifier": 1}} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["id"] == 1 + assert data["database_name"] == "examples" + assert data["backend"] == "postgresql" + + +@patch("superset.daos.database.DatabaseDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_database_info_not_found(mock_find, mcp_server): + """Test get database info when database does not exist.""" + mock_find.return_value = None + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_database_info", {"request": {"identifier": 999}} + ) + assert result.data["error_type"] == "not_found"