From 81c76723e02ecdeb6ebaa6da553eb69fdf408283 Mon Sep 17 00:00:00 2001 From: Christopher Sapinski Date: Thu, 12 Mar 2026 10:41:39 -0700 Subject: [PATCH 01/11] initial commit --- marimo/_dependencies/dependencies.py | 1 + marimo/_sql/engines/starrocks.py | 346 ++++++++++++++++++++++ marimo/_sql/get_engines.py | 2 + marimo/_sql/sql_quoting.py | 2 +- pyproject.toml | 2 + tests/_sql/test_starrocks.py | 421 +++++++++++++++++++++++++++ 6 files changed, 773 insertions(+), 1 deletion(-) create mode 100644 marimo/_sql/engines/starrocks.py create mode 100644 tests/_sql/test_starrocks.py diff --git a/marimo/_dependencies/dependencies.py b/marimo/_dependencies/dependencies.py index ba143557398..9a1a4da9aca 100644 --- a/marimo/_dependencies/dependencies.py +++ b/marimo/_dependencies/dependencies.py @@ -247,6 +247,7 @@ class DependencyManager: boto3 = Dependency("boto3") litellm = Dependency("litellm") redshift_connector = Dependency("redshift_connector") + starrocks = Dependency("starrocks") mcp = Dependency("mcp") pydantic_ai = Dependency( "pydantic_ai", pkg_name_to_install="pydantic-ai-slim" diff --git a/marimo/_sql/engines/starrocks.py b/marimo/_sql/engines/starrocks.py new file mode 100644 index 00000000000..420f553fb86 --- /dev/null +++ b/marimo/_sql/engines/starrocks.py @@ -0,0 +1,346 @@ +# Copyright 2026 Marimo. All rights reserved. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, Optional, Union + +from marimo import _loggers +from marimo._data.models import Database, DataTable, DataTableColumn, Schema +from marimo._dependencies.dependencies import DependencyManager +from marimo._sql.engines.types import InferenceConfig, SQLConnection +from marimo._sql.sql_quoting import quote_sql_identifier +from marimo._sql.utils import convert_to_output, sql_type_to_data_type +from marimo._types.ids import VariableName + +LOGGER = _loggers.marimo_logger() + +if TYPE_CHECKING: + import pandas as pd + import polars as pl + from sqlalchemy import Engine + +# StarRocks databases (marimo Schemas) that are internal and not useful to surface. +_SYSTEM_SCHEMAS = frozenset({"information_schema", "sys", "_statistics_"}) + + +def _quote(name: str) -> str: + return quote_sql_identifier(name, dialect="starrocks") + + +class StarRocksEngine(SQLConnection["Engine"]): + """StarRocks SQL engine with multi-catalog support. + + StarRocks uses a three-level hierarchy: Catalog → Database → Table. + This maps to marimo's Database → Schema → Table model: + + - marimo ``Database`` ↔ StarRocks Catalog + - marimo ``Schema`` ↔ StarRocks Database + - marimo ``DataTable`` ↔ StarRocks Table + """ + + def __init__( + self, connection: Engine, engine_name: Optional[VariableName] = None + ) -> None: + super().__init__(connection, engine_name) + + @property + def source(self) -> str: + return "starrocks" + + @property + def dialect(self) -> str: + return "starrocks" + + @staticmethod + def is_compatible(var: Any) -> bool: + if not DependencyManager.sqlalchemy.imported(): + return False + if not DependencyManager.starrocks.imported(): + return False + + from sqlalchemy.engine import Engine + + return isinstance(var, Engine) and str(var.dialect.name) == "starrocks" + + @property + def inference_config(self) -> InferenceConfig: + return InferenceConfig( + auto_discover_schemas=True, + auto_discover_tables="auto", + auto_discover_columns=False, + ) + + def execute(self, query: str) -> Any: + from sqlalchemy import text + + sql_output_format = self.sql_output_format() + + with self._connection.connect() as conn: + result = conn.execute(text(query)) + if sql_output_format == "native": + return result + + rows = result.fetchall() if result.returns_rows else None + + try: + conn.commit() + except Exception: + LOGGER.info("Unable to commit transaction", exc_info=True) + + if rows is None: + return None + + def convert_to_polars() -> pl.DataFrame: + import polars as pl + + return pl.DataFrame(rows) + + def convert_to_pandas() -> pd.DataFrame: + import pandas as pd + + return pd.DataFrame(rows) + + return convert_to_output( + sql_output_format=sql_output_format, + to_polars=convert_to_polars, + to_pandas=convert_to_pandas, + ) + + def get_default_database(self) -> Optional[str]: + """Return the name of the current catalog.""" + try: + from sqlalchemy import text + + with self._connection.connect() as conn: + row = conn.execute( + text("SELECT CATALOG()") + ).fetchone() + if row is not None and row[0] is not None: + return str(row[0]) + except Exception: + LOGGER.warning("Failed to get current catalog", exc_info=True) + return None + + def get_default_schema(self) -> Optional[str]: + """Return the name of the current database within the current catalog.""" + try: + from sqlalchemy import text + + with self._connection.connect() as conn: + row = conn.execute(text("SELECT DATABASE()")).fetchone() + if row is not None and row[0] is not None: + return str(row[0]) + except Exception: + LOGGER.warning("Failed to get current database", exc_info=True) + return None + + def get_databases( + self, + *, + include_schemas: Union[bool, Literal["auto"]], + include_tables: Union[bool, Literal["auto"]], + include_table_details: Union[bool, Literal["auto"]], + ) -> list[Database]: + """Return all catalogs, each containing its databases as schemas. + + Args: + include_schemas: Whether to enumerate databases within each + catalog. ``"auto"`` resolves to ``True``. + include_tables: Whether to enumerate tables within each database. + ``"auto"`` resolves to ``False`` (StarRocks catalogs can be + very large, so table discovery is opt-in). + include_table_details: Whether to fetch column-level metadata for + each table. ``"auto"`` resolves to ``False``. + """ + should_include_schemas = self._resolve_auto(include_schemas, default=True) + should_include_tables = self._resolve_auto(include_tables, default=False) + should_include_details = self._resolve_auto( + include_table_details, default=False + ) + + databases: list[Database] = [] + for catalog in self._list_catalogs(): + schemas: list[Schema] = [] + if should_include_schemas: + for db_name in self._list_databases_in_catalog(catalog): + tables: list[DataTable] = [] + if should_include_tables: + tables = self.get_tables_in_schema( + schema=db_name, + database=catalog, + include_table_details=should_include_details, + ) + schemas.append(Schema(name=db_name, tables=tables)) + databases.append( + Database( + name=catalog, + dialect=self.dialect, + schemas=schemas, + engine=self._engine_name, + ) + ) + return databases + + def get_tables_in_schema( + self, *, schema: str, database: str, include_table_details: bool + ) -> list[DataTable]: + """Return all tables in a StarRocks database. + + Args: + schema: The StarRocks database name. + database: The StarRocks catalog name. + include_table_details: Whether to fetch column metadata. + """ + try: + from sqlalchemy import text + + query = ( + f"SELECT TABLE_NAME, TABLE_TYPE " + f"FROM {_quote(database)}.information_schema.tables " + f"WHERE TABLE_SCHEMA = :schema" + ) + with self._connection.connect() as conn: + rows = conn.execute(text(query), {"schema": schema}).fetchall() + except Exception: + LOGGER.warning( + "Failed to get tables in %r.%r", + database, + schema, + exc_info=True, + ) + return [] + + tables: list[DataTable] = [] + for row in rows: + table_name = str(row[0]) + raw_type = str(row[1]).upper() if row[1] else "BASE TABLE" + table_type: Literal["table", "view"] = ( + "view" if "VIEW" in raw_type else "table" + ) + + if not include_table_details: + tables.append( + DataTable( + source_type="connection", + source=self.dialect, + name=table_name, + num_rows=None, + num_columns=None, + variable_name=None, + engine=self._engine_name, + type=table_type, + columns=[], + primary_keys=[], + indexes=[], + ) + ) + else: + table = self.get_table_details( + table_name=table_name, + schema_name=schema, + database_name=database, + ) + if table is not None: + table.type = table_type + tables.append(table) + + return tables + + def get_table_details( + self, *, table_name: str, schema_name: str, database_name: str + ) -> Optional[DataTable]: + """Fetch column-level metadata for a table. + + Args: + table_name: The table name. + schema_name: The StarRocks database name. + database_name: The StarRocks catalog name. + """ + try: + from sqlalchemy import text + + query = ( + f"SELECT COLUMN_NAME, DATA_TYPE " + f"FROM {_quote(database_name)}.information_schema.columns " + f"WHERE TABLE_SCHEMA = :schema AND TABLE_NAME = :table " + f"ORDER BY ORDINAL_POSITION" + ) + with self._connection.connect() as conn: + rows = conn.execute( + text(query), + {"schema": schema_name, "table": table_name}, + ).fetchall() + except Exception: + LOGGER.warning( + "Failed to get details for %r.%r.%r", + database_name, + schema_name, + table_name, + exc_info=True, + ) + return None + + columns = [ + DataTableColumn( + name=str(row[0]), + type=sql_type_to_data_type(str(row[1])), + external_type=str(row[1]), + sample_values=[], + ) + for row in rows + ] + + return DataTable( + source_type="connection", + source=self.dialect, + name=table_name, + num_rows=None, + num_columns=len(columns), + variable_name=None, + engine=self._engine_name, + columns=columns, + primary_keys=[], + indexes=[], + ) + + def _list_catalogs(self) -> list[str]: + """Return all catalog names, excluding built-in system catalogs.""" + try: + from sqlalchemy import text + + with self._connection.connect() as conn: + rows = conn.execute(text("SHOW CATALOGS")).fetchall() + return [str(row[0]) for row in rows] + except Exception: + LOGGER.warning("Failed to list catalogs", exc_info=True) + return [] + + def _list_databases_in_catalog(self, catalog: str) -> list[str]: + """Return all database names within *catalog*, excluding system databases.""" + try: + from sqlalchemy import text + + with self._connection.connect() as conn: + rows = conn.execute( + text(f"SHOW DATABASES IN {_quote(catalog)}") + ).fetchall() + return [ + str(row[0]) + for row in rows + if str(row[0]).lower() not in _SYSTEM_SCHEMAS + ] + except Exception: + LOGGER.warning( + "Failed to list databases in catalog %r", + catalog, + exc_info=True, + ) + return [] + + @staticmethod + def _resolve_auto( + value: Union[bool, Literal["auto"]], *, default: bool + ) -> bool: + """Resolve an ``"auto"`` inference flag to a concrete boolean.""" + if value == "auto": + return default + return value diff --git a/marimo/_sql/get_engines.py b/marimo/_sql/get_engines.py index 7ddcdb23e90..8f263670849 100644 --- a/marimo/_sql/get_engines.py +++ b/marimo/_sql/get_engines.py @@ -22,6 +22,7 @@ from marimo._sql.engines.pyiceberg import PyIcebergEngine from marimo._sql.engines.redshift import RedshiftEngine from marimo._sql.engines.sqlalchemy import SQLAlchemyEngine +from marimo._sql.engines.starrocks import StarRocksEngine from marimo._sql.engines.types import ( BaseEngine, EngineCatalog, @@ -33,6 +34,7 @@ # TODO: this is O(n) and can be O(1) using similar logic to the # formatters, but order does matter here SUPPORTED_ENGINES: list[type[BaseEngine[Any]]] = [ + StarRocksEngine, SQLAlchemyEngine, IbisEngine, DuckDBEngine, diff --git a/marimo/_sql/sql_quoting.py b/marimo/_sql/sql_quoting.py index 1d20977b5db..d392f09cc03 100644 --- a/marimo/_sql/sql_quoting.py +++ b/marimo/_sql/sql_quoting.py @@ -22,7 +22,7 @@ def quote_sql_identifier(identifier: str, *, dialect: str = "duckdb") -> str: # Double-quote style: escape embedded " as "" escaped = identifier.replace('"', '""') return f'"{escaped}"' - elif dialect in ("clickhouse", "mysql", "bigquery"): + elif dialect in ("clickhouse", "mysql", "bigquery", "starrocks"): # Backtick style: escape embedded ` as `` escaped = identifier.replace("`", "``") return f"`{escaped}`" diff --git a/pyproject.toml b/pyproject.toml index 179b18f9a44..25c7c17eeab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -178,6 +178,7 @@ test-optional = [ "chdb>=3; platform_system != 'Windows'", # there is no suitable wheel for windows "clickhouse-connect>=0.8.18", "redshift-connector[full]>=2.1.7", + "starrocks>=1.3.0", "pandas>=1.5.3", "hvplot~=0.11.3", "geopandas>=1.1.0", @@ -499,6 +500,7 @@ banned-module-level-imports = [ "typing_extensions", "pyiceberg", "redshift_connector", + "starrocks", "pydantic_ai" ] diff --git a/tests/_sql/test_starrocks.py b/tests/_sql/test_starrocks.py new file mode 100644 index 00000000000..5ea60b5046c --- /dev/null +++ b/tests/_sql/test_starrocks.py @@ -0,0 +1,421 @@ +# Copyright 2026 Marimo. All rights reserved. +from __future__ import annotations + +import sys +from typing import Any +from unittest.mock import MagicMock, patch # noqa: F401 + +import pytest + +from marimo._data.models import Database, DataTable, Schema +from marimo._sql.engines.starrocks import ( + _SYSTEM_SCHEMAS, + StarRocksEngine, +) +from marimo._sql.sql_quoting import quote_sql_identifier + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _mock_sqlalchemy_if_missing(): + """Patch sys.modules with a lightweight sqlalchemy stub when the real + package is not installed, so mock-based tests can still run without it.""" + if "sqlalchemy" in sys.modules: + yield + return + + mock_sa = MagicMock() + # `text()` is used as a pass-through wrapper; the result is fed into the + # mocked conn.execute, so its exact return value doesn't matter. + mock_sa.text = MagicMock(side_effect=lambda q: q) + with patch.dict(sys.modules, {"sqlalchemy": mock_sa}): + yield + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_mock_engine(dialect_name: str = "starrocks") -> MagicMock: + """Return a mock SQLAlchemy Engine with the given dialect name.""" + mock_engine = MagicMock() + mock_engine.dialect.name = dialect_name + return mock_engine + + +def _make_engine(dialect_name: str = "starrocks") -> StarRocksEngine: + return StarRocksEngine(_make_mock_engine(dialect_name), engine_name="sr") + + +def _mock_connection_ctx(engine: StarRocksEngine, side_effects: list[Any]): + """Patch _connection.connect() so that successive execute() calls return + the given side_effects in order (each item is the rows list for one call). + """ + conn_ctx = MagicMock() + conn = MagicMock() + conn_ctx.__enter__ = MagicMock(return_value=conn) + conn_ctx.__exit__ = MagicMock(return_value=False) + engine._connection.connect = MagicMock(return_value=conn_ctx) + + results = [] + for rows in side_effects: + result = MagicMock() + result.fetchone = MagicMock(return_value=rows[0] if rows else None) + result.fetchall = MagicMock(return_value=rows) + results.append(result) + + conn.execute = MagicMock(side_effect=results) + return conn + + +# --------------------------------------------------------------------------- +# is_compatible +# --------------------------------------------------------------------------- + + +class TestIsCompatible: + @pytest.mark.requires("sqlalchemy", "starrocks") + def test_compatible_with_starrocks_dialect(self) -> None: + import sqlalchemy as sa + + mock_engine = MagicMock(spec=sa.Engine) + mock_engine.dialect.name = "starrocks" + assert StarRocksEngine.is_compatible(mock_engine) + + @pytest.mark.requires("sqlalchemy", "starrocks") + def test_not_compatible_with_other_dialects(self) -> None: + import sqlalchemy as sa + + for dialect in ("mysql", "postgresql", "sqlite", "clickhouse"): + mock_engine = MagicMock(spec=sa.Engine) + mock_engine.dialect.name = dialect + assert not StarRocksEngine.is_compatible(mock_engine) + + @pytest.mark.requires("sqlalchemy", "starrocks") + def test_not_compatible_with_non_engine(self) -> None: + assert not StarRocksEngine.is_compatible("not_an_engine") + assert not StarRocksEngine.is_compatible(42) + assert not StarRocksEngine.is_compatible(None) + + +# --------------------------------------------------------------------------- +# source / dialect +# --------------------------------------------------------------------------- + + +class TestSourceAndDialect: + def test_source(self) -> None: + engine = _make_engine() + assert engine.source == "starrocks" + + def test_dialect(self) -> None: + engine = _make_engine() + assert engine.dialect == "starrocks" + + +# --------------------------------------------------------------------------- +# get_default_database / get_default_schema +# --------------------------------------------------------------------------- + + +class TestDefaults: + def test_get_default_database(self) -> None: + engine = _make_engine() + _mock_connection_ctx(engine, [[("default_catalog",)]]) + assert engine.get_default_database() == "default_catalog" + # Verify it uses CATALOG() not CURRENT_CATALOG() + conn = engine._connection.connect().__enter__() + call_args = conn.execute.call_args_list + assert any("CATALOG()" in str(c) for c in call_args) + + def test_get_default_database_none_on_error(self) -> None: + engine = _make_engine() + engine._connection.connect.side_effect = Exception("connection failed") + assert engine.get_default_database() is None + + def test_get_default_schema(self) -> None: + engine = _make_engine() + _mock_connection_ctx(engine, [[("my_db",)]]) + assert engine.get_default_schema() == "my_db" + + def test_get_default_schema_none_on_error(self) -> None: + engine = _make_engine() + engine._connection.connect.side_effect = Exception("connection failed") + assert engine.get_default_schema() is None + + +# --------------------------------------------------------------------------- +# _list_catalogs / _list_databases_in_catalog +# --------------------------------------------------------------------------- + + +class TestListCatalogs: + def test_lists_all_catalogs(self) -> None: + engine = _make_engine() + rows = [ + ("default_catalog",), + ("hive_catalog",), + ("iceberg_catalog",), + ] + _mock_connection_ctx(engine, [rows]) + result = engine._list_catalogs() + assert result == ["default_catalog", "hive_catalog", "iceberg_catalog"] + + def test_returns_empty_on_error(self) -> None: + engine = _make_engine() + engine._connection.connect.side_effect = Exception("oops") + assert engine._list_catalogs() == [] + + +class TestListDatabases: + def test_lists_databases_excluding_system(self) -> None: + engine = _make_engine() + rows = [ + ("tpch",), + ("analytics",), + ("information_schema",), # excluded + ("sys",), # excluded + ("_statistics_",), # excluded + ] + _mock_connection_ctx(engine, [rows]) + result = engine._list_databases_in_catalog("default_catalog") + assert result == ["tpch", "analytics"] + + def test_returns_empty_on_error(self) -> None: + engine = _make_engine() + engine._connection.connect.side_effect = Exception("oops") + assert engine._list_databases_in_catalog("default_catalog") == [] + + +# --------------------------------------------------------------------------- +# get_databases +# --------------------------------------------------------------------------- + + +class TestGetDatabases: + def test_returns_catalog_as_database(self) -> None: + engine = _make_engine() + # Call 1: SHOW CATALOGS + # Call 2: SHOW DATABASES IN `default_catalog` + catalogs_rows = [("default_catalog",), ("hive_catalog",)] + db_rows_default = [("tpch",), ("analytics",)] + db_rows_hive = [("lake",)] + + conn_ctx = MagicMock() + conn = MagicMock() + conn_ctx.__enter__ = MagicMock(return_value=conn) + conn_ctx.__exit__ = MagicMock(return_value=False) + engine._connection.connect = MagicMock(return_value=conn_ctx) + + results = [] + for rows in [catalogs_rows, db_rows_default, db_rows_hive]: + r = MagicMock() + r.fetchall = MagicMock(return_value=rows) + results.append(r) + conn.execute = MagicMock(side_effect=results) + + databases = engine.get_databases( + include_schemas=True, + include_tables=False, + include_table_details=False, + ) + + assert len(databases) == 2 + assert databases[0].name == "default_catalog" + assert databases[1].name == "hive_catalog" + assert [s.name for s in databases[0].schemas] == ["tpch", "analytics"] + assert [s.name for s in databases[1].schemas] == ["lake"] + for db in databases: + assert db.dialect == "starrocks" + assert db.engine == "sr" + + def test_no_schemas_when_include_schemas_false(self) -> None: + engine = _make_engine() + conn_ctx = MagicMock() + conn = MagicMock() + conn_ctx.__enter__ = MagicMock(return_value=conn) + conn_ctx.__exit__ = MagicMock(return_value=False) + engine._connection.connect = MagicMock(return_value=conn_ctx) + + catalogs_result = MagicMock() + catalogs_result.fetchall = MagicMock( + return_value=[("default_catalog",)] + ) + conn.execute = MagicMock(return_value=catalogs_result) + + databases = engine.get_databases( + include_schemas=False, + include_tables=False, + include_table_details=False, + ) + + assert len(databases) == 1 + assert databases[0].name == "default_catalog" + assert databases[0].schemas == [] + + def test_auto_includes_schemas_excludes_tables(self) -> None: + """'auto' should resolve to include_schemas=True, include_tables=False.""" + engine = _make_engine() + conn_ctx = MagicMock() + conn = MagicMock() + conn_ctx.__enter__ = MagicMock(return_value=conn) + conn_ctx.__exit__ = MagicMock(return_value=False) + engine._connection.connect = MagicMock(return_value=conn_ctx) + + # SHOW CATALOGS → 1 catalog; SHOW DATABASES → 1 db + r1 = MagicMock() + r1.fetchall = MagicMock(return_value=[("default_catalog",)]) + r2 = MagicMock() + r2.fetchall = MagicMock(return_value=[("tpch",)]) + conn.execute = MagicMock(side_effect=[r1, r2]) + + databases = engine.get_databases( + include_schemas="auto", + include_tables="auto", + include_table_details="auto", + ) + + assert len(databases) == 1 + assert databases[0].schemas[0].name == "tpch" + # Tables should NOT be fetched (auto → False for tables) + assert databases[0].schemas[0].tables == [] + + +# --------------------------------------------------------------------------- +# get_tables_in_schema +# --------------------------------------------------------------------------- + + +class TestGetTablesInSchema: + def test_returns_tables_and_views(self) -> None: + engine = _make_engine() + rows = [ + ("orders", "BASE TABLE"), + ("lineitem", "BASE TABLE"), + ("revenue_view", "VIEW"), + ] + _mock_connection_ctx(engine, [rows]) + + tables = engine.get_tables_in_schema( + schema="tpch", + database="default_catalog", + include_table_details=False, + ) + + assert len(tables) == 3 + names = [t.name for t in tables] + assert "orders" in names + assert "lineitem" in names + assert "revenue_view" in names + view = next(t for t in tables if t.name == "revenue_view") + assert view.type == "view" + base = next(t for t in tables if t.name == "orders") + assert base.type == "table" + # No columns without details + assert base.columns == [] + + def test_returns_empty_on_error(self) -> None: + engine = _make_engine() + engine._connection.connect.side_effect = Exception("fail") + result = engine.get_tables_in_schema( + schema="tpch", database="default_catalog", include_table_details=False + ) + assert result == [] + + +# --------------------------------------------------------------------------- +# get_table_details +# --------------------------------------------------------------------------- + + +class TestGetTableDetails: + def test_returns_columns(self) -> None: + engine = _make_engine() + rows = [ + ("id", "INT"), + ("name", "VARCHAR"), + ("created_at", "DATETIME"), + ("score", "DOUBLE"), + ("is_active", "BOOLEAN"), + ] + _mock_connection_ctx(engine, [rows]) + + table = engine.get_table_details( + table_name="orders", + schema_name="tpch", + database_name="default_catalog", + ) + + assert table is not None + assert table.name == "orders" + assert table.num_columns == 5 + assert len(table.columns) == 5 + + types = {c.name: c.type for c in table.columns} + assert types["id"] == "integer" + assert types["name"] == "string" + assert types["created_at"] == "datetime" + assert types["score"] == "number" + assert types["is_active"] == "boolean" + + def test_returns_none_on_error(self) -> None: + engine = _make_engine() + engine._connection.connect.side_effect = Exception("fail") + result = engine.get_table_details( + table_name="orders", + schema_name="tpch", + database_name="default_catalog", + ) + assert result is None + + +# --------------------------------------------------------------------------- +# SQL quoting integration +# --------------------------------------------------------------------------- + + +class TestStarRocksQuoting: + def test_starrocks_uses_backtick_style(self) -> None: + assert quote_sql_identifier("my_catalog", dialect="starrocks") == "`my_catalog`" + assert ( + quote_sql_identifier("catalog`with`ticks", dialect="starrocks") + == "`catalog``with``ticks`" + ) + assert ( + quote_sql_identifier("catalog with spaces", dialect="starrocks") + == "`catalog with spaces`" + ) + + +# --------------------------------------------------------------------------- +# _resolve_auto +# --------------------------------------------------------------------------- + + +class TestResolveAuto: + def test_true_stays_true(self) -> None: + assert StarRocksEngine._resolve_auto(True, default=False) is True + + def test_false_stays_false(self) -> None: + assert StarRocksEngine._resolve_auto(False, default=True) is False + + def test_auto_returns_default(self) -> None: + assert StarRocksEngine._resolve_auto("auto", default=True) is True + assert StarRocksEngine._resolve_auto("auto", default=False) is False + + +# --------------------------------------------------------------------------- +# System catalog / database constants +# --------------------------------------------------------------------------- + + +class TestSystemConstants: + def test_system_schemas_excluded(self) -> None: + assert "information_schema" in _SYSTEM_SCHEMAS + assert "sys" in _SYSTEM_SCHEMAS + assert "_statistics_" in _SYSTEM_SCHEMAS From f5da6fcfaa8e0776d31a06b8934dda2355606c3c Mon Sep 17 00:00:00 2001 From: Christopher Sapinski Date: Thu, 12 Mar 2026 16:31:38 -0700 Subject: [PATCH 02/11] refactored inheritance --- marimo/_sql/engines/starrocks.py | 353 ++++++++++++++++--------------- tests/_sql/test_starrocks.py | 47 ++-- 2 files changed, 216 insertions(+), 184 deletions(-) diff --git a/marimo/_sql/engines/starrocks.py b/marimo/_sql/engines/starrocks.py index 420f553fb86..9301e03b281 100644 --- a/marimo/_sql/engines/starrocks.py +++ b/marimo/_sql/engines/starrocks.py @@ -6,16 +6,14 @@ from marimo import _loggers from marimo._data.models import Database, DataTable, DataTableColumn, Schema from marimo._dependencies.dependencies import DependencyManager -from marimo._sql.engines.types import InferenceConfig, SQLConnection +from marimo._sql.engines.sqlalchemy import SQLAlchemyEngine from marimo._sql.sql_quoting import quote_sql_identifier -from marimo._sql.utils import convert_to_output, sql_type_to_data_type +from marimo._sql.utils import sql_type_to_data_type from marimo._types.ids import VariableName LOGGER = _loggers.marimo_logger() if TYPE_CHECKING: - import pandas as pd - import polars as pl from sqlalchemy import Engine # StarRocks databases (marimo Schemas) that are internal and not useful to surface. @@ -26,9 +24,13 @@ def _quote(name: str) -> str: return quote_sql_identifier(name, dialect="starrocks") -class StarRocksEngine(SQLConnection["Engine"]): +class StarRocksEngine(SQLAlchemyEngine): """StarRocks SQL engine with multi-catalog support. + Extends :class:`SQLAlchemyEngine`, inheriting the SQLAlchemy inspector + pattern for the connected (default) catalog. External catalogs fall back + to explicit SQL because the inspector is bound to a single catalog. + StarRocks uses a three-level hierarchy: Catalog → Database → Table. This maps to marimo's Database → Schema → Table model: @@ -37,19 +39,10 @@ class StarRocksEngine(SQLConnection["Engine"]): - marimo ``DataTable`` ↔ StarRocks Table """ - def __init__( - self, connection: Engine, engine_name: Optional[VariableName] = None - ) -> None: - super().__init__(connection, engine_name) - @property def source(self) -> str: return "starrocks" - @property - def dialect(self) -> str: - return "starrocks" - @staticmethod def is_compatible(var: Any) -> bool: if not DependencyManager.sqlalchemy.imported(): @@ -61,76 +54,25 @@ def is_compatible(var: Any) -> bool: return isinstance(var, Engine) and str(var.dialect.name) == "starrocks" - @property - def inference_config(self) -> InferenceConfig: - return InferenceConfig( - auto_discover_schemas=True, - auto_discover_tables="auto", - auto_discover_columns=False, - ) - - def execute(self, query: str) -> Any: - from sqlalchemy import text - - sql_output_format = self.sql_output_format() - - with self._connection.connect() as conn: - result = conn.execute(text(query)) - if sql_output_format == "native": - return result - - rows = result.fetchall() if result.returns_rows else None - - try: - conn.commit() - except Exception: - LOGGER.info("Unable to commit transaction", exc_info=True) - - if rows is None: - return None - - def convert_to_polars() -> pl.DataFrame: - import polars as pl - - return pl.DataFrame(rows) - - def convert_to_pandas() -> pd.DataFrame: - import pandas as pd - - return pd.DataFrame(rows) - - return convert_to_output( - sql_output_format=sql_output_format, - to_polars=convert_to_polars, - to_pandas=convert_to_pandas, - ) + # ------------------------------------------------------------------ + # Default catalog / schema + # ------------------------------------------------------------------ def get_default_database(self) -> Optional[str]: - """Return the name of the current catalog.""" - try: - from sqlalchemy import text - - with self._connection.connect() as conn: - row = conn.execute( - text("SELECT CATALOG()") - ).fetchone() - if row is not None and row[0] is not None: - return str(row[0]) - except Exception: - LOGGER.warning("Failed to get current catalog", exc_info=True) - return None + """Return the current StarRocks catalog via ``SELECT CATALOG()``. - def get_default_schema(self) -> Optional[str]: - """Return the name of the current database within the current catalog.""" + Overrides the parent which reads from the SQLAlchemy connection URL, + because StarRocks exposes catalogs rather than a single database. + """ try: from sqlalchemy import text with self._connection.connect() as conn: - row = conn.execute(text("SELECT DATABASE()")).fetchone() + row = conn.execute(text("SELECT CATALOG()")).fetchone() if row is not None and row[0] is not None: return str(row[0]) except Exception: - LOGGER.warning("Failed to get current database", exc_info=True) + LOGGER.warning("Failed to get current catalog", exc_info=True) return None def get_databases( @@ -140,36 +82,44 @@ def get_databases( include_tables: Union[bool, Literal["auto"]], include_table_details: Union[bool, Literal["auto"]], ) -> list[Database]: - """Return all catalogs, each containing its databases as schemas. - - Args: - include_schemas: Whether to enumerate databases within each - catalog. ``"auto"`` resolves to ``True``. - include_tables: Whether to enumerate tables within each database. - ``"auto"`` resolves to ``False`` (StarRocks catalogs can be - very large, so table discovery is opt-in). - include_table_details: Whether to fetch column-level metadata for - each table. ``"auto"`` resolves to ``False``. + """Return all StarRocks catalogs, each containing its databases as schemas. + + Uses the inherited inspector path for the default catalog and explicit + SQL for external catalogs. + + ``"auto"`` resolution: + - ``include_schemas`` → ``True`` (always show databases) + - ``include_tables`` → ``False`` (StarRocks catalogs can be large) + - ``include_table_details``→ ``False`` """ - should_include_schemas = self._resolve_auto(include_schemas, default=True) - should_include_tables = self._resolve_auto(include_tables, default=False) - should_include_details = self._resolve_auto( - include_table_details, default=False + should_include_schemas = ( + include_schemas if isinstance(include_schemas, bool) else True + ) + should_include_tables = self._resolve_should_auto_discover(include_tables) + should_include_details = self._resolve_should_auto_discover( + include_table_details ) databases: list[Database] = [] for catalog in self._list_catalogs(): - schemas: list[Schema] = [] if should_include_schemas: - for db_name in self._list_databases_in_catalog(catalog): - tables: list[DataTable] = [] - if should_include_tables: - tables = self.get_tables_in_schema( - schema=db_name, - database=catalog, - include_table_details=should_include_details, - ) - schemas.append(Schema(name=db_name, tables=tables)) + if catalog == self.default_database: + # Inspector-based path (inherited from SQLAlchemyEngine). + schemas = self._get_schemas( + database=catalog, + include_tables=should_include_tables, + include_table_details=should_include_details, + ) + else: + # SQL fallback for external catalogs. + schemas = self._get_external_schemas( + catalog=catalog, + include_tables=should_include_tables, + include_table_details=should_include_details, + ) + else: + schemas = [] + databases.append( Database( name=catalog, @@ -183,23 +133,141 @@ def get_databases( def get_tables_in_schema( self, *, schema: str, database: str, include_table_details: bool ) -> list[DataTable]: - """Return all tables in a StarRocks database. + """Return tables for *schema* inside *database* (a StarRocks catalog). - Args: - schema: The StarRocks database name. - database: The StarRocks catalog name. - include_table_details: Whether to fetch column metadata. + Delegates to the inherited inspector path for the default catalog; + falls back to an ``information_schema`` query for external catalogs. """ + if database == self.default_database: + return super().get_tables_in_schema( + schema=schema, + database=database, + include_table_details=include_table_details, + ) + return self._get_external_tables( + schema=schema, + database=database, + include_table_details=include_table_details, + ) + + def get_table_details( + self, *, table_name: str, schema_name: str, database_name: str + ) -> Optional[DataTable]: + """Return column metadata for a table. + + Delegates to the inherited inspector path for the default catalog; + falls back to an ``information_schema`` query for external catalogs. + """ + if database_name == self.default_database: + return super().get_table_details( + table_name=table_name, + schema_name=schema_name, + database_name=database_name, + ) + return self._get_external_table_details( + table_name=table_name, + schema_name=schema_name, + database_name=database_name, + ) + + # ------------------------------------------------------------------ + # Meta-schema filter (overrides SQLAlchemyEngine._get_meta_schemas) + # ------------------------------------------------------------------ + + def _get_schemas( + self, + *, + database: Optional[str], + include_tables: bool, + include_table_details: bool, + ) -> list[Schema]: + """Filter system schemas out of the result entirely. + + The parent implementation keeps meta-schemas in the list but skips + table discovery for them. For StarRocks we want them hidden from the + sidebar completely. + """ + schemas = super()._get_schemas( + database=database, + include_tables=include_tables, + include_table_details=include_table_details, + ) + return [s for s in schemas if s.name.lower() not in _SYSTEM_SCHEMAS] + + def _get_meta_schemas(self) -> list[str]: + return list(_SYSTEM_SCHEMAS) + + # ------------------------------------------------------------------ + # StarRocks-specific helpers + # ------------------------------------------------------------------ + + def _list_catalogs(self) -> list[str]: + """Return all catalog names via ``SHOW CATALOGS``. + + There is no SQLAlchemy inspector equivalent for catalog enumeration. + """ + try: + from sqlalchemy import text + + with self._connection.connect() as conn: + rows = conn.execute(text("SHOW CATALOGS")).fetchall() + return [str(row[0]) for row in rows] + except Exception: + LOGGER.warning("Failed to list catalogs", exc_info=True) + return [] + + def _get_external_schemas( + self, + *, + catalog: str, + include_tables: bool, + include_table_details: bool, + ) -> list[Schema]: + """List databases in an external catalog via ``SHOW DATABASES``.""" try: from sqlalchemy import text - query = ( - f"SELECT TABLE_NAME, TABLE_TYPE " - f"FROM {_quote(database)}.information_schema.tables " - f"WHERE TABLE_SCHEMA = :schema" + with self._connection.connect() as conn: + rows = conn.execute( + text(f"SHOW DATABASES IN {_quote(catalog)}") + ).fetchall() + db_names = [ + str(row[0]) + for row in rows + if str(row[0]).lower() not in _SYSTEM_SCHEMAS + ] + except Exception: + LOGGER.warning( + "Failed to list databases in catalog %r", + catalog, + exc_info=True, ) + return [] + + schemas: list[Schema] = [] + for db_name in db_names: + tables: list[DataTable] = [] + if include_tables: + tables = self._get_external_tables( + schema=db_name, + database=catalog, + include_table_details=include_table_details, + ) + schemas.append(Schema(name=db_name, tables=tables)) + return schemas + + def _get_external_tables( + self, *, schema: str, database: str, include_table_details: bool + ) -> list[DataTable]: + """List tables in an external catalog via ``SHOW FULL TABLES``.""" + try: + from sqlalchemy import text + + qualified = f"{_quote(database)}.{_quote(schema)}" with self._connection.connect() as conn: - rows = conn.execute(text(query), {"schema": schema}).fetchall() + rows = conn.execute( + text(f"SHOW FULL TABLES FROM {qualified}") + ).fetchall() except Exception: LOGGER.warning( "Failed to get tables in %r.%r", @@ -234,7 +302,7 @@ def get_tables_in_schema( ) ) else: - table = self.get_table_details( + table = self._get_external_table_details( table_name=table_name, schema_name=schema, database_name=database, @@ -245,30 +313,20 @@ def get_tables_in_schema( return tables - def get_table_details( + def _get_external_table_details( self, *, table_name: str, schema_name: str, database_name: str ) -> Optional[DataTable]: - """Fetch column-level metadata for a table. - - Args: - table_name: The table name. - schema_name: The StarRocks database name. - database_name: The StarRocks catalog name. - """ + """Describe an external-catalog table via ``DESC ..``.""" try: from sqlalchemy import text - query = ( - f"SELECT COLUMN_NAME, DATA_TYPE " - f"FROM {_quote(database_name)}.information_schema.columns " - f"WHERE TABLE_SCHEMA = :schema AND TABLE_NAME = :table " - f"ORDER BY ORDINAL_POSITION" + qualified = ( + f"{_quote(database_name)}" + f".{_quote(schema_name)}" + f".{_quote(table_name)}" ) with self._connection.connect() as conn: - rows = conn.execute( - text(query), - {"schema": schema_name, "table": table_name}, - ).fetchall() + rows = conn.execute(text(f"DESC {qualified}")).fetchall() except Exception: LOGGER.warning( "Failed to get details for %r.%r.%r", @@ -301,46 +359,3 @@ def get_table_details( primary_keys=[], indexes=[], ) - - def _list_catalogs(self) -> list[str]: - """Return all catalog names, excluding built-in system catalogs.""" - try: - from sqlalchemy import text - - with self._connection.connect() as conn: - rows = conn.execute(text("SHOW CATALOGS")).fetchall() - return [str(row[0]) for row in rows] - except Exception: - LOGGER.warning("Failed to list catalogs", exc_info=True) - return [] - - def _list_databases_in_catalog(self, catalog: str) -> list[str]: - """Return all database names within *catalog*, excluding system databases.""" - try: - from sqlalchemy import text - - with self._connection.connect() as conn: - rows = conn.execute( - text(f"SHOW DATABASES IN {_quote(catalog)}") - ).fetchall() - return [ - str(row[0]) - for row in rows - if str(row[0]).lower() not in _SYSTEM_SCHEMAS - ] - except Exception: - LOGGER.warning( - "Failed to list databases in catalog %r", - catalog, - exc_info=True, - ) - return [] - - @staticmethod - def _resolve_auto( - value: Union[bool, Literal["auto"]], *, default: bool - ) -> bool: - """Resolve an ``"auto"`` inference flag to a concrete boolean.""" - if value == "auto": - return default - return value diff --git a/tests/_sql/test_starrocks.py b/tests/_sql/test_starrocks.py index 5ea60b5046c..5c02c3dc2c1 100644 --- a/tests/_sql/test_starrocks.py +++ b/tests/_sql/test_starrocks.py @@ -139,12 +139,16 @@ def test_get_default_database_none_on_error(self) -> None: assert engine.get_default_database() is None def test_get_default_schema(self) -> None: + # get_default_schema() is inherited from SQLAlchemyEngine and tries + # inspector.default_schema_name first. engine = _make_engine() - _mock_connection_ctx(engine, [[("my_db",)]]) + engine.inspector = MagicMock() + engine.inspector.default_schema_name = "my_db" assert engine.get_default_schema() == "my_db" def test_get_default_schema_none_on_error(self) -> None: engine = _make_engine() + engine.inspector = None engine._connection.connect.side_effect = Exception("connection failed") assert engine.get_default_schema() is None @@ -172,7 +176,7 @@ def test_returns_empty_on_error(self) -> None: assert engine._list_catalogs() == [] -class TestListDatabases: +class TestExternalSchemas: def test_lists_databases_excluding_system(self) -> None: engine = _make_engine() rows = [ @@ -183,13 +187,21 @@ def test_lists_databases_excluding_system(self) -> None: ("_statistics_",), # excluded ] _mock_connection_ctx(engine, [rows]) - result = engine._list_databases_in_catalog("default_catalog") - assert result == ["tpch", "analytics"] + schemas = engine._get_external_schemas( + catalog="hive_catalog", + include_tables=False, + include_table_details=False, + ) + assert [s.name for s in schemas] == ["tpch", "analytics"] def test_returns_empty_on_error(self) -> None: engine = _make_engine() engine._connection.connect.side_effect = Exception("oops") - assert engine._list_databases_in_catalog("default_catalog") == [] + assert engine._get_external_schemas( + catalog="hive_catalog", + include_tables=False, + include_table_details=False, + ) == [] # --------------------------------------------------------------------------- @@ -294,6 +306,7 @@ def test_auto_includes_schemas_excludes_tables(self) -> None: class TestGetTablesInSchema: def test_returns_tables_and_views(self) -> None: engine = _make_engine() + # SHOW FULL TABLES returns (Tables_in_, Table_type) rows = [ ("orders", "BASE TABLE"), ("lineitem", "BASE TABLE"), @@ -336,12 +349,13 @@ def test_returns_empty_on_error(self) -> None: class TestGetTableDetails: def test_returns_columns(self) -> None: engine = _make_engine() + # DESC output: Field, Type, Null, Key, Default, Extra, Comment rows = [ - ("id", "INT"), - ("name", "VARCHAR"), - ("created_at", "DATETIME"), - ("score", "DOUBLE"), - ("is_active", "BOOLEAN"), + ("id", "INT", "YES", "", None, "", ""), + ("name", "VARCHAR(255)", "YES", "", None, "", ""), + ("created_at", "DATETIME", "YES", "", None, "", ""), + ("score", "DOUBLE", "YES", "", None, "", ""), + ("is_active", "BOOLEAN", "YES", "", None, "", ""), ] _mock_connection_ctx(engine, [rows]) @@ -399,14 +413,17 @@ def test_starrocks_uses_backtick_style(self) -> None: class TestResolveAuto: def test_true_stays_true(self) -> None: - assert StarRocksEngine._resolve_auto(True, default=False) is True + engine = _make_engine() + assert engine._resolve_should_auto_discover(True) is True def test_false_stays_false(self) -> None: - assert StarRocksEngine._resolve_auto(False, default=True) is False + engine = _make_engine() + assert engine._resolve_should_auto_discover(False) is False - def test_auto_returns_default(self) -> None: - assert StarRocksEngine._resolve_auto("auto", default=True) is True - assert StarRocksEngine._resolve_auto("auto", default=False) is False + def test_auto_resolves_to_false_for_starrocks(self) -> None: + # StarRocks is not in CHEAP_DISCOVERY_DATABASES, so "auto" → False. + engine = _make_engine() + assert engine._resolve_should_auto_discover("auto") is False # --------------------------------------------------------------------------- From 34f70e7445380f79a65231bcdf3d86b5b1941827 Mon Sep 17 00:00:00 2001 From: Christopher Sapinski Date: Tue, 17 Mar 2026 13:57:16 -0700 Subject: [PATCH 03/11] remove comments --- tests/_sql/test_starrocks.py | 63 ------------------------------------ 1 file changed, 63 deletions(-) diff --git a/tests/_sql/test_starrocks.py b/tests/_sql/test_starrocks.py index 5c02c3dc2c1..35b1a0f7f6c 100644 --- a/tests/_sql/test_starrocks.py +++ b/tests/_sql/test_starrocks.py @@ -15,11 +15,6 @@ from marimo._sql.sql_quoting import quote_sql_identifier -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - @pytest.fixture(autouse=True) def _mock_sqlalchemy_if_missing(): """Patch sys.modules with a lightweight sqlalchemy stub when the real @@ -29,18 +24,11 @@ def _mock_sqlalchemy_if_missing(): return mock_sa = MagicMock() - # `text()` is used as a pass-through wrapper; the result is fed into the - # mocked conn.execute, so its exact return value doesn't matter. mock_sa.text = MagicMock(side_effect=lambda q: q) with patch.dict(sys.modules, {"sqlalchemy": mock_sa}): yield -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - def _make_mock_engine(dialect_name: str = "starrocks") -> MagicMock: """Return a mock SQLAlchemy Engine with the given dialect name.""" mock_engine = MagicMock() @@ -73,11 +61,6 @@ def _mock_connection_ctx(engine: StarRocksEngine, side_effects: list[Any]): return conn -# --------------------------------------------------------------------------- -# is_compatible -# --------------------------------------------------------------------------- - - class TestIsCompatible: @pytest.mark.requires("sqlalchemy", "starrocks") def test_compatible_with_starrocks_dialect(self) -> None: @@ -103,11 +86,6 @@ def test_not_compatible_with_non_engine(self) -> None: assert not StarRocksEngine.is_compatible(None) -# --------------------------------------------------------------------------- -# source / dialect -# --------------------------------------------------------------------------- - - class TestSourceAndDialect: def test_source(self) -> None: engine = _make_engine() @@ -118,11 +96,6 @@ def test_dialect(self) -> None: assert engine.dialect == "starrocks" -# --------------------------------------------------------------------------- -# get_default_database / get_default_schema -# --------------------------------------------------------------------------- - - class TestDefaults: def test_get_default_database(self) -> None: engine = _make_engine() @@ -153,11 +126,6 @@ def test_get_default_schema_none_on_error(self) -> None: assert engine.get_default_schema() is None -# --------------------------------------------------------------------------- -# _list_catalogs / _list_databases_in_catalog -# --------------------------------------------------------------------------- - - class TestListCatalogs: def test_lists_all_catalogs(self) -> None: engine = _make_engine() @@ -204,11 +172,6 @@ def test_returns_empty_on_error(self) -> None: ) == [] -# --------------------------------------------------------------------------- -# get_databases -# --------------------------------------------------------------------------- - - class TestGetDatabases: def test_returns_catalog_as_database(self) -> None: engine = _make_engine() @@ -298,11 +261,6 @@ def test_auto_includes_schemas_excludes_tables(self) -> None: assert databases[0].schemas[0].tables == [] -# --------------------------------------------------------------------------- -# get_tables_in_schema -# --------------------------------------------------------------------------- - - class TestGetTablesInSchema: def test_returns_tables_and_views(self) -> None: engine = _make_engine() @@ -340,12 +298,6 @@ def test_returns_empty_on_error(self) -> None: ) assert result == [] - -# --------------------------------------------------------------------------- -# get_table_details -# --------------------------------------------------------------------------- - - class TestGetTableDetails: def test_returns_columns(self) -> None: engine = _make_engine() @@ -388,11 +340,6 @@ def test_returns_none_on_error(self) -> None: assert result is None -# --------------------------------------------------------------------------- -# SQL quoting integration -# --------------------------------------------------------------------------- - - class TestStarRocksQuoting: def test_starrocks_uses_backtick_style(self) -> None: assert quote_sql_identifier("my_catalog", dialect="starrocks") == "`my_catalog`" @@ -406,11 +353,6 @@ def test_starrocks_uses_backtick_style(self) -> None: ) -# --------------------------------------------------------------------------- -# _resolve_auto -# --------------------------------------------------------------------------- - - class TestResolveAuto: def test_true_stays_true(self) -> None: engine = _make_engine() @@ -426,11 +368,6 @@ def test_auto_resolves_to_false_for_starrocks(self) -> None: assert engine._resolve_should_auto_discover("auto") is False -# --------------------------------------------------------------------------- -# System catalog / database constants -# --------------------------------------------------------------------------- - - class TestSystemConstants: def test_system_schemas_excluded(self) -> None: assert "information_schema" in _SYSTEM_SCHEMAS From 227f305d34725e9734f9a97be2749bbe863cac5d Mon Sep 17 00:00:00 2001 From: Christopher Sapinski Date: Tue, 24 Mar 2026 10:20:26 -0700 Subject: [PATCH 04/11] added generic tests --- tests/_sql/test_get_engines.py | 36 ++++++++++++++++++++++++++++++++++ tests/_sql/test_sql_quoting.py | 30 ++++++++++++++++++++++++++++ tests/_sql/test_starrocks.py | 14 ------------- 3 files changed, 66 insertions(+), 14 deletions(-) diff --git a/tests/_sql/test_get_engines.py b/tests/_sql/test_get_engines.py index 62c580d0e2a..f14005c5062 100644 --- a/tests/_sql/test_get_engines.py +++ b/tests/_sql/test_get_engines.py @@ -19,6 +19,7 @@ from marimo._sql.engines.ibis import IbisEngine from marimo._sql.engines.redshift import RedshiftEngine from marimo._sql.engines.sqlalchemy import SQLAlchemyEngine +from marimo._sql.engines.starrocks import StarRocksEngine from marimo._sql.get_engines import ( engine_to_data_source_connection, get_engines_from_variables, @@ -33,6 +34,7 @@ HAS_REDSHIFT = DependencyManager.redshift_connector.has() HAS_PYARROW = DependencyManager.pyarrow.has() HAS_IBIS = DependencyManager.ibis.has() +HAS_STARROCKS = DependencyManager.starrocks.has() @pytest.mark.skipif(not HAS_SQLALCHEMY, reason="SQLAlchemy not installed") @@ -85,6 +87,19 @@ def test_engine_to_data_source_connection() -> None: assert connection.name == "my_postgres" assert connection.display_name == "postgresql (my_postgres)" + # Test with StarRocks engine + mock_sr_engine = MagicMock() + mock_sr_engine.dialect.name = "starrocks" + sr_engine = StarRocksEngine(mock_sr_engine, engine_name=VariableName("my_sr")) + connection = engine_to_data_source_connection( + VariableName("my_sr"), sr_engine + ) + assert isinstance(connection, DataSourceConnection) + assert connection.source == "starrocks" + assert connection.dialect == "starrocks" + assert connection.name == "my_sr" + assert connection.display_name == "starrocks (my_sr)" + # Test with Ibis engine var_name = "my_ibis" backend_name = "duckdb" @@ -474,3 +489,24 @@ def test_variables_without_datasource_engine() -> None: variables = [("deferred_for_test", deferred_for_test)] engines = get_engines_from_variables(variables) assert not engines + + +@pytest.mark.skipif( + not (HAS_SQLALCHEMY and HAS_STARROCKS), + reason="SQLAlchemy and starrocks not installed", +) +def test_get_engines_starrocks() -> None: + import sqlalchemy as sa + import starrocks + + mock_engine = MagicMock(spec=sa.Engine) + mock_engine.dialect = MagicMock() + mock_engine.dialect.name = "starrocks" + variables: list[tuple[str, object]] = [("sr_engine", mock_engine)] + + engines = get_engines_from_variables(variables) + + assert len(engines) == 1 + var_name, engine = engines[0] + assert var_name == "sr_engine" + assert isinstance(engine, StarRocksEngine) diff --git a/tests/_sql/test_sql_quoting.py b/tests/_sql/test_sql_quoting.py index a5be8661b01..7e60ac28be7 100644 --- a/tests/_sql/test_sql_quoting.py +++ b/tests/_sql/test_sql_quoting.py @@ -46,6 +46,12 @@ class TestQuoteSqlIdentifier: ("table", "bigquery", "`table`"), ("my table", "bigquery", "`my table`"), ("has`backtick", "bigquery", "`has``backtick`"), + # StarRocks uses backtick style + ("table", "starrocks", "`table`"), + ("my table", "starrocks", "`my table`"), + ("nested.namespace", "starrocks", "`nested.namespace`"), + ("has`backtick", "starrocks", "`has``backtick`"), + ('has"quotes', "starrocks", '`has"quotes`'), # Unknown dialect returns unquoted ("table", "sqlite", "table"), ("my table", "unknown", "my table"), @@ -98,6 +104,24 @@ def test_clickhouse_roundtrip_safe(self, identifier: str) -> None: inner = quoted[1:-1] assert inner.replace("``", "`") == identifier + @pytest.mark.parametrize( + "identifier", + [ + "simple", + "with spaces", + "with.dots", + "with`backticks", + 'with"quotes', + ], + ) + def test_starrocks_roundtrip_safe(self, identifier: str) -> None: + """Verify that quoting an identifier produces valid StarRocks syntax.""" + quoted = quote_sql_identifier(identifier, dialect="starrocks") + assert quoted.startswith("`") + assert quoted.endswith("`") + inner = quoted[1:-1] + assert inner.replace("``", "`") == identifier + class TestQuoteQualifiedName: @pytest.mark.parametrize( @@ -127,6 +151,12 @@ class TestQuoteQualifiedName: "redshift", '"catalog"."schema"."table"', ), + # StarRocks 3-part name (catalog.database.table) + ( + ("iceberg_catalog", "tpch", "orders"), + "starrocks", + "`iceberg_catalog`.`tpch`.`orders`", + ), # Single part ( ("just_table",), diff --git a/tests/_sql/test_starrocks.py b/tests/_sql/test_starrocks.py index 35b1a0f7f6c..a3c2bad833f 100644 --- a/tests/_sql/test_starrocks.py +++ b/tests/_sql/test_starrocks.py @@ -15,20 +15,6 @@ from marimo._sql.sql_quoting import quote_sql_identifier -@pytest.fixture(autouse=True) -def _mock_sqlalchemy_if_missing(): - """Patch sys.modules with a lightweight sqlalchemy stub when the real - package is not installed, so mock-based tests can still run without it.""" - if "sqlalchemy" in sys.modules: - yield - return - - mock_sa = MagicMock() - mock_sa.text = MagicMock(side_effect=lambda q: q) - with patch.dict(sys.modules, {"sqlalchemy": mock_sa}): - yield - - def _make_mock_engine(dialect_name: str = "starrocks") -> MagicMock: """Return a mock SQLAlchemy Engine with the given dialect name.""" mock_engine = MagicMock() From f8310c033fcd228198e6bcd3a213113495d3d9ae Mon Sep 17 00:00:00 2001 From: Christopher Sapinski Date: Tue, 24 Mar 2026 10:35:39 -0700 Subject: [PATCH 05/11] resolve pycheck --- marimo/_sql/engines/starrocks.py | 26 ++++++-------------------- marimo/_sql/sql.py | 2 +- tests/_sql/test_get_engines.py | 6 ++++-- tests/_sql/test_starrocks.py | 32 +++++++++++++++++++------------- 4 files changed, 30 insertions(+), 36 deletions(-) diff --git a/marimo/_sql/engines/starrocks.py b/marimo/_sql/engines/starrocks.py index 9301e03b281..12523f86e43 100644 --- a/marimo/_sql/engines/starrocks.py +++ b/marimo/_sql/engines/starrocks.py @@ -1,7 +1,7 @@ # Copyright 2026 Marimo. All rights reserved. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import Any, Literal, Optional, Union from marimo import _loggers from marimo._data.models import Database, DataTable, DataTableColumn, Schema @@ -9,13 +9,9 @@ from marimo._sql.engines.sqlalchemy import SQLAlchemyEngine from marimo._sql.sql_quoting import quote_sql_identifier from marimo._sql.utils import sql_type_to_data_type -from marimo._types.ids import VariableName LOGGER = _loggers.marimo_logger() -if TYPE_CHECKING: - from sqlalchemy import Engine - # StarRocks databases (marimo Schemas) that are internal and not useful to surface. _SYSTEM_SCHEMAS = frozenset({"information_schema", "sys", "_statistics_"}) @@ -54,10 +50,6 @@ def is_compatible(var: Any) -> bool: return isinstance(var, Engine) and str(var.dialect.name) == "starrocks" - # ------------------------------------------------------------------ - # Default catalog / schema - # ------------------------------------------------------------------ - def get_default_database(self) -> Optional[str]: """Return the current StarRocks catalog via ``SELECT CATALOG()``. @@ -95,7 +87,9 @@ def get_databases( should_include_schemas = ( include_schemas if isinstance(include_schemas, bool) else True ) - should_include_tables = self._resolve_should_auto_discover(include_tables) + should_include_tables = self._resolve_should_auto_discover( + include_tables + ) should_include_details = self._resolve_should_auto_discover( include_table_details ) @@ -136,7 +130,7 @@ def get_tables_in_schema( """Return tables for *schema* inside *database* (a StarRocks catalog). Delegates to the inherited inspector path for the default catalog; - falls back to an ``information_schema`` query for external catalogs. + falls back to an ``DESC`` query for external catalogs. """ if database == self.default_database: return super().get_tables_in_schema( @@ -156,7 +150,7 @@ def get_table_details( """Return column metadata for a table. Delegates to the inherited inspector path for the default catalog; - falls back to an ``information_schema`` query for external catalogs. + falls back to an ``DESC`` query for external catalogs. """ if database_name == self.default_database: return super().get_table_details( @@ -170,10 +164,6 @@ def get_table_details( database_name=database_name, ) - # ------------------------------------------------------------------ - # Meta-schema filter (overrides SQLAlchemyEngine._get_meta_schemas) - # ------------------------------------------------------------------ - def _get_schemas( self, *, @@ -197,10 +187,6 @@ def _get_schemas( def _get_meta_schemas(self) -> list[str]: return list(_SYSTEM_SCHEMAS) - # ------------------------------------------------------------------ - # StarRocks-specific helpers - # ------------------------------------------------------------------ - def _list_catalogs(self) -> list[str]: """Return all catalog names via ``SHOW CATALOGS``. diff --git a/marimo/_sql/sql.py b/marimo/_sql/sql.py index 8603faaac90..9c5e12e68cf 100644 --- a/marimo/_sql/sql.py +++ b/marimo/_sql/sql.py @@ -76,7 +76,7 @@ def sql( break else: raise ValueError( - "Unsupported engine. Must be a SQLAlchemy, Ibis, Clickhouse, DuckDB, Redshift or DBAPI 2.0 compatible engine." + "Unsupported engine. Must be a SQLAlchemy, Ibis, Clickhouse, DuckDB, Redshift, StarRocks or DBAPI 2.0 compatible engine." ) try: diff --git a/tests/_sql/test_get_engines.py b/tests/_sql/test_get_engines.py index f14005c5062..6cedef6b583 100644 --- a/tests/_sql/test_get_engines.py +++ b/tests/_sql/test_get_engines.py @@ -90,7 +90,9 @@ def test_engine_to_data_source_connection() -> None: # Test with StarRocks engine mock_sr_engine = MagicMock() mock_sr_engine.dialect.name = "starrocks" - sr_engine = StarRocksEngine(mock_sr_engine, engine_name=VariableName("my_sr")) + sr_engine = StarRocksEngine( + mock_sr_engine, engine_name=VariableName("my_sr") + ) connection = engine_to_data_source_connection( VariableName("my_sr"), sr_engine ) @@ -497,7 +499,7 @@ def test_variables_without_datasource_engine() -> None: ) def test_get_engines_starrocks() -> None: import sqlalchemy as sa - import starrocks + import starrocks # noqa: F401 mock_engine = MagicMock(spec=sa.Engine) mock_engine.dialect = MagicMock() diff --git a/tests/_sql/test_starrocks.py b/tests/_sql/test_starrocks.py index a3c2bad833f..31f84bd5fb6 100644 --- a/tests/_sql/test_starrocks.py +++ b/tests/_sql/test_starrocks.py @@ -1,13 +1,11 @@ # Copyright 2026 Marimo. All rights reserved. from __future__ import annotations -import sys from typing import Any from unittest.mock import MagicMock, patch # noqa: F401 import pytest -from marimo._data.models import Database, DataTable, Schema from marimo._sql.engines.starrocks import ( _SYSTEM_SCHEMAS, StarRocksEngine, @@ -51,8 +49,10 @@ class TestIsCompatible: @pytest.mark.requires("sqlalchemy", "starrocks") def test_compatible_with_starrocks_dialect(self) -> None: import sqlalchemy as sa + import starrocks # noqa: F401 mock_engine = MagicMock(spec=sa.Engine) + mock_engine.dialect = MagicMock() mock_engine.dialect.name = "starrocks" assert StarRocksEngine.is_compatible(mock_engine) @@ -62,6 +62,7 @@ def test_not_compatible_with_other_dialects(self) -> None: for dialect in ("mysql", "postgresql", "sqlite", "clickhouse"): mock_engine = MagicMock(spec=sa.Engine) + mock_engine.dialect = MagicMock() mock_engine.dialect.name = dialect assert not StarRocksEngine.is_compatible(mock_engine) @@ -87,10 +88,6 @@ def test_get_default_database(self) -> None: engine = _make_engine() _mock_connection_ctx(engine, [[("default_catalog",)]]) assert engine.get_default_database() == "default_catalog" - # Verify it uses CATALOG() not CURRENT_CATALOG() - conn = engine._connection.connect().__enter__() - call_args = conn.execute.call_args_list - assert any("CATALOG()" in str(c) for c in call_args) def test_get_default_database_none_on_error(self) -> None: engine = _make_engine() @@ -151,11 +148,14 @@ def test_lists_databases_excluding_system(self) -> None: def test_returns_empty_on_error(self) -> None: engine = _make_engine() engine._connection.connect.side_effect = Exception("oops") - assert engine._get_external_schemas( - catalog="hive_catalog", - include_tables=False, - include_table_details=False, - ) == [] + assert ( + engine._get_external_schemas( + catalog="hive_catalog", + include_tables=False, + include_table_details=False, + ) + == [] + ) class TestGetDatabases: @@ -280,10 +280,13 @@ def test_returns_empty_on_error(self) -> None: engine = _make_engine() engine._connection.connect.side_effect = Exception("fail") result = engine.get_tables_in_schema( - schema="tpch", database="default_catalog", include_table_details=False + schema="tpch", + database="default_catalog", + include_table_details=False, ) assert result == [] + class TestGetTableDetails: def test_returns_columns(self) -> None: engine = _make_engine() @@ -328,7 +331,10 @@ def test_returns_none_on_error(self) -> None: class TestStarRocksQuoting: def test_starrocks_uses_backtick_style(self) -> None: - assert quote_sql_identifier("my_catalog", dialect="starrocks") == "`my_catalog`" + assert ( + quote_sql_identifier("my_catalog", dialect="starrocks") + == "`my_catalog`" + ) assert ( quote_sql_identifier("catalog`with`ticks", dialect="starrocks") == "`catalog``with``ticks`" From aeead7b7814bac331ce28dac0b547678ee53581b Mon Sep 17 00:00:00 2001 From: Christopher Sapinski Date: Tue, 24 Mar 2026 14:17:28 -0700 Subject: [PATCH 06/11] fix copilot review --- marimo/_sql/engines/starrocks.py | 2 +- marimo/_sql/sql_quoting.py | 2 +- tests/_sql/test_starrocks.py | 5 ++++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/marimo/_sql/engines/starrocks.py b/marimo/_sql/engines/starrocks.py index 12523f86e43..98a1a015534 100644 --- a/marimo/_sql/engines/starrocks.py +++ b/marimo/_sql/engines/starrocks.py @@ -130,7 +130,7 @@ def get_tables_in_schema( """Return tables for *schema* inside *database* (a StarRocks catalog). Delegates to the inherited inspector path for the default catalog; - falls back to an ``DESC`` query for external catalogs. + falls back to a ``SHOW`` query for external catalogs. """ if database == self.default_database: return super().get_tables_in_schema( diff --git a/marimo/_sql/sql_quoting.py b/marimo/_sql/sql_quoting.py index d392f09cc03..b9d952d0bde 100644 --- a/marimo/_sql/sql_quoting.py +++ b/marimo/_sql/sql_quoting.py @@ -12,7 +12,7 @@ def quote_sql_identifier(identifier: str, *, dialect: str = "duckdb") -> str: identifier: The raw identifier string (database, schema, or table name). dialect: The SQL dialect. Double-quote style: "duckdb", "redshift", "postgresql"/"postgres". - Backtick style: "clickhouse", "mysql", "bigquery". + Backtick style: "clickhouse", "mysql", "bigquery", "starrocks". Unknown dialects return the identifier unquoted. Returns: diff --git a/tests/_sql/test_starrocks.py b/tests/_sql/test_starrocks.py index 31f84bd5fb6..d90c12d4f50 100644 --- a/tests/_sql/test_starrocks.py +++ b/tests/_sql/test_starrocks.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Any -from unittest.mock import MagicMock, patch # noqa: F401 +from unittest.mock import MagicMock import pytest @@ -12,6 +12,9 @@ ) from marimo._sql.sql_quoting import quote_sql_identifier +# Skip the entire module when sqlalchemy is not installed. +pytestmark = pytest.mark.requires("sqlalchemy") + def _make_mock_engine(dialect_name: str = "starrocks") -> MagicMock: """Return a mock SQLAlchemy Engine with the given dialect name.""" From a4b11d2cda2f38d277b19788ef36ff216cb03e19 Mon Sep 17 00:00:00 2001 From: Christopher Sapinski Date: Mon, 30 Mar 2026 10:44:13 -0700 Subject: [PATCH 07/11] merge main branch --- marimo/_sql/engines/starrocks.py | 97 +++++++++----------------- tests/_sql/test_starrocks.py | 114 +++++++++++-------------------- 2 files changed, 73 insertions(+), 138 deletions(-) diff --git a/marimo/_sql/engines/starrocks.py b/marimo/_sql/engines/starrocks.py index 98a1a015534..ecf76233755 100644 --- a/marimo/_sql/engines/starrocks.py +++ b/marimo/_sql/engines/starrocks.py @@ -74,55 +74,42 @@ def get_databases( include_tables: Union[bool, Literal["auto"]], include_table_details: Union[bool, Literal["auto"]], ) -> list[Database]: - """Return all StarRocks catalogs, each containing its databases as schemas. + """Return all StarRocks catalogs as databases.""" + return [ + Database( + name=catalog, + dialect=self.dialect, + schemas=[], + engine=self._engine_name, + ) + for catalog in self._list_catalogs() + ] - Uses the inherited inspector path for the default catalog and explicit - SQL for external catalogs. + def get_schemas( + self, + *, + database: Optional[str], + include_tables: bool, + include_table_details: bool, + ) -> list[Schema]: + """Return schemas for a catalog, lazily fetched on demand. - ``"auto"`` resolution: - - ``include_schemas`` → ``True`` (always show databases) - - ``include_tables`` → ``False`` (StarRocks catalogs can be large) - - ``include_table_details``→ ``False`` + Routes the default catalog through the inherited inspector path; + external catalogs use ``SHOW DATABASES IN ``. """ - should_include_schemas = ( - include_schemas if isinstance(include_schemas, bool) else True - ) - should_include_tables = self._resolve_should_auto_discover( - include_tables - ) - should_include_details = self._resolve_should_auto_discover( - include_table_details - ) - - databases: list[Database] = [] - for catalog in self._list_catalogs(): - if should_include_schemas: - if catalog == self.default_database: - # Inspector-based path (inherited from SQLAlchemyEngine). - schemas = self._get_schemas( - database=catalog, - include_tables=should_include_tables, - include_table_details=should_include_details, - ) - else: - # SQL fallback for external catalogs. - schemas = self._get_external_schemas( - catalog=catalog, - include_tables=should_include_tables, - include_table_details=should_include_details, - ) - else: - schemas = [] - - databases.append( - Database( - name=catalog, - dialect=self.dialect, - schemas=schemas, - engine=self._engine_name, - ) + if database == self.default_database: + return super().get_schemas( + database=database, + include_tables=include_tables, + include_table_details=include_table_details, + ) + if database is not None: + return self._get_external_schemas( + catalog=database, + include_tables=include_tables, + include_table_details=include_table_details, ) - return databases + return [] def get_tables_in_schema( self, *, schema: str, database: str, include_table_details: bool @@ -164,26 +151,6 @@ def get_table_details( database_name=database_name, ) - def _get_schemas( - self, - *, - database: Optional[str], - include_tables: bool, - include_table_details: bool, - ) -> list[Schema]: - """Filter system schemas out of the result entirely. - - The parent implementation keeps meta-schemas in the list but skips - table discovery for them. For StarRocks we want them hidden from the - sidebar completely. - """ - schemas = super()._get_schemas( - database=database, - include_tables=include_tables, - include_table_details=include_table_details, - ) - return [s for s in schemas if s.name.lower() not in _SYSTEM_SCHEMAS] - def _get_meta_schemas(self) -> list[str]: return list(_SYSTEM_SCHEMAS) diff --git a/tests/_sql/test_starrocks.py b/tests/_sql/test_starrocks.py index d90c12d4f50..8e413bdebe5 100644 --- a/tests/_sql/test_starrocks.py +++ b/tests/_sql/test_starrocks.py @@ -162,26 +162,11 @@ def test_returns_empty_on_error(self) -> None: class TestGetDatabases: - def test_returns_catalog_as_database(self) -> None: + def test_returns_catalogs_with_empty_schemas(self) -> None: + """get_databases() lists catalogs only; schemas are fetched lazily.""" engine = _make_engine() - # Call 1: SHOW CATALOGS - # Call 2: SHOW DATABASES IN `default_catalog` catalogs_rows = [("default_catalog",), ("hive_catalog",)] - db_rows_default = [("tpch",), ("analytics",)] - db_rows_hive = [("lake",)] - - conn_ctx = MagicMock() - conn = MagicMock() - conn_ctx.__enter__ = MagicMock(return_value=conn) - conn_ctx.__exit__ = MagicMock(return_value=False) - engine._connection.connect = MagicMock(return_value=conn_ctx) - - results = [] - for rows in [catalogs_rows, db_rows_default, db_rows_hive]: - r = MagicMock() - r.fetchall = MagicMock(return_value=rows) - results.append(r) - conn.execute = MagicMock(side_effect=results) + _mock_connection_ctx(engine, [catalogs_rows]) databases = engine.get_databases( include_schemas=True, @@ -192,62 +177,60 @@ def test_returns_catalog_as_database(self) -> None: assert len(databases) == 2 assert databases[0].name == "default_catalog" assert databases[1].name == "hive_catalog" - assert [s.name for s in databases[0].schemas] == ["tpch", "analytics"] - assert [s.name for s in databases[1].schemas] == ["lake"] + # Schemas are always empty — lazy loading handles them for db in databases: + assert db.schemas == [] assert db.dialect == "starrocks" assert db.engine == "sr" - def test_no_schemas_when_include_schemas_false(self) -> None: + def test_returns_empty_on_error(self) -> None: engine = _make_engine() - conn_ctx = MagicMock() - conn = MagicMock() - conn_ctx.__enter__ = MagicMock(return_value=conn) - conn_ctx.__exit__ = MagicMock(return_value=False) - engine._connection.connect = MagicMock(return_value=conn_ctx) - - catalogs_result = MagicMock() - catalogs_result.fetchall = MagicMock( - return_value=[("default_catalog",)] - ) - conn.execute = MagicMock(return_value=catalogs_result) - + engine._connection.connect.side_effect = Exception("oops") databases = engine.get_databases( include_schemas=False, include_tables=False, include_table_details=False, ) + assert databases == [] - assert len(databases) == 1 - assert databases[0].name == "default_catalog" - assert databases[0].schemas == [] - def test_auto_includes_schemas_excludes_tables(self) -> None: - """'auto' should resolve to include_schemas=True, include_tables=False.""" +class TestGetSchemas: + def test_external_catalog_returns_schemas(self) -> None: + """get_schemas() for an external catalog uses SHOW DATABASES.""" engine = _make_engine() - conn_ctx = MagicMock() - conn = MagicMock() - conn_ctx.__enter__ = MagicMock(return_value=conn) - conn_ctx.__exit__ = MagicMock(return_value=False) - engine._connection.connect = MagicMock(return_value=conn_ctx) - - # SHOW CATALOGS → 1 catalog; SHOW DATABASES → 1 db - r1 = MagicMock() - r1.fetchall = MagicMock(return_value=[("default_catalog",)]) - r2 = MagicMock() - r2.fetchall = MagicMock(return_value=[("tpch",)]) - conn.execute = MagicMock(side_effect=[r1, r2]) + rows = [ + ("tpch",), + ("analytics",), + ("information_schema",), # excluded + ("sys",), # excluded + ] + _mock_connection_ctx(engine, [rows]) - databases = engine.get_databases( - include_schemas="auto", - include_tables="auto", - include_table_details="auto", + schemas = engine.get_schemas( + database="hive_catalog", + include_tables=False, + include_table_details=False, + ) + assert [s.name for s in schemas] == ["tpch", "analytics"] + + def test_returns_empty_for_none_database(self) -> None: + engine = _make_engine() + schemas = engine.get_schemas( + database=None, + include_tables=False, + include_table_details=False, ) + assert schemas == [] - assert len(databases) == 1 - assert databases[0].schemas[0].name == "tpch" - # Tables should NOT be fetched (auto → False for tables) - assert databases[0].schemas[0].tables == [] + def test_returns_empty_on_error(self) -> None: + engine = _make_engine() + engine._connection.connect.side_effect = Exception("oops") + schemas = engine.get_schemas( + database="hive_catalog", + include_tables=False, + include_table_details=False, + ) + assert schemas == [] class TestGetTablesInSchema: @@ -348,21 +331,6 @@ def test_starrocks_uses_backtick_style(self) -> None: ) -class TestResolveAuto: - def test_true_stays_true(self) -> None: - engine = _make_engine() - assert engine._resolve_should_auto_discover(True) is True - - def test_false_stays_false(self) -> None: - engine = _make_engine() - assert engine._resolve_should_auto_discover(False) is False - - def test_auto_resolves_to_false_for_starrocks(self) -> None: - # StarRocks is not in CHEAP_DISCOVERY_DATABASES, so "auto" → False. - engine = _make_engine() - assert engine._resolve_should_auto_discover("auto") is False - - class TestSystemConstants: def test_system_schemas_excluded(self) -> None: assert "information_schema" in _SYSTEM_SCHEMAS From fe0226c1b85b55e737633abd92b4b303adb27701 Mon Sep 17 00:00:00 2001 From: Christopher Sapinski Date: Mon, 30 Mar 2026 11:04:48 -0700 Subject: [PATCH 08/11] fix py-check --- marimo/_sql/engines/starrocks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/marimo/_sql/engines/starrocks.py b/marimo/_sql/engines/starrocks.py index ecf76233755..59c18d09ca2 100644 --- a/marimo/_sql/engines/starrocks.py +++ b/marimo/_sql/engines/starrocks.py @@ -75,6 +75,7 @@ def get_databases( include_table_details: Union[bool, Literal["auto"]], ) -> list[Database]: """Return all StarRocks catalogs as databases.""" + _, _, _ = include_schemas, include_tables, include_table_details return [ Database( name=catalog, From 9de8ec37194a14810f018c88a0155e67ca9dcaa6 Mon Sep 17 00:00:00 2001 From: Christopher Sapinski Date: Wed, 1 Apr 2026 12:21:40 -0700 Subject: [PATCH 09/11] merge main and add changes --- marimo/_sql/engines/sqlalchemy.py | 480 +++++++++++++++++++++++------- marimo/_sql/engines/starrocks.py | 315 -------------------- marimo/_sql/get_engines.py | 2 - pyproject.toml | 13 +- tests/_sql/test_get_engines.py | 37 --- tests/_sql/test_starrocks.py | 338 --------------------- 6 files changed, 380 insertions(+), 805 deletions(-) delete mode 100644 marimo/_sql/engines/starrocks.py delete mode 100644 tests/_sql/test_starrocks.py diff --git a/marimo/_sql/engines/sqlalchemy.py b/marimo/_sql/engines/sqlalchemy.py index dd2163e00c3..5cc1216db74 100644 --- a/marimo/_sql/engines/sqlalchemy.py +++ b/marimo/_sql/engines/sqlalchemy.py @@ -1,7 +1,19 @@ # Copyright 2026 Marimo. All rights reserved. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +import functools +import re +from contextlib import contextmanager +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Literal, + Optional, + ParamSpec, + TypeVar, + Union, +) from marimo import _loggers from marimo._data.models import ( @@ -24,12 +36,68 @@ LOGGER = _loggers.marimo_logger() if TYPE_CHECKING: + from collections.abc import Iterator + import pandas as pd import polars as pl from sqlalchemy import Engine, Inspector from sqlalchemy.engine.cursor import CursorResult + from sqlalchemy.engine.interfaces import ReflectedColumn, ReflectedIndex from sqlalchemy.sql.type_api import TypeEngine +# Quote if the identifier contains anything other than letters, digits, underscores, or dollar signs. +_SNOWFLAKE_NEEDS_QUOTING_RE = re.compile(r"[^A-Za-z0-9_$]") + + +# ------------------------------------------------------------------ # +# Decorators # +# ------------------------------------------------------------------ # + + +T = TypeVar("T") +P = ParamSpec("P") +F = TypeVar("F") + + +def safe_execute( + *, + fallback: F, + message: str = "Operation failed", + log_level: Literal["debug", "info", "warning", "error"] = "warning", + silent_exceptions: tuple[type[BaseException], ...] = (), +) -> Callable[[Callable[P, T]], Callable[P, T | F]]: + """Catch exceptions, log them, and return a fallback value. + + Args: + fallback: Value returned when the wrapped function raises. + message: Message written to the logger on failure. + log_level: Logger level – must be one of + ``'debug'``, ``'info'``, ``'warning'``, or ``'error'``. + silent_exceptions: Exception types that should return *fallback* + without any logging. Useful for expected control-flow + exceptions like ``NotImplementedError``. + """ + + def decorator(func: Callable[P, T]) -> Callable[P, T | F]: + @functools.wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | F: + try: + return func(*args, **kwargs) + except silent_exceptions: + return fallback + except Exception: + getattr(LOGGER, log_level)(message, exc_info=True) + return fallback + + return wrapper + + return decorator + + +# ------------------------------------------------------------------ # +# SQLAlchemyEngine # +# ------------------------------------------------------------------ # + class SQLAlchemyEngine(SQLConnection["Engine"]): """SQLAlchemy engine.""" @@ -52,6 +120,58 @@ def __init__( self.default_database = self.get_default_database() self.default_schema = self.get_default_schema() + def _quote_identifier(self, identifier: str) -> str: + """Quote an identifier based on the SQL dialect's quoting rules.""" + dialect_quoting: dict[str, tuple[re.Pattern[str], str, str]] = { + "snowflake": (_SNOWFLAKE_NEEDS_QUOTING_RE, '"', '"'), + "starrocks": (_SNOWFLAKE_NEEDS_QUOTING_RE, '`', '`'), + } + + if self.dialect not in dialect_quoting: + return identifier + + pattern, open_quote, close_quote = dialect_quoting[self.dialect] + if pattern.search(identifier) or identifier != identifier.lower(): + escaped = identifier.replace( + close_quote, close_quote + close_quote + ) + return f"{open_quote}{escaped}{close_quote}" + return identifier + + @contextmanager + def _get_inspector(self, database: str) -> Iterator[Optional[Inspector]]: + """Yield an appropriate SQLAlchemy Inspector for the given database. + + For dialects that require a USE DATABASE command (e.g. Snowflake), + this opens a connection, executes the command, and yields an + inspector bound to that connection. + + For all other dialects, it yields ``self.inspector`` (which may + be ``None``). + + Usage:: + + with self._get_inspector(database) as inspector: + if inspector is None: + return [] + return inspector.get_schema_names() + """ + + from sqlalchemy import inspect, text + + _use_database_dialect_command: dict[str, str] = { + "snowflake": f"USE DATABASE {self._quote_identifier(database)}", + "starrocks": f"SET CATALOG {self._quote_identifier(database)}", + } + dialect_command = _use_database_dialect_command.get(self.dialect) + + if dialect_command is not None: + with self._connection.connect() as connection: + connection.execute(text(dialect_command)) + yield inspect(connection) + else: + yield self.inspector + @property def source(self) -> str: return "sqlalchemy" @@ -108,7 +228,7 @@ def is_compatible(var: Any) -> bool: @property def inference_config(self) -> InferenceConfig: return InferenceConfig( - auto_discover_schemas=True, + auto_discover_schemas="auto", auto_discover_tables="auto", auto_discover_columns=False, ) @@ -139,6 +259,7 @@ def get_default_database(self) -> Optional[str]: "postgresql": "SELECT current_database()", "mssql": "SELECT DB_NAME()", "timeplus": "SELECT current_database()", + "starrocks": "SELECT CATALOG()", } # Try to get the database name by querying the database directly @@ -166,23 +287,102 @@ def get_default_database(self) -> Optional[str]: return database_name or "" + @safe_execute( + fallback=None, + message="Failed to get default schema name", + log_level="warning", + ) def get_default_schema(self) -> Optional[str]: """Get the default schema name""" if self.inspector is None: return None - try: - default_schema_name = self.inspector.default_schema_name - # https://github.com/marimo-team/marimo/issues/6436. - # Upstream bug where default schema name is not a string. - if default_schema_name is None or not isinstance( - default_schema_name, str - ): - return None - return str(default_schema_name) - except Exception: - LOGGER.warning("Failed to get default schema name", exc_info=True) + default_schema_name = self.inspector.default_schema_name + # https://github.com/marimo-team/marimo/issues/6436. + # Upstream bug where default schema name is not a string. + if default_schema_name is None or not isinstance( + default_schema_name, str + ): return None + return str(default_schema_name) + + # -------------------------------------------------------------- # + # Databases resolution # + # -------------------------------------------------------------- # + + # Get database names for SNOWFLAKE + def _get_snowflake_database_names(self) -> list[str]: + """Get database names for Snowflake via 'SHOW DATABASES'. + + If the default database exists in the results, return only that. + Otherwise, return all discovered databases. + + Unquoted identifiers are normalized to lowercase for consistency. + Identifiers that need quoting are preserved as-is. + """ + from sqlalchemy import text + + with self._connection.connect() as connection: + result = connection.execute(text("SHOW DATABASES")) + columns = list(result.keys()) + + try: + name_col_index = columns.index("name") + except ValueError as err: + raise RuntimeError( + "Unexpected SHOW DATABASES result: " + f"'name' column not found in {columns}" + ) from err + + database_names: list[str] = [] + for row in result.fetchall(): + raw_name = str(row[name_col_index]) + if ( + _SNOWFLAKE_NEEDS_QUOTING_RE.search(raw_name) + or raw_name != raw_name.upper() + ): + database_names.append(raw_name) + else: + database_names.append(raw_name.lower()) + + if self.default_database: + default_lower = self.default_database.lower() + for db in database_names: + if db.lower() == default_lower: + return [db] + + return database_names + + def _get_starrocks_database_names(self) -> list[str]: + """Get catalog names for StarRocks via 'SHOW CATALOGS'. + + StarRocks uses a three-level hierarchy (Catalog → Database → Table) + which maps to marimo's (Database → Schema → Table). + """ + from sqlalchemy import text + + with self._connection.connect() as connection: + result = connection.execute(text("SHOW CATALOGS")) + return [str(row[0]) for row in result.fetchall()] + + @safe_execute( + fallback=[], + message="Failed to get database names", + log_level="warning", + ) + def _get_database_names(self) -> list[str]: + """Get database names using dialect-specific queries. + + Returns a single-element list with the default database when + the dialect has no dedicated discovery mechanism. + """ + dialect = self.dialect.lower() + if dialect == "snowflake": + return self._get_snowflake_database_names() + if dialect == "starrocks": + return self._get_starrocks_database_names() + + return [self.default_database] if self.default_database else [] def get_databases( self, @@ -194,44 +394,65 @@ def get_databases( """Get all databases from the engine. Args: - include_schemas: Whether to include schema information. If False, databases will have empty schemas. - include_tables: Whether to include table information within schemas. If False, schemas will have empty tables. - include_table_details: Whether to include each table's detailed information. If False, tables will have empty columns, PK, indexes. + include_schemas: Include schema information per database. + include_tables: Include table information within each schema. + include_table_details: Include columns, PKs, and indexes + for each table. Returns: List of Database objects representing the database structure. - Note: This operation can be performance intensive when fetching full metadata. + Note: + This operation can be performance-intensive when fetching + full metadata. """ + should_include_schemas = self._resolve_should_auto_discover( + include_schemas + ) + should_include_tables = self._resolve_should_auto_discover( + include_tables + ) + should_include_details = self._resolve_should_auto_discover( + include_table_details + ) + databases: list[Database] = [] - if self.default_database is None: - return databases - database_name = self.default_database - - schemas = ( - self.get_schemas( - database=database_name, - include_tables=self._resolve_should_auto_discover( - include_tables - ), - include_table_details=self._resolve_should_auto_discover( - include_table_details - ), + for database_name in self._get_database_names(): + schemas = ( + self.get_schemas( + database=database_name, + include_tables=should_include_tables, + include_table_details=should_include_details, + ) + if should_include_schemas + else [] ) - if self._resolve_should_auto_discover(include_schemas) - else [] - ) - databases.append( - Database( - name=database_name, - dialect=self.dialect, - schemas=schemas, - engine=self._engine_name, + databases.append( + Database( + name=database_name, + dialect=self.dialect, + schemas=schemas, + engine=self._engine_name, + ) ) - ) + return databases + # -------------------------------------------------------------- # + # Schemas resolution # + # -------------------------------------------------------------- # + + @safe_execute( + fallback=[], message="Failed to get schema names", log_level="warning" + ) + def _get_schema_names(self, database: str) -> list[str]: + + with self._get_inspector(database) as inspector: + if inspector is None: + return [] + return inspector.get_schema_names() + def get_schemas( self, *, @@ -241,13 +462,11 @@ def get_schemas( ) -> list[Schema]: """Get all schemas and optionally their tables. Keys are schema names.""" - if self.inspector is None: - return [] - try: - schema_names = self.inspector.get_schema_names() - except Exception: - LOGGER.warning("Failed to get schema names", exc_info=True) - return [] + if database is None: + schema_names: list[str] = [] + else: + schema_names = self._get_schema_names(database) + schemas: list[Schema] = [] for schema in schema_names: @@ -267,22 +486,38 @@ def _get_meta_schemas(self) -> list[str]: dialect = self.dialect.lower() if dialect == "postgresql": return ["information_schema", "pg_catalog"] + if dialect == "starrocks": + return ["information_schema", "sys", "_statistics_"] return ["information_schema"] + # -------------------------------------------------------------- # + # Tables resolution # + # -------------------------------------------------------------- # + + @safe_execute( + fallback=([], []), + message="Failed to get tables in schema", + log_level="warning", + ) + def _get_table_names( + self, schema: str, database: str + ) -> tuple[list[str], list[str]]: + + with self._get_inspector(database) as inspector: + if inspector is None: + return [], [] + return inspector.get_table_names( + schema=schema + ), inspector.get_view_names(schema=schema) + def get_tables_in_schema( self, *, schema: str, database: str, include_table_details: bool ) -> list[DataTable]: """Return all tables in a schema.""" - _ = database - if self.inspector is None: - return [] - try: - table_names = self.inspector.get_table_names(schema=schema) - view_names = self.inspector.get_view_names(schema=schema) - except Exception: - LOGGER.warning("Failed to get tables in schema", exc_info=True) - return [] + table_names, view_names = self._get_table_names( + schema=schema, database=database + ) tables: list[tuple[DataTableType, str]] = [] for name in table_names: @@ -319,48 +554,77 @@ def get_tables_in_schema( return data_tables + # -------------------------------------------------------------- # + # Table Details resolution # + # -------------------------------------------------------------- # + + @safe_execute( + fallback=None, + message="Failed to get table details", + log_level="warning", + ) + def _get_columns( + self, table_name: str, schema: str, database: str + ) -> Optional[list[ReflectedColumn]]: + + with self._get_inspector(database) as inspector: + if inspector is None: + return None + return inspector.get_columns(table_name, schema=schema) + + @safe_execute(fallback=[], message="Failed to get primary keys") + def _fetch_primary_keys( + self, table_name: str, schema: str, database: str + ) -> list[str]: + + with self._get_inspector(database) as inspector: + if inspector is None: + return [] + return inspector.get_pk_constraint(table_name, schema=schema).get( + "constrained_columns", [] + ) + + @safe_execute(fallback=[], message="Failed to get indexes") + def _fetch_indexes( + self, table_name: str, schema: str, database: str + ) -> list[str]: + + with self._get_inspector(database) as inspector: + if inspector is None: + return [] + indexes = inspector.get_indexes(table_name, schema=schema) + return self._extract_index_columns(indexes) + + @staticmethod + def _extract_index_columns(indexes: list[ReflectedIndex]) -> list[str]: + """Extract and deduplicate column names from a list of index definitions.""" + index_columns: list[str] = [] + seen: set[str] = set() + for index in indexes: + if index_cols := index.get("column_names"): + for col in index_cols: + if col is not None and col not in seen: + seen.add(col) + index_columns.append(col) + return index_columns + def get_table_details( self, *, table_name: str, schema_name: str, database_name: str ) -> Optional[DataTable]: """Get a single table from the engine.""" - _ = database_name - if self.inspector is None: - return None - try: - columns = self.inspector.get_columns( - table_name, schema=schema_name - ) - except Exception: - LOGGER.warning( - f"Failed to get table {table_name} in schema {schema_name}", - exc_info=True, - ) + columns = self._get_columns( + table_name, schema=schema_name, database=database_name + ) + if columns is None: return None - primary_keys: list[str] = [] - index_list: list[str] = [] - - try: - primary_keys = self.inspector.get_pk_constraint( - table_name, schema=schema_name - )["constrained_columns"] - except Exception: - pass - - # TODO: Handle multi column PK and indexes - try: - indexes = self.inspector.get_indexes( - table_name, schema=schema_name - ) - for index in indexes: - if index_cols := index["column_names"]: - index_list.extend( - col for col in index_cols if col is not None - ) - except Exception: - LOGGER.warning("Failed to get indexes", exc_info=True) - pass + primary_keys = self._fetch_primary_keys( + table_name, schema_name, database_name + ) + index_list = self._fetch_indexes( + table_name, schema_name, database_name + ) cols: list[DataTableColumn] = [] for col in columns: @@ -393,29 +657,29 @@ def get_table_details( indexes=index_list, ) + @safe_execute( + fallback=None, + message="Failed to get column type", + log_level="warning", + silent_exceptions=(NotImplementedError,), + ) def _get_python_type( self, engine_type: TypeEngine[Any] ) -> DataType | None: - try: - col_type = engine_type.python_type - return sql_type_to_data_type(str(col_type)) - except NotImplementedError: - return None - except Exception: - LOGGER.debug("Failed to get python type", exc_info=True) - return None - + col_type = engine_type.python_type + return sql_type_to_data_type(str(col_type)) + + @safe_execute( + fallback=None, + message="Failed to get generic type", + log_level="debug", + silent_exceptions=(NotImplementedError,), + ) def _get_generic_type( self, engine_type: TypeEngine[Any] ) -> DataType | None: - try: - col_type = engine_type.as_generic() - return sql_type_to_data_type(str(col_type)) - except NotImplementedError: - return None - except Exception: - LOGGER.debug("Failed to get generic type", exc_info=True) - return None + col_type = engine_type.as_generic() + return sql_type_to_data_type(str(col_type)) def _resolve_should_auto_discover( self, diff --git a/marimo/_sql/engines/starrocks.py b/marimo/_sql/engines/starrocks.py deleted file mode 100644 index 59c18d09ca2..00000000000 --- a/marimo/_sql/engines/starrocks.py +++ /dev/null @@ -1,315 +0,0 @@ -# Copyright 2026 Marimo. All rights reserved. -from __future__ import annotations - -from typing import Any, Literal, Optional, Union - -from marimo import _loggers -from marimo._data.models import Database, DataTable, DataTableColumn, Schema -from marimo._dependencies.dependencies import DependencyManager -from marimo._sql.engines.sqlalchemy import SQLAlchemyEngine -from marimo._sql.sql_quoting import quote_sql_identifier -from marimo._sql.utils import sql_type_to_data_type - -LOGGER = _loggers.marimo_logger() - -# StarRocks databases (marimo Schemas) that are internal and not useful to surface. -_SYSTEM_SCHEMAS = frozenset({"information_schema", "sys", "_statistics_"}) - - -def _quote(name: str) -> str: - return quote_sql_identifier(name, dialect="starrocks") - - -class StarRocksEngine(SQLAlchemyEngine): - """StarRocks SQL engine with multi-catalog support. - - Extends :class:`SQLAlchemyEngine`, inheriting the SQLAlchemy inspector - pattern for the connected (default) catalog. External catalogs fall back - to explicit SQL because the inspector is bound to a single catalog. - - StarRocks uses a three-level hierarchy: Catalog → Database → Table. - This maps to marimo's Database → Schema → Table model: - - - marimo ``Database`` ↔ StarRocks Catalog - - marimo ``Schema`` ↔ StarRocks Database - - marimo ``DataTable`` ↔ StarRocks Table - """ - - @property - def source(self) -> str: - return "starrocks" - - @staticmethod - def is_compatible(var: Any) -> bool: - if not DependencyManager.sqlalchemy.imported(): - return False - if not DependencyManager.starrocks.imported(): - return False - - from sqlalchemy.engine import Engine - - return isinstance(var, Engine) and str(var.dialect.name) == "starrocks" - - def get_default_database(self) -> Optional[str]: - """Return the current StarRocks catalog via ``SELECT CATALOG()``. - - Overrides the parent which reads from the SQLAlchemy connection URL, - because StarRocks exposes catalogs rather than a single database. - """ - try: - from sqlalchemy import text - - with self._connection.connect() as conn: - row = conn.execute(text("SELECT CATALOG()")).fetchone() - if row is not None and row[0] is not None: - return str(row[0]) - except Exception: - LOGGER.warning("Failed to get current catalog", exc_info=True) - return None - - def get_databases( - self, - *, - include_schemas: Union[bool, Literal["auto"]], - include_tables: Union[bool, Literal["auto"]], - include_table_details: Union[bool, Literal["auto"]], - ) -> list[Database]: - """Return all StarRocks catalogs as databases.""" - _, _, _ = include_schemas, include_tables, include_table_details - return [ - Database( - name=catalog, - dialect=self.dialect, - schemas=[], - engine=self._engine_name, - ) - for catalog in self._list_catalogs() - ] - - def get_schemas( - self, - *, - database: Optional[str], - include_tables: bool, - include_table_details: bool, - ) -> list[Schema]: - """Return schemas for a catalog, lazily fetched on demand. - - Routes the default catalog through the inherited inspector path; - external catalogs use ``SHOW DATABASES IN ``. - """ - if database == self.default_database: - return super().get_schemas( - database=database, - include_tables=include_tables, - include_table_details=include_table_details, - ) - if database is not None: - return self._get_external_schemas( - catalog=database, - include_tables=include_tables, - include_table_details=include_table_details, - ) - return [] - - def get_tables_in_schema( - self, *, schema: str, database: str, include_table_details: bool - ) -> list[DataTable]: - """Return tables for *schema* inside *database* (a StarRocks catalog). - - Delegates to the inherited inspector path for the default catalog; - falls back to a ``SHOW`` query for external catalogs. - """ - if database == self.default_database: - return super().get_tables_in_schema( - schema=schema, - database=database, - include_table_details=include_table_details, - ) - return self._get_external_tables( - schema=schema, - database=database, - include_table_details=include_table_details, - ) - - def get_table_details( - self, *, table_name: str, schema_name: str, database_name: str - ) -> Optional[DataTable]: - """Return column metadata for a table. - - Delegates to the inherited inspector path for the default catalog; - falls back to an ``DESC`` query for external catalogs. - """ - if database_name == self.default_database: - return super().get_table_details( - table_name=table_name, - schema_name=schema_name, - database_name=database_name, - ) - return self._get_external_table_details( - table_name=table_name, - schema_name=schema_name, - database_name=database_name, - ) - - def _get_meta_schemas(self) -> list[str]: - return list(_SYSTEM_SCHEMAS) - - def _list_catalogs(self) -> list[str]: - """Return all catalog names via ``SHOW CATALOGS``. - - There is no SQLAlchemy inspector equivalent for catalog enumeration. - """ - try: - from sqlalchemy import text - - with self._connection.connect() as conn: - rows = conn.execute(text("SHOW CATALOGS")).fetchall() - return [str(row[0]) for row in rows] - except Exception: - LOGGER.warning("Failed to list catalogs", exc_info=True) - return [] - - def _get_external_schemas( - self, - *, - catalog: str, - include_tables: bool, - include_table_details: bool, - ) -> list[Schema]: - """List databases in an external catalog via ``SHOW DATABASES``.""" - try: - from sqlalchemy import text - - with self._connection.connect() as conn: - rows = conn.execute( - text(f"SHOW DATABASES IN {_quote(catalog)}") - ).fetchall() - db_names = [ - str(row[0]) - for row in rows - if str(row[0]).lower() not in _SYSTEM_SCHEMAS - ] - except Exception: - LOGGER.warning( - "Failed to list databases in catalog %r", - catalog, - exc_info=True, - ) - return [] - - schemas: list[Schema] = [] - for db_name in db_names: - tables: list[DataTable] = [] - if include_tables: - tables = self._get_external_tables( - schema=db_name, - database=catalog, - include_table_details=include_table_details, - ) - schemas.append(Schema(name=db_name, tables=tables)) - return schemas - - def _get_external_tables( - self, *, schema: str, database: str, include_table_details: bool - ) -> list[DataTable]: - """List tables in an external catalog via ``SHOW FULL TABLES``.""" - try: - from sqlalchemy import text - - qualified = f"{_quote(database)}.{_quote(schema)}" - with self._connection.connect() as conn: - rows = conn.execute( - text(f"SHOW FULL TABLES FROM {qualified}") - ).fetchall() - except Exception: - LOGGER.warning( - "Failed to get tables in %r.%r", - database, - schema, - exc_info=True, - ) - return [] - - tables: list[DataTable] = [] - for row in rows: - table_name = str(row[0]) - raw_type = str(row[1]).upper() if row[1] else "BASE TABLE" - table_type: Literal["table", "view"] = ( - "view" if "VIEW" in raw_type else "table" - ) - - if not include_table_details: - tables.append( - DataTable( - source_type="connection", - source=self.dialect, - name=table_name, - num_rows=None, - num_columns=None, - variable_name=None, - engine=self._engine_name, - type=table_type, - columns=[], - primary_keys=[], - indexes=[], - ) - ) - else: - table = self._get_external_table_details( - table_name=table_name, - schema_name=schema, - database_name=database, - ) - if table is not None: - table.type = table_type - tables.append(table) - - return tables - - def _get_external_table_details( - self, *, table_name: str, schema_name: str, database_name: str - ) -> Optional[DataTable]: - """Describe an external-catalog table via ``DESC ..
``.""" - try: - from sqlalchemy import text - - qualified = ( - f"{_quote(database_name)}" - f".{_quote(schema_name)}" - f".{_quote(table_name)}" - ) - with self._connection.connect() as conn: - rows = conn.execute(text(f"DESC {qualified}")).fetchall() - except Exception: - LOGGER.warning( - "Failed to get details for %r.%r.%r", - database_name, - schema_name, - table_name, - exc_info=True, - ) - return None - - columns = [ - DataTableColumn( - name=str(row[0]), - type=sql_type_to_data_type(str(row[1])), - external_type=str(row[1]), - sample_values=[], - ) - for row in rows - ] - - return DataTable( - source_type="connection", - source=self.dialect, - name=table_name, - num_rows=None, - num_columns=len(columns), - variable_name=None, - engine=self._engine_name, - columns=columns, - primary_keys=[], - indexes=[], - ) diff --git a/marimo/_sql/get_engines.py b/marimo/_sql/get_engines.py index 8f263670849..7ddcdb23e90 100644 --- a/marimo/_sql/get_engines.py +++ b/marimo/_sql/get_engines.py @@ -22,7 +22,6 @@ from marimo._sql.engines.pyiceberg import PyIcebergEngine from marimo._sql.engines.redshift import RedshiftEngine from marimo._sql.engines.sqlalchemy import SQLAlchemyEngine -from marimo._sql.engines.starrocks import StarRocksEngine from marimo._sql.engines.types import ( BaseEngine, EngineCatalog, @@ -34,7 +33,6 @@ # TODO: this is O(n) and can be O(1) using similar logic to the # formatters, but order does matter here SUPPORTED_ENGINES: list[type[BaseEngine[Any]]] = [ - StarRocksEngine, SQLAlchemyEngine, IbisEngine, DuckDBEngine, diff --git a/pyproject.toml b/pyproject.toml index 77ba22caa6f..80aa6e39f6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "uv_build" [project] name = "marimo" -version = "0.21.1" +version = "0.22.0" description = "A library for making reactive notebooks and apps" # We try to keep dependencies to a minimum, to avoid conflicts with # user environments; we need a very compelling reason for each dependency added. @@ -52,6 +52,7 @@ dependencies = [ "msgspec>=0.20.0", # for IPC, required for marimo edit --sandbox, marimo run "pyzmq>=27.1.0; python_version < '3.15'", + "sqlalchemy>=2.0.48", ] readme = "README.md" license = { file = "LICENSE" } @@ -101,7 +102,7 @@ recommended = [ "marimo[sql]", "marimo[sandbox]", # For `marimo edit --sandbox DIRECTORY` "altair>=5.4.0", # Plotting in datasource viewer - "pydantic-ai-slim[openai]>=1.39.0", # AI features + "pydantic-ai-slim[openai]>=1.52.0", # AI features "ruff", # Formatting "nbformat>=5.7.0", # Export as IPYNB ] @@ -131,7 +132,7 @@ dev = [ # For linting "ruff>=0.14.0,<0.15.2", # TODO: remove upper bound once we fix the ruff issues # For AI - "pydantic-ai-slim[openai]>=1.52.0", + "pydantic-ai-slim[openai]>=1.71.0", ] test = [ @@ -146,6 +147,7 @@ test = [ "pytest-asyncio~=1.3.0", "pytest-picked>=0.5.1", "pytest-sugar~=1.1.1", + "pytest-xdist~=3.5", # Comparison testing "inline-snapshot~=0.29.0", "hypothesis~=6.102.1", @@ -181,6 +183,7 @@ test-optional = [ "chdb>=3; platform_system != 'Windows'", # there is no suitable wheel for windows "clickhouse-connect>=0.8.18", "redshift-connector[full]>=2.1.7", + # For testing starrocks "starrocks>=1.3.0", "pandas>=1.5.3", "hvplot~=0.11.3", @@ -192,7 +195,7 @@ test-optional = [ "anywidget~=0.9.21", "ipython~=8.12.3", # testing gen ai - "pydantic-ai-slim[google,anthropic,bedrock,openai]>=1.52.0", + "pydantic-ai-slim[google,anthropic,bedrock,openai]>=1.71.0", # - google-auth uses cachetools, and cachetools<5.0.0 uses collections.MutableMapping (removed in Python 3.10) "cachetools>=5.0.0", "boto3>=1.38.46", @@ -229,7 +232,7 @@ typecheck = [ "sqlalchemy>=2.0.40", "obstore>=0.8.2", "fsspec>=2026.2.0", - "pydantic-ai-slim[google,anthropic,bedrock,openai]>=1.52.0", + "pydantic-ai-slim[google,anthropic,bedrock,openai]>=1.71.0", "loro>=1.5.0", "boto3-stubs>=1.38.46", "pandas-stubs>=1.5.3.230321", diff --git a/tests/_sql/test_get_engines.py b/tests/_sql/test_get_engines.py index 6cedef6b583..f5a84a191a1 100644 --- a/tests/_sql/test_get_engines.py +++ b/tests/_sql/test_get_engines.py @@ -19,7 +19,6 @@ from marimo._sql.engines.ibis import IbisEngine from marimo._sql.engines.redshift import RedshiftEngine from marimo._sql.engines.sqlalchemy import SQLAlchemyEngine -from marimo._sql.engines.starrocks import StarRocksEngine from marimo._sql.get_engines import ( engine_to_data_source_connection, get_engines_from_variables, @@ -34,7 +33,6 @@ HAS_REDSHIFT = DependencyManager.redshift_connector.has() HAS_PYARROW = DependencyManager.pyarrow.has() HAS_IBIS = DependencyManager.ibis.has() -HAS_STARROCKS = DependencyManager.starrocks.has() @pytest.mark.skipif(not HAS_SQLALCHEMY, reason="SQLAlchemy not installed") @@ -87,21 +85,6 @@ def test_engine_to_data_source_connection() -> None: assert connection.name == "my_postgres" assert connection.display_name == "postgresql (my_postgres)" - # Test with StarRocks engine - mock_sr_engine = MagicMock() - mock_sr_engine.dialect.name = "starrocks" - sr_engine = StarRocksEngine( - mock_sr_engine, engine_name=VariableName("my_sr") - ) - connection = engine_to_data_source_connection( - VariableName("my_sr"), sr_engine - ) - assert isinstance(connection, DataSourceConnection) - assert connection.source == "starrocks" - assert connection.dialect == "starrocks" - assert connection.name == "my_sr" - assert connection.display_name == "starrocks (my_sr)" - # Test with Ibis engine var_name = "my_ibis" backend_name = "duckdb" @@ -492,23 +475,3 @@ def test_variables_without_datasource_engine() -> None: engines = get_engines_from_variables(variables) assert not engines - -@pytest.mark.skipif( - not (HAS_SQLALCHEMY and HAS_STARROCKS), - reason="SQLAlchemy and starrocks not installed", -) -def test_get_engines_starrocks() -> None: - import sqlalchemy as sa - import starrocks # noqa: F401 - - mock_engine = MagicMock(spec=sa.Engine) - mock_engine.dialect = MagicMock() - mock_engine.dialect.name = "starrocks" - variables: list[tuple[str, object]] = [("sr_engine", mock_engine)] - - engines = get_engines_from_variables(variables) - - assert len(engines) == 1 - var_name, engine = engines[0] - assert var_name == "sr_engine" - assert isinstance(engine, StarRocksEngine) diff --git a/tests/_sql/test_starrocks.py b/tests/_sql/test_starrocks.py deleted file mode 100644 index 8e413bdebe5..00000000000 --- a/tests/_sql/test_starrocks.py +++ /dev/null @@ -1,338 +0,0 @@ -# Copyright 2026 Marimo. All rights reserved. -from __future__ import annotations - -from typing import Any -from unittest.mock import MagicMock - -import pytest - -from marimo._sql.engines.starrocks import ( - _SYSTEM_SCHEMAS, - StarRocksEngine, -) -from marimo._sql.sql_quoting import quote_sql_identifier - -# Skip the entire module when sqlalchemy is not installed. -pytestmark = pytest.mark.requires("sqlalchemy") - - -def _make_mock_engine(dialect_name: str = "starrocks") -> MagicMock: - """Return a mock SQLAlchemy Engine with the given dialect name.""" - mock_engine = MagicMock() - mock_engine.dialect.name = dialect_name - return mock_engine - - -def _make_engine(dialect_name: str = "starrocks") -> StarRocksEngine: - return StarRocksEngine(_make_mock_engine(dialect_name), engine_name="sr") - - -def _mock_connection_ctx(engine: StarRocksEngine, side_effects: list[Any]): - """Patch _connection.connect() so that successive execute() calls return - the given side_effects in order (each item is the rows list for one call). - """ - conn_ctx = MagicMock() - conn = MagicMock() - conn_ctx.__enter__ = MagicMock(return_value=conn) - conn_ctx.__exit__ = MagicMock(return_value=False) - engine._connection.connect = MagicMock(return_value=conn_ctx) - - results = [] - for rows in side_effects: - result = MagicMock() - result.fetchone = MagicMock(return_value=rows[0] if rows else None) - result.fetchall = MagicMock(return_value=rows) - results.append(result) - - conn.execute = MagicMock(side_effect=results) - return conn - - -class TestIsCompatible: - @pytest.mark.requires("sqlalchemy", "starrocks") - def test_compatible_with_starrocks_dialect(self) -> None: - import sqlalchemy as sa - import starrocks # noqa: F401 - - mock_engine = MagicMock(spec=sa.Engine) - mock_engine.dialect = MagicMock() - mock_engine.dialect.name = "starrocks" - assert StarRocksEngine.is_compatible(mock_engine) - - @pytest.mark.requires("sqlalchemy", "starrocks") - def test_not_compatible_with_other_dialects(self) -> None: - import sqlalchemy as sa - - for dialect in ("mysql", "postgresql", "sqlite", "clickhouse"): - mock_engine = MagicMock(spec=sa.Engine) - mock_engine.dialect = MagicMock() - mock_engine.dialect.name = dialect - assert not StarRocksEngine.is_compatible(mock_engine) - - @pytest.mark.requires("sqlalchemy", "starrocks") - def test_not_compatible_with_non_engine(self) -> None: - assert not StarRocksEngine.is_compatible("not_an_engine") - assert not StarRocksEngine.is_compatible(42) - assert not StarRocksEngine.is_compatible(None) - - -class TestSourceAndDialect: - def test_source(self) -> None: - engine = _make_engine() - assert engine.source == "starrocks" - - def test_dialect(self) -> None: - engine = _make_engine() - assert engine.dialect == "starrocks" - - -class TestDefaults: - def test_get_default_database(self) -> None: - engine = _make_engine() - _mock_connection_ctx(engine, [[("default_catalog",)]]) - assert engine.get_default_database() == "default_catalog" - - def test_get_default_database_none_on_error(self) -> None: - engine = _make_engine() - engine._connection.connect.side_effect = Exception("connection failed") - assert engine.get_default_database() is None - - def test_get_default_schema(self) -> None: - # get_default_schema() is inherited from SQLAlchemyEngine and tries - # inspector.default_schema_name first. - engine = _make_engine() - engine.inspector = MagicMock() - engine.inspector.default_schema_name = "my_db" - assert engine.get_default_schema() == "my_db" - - def test_get_default_schema_none_on_error(self) -> None: - engine = _make_engine() - engine.inspector = None - engine._connection.connect.side_effect = Exception("connection failed") - assert engine.get_default_schema() is None - - -class TestListCatalogs: - def test_lists_all_catalogs(self) -> None: - engine = _make_engine() - rows = [ - ("default_catalog",), - ("hive_catalog",), - ("iceberg_catalog",), - ] - _mock_connection_ctx(engine, [rows]) - result = engine._list_catalogs() - assert result == ["default_catalog", "hive_catalog", "iceberg_catalog"] - - def test_returns_empty_on_error(self) -> None: - engine = _make_engine() - engine._connection.connect.side_effect = Exception("oops") - assert engine._list_catalogs() == [] - - -class TestExternalSchemas: - def test_lists_databases_excluding_system(self) -> None: - engine = _make_engine() - rows = [ - ("tpch",), - ("analytics",), - ("information_schema",), # excluded - ("sys",), # excluded - ("_statistics_",), # excluded - ] - _mock_connection_ctx(engine, [rows]) - schemas = engine._get_external_schemas( - catalog="hive_catalog", - include_tables=False, - include_table_details=False, - ) - assert [s.name for s in schemas] == ["tpch", "analytics"] - - def test_returns_empty_on_error(self) -> None: - engine = _make_engine() - engine._connection.connect.side_effect = Exception("oops") - assert ( - engine._get_external_schemas( - catalog="hive_catalog", - include_tables=False, - include_table_details=False, - ) - == [] - ) - - -class TestGetDatabases: - def test_returns_catalogs_with_empty_schemas(self) -> None: - """get_databases() lists catalogs only; schemas are fetched lazily.""" - engine = _make_engine() - catalogs_rows = [("default_catalog",), ("hive_catalog",)] - _mock_connection_ctx(engine, [catalogs_rows]) - - databases = engine.get_databases( - include_schemas=True, - include_tables=False, - include_table_details=False, - ) - - assert len(databases) == 2 - assert databases[0].name == "default_catalog" - assert databases[1].name == "hive_catalog" - # Schemas are always empty — lazy loading handles them - for db in databases: - assert db.schemas == [] - assert db.dialect == "starrocks" - assert db.engine == "sr" - - def test_returns_empty_on_error(self) -> None: - engine = _make_engine() - engine._connection.connect.side_effect = Exception("oops") - databases = engine.get_databases( - include_schemas=False, - include_tables=False, - include_table_details=False, - ) - assert databases == [] - - -class TestGetSchemas: - def test_external_catalog_returns_schemas(self) -> None: - """get_schemas() for an external catalog uses SHOW DATABASES.""" - engine = _make_engine() - rows = [ - ("tpch",), - ("analytics",), - ("information_schema",), # excluded - ("sys",), # excluded - ] - _mock_connection_ctx(engine, [rows]) - - schemas = engine.get_schemas( - database="hive_catalog", - include_tables=False, - include_table_details=False, - ) - assert [s.name for s in schemas] == ["tpch", "analytics"] - - def test_returns_empty_for_none_database(self) -> None: - engine = _make_engine() - schemas = engine.get_schemas( - database=None, - include_tables=False, - include_table_details=False, - ) - assert schemas == [] - - def test_returns_empty_on_error(self) -> None: - engine = _make_engine() - engine._connection.connect.side_effect = Exception("oops") - schemas = engine.get_schemas( - database="hive_catalog", - include_tables=False, - include_table_details=False, - ) - assert schemas == [] - - -class TestGetTablesInSchema: - def test_returns_tables_and_views(self) -> None: - engine = _make_engine() - # SHOW FULL TABLES returns (Tables_in_, Table_type) - rows = [ - ("orders", "BASE TABLE"), - ("lineitem", "BASE TABLE"), - ("revenue_view", "VIEW"), - ] - _mock_connection_ctx(engine, [rows]) - - tables = engine.get_tables_in_schema( - schema="tpch", - database="default_catalog", - include_table_details=False, - ) - - assert len(tables) == 3 - names = [t.name for t in tables] - assert "orders" in names - assert "lineitem" in names - assert "revenue_view" in names - view = next(t for t in tables if t.name == "revenue_view") - assert view.type == "view" - base = next(t for t in tables if t.name == "orders") - assert base.type == "table" - # No columns without details - assert base.columns == [] - - def test_returns_empty_on_error(self) -> None: - engine = _make_engine() - engine._connection.connect.side_effect = Exception("fail") - result = engine.get_tables_in_schema( - schema="tpch", - database="default_catalog", - include_table_details=False, - ) - assert result == [] - - -class TestGetTableDetails: - def test_returns_columns(self) -> None: - engine = _make_engine() - # DESC output: Field, Type, Null, Key, Default, Extra, Comment - rows = [ - ("id", "INT", "YES", "", None, "", ""), - ("name", "VARCHAR(255)", "YES", "", None, "", ""), - ("created_at", "DATETIME", "YES", "", None, "", ""), - ("score", "DOUBLE", "YES", "", None, "", ""), - ("is_active", "BOOLEAN", "YES", "", None, "", ""), - ] - _mock_connection_ctx(engine, [rows]) - - table = engine.get_table_details( - table_name="orders", - schema_name="tpch", - database_name="default_catalog", - ) - - assert table is not None - assert table.name == "orders" - assert table.num_columns == 5 - assert len(table.columns) == 5 - - types = {c.name: c.type for c in table.columns} - assert types["id"] == "integer" - assert types["name"] == "string" - assert types["created_at"] == "datetime" - assert types["score"] == "number" - assert types["is_active"] == "boolean" - - def test_returns_none_on_error(self) -> None: - engine = _make_engine() - engine._connection.connect.side_effect = Exception("fail") - result = engine.get_table_details( - table_name="orders", - schema_name="tpch", - database_name="default_catalog", - ) - assert result is None - - -class TestStarRocksQuoting: - def test_starrocks_uses_backtick_style(self) -> None: - assert ( - quote_sql_identifier("my_catalog", dialect="starrocks") - == "`my_catalog`" - ) - assert ( - quote_sql_identifier("catalog`with`ticks", dialect="starrocks") - == "`catalog``with``ticks`" - ) - assert ( - quote_sql_identifier("catalog with spaces", dialect="starrocks") - == "`catalog with spaces`" - ) - - -class TestSystemConstants: - def test_system_schemas_excluded(self) -> None: - assert "information_schema" in _SYSTEM_SCHEMAS - assert "sys" in _SYSTEM_SCHEMAS - assert "_statistics_" in _SYSTEM_SCHEMAS From df32d7846bd54e77bff199b7f931c480878c18a6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Apr 2026 19:31:29 +0000 Subject: [PATCH 10/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- marimo/_sql/engines/sqlalchemy.py | 2 +- tests/_sql/test_get_engines.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/marimo/_sql/engines/sqlalchemy.py b/marimo/_sql/engines/sqlalchemy.py index 5cc1216db74..ac2b6c9e1cc 100644 --- a/marimo/_sql/engines/sqlalchemy.py +++ b/marimo/_sql/engines/sqlalchemy.py @@ -124,7 +124,7 @@ def _quote_identifier(self, identifier: str) -> str: """Quote an identifier based on the SQL dialect's quoting rules.""" dialect_quoting: dict[str, tuple[re.Pattern[str], str, str]] = { "snowflake": (_SNOWFLAKE_NEEDS_QUOTING_RE, '"', '"'), - "starrocks": (_SNOWFLAKE_NEEDS_QUOTING_RE, '`', '`'), + "starrocks": (_SNOWFLAKE_NEEDS_QUOTING_RE, "`", "`"), } if self.dialect not in dialect_quoting: diff --git a/tests/_sql/test_get_engines.py b/tests/_sql/test_get_engines.py index f5a84a191a1..62c580d0e2a 100644 --- a/tests/_sql/test_get_engines.py +++ b/tests/_sql/test_get_engines.py @@ -474,4 +474,3 @@ def test_variables_without_datasource_engine() -> None: variables = [("deferred_for_test", deferred_for_test)] engines = get_engines_from_variables(variables) assert not engines - From 2884c497cd5b177c7b347b3e6b74cc79ce11bf10 Mon Sep 17 00:00:00 2001 From: Christopher Sapinski Date: Wed, 1 Apr 2026 12:41:14 -0700 Subject: [PATCH 11/11] remove import --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4bba54b51bf..e829fb3f64c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,6 @@ dependencies = [ "msgspec>=0.20.0", # for IPC, required for marimo edit --sandbox, marimo run "pyzmq>=27.1.0; python_version < '3.15'", - "sqlalchemy>=2.0.48", ] readme = "README.md" license = { file = "LICENSE" }