55from collections import defaultdict
66from typing import Any , Literal
77
8- from hotdata .models .table_info import TableInfo
98from hotdata_runtime import HotdataClient
109
1110from marimo import _loggers
1413 DataSourceConnection ,
1514 DataTable ,
1615 DataTableColumn ,
17- DataType ,
1816 Schema ,
1917)
2018from marimo ._sql .engines .types import InferenceConfig , SQLConnection
2422LOGGER = _loggers .marimo_logger ()
2523
2624
27- def _table_schema_name (t : TableInfo ) -> str :
25+ def _table_schema_name (t : Any ) -> str :
2826 return str (t .var_schema )
2927
3028
@@ -35,6 +33,15 @@ class HotdataMarimoEngine(SQLConnection[HotdataClient]):
3533 via :meth:`~hotdata_runtime.HotdataClient.execute_sql` (no catalog calls in that path).
3634 """
3735
36+ def __init__ (
37+ self ,
38+ connection : HotdataClient ,
39+ engine_name : VariableName | None = None ,
40+ ) -> None :
41+ super ().__init__ (connection , engine_name )
42+ self ._connections_cache : list [Any ] | None = None
43+ self ._connection_id_cache : dict [str , str ] | None = None
44+
3845 @property
3946 def source (self ) -> str :
4047 return "hotdata"
@@ -65,21 +72,29 @@ def _resolve_should_auto_discover(
6572 return value
6673
6774 def _connection_ids (self ) -> dict [str , str ]:
68- out : dict [str , str ] = {}
69- for c in self ._connection .connections ().list_connections ().connections :
70- out [str (c .name )] = str (c .id )
71- return out
75+ if self ._connection_id_cache is None :
76+ self ._connection_id_cache = {
77+ str (c .name ): str (c .id ) for c in self ._connections ()
78+ }
79+ return self ._connection_id_cache
7280
7381 def _connection_id (self , connection_name : str ) -> str | None :
7482 return self ._connection_ids ().get (connection_name )
7583
84+ def _connections (self ) -> list [Any ]:
85+ if self ._connections_cache is None :
86+ self ._connections_cache = list (
87+ self ._connection .connections ().list_connections ().connections
88+ )
89+ return self ._connections_cache
90+
7691 def _iter_grouped (
7792 self ,
7893 * ,
7994 connection_id : str | None ,
8095 include_columns : bool ,
81- ) -> dict [str , dict [str , list [TableInfo ]]]:
82- grouped : dict [str , dict [str , list [TableInfo ]]] = defaultdict (
96+ ) -> dict [str , dict [str , list [Any ]]]:
97+ grouped : dict [str , dict [str , list [Any ]]] = defaultdict (
8398 lambda : defaultdict (list )
8499 )
85100 for t in self ._connection .iter_tables (
@@ -90,7 +105,7 @@ def _iter_grouped(
90105 return grouped
91106
92107 def get_default_database (self ) -> str | None :
93- listing = self ._connection . connections (). list_connections (). connections
108+ listing = self ._connections ()
94109 if not listing :
95110 return None
96111 return str (listing [0 ].name )
@@ -106,7 +121,7 @@ def get_databases(
106121 include_table_details : bool | Literal ["auto" ],
107122 ) -> list [Database ]:
108123 databases : list [Database ] = []
109- for c in self ._connection . connections (). list_connections (). connections :
124+ for c in self ._connections () :
110125 name = str (c .name )
111126 if self ._resolve_should_auto_discover (include_schemas ):
112127 schemas = self .get_schemas (
@@ -162,7 +177,7 @@ def get_schemas(
162177 schemas .append (Schema (name = schema_name , tables = tables ))
163178 return schemas
164179
165- def _data_table_from_table_info (self , t : TableInfo ) -> DataTable :
180+ def _data_table_from_table_info (self , t : Any ) -> DataTable :
166181 cols : list [DataTableColumn ] = []
167182 for col in t .columns or []:
168183 cols .append (
0 commit comments