Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions marimo/_dependencies/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ class DependencyManager:
boto3 = Dependency("boto3")

redshift_connector = Dependency("redshift_connector")
starrocks = Dependency("starrocks")
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DependencyManager.starrocks is introduced but appears unused (no references in the repo). If StarRocks isn’t being explicitly dependency-gated anywhere, consider removing this entry; otherwise, add the corresponding require_many/feature checks so this dependency record has an effect.

Suggested change
starrocks = Dependency("starrocks")

Copilot uses AI. Check for mistakes.
mcp = Dependency("mcp")
pydantic_ai = Dependency(
"pydantic_ai", pkg_name_to_install="pydantic-ai-slim"
Expand Down
22 changes: 21 additions & 1 deletion marimo/_sql/engines/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def _quote_identifier(self, identifier: str) -> str:
"""Quote an identifier based on the SQL dialect's quoting rules."""
dialect_quoting: dict[str, tuple[re.Pattern[str], str, str]] = {
"snowflake": (_SNOWFLAKE_NEEDS_QUOTING_RE, '"', '"'),
"starrocks": (_SNOWFLAKE_NEEDS_QUOTING_RE, "`", "`"),
}

if self.dialect not in dialect_quoting:
Expand Down Expand Up @@ -160,6 +161,7 @@ def _get_inspector(self, database: str) -> Iterator[Optional[Inspector]]:

_use_database_dialect_command: dict[str, str] = {
"snowflake": f"USE DATABASE {self._quote_identifier(database)}",
"starrocks": f"SET CATALOG {self._quote_identifier(database)}",
}
dialect_command = _use_database_dialect_command.get(self.dialect)

Expand Down Expand Up @@ -257,6 +259,7 @@ def get_default_database(self) -> Optional[str]:
"postgresql": "SELECT current_database()",
"mssql": "SELECT DB_NAME()",
"timeplus": "SELECT current_database()",
"starrocks": "SELECT CATALOG()",
}

# Try to get the database name by querying the database directly
Expand Down Expand Up @@ -350,6 +353,18 @@ def _get_snowflake_database_names(self) -> list[str]:

return database_names

def _get_starrocks_database_names(self) -> list[str]:
"""Get catalog names for StarRocks via 'SHOW CATALOGS'.

StarRocks uses a three-level hierarchy (Catalog → Database → Table)
which maps to marimo's (Database → Schema → Table).
"""
Comment on lines +356 to +361
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description says it “introduces a new engine which extends SQLAlchemyEngine”, but there doesn’t appear to be a StarRocksEngine implementation in the codebase; instead the StarRocks behavior is implemented directly in SQLAlchemyEngine. Please either update the PR description to match the implementation or add the intended engine subclass + registration (if that’s still the plan).

Copilot uses AI. Check for mistakes.
from sqlalchemy import text

with self._connection.connect() as connection:
result = connection.execute(text("SHOW CATALOGS"))
return [str(row[0]) for row in result.fetchall()]

Comment on lines +366 to +367
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_starrocks_database_names() assumes the catalog name is always in row[0]. To make this more robust across dialect/driver variations, consider selecting the column by name (using result.keys() like the Snowflake implementation) and raising a clear error if the expected column is missing.

Suggested change
return [str(row[0]) for row in result.fetchall()]
columns = list(result.keys())
# StarRocks generally returns a "Catalog" column for SHOW CATALOGS,
# but we defensively support a few plausible variants.
candidate_columns = ("Catalog", "catalog_name", "name")
name_col_index: Optional[int] = None
for col_name in candidate_columns:
if col_name in columns:
name_col_index = columns.index(col_name)
break
if name_col_index is None:
raise RuntimeError(
"Unexpected SHOW CATALOGS result: expected one of "
f"{candidate_columns!r} columns, but got {columns!r}"
)
return [str(row[name_col_index]) for row in result.fetchall()]

Copilot uses AI. Check for mistakes.
@safe_execute(
fallback=[],
message="Failed to get database names",
Expand All @@ -361,8 +376,11 @@ def _get_database_names(self) -> list[str]:
Returns a single-element list with the default database when
the dialect has no dedicated discovery mechanism.
"""
if self.dialect.lower() == "snowflake":
dialect = self.dialect.lower()
if dialect == "snowflake":
return self._get_snowflake_database_names()
if dialect == "starrocks":
return self._get_starrocks_database_names()
Comment on lines 356 to +383
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The StarRocks-specific paths (SET CATALOG in _get_inspector, SHOW CATALOGS discovery, and the starrocks branch in _get_database_names) aren’t covered by existing SQLAlchemyEngine tests. Since this file already has unit tests (including Snowflake mocking), consider adding a mocked StarRocks test to assert the emitted SQL and that the discovery branch is exercised.

Copilot uses AI. Check for mistakes.

return [self.default_database] if self.default_database else []

Expand Down Expand Up @@ -468,6 +486,8 @@ def _get_meta_schemas(self) -> list[str]:
dialect = self.dialect.lower()
if dialect == "postgresql":
return ["information_schema", "pg_catalog"]
if dialect == "starrocks":
return ["information_schema", "sys", "_statistics_"]
return ["information_schema"]

# -------------------------------------------------------------- #
Expand Down
2 changes: 1 addition & 1 deletion marimo/_sql/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def sql(
break
else:
raise ValueError(
"Unsupported engine. Must be a SQLAlchemy, Ibis, Clickhouse, DuckDB, Redshift or DBAPI 2.0 compatible engine."
"Unsupported engine. Must be a SQLAlchemy, Ibis, Clickhouse, DuckDB, Redshift, StarRocks or DBAPI 2.0 compatible engine."
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The updated error message lists StarRocks as a separate supported engine type, but StarRocks connections are still surfaced via the generic SQLAlchemy engine in SUPPORTED_ENGINES. Consider rewording to avoid implying a distinct engine class (e.g., “SQLAlchemy (including StarRocks dialect)” or similar).

Suggested change
"Unsupported engine. Must be a SQLAlchemy, Ibis, Clickhouse, DuckDB, Redshift, StarRocks or DBAPI 2.0 compatible engine."
"Unsupported engine. Must be a SQLAlchemy (including StarRocks dialect), Ibis, Clickhouse, DuckDB, Redshift, or DBAPI 2.0 compatible engine."

Copilot uses AI. Check for mistakes.
)

try:
Expand Down
4 changes: 2 additions & 2 deletions marimo/_sql/sql_quoting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def quote_sql_identifier(identifier: str, *, dialect: str = "duckdb") -> str:
identifier: The raw identifier string (database, schema, or table name).
dialect: The SQL dialect.
Double-quote style: "duckdb", "redshift", "postgresql"/"postgres".
Backtick style: "clickhouse", "mysql", "bigquery".
Backtick style: "clickhouse", "mysql", "bigquery", "starrocks".
Unknown dialects return the identifier unquoted.

Returns:
Expand All @@ -22,7 +22,7 @@ def quote_sql_identifier(identifier: str, *, dialect: str = "duckdb") -> str:
# Double-quote style: escape embedded " as ""
escaped = identifier.replace('"', '""')
return f'"{escaped}"'
elif dialect in ("clickhouse", "mysql", "bigquery"):
elif dialect in ("clickhouse", "mysql", "bigquery", "starrocks"):
# Backtick style: escape embedded ` as ``
escaped = identifier.replace("`", "``")
return f"`{escaped}`"
Comment thread
chris-celerdata marked this conversation as resolved.
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ test-optional = [
"chdb>=3; platform_system != 'Windows'", # there is no suitable wheel for windows
"clickhouse-connect>=0.8.18",
"redshift-connector[full]>=2.1.7",
# For testing starrocks
"starrocks>=1.3.0",
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

starrocks>=1.3.0 is added to the test-optional dependency group, but the current test changes don’t import or require the StarRocks package (quoting tests are pure string logic). If there aren’t additional StarRocks tests elsewhere in the PR, consider removing this optional dependency until it’s needed, to keep the optional test environment lighter.

Suggested change
"starrocks>=1.3.0",

Copilot uses AI. Check for mistakes.
"pandas>=1.5.3",
"hvplot~=0.11.3",
"geopandas>=1.1.0",
Expand Down Expand Up @@ -440,6 +442,7 @@ banned-module-level-imports = [
"typing_extensions",
"pyiceberg",
"redshift_connector",
"starrocks",
"pydantic_ai",
# a top-level import of zmq may break WASM
"zmq"
Expand Down
30 changes: 30 additions & 0 deletions tests/_sql/test_sql_quoting.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ class TestQuoteSqlIdentifier:
("table", "bigquery", "`table`"),
("my table", "bigquery", "`my table`"),
("has`backtick", "bigquery", "`has``backtick`"),
# StarRocks uses backtick style
("table", "starrocks", "`table`"),
("my table", "starrocks", "`my table`"),
("nested.namespace", "starrocks", "`nested.namespace`"),
("has`backtick", "starrocks", "`has``backtick`"),
('has"quotes', "starrocks", '`has"quotes`'),
# Unknown dialect returns unquoted
("table", "sqlite", "table"),
("my table", "unknown", "my table"),
Expand Down Expand Up @@ -98,6 +104,24 @@ def test_clickhouse_roundtrip_safe(self, identifier: str) -> None:
inner = quoted[1:-1]
assert inner.replace("``", "`") == identifier

@pytest.mark.parametrize(
"identifier",
[
"simple",
"with spaces",
"with.dots",
"with`backticks",
'with"quotes',
],
)
def test_starrocks_roundtrip_safe(self, identifier: str) -> None:
"""Verify that quoting an identifier produces valid StarRocks syntax."""
quoted = quote_sql_identifier(identifier, dialect="starrocks")
assert quoted.startswith("`")
assert quoted.endswith("`")
inner = quoted[1:-1]
assert inner.replace("``", "`") == identifier


class TestQuoteQualifiedName:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -127,6 +151,12 @@ class TestQuoteQualifiedName:
"redshift",
'"catalog"."schema"."table"',
),
# StarRocks 3-part name (catalog.database.table)
(
("iceberg_catalog", "tpch", "orders"),
"starrocks",
"`iceberg_catalog`.`tpch`.`orders`",
),
# Single part
(
("just_table",),
Expand Down
Loading