Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions hotdata_marimo/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def databases_panel(client: HotdataClient):
gap=1,
)
rows: list[dict[str, object]] = [
{"name": db.name, "id": db.id, "sql_prefix": f"{db.name}.{{schema}}.{{table}}"}
{"description": db.description or db.id, "id": db.id, "sql_prefix": f"{db.id}.{{schema}}.{{table}}"}
for db in dbs
]
return mo.vstack(
Expand Down Expand Up @@ -127,13 +127,16 @@ def _rebuild_database_pick(self) -> None:
message="(create one first)",
)
return
options = {db.name: db.name for db in dbs}
value = current if current in options else next(iter(options))
options = {db.description or db.id: db.id for db in dbs}
# current holds the previously selected database ID (.value returns the dict value).
# mo.ui.dropdown validates value= against option keys (labels), not values.
default_key = next(iter(options))
selected_key = next((k for k, v in options.items() if v == current), default_key)
self.database = mo.ui.dropdown(
options=options,
label="Database",
full_width=True,
value=value,
value=selected_key,
)

def _maybe_create(self) -> None:
Expand All @@ -153,7 +156,7 @@ def _maybe_create(self) -> None:
tables = _parse_table_names(self.tables.value)
try:
self._create_result = self._client.create_managed_database(
db_name,
description=db_name,
schema=schema,
tables=tables or None,
)
Expand Down Expand Up @@ -209,7 +212,7 @@ def result_panel(self):
db = self._create_result
return mo.callout(
mo.md(
f"Created **{db.name}** (`{db.id}`). "
f"Created **{db.description or db.id}** (`{db.id}`). "
"Load parquet into a declared table below."
),
kind="success",
Expand Down
9 changes: 6 additions & 3 deletions hotdata_marimo/sql_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ def __init__(
default_sql: str = "",
label: str = "SQL",
run_label: str = "Run on Hotdata",
database: str | None = None,
) -> None:
self._client = client
self._database = database
self.sql = mo.ui.text_area(default_sql, label=label)
self.run = mo.ui.button(
value=0,
Expand Down Expand Up @@ -103,7 +105,7 @@ def _execute_or_cached(self) -> QueryResult | None:
title="Running on Hotdata",
subtitle="Re-running last query and waiting for results…",
):
result = self._client.execute_sql(self._cached_sql or "")
result = self._client.execute_sql(self._cached_sql, database=self._database)
self._result_cache = result
self._last_rerun_n = rerun_n
return result
Expand All @@ -113,7 +115,7 @@ def _execute_or_cached(self) -> QueryResult | None:
title="Running on Hotdata",
subtitle="Executing query and waiting for results…",
):
result = self._client.execute_sql(sql_text)
result = self._client.execute_sql(sql_text, database=self._database)
self._result_cache = result
self._cached_sql = sql_text
self._last_run_n = run_n
Expand Down Expand Up @@ -195,7 +197,8 @@ def sql_editor(
default_sql: str = "",
label: str = "SQL",
run_label: str = "Run on Hotdata",
database: str | None = None,
) -> SqlEditor:
return SqlEditor(
client, default_sql=default_sql, label=label, run_label=run_label
client, default_sql=default_sql, label=label, run_label=run_label, database=database
)
9 changes: 8 additions & 1 deletion hotdata_marimo/sql_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@ def __init__(
self,
connection: HotdataClient,
engine_name: VariableName | None = None,
*,
default_database: str | None = None,
) -> None:
super().__init__(connection, engine_name)
self._connections_cache: list[Any] | None = None
self._default_database = default_database

@property
def source(self) -> str:
Expand Down Expand Up @@ -291,7 +294,7 @@ def get_table_details(
)

def execute(self, query: str) -> Any:
qr = self._connection.execute_sql(query)
qr = self._connection.execute_sql(query, database=self._default_database)
fmt = self.sql_output_format()

def to_polars() -> Any:
Expand Down Expand Up @@ -365,7 +368,11 @@ def register_hotdata_sql_engine() -> None:

def unregister_hotdata_sql_engine() -> None:
"""Remove :class:`HotdataMarimoEngine` from Marimo's registry (mostly for tests)."""
global _ORIGINAL_ENGINE_TO_CONNECTION
from marimo._sql.get_engines import SUPPORTED_ENGINES

while HotdataMarimoEngine in SUPPORTED_ENGINES:
SUPPORTED_ENGINES.remove(HotdataMarimoEngine)
if _ORIGINAL_ENGINE_TO_CONNECTION is not None:
_set_engine_to_data_source_connection(_ORIGINAL_ENGINE_TO_CONNECTION)
_ORIGINAL_ENGINE_TO_CONNECTION = None
92 changes: 37 additions & 55 deletions hotdata_marimo/table_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
resolve_connection_picker,
)

__all__ = ["TableBrowser", "connection_picker", "table_browser"]


class TableBrowser:
"""Pick a fully qualified `connection.schema.table` and inspect columns.
Expand Down Expand Up @@ -43,58 +45,26 @@ def __init__(
)

self._table_pick_ctx: str | None = None
self._rebuilt_table_pick_this_run = False
self._init_table_pick()

def _init_table_pick(self) -> None:
if self._conn_pick is not None:
self.table_pick = empty_dropdown(
label="Table",
message="(select connection above)",
)
self._empty_catalog = True
self._all_names = []
self._table_pick_ctx = ""
return

names = self._names_for_active_connection()
self._all_names = names
if not names:
self.table_pick = empty_dropdown(
label="Table",
message="(no tables in catalog)",
)
self._empty_catalog = True
else:
self._empty_catalog = False
self.table_pick = mo.ui.dropdown(
options={n: n for n in names},
label="Table",
full_width=True,
searchable=True,
)
self._table_pick_ctx = self._active_connection_id()

def _active_connection_id(self) -> str | None:
if self._override_connection_id is not None:
return self._override_connection_id or None
if self._conn_pick is not None:
v = self._conn_pick.value # type: ignore[attr-defined]
return v if v else None
if self._implicit_connection_id is None:
return None
return self._conn_pick.value or None # type: ignore[attr-defined]
return self._implicit_connection_id or None

def _names_for_active_connection(self) -> list[str]:
cid = self._active_connection_id()
if cid is None or cid == "":
if not cid:
return []
return self._client.list_qualified_table_names(
limit=self._table_limit,
connection_id=cid,
)

def _rebuild_table_pick(self, names: list[str]) -> None:
def _set_table_pick(self, names: list[str]) -> None:
"""Create or replace the table dropdown for the given names list."""
self._all_names = names
if not names:
self._empty_catalog = True
Expand All @@ -111,7 +81,32 @@ def _rebuild_table_pick(self, names: list[str]) -> None:
searchable=True,
)
self._table_pick_ctx = self._active_connection_id()
self._rebuilt_table_pick_this_run = True

def _init_table_pick(self) -> None:
if self._conn_pick is not None:
self._all_names = []
self._empty_catalog = True
self.table_pick = empty_dropdown(
label="Table",
message="(select connection above)",
)
self._table_pick_ctx = ""
return
self._set_table_pick(self._names_for_active_connection())

def _sync_table_catalog(self) -> bool:
"""Refresh the table dropdown when the active connection changes.

Returns True if the dropdown was rebuilt (so the caller knows not to
read ``.value`` on the new widget in the same Marimo run).
"""
if self._conn_pick is not None:
_ = self._conn_pick.value # type: ignore[attr-defined]
cid = self._active_connection_id()
if not cid or cid == self._table_pick_ctx:
return False
self._set_table_pick(self._names_for_active_connection())
return True

@property
def selected_connection_id(self) -> str | None:
Expand All @@ -122,30 +117,17 @@ def selected_table(self) -> str | None:
v = self.table_pick.value
return v if v else None

def _sync_table_catalog(self) -> None:
"""Refresh the table dropdown when the active connection changes."""
if self._conn_pick is not None:
_ = self._conn_pick.value # type: ignore[attr-defined]
cid = self._active_connection_id()
if not cid:
return
if cid == self._table_pick_ctx:
return
self._rebuild_table_pick(self._names_for_active_connection())

@property
def ui(self):
self._rebuilt_table_pick_this_run = False
self._sync_table_catalog()

if not self._rebuilt_table_pick_this_run:
rebuilt = self._sync_table_catalog()
if not rebuilt:
_ = self.table_pick.value

sel = None if self._rebuilt_table_pick_this_run else self.selected_table
sel = None if rebuilt else self.selected_table
cid = self._active_connection_id()
conn_header = (
mo.md(f"**Connection** `{self._active_connection_id()}`")
if self._active_connection_id()
mo.md(f"**Connection** `{cid}`")
if cid
else None
)
if not sel:
Expand Down
18 changes: 12 additions & 6 deletions hotdata_marimo/workspace_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(
self._api_key = api_key
self._host = host or default_host()
self._session_id = session_id
self._client_cache: HotdataClient | None = None
self._client_cache_wid: str | None = None
selection = resolve_workspace_selection(api_key, self._host, session_id)
self._explicit = selection.source == "explicit_env"
if self._explicit:
Expand Down Expand Up @@ -64,12 +66,16 @@ def workspace_id(self) -> str:

@property
def client(self) -> HotdataClient:
return HotdataClient(
self._api_key,
self.workspace_id,
host=self._host,
session_id=self._session_id,
)
wid = self.workspace_id
if self._client_cache is None or self._client_cache_wid != wid:
self._client_cache = HotdataClient(
self._api_key,
wid,
host=self._host,
session_id=self._session_id,
)
self._client_cache_wid = wid
return self._client_cache

@property
def ui(self):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ readme = "README.md"
requires-python = ">=3.10"
license = { text = "MIT" }
dependencies = [
"hotdata-runtime>=0.1.1",
"hotdata-runtime>=0.2.1",
"hotdata>=0.2.0",
"marimo>=0.10.0",
]
Expand Down
10 changes: 5 additions & 5 deletions tests/test_databases_marimo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_databases_panel_empty_state(mock_client):

def test_databases_panel_lists_managed_databases(mock_client):
mock_client.list_managed_databases.return_value = [
ManagedDatabase(id="c1", name="sales", source_type="managed"),
ManagedDatabase(id="c1", description="sales", default_connection_id="conn_c1"),
]
with patch("hotdata_marimo.databases.mo.vstack", return_value="panel"), patch(
"hotdata_marimo.databases.mo.md", side_effect=lambda x: x
Expand All @@ -30,8 +30,8 @@ def test_managed_database_writer_creates_database(mock_client):
mock_client.list_managed_databases.return_value = []
mock_client.create_managed_database.return_value = ManagedDatabase(
id="conn_new",
name="sales",
source_type="managed",
description="sales",
default_connection_id="conn_c1",
)
create = MagicMock()
create.value = 1
Expand Down Expand Up @@ -71,7 +71,7 @@ def test_managed_database_writer_creates_database(mock_client):
panel = writer.result_panel

mock_client.create_managed_database.assert_called_once_with(
"sales",
description="sales",
schema="public",
tables=["orders", "customers"],
)
Expand All @@ -80,7 +80,7 @@ def test_managed_database_writer_creates_database(mock_client):

def test_managed_database_writer_loads_parquet(mock_client):
mock_client.list_managed_databases.return_value = [
ManagedDatabase(id="c1", name="sales", source_type="managed"),
ManagedDatabase(id="c1", description="sales", default_connection_id="conn_c1"),
]
mock_client.upload_parquet.return_value = "upl_1"
mock_client.load_managed_table.return_value = LoadManagedTableResult(
Expand Down
14 changes: 7 additions & 7 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading