Skip to content

Commit f15c942

Browse files
authored
feat: add database= parameter to sql engine and editor for managed database scoping (#8)
* feat: add database= parameter to sql engine and editor for managed database scoping Pass database= to client.execute_sql() so queries are scoped to a managed database via the X-Database-Id header (hotdata-runtime>=0.2.1). - HotdataMarimoEngine: add default_database= constructor param, pass to execute() - SqlEditor: add database= constructor param, pass to both execute_sql calls - ManagedDatabaseWriter: use description= kwarg matching ManagedDatabase v0.2.0 API - Fix test_databases_marimo.py syntax error and update assertions * refactor: eliminate flag-based side-effect tracking, fix unregister, remove dead code - table_browser: extract _set_table_pick() replacing duplicate _init/_rebuild methods; _sync_table_catalog returns bool so ui drops _rebuilt_table_pick_this_run flag; standardize _active_connection_id to use 'or None' consistently - sql_engine: unregister now restores original engine_to_data_source_connection and resets sentinel so register/unregister/register round-trip works correctly - sql_editor: remove dead 'or ""' on _cached_sql (already guarded by None check above) - workspace_selector: cache HotdataClient, only reconstruct when workspace_id changes * fix: pass dropdown label key (not value) to mo.ui.dropdown value= init param When options is a dict {label: value}, Marimo validates value= against the dict keys (labels), not the values. _rebuild_database_pick was passing a database ID (dict value) which raised ValueError on startup. Now resolves the label key corresponding to the previously-selected ID instead. --------- Co-authored-by: Eddie A Tejeda <669988+eddietejeda@users.noreply.github.com>
1 parent 52c525c commit f15c942

8 files changed

Lines changed: 85 additions & 84 deletions

File tree

hotdata_marimo/databases.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def databases_panel(client: HotdataClient):
4747
gap=1,
4848
)
4949
rows: list[dict[str, object]] = [
50-
{"name": db.name, "id": db.id, "sql_prefix": f"{db.name}.{{schema}}.{{table}}"}
50+
{"description": db.description or db.id, "id": db.id, "sql_prefix": f"{db.id}.{{schema}}.{{table}}"}
5151
for db in dbs
5252
]
5353
return mo.vstack(
@@ -127,13 +127,16 @@ def _rebuild_database_pick(self) -> None:
127127
message="(create one first)",
128128
)
129129
return
130-
options = {db.name: db.name for db in dbs}
131-
value = current if current in options else next(iter(options))
130+
options = {db.description or db.id: db.id for db in dbs}
131+
# current holds the previously selected database ID (.value returns the dict value).
132+
# mo.ui.dropdown validates value= against option keys (labels), not values.
133+
default_key = next(iter(options))
134+
selected_key = next((k for k, v in options.items() if v == current), default_key)
132135
self.database = mo.ui.dropdown(
133136
options=options,
134137
label="Database",
135138
full_width=True,
136-
value=value,
139+
value=selected_key,
137140
)
138141

139142
def _maybe_create(self) -> None:
@@ -153,7 +156,7 @@ def _maybe_create(self) -> None:
153156
tables = _parse_table_names(self.tables.value)
154157
try:
155158
self._create_result = self._client.create_managed_database(
156-
db_name,
159+
description=db_name,
157160
schema=schema,
158161
tables=tables or None,
159162
)
@@ -209,7 +212,7 @@ def result_panel(self):
209212
db = self._create_result
210213
return mo.callout(
211214
mo.md(
212-
f"Created **{db.name}** (`{db.id}`). "
215+
f"Created **{db.description or db.id}** (`{db.id}`). "
213216
"Load parquet into a declared table below."
214217
),
215218
kind="success",

hotdata_marimo/sql_editor.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ def __init__(
2020
default_sql: str = "",
2121
label: str = "SQL",
2222
run_label: str = "Run on Hotdata",
23+
database: str | None = None,
2324
) -> None:
2425
self._client = client
26+
self._database = database
2527
self.sql = mo.ui.text_area(default_sql, label=label)
2628
self.run = mo.ui.button(
2729
value=0,
@@ -103,7 +105,7 @@ def _execute_or_cached(self) -> QueryResult | None:
103105
title="Running on Hotdata",
104106
subtitle="Re-running last query and waiting for results…",
105107
):
106-
result = self._client.execute_sql(self._cached_sql or "")
108+
result = self._client.execute_sql(self._cached_sql, database=self._database)
107109
self._result_cache = result
108110
self._last_rerun_n = rerun_n
109111
return result
@@ -113,7 +115,7 @@ def _execute_or_cached(self) -> QueryResult | None:
113115
title="Running on Hotdata",
114116
subtitle="Executing query and waiting for results…",
115117
):
116-
result = self._client.execute_sql(sql_text)
118+
result = self._client.execute_sql(sql_text, database=self._database)
117119
self._result_cache = result
118120
self._cached_sql = sql_text
119121
self._last_run_n = run_n
@@ -195,7 +197,8 @@ def sql_editor(
195197
default_sql: str = "",
196198
label: str = "SQL",
197199
run_label: str = "Run on Hotdata",
200+
database: str | None = None,
198201
) -> SqlEditor:
199202
return SqlEditor(
200-
client, default_sql=default_sql, label=label, run_label=run_label
203+
client, default_sql=default_sql, label=label, run_label=run_label, database=database
201204
)

hotdata_marimo/sql_engine.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,12 @@ def __init__(
3737
self,
3838
connection: HotdataClient,
3939
engine_name: VariableName | None = None,
40+
*,
41+
default_database: str | None = None,
4042
) -> None:
4143
super().__init__(connection, engine_name)
4244
self._connections_cache: list[Any] | None = None
45+
self._default_database = default_database
4346

4447
@property
4548
def source(self) -> str:
@@ -291,7 +294,7 @@ def get_table_details(
291294
)
292295

293296
def execute(self, query: str) -> Any:
294-
qr = self._connection.execute_sql(query)
297+
qr = self._connection.execute_sql(query, database=self._default_database)
295298
fmt = self.sql_output_format()
296299

297300
def to_polars() -> Any:
@@ -365,7 +368,11 @@ def register_hotdata_sql_engine() -> None:
365368

366369
def unregister_hotdata_sql_engine() -> None:
367370
"""Remove :class:`HotdataMarimoEngine` from Marimo's registry (mostly for tests)."""
371+
global _ORIGINAL_ENGINE_TO_CONNECTION
368372
from marimo._sql.get_engines import SUPPORTED_ENGINES
369373

370374
while HotdataMarimoEngine in SUPPORTED_ENGINES:
371375
SUPPORTED_ENGINES.remove(HotdataMarimoEngine)
376+
if _ORIGINAL_ENGINE_TO_CONNECTION is not None:
377+
_set_engine_to_data_source_connection(_ORIGINAL_ENGINE_TO_CONNECTION)
378+
_ORIGINAL_ENGINE_TO_CONNECTION = None

hotdata_marimo/table_browser.py

Lines changed: 37 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
resolve_connection_picker,
1313
)
1414

15+
__all__ = ["TableBrowser", "connection_picker", "table_browser"]
16+
1517

1618
class TableBrowser:
1719
"""Pick a fully qualified `connection.schema.table` and inspect columns.
@@ -43,58 +45,26 @@ def __init__(
4345
)
4446

4547
self._table_pick_ctx: str | None = None
46-
self._rebuilt_table_pick_this_run = False
4748
self._init_table_pick()
4849

49-
def _init_table_pick(self) -> None:
50-
if self._conn_pick is not None:
51-
self.table_pick = empty_dropdown(
52-
label="Table",
53-
message="(select connection above)",
54-
)
55-
self._empty_catalog = True
56-
self._all_names = []
57-
self._table_pick_ctx = ""
58-
return
59-
60-
names = self._names_for_active_connection()
61-
self._all_names = names
62-
if not names:
63-
self.table_pick = empty_dropdown(
64-
label="Table",
65-
message="(no tables in catalog)",
66-
)
67-
self._empty_catalog = True
68-
else:
69-
self._empty_catalog = False
70-
self.table_pick = mo.ui.dropdown(
71-
options={n: n for n in names},
72-
label="Table",
73-
full_width=True,
74-
searchable=True,
75-
)
76-
self._table_pick_ctx = self._active_connection_id()
77-
7850
def _active_connection_id(self) -> str | None:
7951
if self._override_connection_id is not None:
8052
return self._override_connection_id or None
8153
if self._conn_pick is not None:
82-
v = self._conn_pick.value # type: ignore[attr-defined]
83-
return v if v else None
84-
if self._implicit_connection_id is None:
85-
return None
54+
return self._conn_pick.value or None # type: ignore[attr-defined]
8655
return self._implicit_connection_id or None
8756

8857
def _names_for_active_connection(self) -> list[str]:
8958
cid = self._active_connection_id()
90-
if cid is None or cid == "":
59+
if not cid:
9160
return []
9261
return self._client.list_qualified_table_names(
9362
limit=self._table_limit,
9463
connection_id=cid,
9564
)
9665

97-
def _rebuild_table_pick(self, names: list[str]) -> None:
66+
def _set_table_pick(self, names: list[str]) -> None:
67+
"""Create or replace the table dropdown for the given names list."""
9868
self._all_names = names
9969
if not names:
10070
self._empty_catalog = True
@@ -111,7 +81,32 @@ def _rebuild_table_pick(self, names: list[str]) -> None:
11181
searchable=True,
11282
)
11383
self._table_pick_ctx = self._active_connection_id()
114-
self._rebuilt_table_pick_this_run = True
84+
85+
def _init_table_pick(self) -> None:
86+
if self._conn_pick is not None:
87+
self._all_names = []
88+
self._empty_catalog = True
89+
self.table_pick = empty_dropdown(
90+
label="Table",
91+
message="(select connection above)",
92+
)
93+
self._table_pick_ctx = ""
94+
return
95+
self._set_table_pick(self._names_for_active_connection())
96+
97+
def _sync_table_catalog(self) -> bool:
98+
"""Refresh the table dropdown when the active connection changes.
99+
100+
Returns True if the dropdown was rebuilt (so the caller knows not to
101+
read ``.value`` on the new widget in the same Marimo run).
102+
"""
103+
if self._conn_pick is not None:
104+
_ = self._conn_pick.value # type: ignore[attr-defined]
105+
cid = self._active_connection_id()
106+
if not cid or cid == self._table_pick_ctx:
107+
return False
108+
self._set_table_pick(self._names_for_active_connection())
109+
return True
115110

116111
@property
117112
def selected_connection_id(self) -> str | None:
@@ -122,30 +117,17 @@ def selected_table(self) -> str | None:
122117
v = self.table_pick.value
123118
return v if v else None
124119

125-
def _sync_table_catalog(self) -> None:
126-
"""Refresh the table dropdown when the active connection changes."""
127-
if self._conn_pick is not None:
128-
_ = self._conn_pick.value # type: ignore[attr-defined]
129-
cid = self._active_connection_id()
130-
if not cid:
131-
return
132-
if cid == self._table_pick_ctx:
133-
return
134-
self._rebuild_table_pick(self._names_for_active_connection())
135-
136120
@property
137121
def ui(self):
138-
self._rebuilt_table_pick_this_run = False
139-
self._sync_table_catalog()
140-
141-
if not self._rebuilt_table_pick_this_run:
122+
rebuilt = self._sync_table_catalog()
123+
if not rebuilt:
142124
_ = self.table_pick.value
143125

144-
sel = None if self._rebuilt_table_pick_this_run else self.selected_table
126+
sel = None if rebuilt else self.selected_table
145127
cid = self._active_connection_id()
146128
conn_header = (
147-
mo.md(f"**Connection** `{self._active_connection_id()}`")
148-
if self._active_connection_id()
129+
mo.md(f"**Connection** `{cid}`")
130+
if cid
149131
else None
150132
)
151133
if not sel:

hotdata_marimo/workspace_selector.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def __init__(
2626
self._api_key = api_key
2727
self._host = host or default_host()
2828
self._session_id = session_id
29+
self._client_cache: HotdataClient | None = None
30+
self._client_cache_wid: str | None = None
2931
selection = resolve_workspace_selection(api_key, self._host, session_id)
3032
self._explicit = selection.source == "explicit_env"
3133
if self._explicit:
@@ -64,12 +66,16 @@ def workspace_id(self) -> str:
6466

6567
@property
6668
def client(self) -> HotdataClient:
67-
return HotdataClient(
68-
self._api_key,
69-
self.workspace_id,
70-
host=self._host,
71-
session_id=self._session_id,
72-
)
69+
wid = self.workspace_id
70+
if self._client_cache is None or self._client_cache_wid != wid:
71+
self._client_cache = HotdataClient(
72+
self._api_key,
73+
wid,
74+
host=self._host,
75+
session_id=self._session_id,
76+
)
77+
self._client_cache_wid = wid
78+
return self._client_cache
7379

7480
@property
7581
def ui(self):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ readme = "README.md"
1010
requires-python = ">=3.10"
1111
license = { text = "MIT" }
1212
dependencies = [
13-
"hotdata-runtime>=0.1.1",
13+
"hotdata-runtime>=0.2.1",
1414
"hotdata>=0.2.0",
1515
"marimo>=0.10.0",
1616
]

tests/test_databases_marimo.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_databases_panel_empty_state(mock_client):
1717

1818
def test_databases_panel_lists_managed_databases(mock_client):
1919
mock_client.list_managed_databases.return_value = [
20-
ManagedDatabase(id="c1", name="sales", source_type="managed"),
20+
ManagedDatabase(id="c1", description="sales", default_connection_id="conn_c1"),
2121
]
2222
with patch("hotdata_marimo.databases.mo.vstack", return_value="panel"), patch(
2323
"hotdata_marimo.databases.mo.md", side_effect=lambda x: x
@@ -30,8 +30,8 @@ def test_managed_database_writer_creates_database(mock_client):
3030
mock_client.list_managed_databases.return_value = []
3131
mock_client.create_managed_database.return_value = ManagedDatabase(
3232
id="conn_new",
33-
name="sales",
34-
source_type="managed",
33+
description="sales",
34+
default_connection_id="conn_c1",
3535
)
3636
create = MagicMock()
3737
create.value = 1
@@ -71,7 +71,7 @@ def test_managed_database_writer_creates_database(mock_client):
7171
panel = writer.result_panel
7272

7373
mock_client.create_managed_database.assert_called_once_with(
74-
"sales",
74+
description="sales",
7575
schema="public",
7676
tables=["orders", "customers"],
7777
)
@@ -80,7 +80,7 @@ def test_managed_database_writer_creates_database(mock_client):
8080

8181
def test_managed_database_writer_loads_parquet(mock_client):
8282
mock_client.list_managed_databases.return_value = [
83-
ManagedDatabase(id="c1", name="sales", source_type="managed"),
83+
ManagedDatabase(id="c1", description="sales", default_connection_id="conn_c1"),
8484
]
8585
mock_client.upload_parquet.return_value = "upl_1"
8686
mock_client.load_managed_table.return_value = LoadManagedTableResult(

uv.lock

Lines changed: 7 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)