Skip to content

Commit ac68bbc

Browse files
committed
feat: add HotdataMarimoEngine for mo.sql
Register a Marimo SQL engine backed by HotdataClient, including catalog introspection for the Data Sources panel and a display-name patch so the connection shows as Hotdata in the UI.
1 parent 2183373 commit ac68bbc

3 files changed

Lines changed: 405 additions & 0 deletions

File tree

hotdata_marimo/sql_engine.py

Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
"""Marimo ``mo.sql`` engine integration for :class:`~hotdata_runtime.HotdataClient`."""
2+
3+
from __future__ import annotations
4+
5+
from collections import defaultdict
6+
from typing import Any, Literal
7+
8+
from hotdata.models.table_info import TableInfo
9+
from hotdata_runtime import HotdataClient
10+
11+
from marimo import _loggers
12+
from marimo._data.models import (
13+
Database,
14+
DataSourceConnection,
15+
DataTable,
16+
DataTableColumn,
17+
DataType,
18+
Schema,
19+
)
20+
from marimo._sql.engines.types import InferenceConfig, SQLConnection
21+
from marimo._sql.utils import convert_to_output, sql_type_to_data_type
22+
from marimo._types.ids import VariableName
23+
24+
LOGGER = _loggers.marimo_logger()
25+
26+
27+
def _table_schema_name(t: TableInfo) -> str:
28+
return str(t.var_schema)
29+
30+
31+
class HotdataMarimoEngine(SQLConnection[HotdataClient]):
32+
"""Marimo :class:`~marimo._sql.engines.types.SQLConnection` backed by Hotdata.
33+
34+
Catalog methods support Marimo's Data Sources panel. ``execute()`` only runs SQL
35+
via :meth:`~hotdata_runtime.HotdataClient.execute_sql` (no catalog calls in that path).
36+
"""
37+
38+
@property
39+
def source(self) -> str:
40+
return "hotdata"
41+
42+
@property
43+
def dialect(self) -> str:
44+
# Marimo labels engines as ``{dialect} ({variable_name})``; display_name is patched to "Hotdata".
45+
return "hotdata"
46+
47+
@staticmethod
48+
def is_compatible(var: Any) -> bool:
49+
return isinstance(var, HotdataClient)
50+
51+
@property
52+
def inference_config(self) -> InferenceConfig:
53+
return InferenceConfig(
54+
auto_discover_schemas=True,
55+
auto_discover_tables="auto",
56+
auto_discover_columns="auto",
57+
)
58+
59+
def _resolve_should_auto_discover(
60+
self,
61+
value: bool | Literal["auto"],
62+
) -> bool:
63+
if value == "auto":
64+
return True
65+
return value
66+
67+
def _connection_ids(self) -> dict[str, str]:
68+
out: dict[str, str] = {}
69+
for c in self._connection.connections().list_connections().connections:
70+
out[str(c.name)] = str(c.id)
71+
return out
72+
73+
def _connection_id(self, connection_name: str) -> str | None:
74+
return self._connection_ids().get(connection_name)
75+
76+
def _iter_grouped(
77+
self,
78+
*,
79+
connection_id: str | None,
80+
include_columns: bool,
81+
) -> dict[str, dict[str, list[TableInfo]]]:
82+
grouped: dict[str, dict[str, list[TableInfo]]] = defaultdict(
83+
lambda: defaultdict(list)
84+
)
85+
for t in self._connection.iter_tables(
86+
connection_id=connection_id,
87+
include_columns=include_columns,
88+
):
89+
grouped[str(t.connection)][_table_schema_name(t)].append(t)
90+
return grouped
91+
92+
def get_default_database(self) -> str | None:
93+
listing = self._connection.connections().list_connections().connections
94+
if not listing:
95+
return None
96+
return str(listing[0].name)
97+
98+
def get_default_schema(self) -> str | None:
99+
return None
100+
101+
def get_databases(
102+
self,
103+
*,
104+
include_schemas: bool | Literal["auto"],
105+
include_tables: bool | Literal["auto"],
106+
include_table_details: bool | Literal["auto"],
107+
) -> list[Database]:
108+
databases: list[Database] = []
109+
for c in self._connection.connections().list_connections().connections:
110+
name = str(c.name)
111+
if self._resolve_should_auto_discover(include_schemas):
112+
schemas = self.get_schemas(
113+
database=name,
114+
include_tables=self._resolve_should_auto_discover(
115+
include_tables
116+
),
117+
include_table_details=self._resolve_should_auto_discover(
118+
include_table_details
119+
),
120+
)
121+
else:
122+
schemas = []
123+
databases.append(
124+
Database(
125+
name=name,
126+
dialect=self.dialect,
127+
schemas=schemas,
128+
engine=self._engine_name,
129+
)
130+
)
131+
return databases
132+
133+
def get_schemas(
134+
self,
135+
*,
136+
database: str | None,
137+
include_tables: bool,
138+
include_table_details: bool,
139+
) -> list[Schema]:
140+
if not database:
141+
return []
142+
conn_id = self._connection_id(database)
143+
if conn_id is None:
144+
LOGGER.warning("Unknown Hotdata connection name %r", database)
145+
return []
146+
grouped = self._iter_grouped(
147+
connection_id=conn_id,
148+
include_columns=include_table_details,
149+
)
150+
inner = grouped.get(database, {})
151+
schemas: list[Schema] = []
152+
for schema_name in sorted(inner.keys()):
153+
tables: list[DataTable] = []
154+
if include_tables:
155+
tables = self.get_tables_in_schema(
156+
schema=schema_name,
157+
database=database,
158+
include_table_details=include_table_details,
159+
)
160+
if not tables:
161+
continue
162+
schemas.append(Schema(name=schema_name, tables=tables))
163+
return schemas
164+
165+
def _data_table_from_table_info(self, t: TableInfo) -> DataTable:
166+
cols: list[DataTableColumn] = []
167+
for col in t.columns or []:
168+
cols.append(
169+
DataTableColumn(
170+
name=str(col.name),
171+
type=sql_type_to_data_type(str(col.data_type)),
172+
external_type=str(col.data_type),
173+
sample_values=[],
174+
)
175+
)
176+
return DataTable(
177+
source_type="connection",
178+
source=self.source,
179+
name=str(t.table),
180+
num_rows=None,
181+
num_columns=len(cols) if cols else None,
182+
variable_name=None,
183+
engine=self._engine_name,
184+
type="table",
185+
columns=cols,
186+
primary_keys=None,
187+
indexes=None,
188+
)
189+
190+
def get_tables_in_schema(
191+
self,
192+
*,
193+
schema: str,
194+
database: str,
195+
include_table_details: bool,
196+
) -> list[DataTable]:
197+
conn_id = self._connection_id(database)
198+
if conn_id is None:
199+
return []
200+
grouped = self._iter_grouped(
201+
connection_id=conn_id,
202+
include_columns=include_table_details,
203+
)
204+
tables_info = grouped.get(database, {}).get(schema, [])
205+
out: list[DataTable] = []
206+
for t in sorted(tables_info, key=lambda x: str(x.table)):
207+
if include_table_details:
208+
if t.columns:
209+
out.append(self._data_table_from_table_info(t))
210+
continue
211+
dt = self.get_table_details(
212+
table_name=str(t.table),
213+
schema_name=schema,
214+
database_name=database,
215+
)
216+
if dt is not None:
217+
out.append(dt)
218+
else:
219+
out.append(
220+
DataTable(
221+
source_type="connection",
222+
source=self.source,
223+
name=str(t.table),
224+
num_rows=None,
225+
num_columns=len(t.columns or []) if t.columns else None,
226+
variable_name=None,
227+
engine=self._engine_name,
228+
type="table",
229+
columns=[],
230+
primary_keys=None,
231+
indexes=None,
232+
)
233+
)
234+
return out
235+
236+
def get_table_details(
237+
self,
238+
*,
239+
table_name: str,
240+
schema_name: str,
241+
database_name: str,
242+
) -> DataTable | None:
243+
conn_id = self._connection_id(database_name)
244+
if conn_id is None:
245+
return None
246+
qualified = f"{database_name}.{schema_name}.{table_name}"
247+
try:
248+
cols_raw = self._connection.columns_for_qualified(
249+
qualified, connection_id=conn_id
250+
)
251+
except Exception:
252+
LOGGER.warning(
253+
"Failed to load columns for %s",
254+
qualified,
255+
exc_info=True,
256+
)
257+
return None
258+
cols: list[DataTableColumn] = []
259+
for col in cols_raw:
260+
cols.append(
261+
DataTableColumn(
262+
name=str(col.name),
263+
type=sql_type_to_data_type(str(col.data_type)),
264+
external_type=str(col.data_type),
265+
sample_values=[],
266+
)
267+
)
268+
return DataTable(
269+
source_type="connection",
270+
source=self.source,
271+
name=table_name,
272+
num_rows=None,
273+
num_columns=len(cols),
274+
variable_name=None,
275+
engine=self._engine_name,
276+
type="table",
277+
columns=cols,
278+
primary_keys=None,
279+
indexes=None,
280+
)
281+
282+
def execute(self, query: str) -> Any:
283+
qr = self._connection.execute_sql(query)
284+
fmt = self.sql_output_format()
285+
286+
def to_polars() -> Any:
287+
import polars as pl
288+
289+
if not qr.columns:
290+
return pl.DataFrame()
291+
return pl.DataFrame(qr.rows, schema=qr.columns, orient="row")
292+
293+
return convert_to_output(
294+
sql_output_format=fmt,
295+
to_polars=to_polars,
296+
to_pandas=qr.to_pandas,
297+
to_native=to_polars,
298+
)
299+
300+
301+
_HOTDATA_ENGINE_DISPLAY_NAME = "Hotdata"
302+
_ORIGINAL_ENGINE_TO_CONNECTION = None
303+
304+
305+
def _install_hotdata_engine_display_name() -> None:
306+
"""Show ``Hotdata`` in Marimo's SQL engine / Data Sources UI (not ``sql (client)``)."""
307+
global _ORIGINAL_ENGINE_TO_CONNECTION
308+
if _ORIGINAL_ENGINE_TO_CONNECTION is not None:
309+
return
310+
311+
import marimo._sql.get_engines as ge
312+
313+
_ORIGINAL_ENGINE_TO_CONNECTION = ge.engine_to_data_source_connection
314+
315+
def engine_to_data_source_connection(
316+
variable_name: VariableName, engine: object
317+
) -> DataSourceConnection:
318+
conn = _ORIGINAL_ENGINE_TO_CONNECTION(variable_name, engine) # type: ignore[arg-type]
319+
if not isinstance(engine, HotdataMarimoEngine):
320+
return conn
321+
return DataSourceConnection(
322+
source=conn.source,
323+
dialect=conn.dialect,
324+
name=conn.name,
325+
display_name=_HOTDATA_ENGINE_DISPLAY_NAME,
326+
databases=conn.databases,
327+
default_database=conn.default_database,
328+
default_schema=conn.default_schema,
329+
)
330+
331+
_set_engine_to_data_source_connection(engine_to_data_source_connection)
332+
333+
334+
def _set_engine_to_data_source_connection(fn: object) -> None:
335+
"""Marimo imports this helper in multiple modules; patch all bindings."""
336+
import marimo._runtime.runner.hooks_post_execution as hpe
337+
import marimo._runtime.runtime as rt
338+
import marimo._sql.get_engines as ge
339+
340+
ge.engine_to_data_source_connection = fn # type: ignore[assignment]
341+
hpe.engine_to_data_source_connection = fn # type: ignore[assignment]
342+
rt.engine_to_data_source_connection = fn # type: ignore[assignment]
343+
344+
345+
def register_hotdata_sql_engine() -> None:
346+
"""Register :class:`HotdataMarimoEngine` with Marimo's SQL engine registry (idempotent)."""
347+
_install_hotdata_engine_display_name()
348+
from marimo._sql.get_engines import SUPPORTED_ENGINES
349+
350+
if HotdataMarimoEngine in SUPPORTED_ENGINES:
351+
return
352+
SUPPORTED_ENGINES.insert(0, HotdataMarimoEngine)
353+
354+
355+
def unregister_hotdata_sql_engine() -> None:
356+
"""Remove :class:`HotdataMarimoEngine` from Marimo's registry (mostly for tests)."""
357+
from marimo._sql.get_engines import SUPPORTED_ENGINES
358+
359+
while HotdataMarimoEngine in SUPPORTED_ENGINES:
360+
SUPPORTED_ENGINES.remove(HotdataMarimoEngine)

0 commit comments

Comments
 (0)