Skip to content

Commit 77feaab

Browse files
committed
refactor: tighten Hotdata marimo SQL engine imports
Avoid importing SDK model types through transitive dependencies and cache connection listings during catalog discovery.
1 parent 367a9ff commit 77feaab

1 file changed

Lines changed: 27 additions & 12 deletions

File tree

hotdata_marimo/sql_engine.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from collections import defaultdict
66
from typing import Any, Literal
77

8-
from hotdata.models.table_info import TableInfo
98
from hotdata_runtime import HotdataClient
109

1110
from marimo import _loggers
@@ -14,7 +13,6 @@
1413
DataSourceConnection,
1514
DataTable,
1615
DataTableColumn,
17-
DataType,
1816
Schema,
1917
)
2018
from marimo._sql.engines.types import InferenceConfig, SQLConnection
@@ -24,7 +22,7 @@
2422
LOGGER = _loggers.marimo_logger()
2523

2624

27-
def _table_schema_name(t: TableInfo) -> str:
25+
def _table_schema_name(t: Any) -> str:
2826
return str(t.var_schema)
2927

3028

@@ -35,6 +33,15 @@ class HotdataMarimoEngine(SQLConnection[HotdataClient]):
3533
via :meth:`~hotdata_runtime.HotdataClient.execute_sql` (no catalog calls in that path).
3634
"""
3735

36+
def __init__(
37+
self,
38+
connection: HotdataClient,
39+
engine_name: VariableName | None = None,
40+
) -> None:
41+
super().__init__(connection, engine_name)
42+
self._connections_cache: list[Any] | None = None
43+
self._connection_id_cache: dict[str, str] | None = None
44+
3845
@property
3946
def source(self) -> str:
4047
return "hotdata"
@@ -65,21 +72,29 @@ def _resolve_should_auto_discover(
6572
return value
6673

6774
def _connection_ids(self) -> dict[str, str]:
68-
out: dict[str, str] = {}
69-
for c in self._connection.connections().list_connections().connections:
70-
out[str(c.name)] = str(c.id)
71-
return out
75+
if self._connection_id_cache is None:
76+
self._connection_id_cache = {
77+
str(c.name): str(c.id) for c in self._connections()
78+
}
79+
return self._connection_id_cache
7280

7381
def _connection_id(self, connection_name: str) -> str | None:
7482
return self._connection_ids().get(connection_name)
7583

84+
def _connections(self) -> list[Any]:
85+
if self._connections_cache is None:
86+
self._connections_cache = list(
87+
self._connection.connections().list_connections().connections
88+
)
89+
return self._connections_cache
90+
7691
def _iter_grouped(
7792
self,
7893
*,
7994
connection_id: str | None,
8095
include_columns: bool,
81-
) -> dict[str, dict[str, list[TableInfo]]]:
82-
grouped: dict[str, dict[str, list[TableInfo]]] = defaultdict(
96+
) -> dict[str, dict[str, list[Any]]]:
97+
grouped: dict[str, dict[str, list[Any]]] = defaultdict(
8398
lambda: defaultdict(list)
8499
)
85100
for t in self._connection.iter_tables(
@@ -90,7 +105,7 @@ def _iter_grouped(
90105
return grouped
91106

92107
def get_default_database(self) -> str | None:
93-
listing = self._connection.connections().list_connections().connections
108+
listing = self._connections()
94109
if not listing:
95110
return None
96111
return str(listing[0].name)
@@ -106,7 +121,7 @@ def get_databases(
106121
include_table_details: bool | Literal["auto"],
107122
) -> list[Database]:
108123
databases: list[Database] = []
109-
for c in self._connection.connections().list_connections().connections:
124+
for c in self._connections():
110125
name = str(c.name)
111126
if self._resolve_should_auto_discover(include_schemas):
112127
schemas = self.get_schemas(
@@ -162,7 +177,7 @@ def get_schemas(
162177
schemas.append(Schema(name=schema_name, tables=tables))
163178
return schemas
164179

165-
def _data_table_from_table_info(self, t: TableInfo) -> DataTable:
180+
def _data_table_from_table_info(self, t: Any) -> DataTable:
166181
cols: list[DataTableColumn] = []
167182
for col in t.columns or []:
168183
cols.append(

0 commit comments

Comments
 (0)