Skip to content

Commit 26de166

Browse files
committed
feat: add database= parameter to execute_sql for managed database scoping
Pass database= to client.execute_sql() so queries are scoped to a managed database via the X-Database-Id header (hotdata-runtime>=0.2.1). Also updates ManagedDatabase constructor calls to use description= and default_connection_id= fields introduced in hotdata-runtime v0.2.0.
1 parent 1074c8b commit 26de166

5 files changed

Lines changed: 36 additions & 24 deletions

File tree

hotdata_llamaindex/databases.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
def list_managed_databases_json(client: HotdataClient) -> str:
1717
rows = [
1818
{
19-
"name": db.name,
19+
"description": db.description,
2020
"id": db.id,
21-
"sql_prefix": f"{db.name}.{{schema}}.{{table}}",
21+
"sql_prefix": f"{db.id}.{{schema}}.{{table}}",
2222
}
2323
for db in client.list_managed_databases()
2424
]
@@ -32,7 +32,7 @@ def create_managed_database(
3232
schema: str = DEFAULT_SCHEMA,
3333
tables: list[str] | None = None,
3434
) -> ManagedDatabase:
35-
return client.create_managed_database(name, schema=schema, tables=tables)
35+
return client.create_managed_database(description=name, schema=schema, tables=tables)
3636

3737

3838
def load_managed_table(
@@ -47,7 +47,7 @@ def load_managed_table(
4747

4848

4949
def managed_database_summary(db: ManagedDatabase) -> dict[str, str]:
50-
return {"id": db.id, "name": db.name, "source_type": db.source_type}
50+
return {"id": db.id, "description": db.description or db.id}
5151

5252

5353
def load_result_summary(result: LoadManagedTableResult) -> dict[str, Any]:

hotdata_llamaindex/tools.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,14 @@ def result_rows_for_llm(result: QueryResult, *, max_rows: int = 20) -> list[dict
2222
return result.to_records(max_rows=max_rows)
2323

2424

25-
def execute_sql_json(client: HotdataClient, sql: str, *, max_rows: int = 100) -> str:
26-
result = client.execute_sql(sql)
25+
def execute_sql_json(
26+
client: HotdataClient,
27+
sql: str,
28+
*,
29+
max_rows: int = 100,
30+
database: str | None = None,
31+
) -> str:
32+
result = client.execute_sql(sql, database=database)
2733
payload = {
2834
"metadata": result.metadata_dict(),
2935
"rows": result.to_records(max_rows=max_rows),
@@ -35,12 +41,13 @@ def make_hotdata_tools(
3541
client: HotdataClient,
3642
*,
3743
max_rows: int = 100,
44+
database: str | None = None,
3845
) -> list[FunctionTool]:
3946
"""Return LlamaIndex tools for SQL and managed database workflows."""
4047

4148
def hotdata_execute_sql(sql: str) -> str:
4249
"""Run SQL against the Hotdata workspace and return JSON rows."""
43-
return execute_sql_json(client, sql, max_rows=max_rows)
50+
return execute_sql_json(client, sql, max_rows=max_rows, database=database)
4451

4552
def hotdata_list_managed_databases() -> str:
4653
"""List Hotdata-managed databases in the workspace."""

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ readme = "README.md"
1010
requires-python = ">=3.10"
1111
license = { text = "MIT" }
1212
dependencies = [
13-
"hotdata-runtime>=0.1.1",
13+
"hotdata-runtime>=0.2.0",
1414
"hotdata>=0.2.0",
1515
"llama-index-core>=0.12.0",
1616
]

tests/test_tools.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,30 +21,35 @@ def test_execute_sql_json(mock_client, sample_result):
2121
payload = json.loads(execute_sql_json(mock_client, "select 1"))
2222
assert payload["metadata"]["row_count"] == 2
2323
assert payload["rows"] == [{"n": 1}, {"n": 2}]
24-
mock_client.execute_sql.assert_called_once_with("select 1")
24+
mock_client.execute_sql.assert_called_once_with("select 1", database=None)
25+
26+
27+
def test_execute_sql_json_with_database(mock_client, sample_result):
28+
execute_sql_json(mock_client, "select 1", database="my_db")
29+
mock_client.execute_sql.assert_called_once_with("select 1", database="my_db")
2530

2631

2732
def test_list_managed_databases_json(mock_client):
2833
mock_client.list_managed_databases.return_value = [
29-
ManagedDatabase(id="c1", name="sales", source_type="managed"),
34+
ManagedDatabase(id="c1", description="sales", default_connection_id="conn_c1"),
3035
]
3136
payload = json.loads(list_managed_databases_json(mock_client))
32-
assert payload[0]["name"] == "sales"
37+
assert payload[0]["description"] == "sales"
3338

3439

3540
def test_create_managed_database_delegates(mock_client):
3641
mock_client.create_managed_database.return_value = ManagedDatabase(
3742
id="c1",
38-
name="sales",
39-
source_type="managed",
43+
description="sales",
44+
default_connection_id="conn_c1",
4045
)
4146
db = create_managed_database(mock_client, name="sales", tables=["orders"])
4247
mock_client.create_managed_database.assert_called_once_with(
43-
"sales",
48+
description="sales",
4449
schema="public",
4550
tables=["orders"],
4651
)
47-
assert db.name == "sales"
52+
assert db.description == "sales"
4853

4954

5055
def test_load_managed_table_delegates(mock_client):
@@ -73,8 +78,8 @@ def test_load_managed_table_delegates(mock_client):
7378
def test_make_hotdata_tools(mock_client, sample_result):
7479
mock_client.create_managed_database.return_value = ManagedDatabase(
7580
id="c1",
76-
name="sales",
77-
source_type="managed",
81+
description="sales",
82+
default_connection_id="conn_c1",
7883
)
7984
mock_client.load_managed_table.return_value = LoadManagedTableResult(
8085
connection_id="c1",

uv.lock

Lines changed: 7 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)