diff --git a/marimo/_dependencies/dependencies.py b/marimo/_dependencies/dependencies.py index 169733de5a3..d0e5f951b94 100644 --- a/marimo/_dependencies/dependencies.py +++ b/marimo/_dependencies/dependencies.py @@ -248,6 +248,7 @@ class DependencyManager: boto3 = Dependency("boto3") redshift_connector = Dependency("redshift_connector") + starrocks = Dependency("starrocks") mcp = Dependency("mcp") pydantic_ai = Dependency( "pydantic_ai", pkg_name_to_install="pydantic-ai-slim" diff --git a/marimo/_sql/engines/sqlalchemy.py b/marimo/_sql/engines/sqlalchemy.py index 638e7ab186f..ac2b6c9e1cc 100644 --- a/marimo/_sql/engines/sqlalchemy.py +++ b/marimo/_sql/engines/sqlalchemy.py @@ -124,6 +124,7 @@ def _quote_identifier(self, identifier: str) -> str: """Quote an identifier based on the SQL dialect's quoting rules.""" dialect_quoting: dict[str, tuple[re.Pattern[str], str, str]] = { "snowflake": (_SNOWFLAKE_NEEDS_QUOTING_RE, '"', '"'), + "starrocks": (_SNOWFLAKE_NEEDS_QUOTING_RE, "`", "`"), } if self.dialect not in dialect_quoting: @@ -160,6 +161,7 @@ def _get_inspector(self, database: str) -> Iterator[Optional[Inspector]]: _use_database_dialect_command: dict[str, str] = { "snowflake": f"USE DATABASE {self._quote_identifier(database)}", + "starrocks": f"SET CATALOG {self._quote_identifier(database)}", } dialect_command = _use_database_dialect_command.get(self.dialect) @@ -257,6 +259,7 @@ def get_default_database(self) -> Optional[str]: "postgresql": "SELECT current_database()", "mssql": "SELECT DB_NAME()", "timeplus": "SELECT current_database()", + "starrocks": "SELECT CATALOG()", } # Try to get the database name by querying the database directly @@ -350,6 +353,18 @@ def _get_snowflake_database_names(self) -> list[str]: return database_names + def _get_starrocks_database_names(self) -> list[str]: + """Get catalog names for StarRocks via 'SHOW CATALOGS'. + + StarRocks uses a three-level hierarchy (Catalog → Database → Table) + which maps to marimo's (Database → Schema → Table). + """ + from sqlalchemy import text + + with self._connection.connect() as connection: + result = connection.execute(text("SHOW CATALOGS")) + return [str(row[0]) for row in result.fetchall()] + @safe_execute( fallback=[], message="Failed to get database names", @@ -361,8 +376,11 @@ def _get_database_names(self) -> list[str]: Returns a single-element list with the default database when the dialect has no dedicated discovery mechanism. """ - if self.dialect.lower() == "snowflake": + dialect = self.dialect.lower() + if dialect == "snowflake": return self._get_snowflake_database_names() + if dialect == "starrocks": + return self._get_starrocks_database_names() return [self.default_database] if self.default_database else [] @@ -468,6 +486,8 @@ def _get_meta_schemas(self) -> list[str]: dialect = self.dialect.lower() if dialect == "postgresql": return ["information_schema", "pg_catalog"] + if dialect == "starrocks": + return ["information_schema", "sys", "_statistics_"] return ["information_schema"] # -------------------------------------------------------------- # diff --git a/marimo/_sql/sql.py b/marimo/_sql/sql.py index 6d8d3b284c3..19d74e68381 100644 --- a/marimo/_sql/sql.py +++ b/marimo/_sql/sql.py @@ -77,7 +77,7 @@ def sql( break else: raise ValueError( - "Unsupported engine. Must be a SQLAlchemy, Ibis, Clickhouse, DuckDB, Redshift or DBAPI 2.0 compatible engine." + "Unsupported engine. Must be a SQLAlchemy, Ibis, Clickhouse, DuckDB, Redshift, StarRocks or DBAPI 2.0 compatible engine." ) try: diff --git a/marimo/_sql/sql_quoting.py b/marimo/_sql/sql_quoting.py index 1d20977b5db..b9d952d0bde 100644 --- a/marimo/_sql/sql_quoting.py +++ b/marimo/_sql/sql_quoting.py @@ -12,7 +12,7 @@ def quote_sql_identifier(identifier: str, *, dialect: str = "duckdb") -> str: identifier: The raw identifier string (database, schema, or table name). dialect: The SQL dialect. Double-quote style: "duckdb", "redshift", "postgresql"/"postgres". - Backtick style: "clickhouse", "mysql", "bigquery". + Backtick style: "clickhouse", "mysql", "bigquery", "starrocks". Unknown dialects return the identifier unquoted. Returns: @@ -22,7 +22,7 @@ def quote_sql_identifier(identifier: str, *, dialect: str = "duckdb") -> str: # Double-quote style: escape embedded " as "" escaped = identifier.replace('"', '""') return f'"{escaped}"' - elif dialect in ("clickhouse", "mysql", "bigquery"): + elif dialect in ("clickhouse", "mysql", "bigquery", "starrocks"): # Backtick style: escape embedded ` as `` escaped = identifier.replace("`", "``") return f"`{escaped}`" diff --git a/pyproject.toml b/pyproject.toml index 2cd03b1f388..e829fb3f64c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -182,6 +182,8 @@ test-optional = [ "chdb>=3; platform_system != 'Windows'", # there is no suitable wheel for windows "clickhouse-connect>=0.8.18", "redshift-connector[full]>=2.1.7", + # For testing starrocks + "starrocks>=1.3.0", "pandas>=1.5.3", "hvplot~=0.11.3", "geopandas>=1.1.0", @@ -440,6 +442,7 @@ banned-module-level-imports = [ "typing_extensions", "pyiceberg", "redshift_connector", + "starrocks", "pydantic_ai", # a top-level import of zmq may break WASM "zmq" diff --git a/tests/_sql/test_sql_quoting.py b/tests/_sql/test_sql_quoting.py index a5be8661b01..7e60ac28be7 100644 --- a/tests/_sql/test_sql_quoting.py +++ b/tests/_sql/test_sql_quoting.py @@ -46,6 +46,12 @@ class TestQuoteSqlIdentifier: ("table", "bigquery", "`table`"), ("my table", "bigquery", "`my table`"), ("has`backtick", "bigquery", "`has``backtick`"), + # StarRocks uses backtick style + ("table", "starrocks", "`table`"), + ("my table", "starrocks", "`my table`"), + ("nested.namespace", "starrocks", "`nested.namespace`"), + ("has`backtick", "starrocks", "`has``backtick`"), + ('has"quotes', "starrocks", '`has"quotes`'), # Unknown dialect returns unquoted ("table", "sqlite", "table"), ("my table", "unknown", "my table"), @@ -98,6 +104,24 @@ def test_clickhouse_roundtrip_safe(self, identifier: str) -> None: inner = quoted[1:-1] assert inner.replace("``", "`") == identifier + @pytest.mark.parametrize( + "identifier", + [ + "simple", + "with spaces", + "with.dots", + "with`backticks", + 'with"quotes', + ], + ) + def test_starrocks_roundtrip_safe(self, identifier: str) -> None: + """Verify that quoting an identifier produces valid StarRocks syntax.""" + quoted = quote_sql_identifier(identifier, dialect="starrocks") + assert quoted.startswith("`") + assert quoted.endswith("`") + inner = quoted[1:-1] + assert inner.replace("``", "`") == identifier + class TestQuoteQualifiedName: @pytest.mark.parametrize( @@ -127,6 +151,12 @@ class TestQuoteQualifiedName: "redshift", '"catalog"."schema"."table"', ), + # StarRocks 3-part name (catalog.database.table) + ( + ("iceberg_catalog", "tpch", "orders"), + "starrocks", + "`iceberg_catalog`.`tpch`.`orders`", + ), # Single part ( ("just_table",),