Skip to content

Commit 629b018

Browse files
authored
feat: add database parameter to execute_sql for managed database scoping (#11)
Adds an optional keyword-only `database` parameter to `execute_sql`. When provided, the database name is resolved to an ID once before the retry loop and passed as the `X-Database-Id` header on every query attempt. Inside a managed database the built-in catalog is always "default", so callers should reference tables as "default"."<schema>"."<table>". Without the parameter, behaviour is unchanged (no header sent). All downstream integration libraries (hotdata-langchain, hotdata-llamaindex, hotdata-langgraph, hotdata-jupyter, hotdata-marimo, hotdata-streamlit) were calling execute_sql without database scoping, causing 400 errors when querying managed database tables. They can now pass database="<name>" to fix the issue without any further changes. Co-authored-by: Eddie A Tejeda <669988+eddietejeda@users.noreply.github.com>
1 parent e868b3c commit 629b018

2 files changed

Lines changed: 77 additions & 4 deletions

File tree

hotdata_runtime/client.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -475,22 +475,34 @@ def _wait_result_ready(
475475
f"(last status: {getattr(last, 'status', None)})"
476476
)
477477

478-
def execute_sql(self, sql: str) -> QueryResult:
478+
def execute_sql(self, sql: str, *, database: str | None = None) -> QueryResult:
479+
"""Execute SQL and return a :class:`QueryResult`.
480+
481+
Pass ``database`` to scope the query to a managed database. The name
482+
is resolved to a database ID once before the retry loop, and the
483+
``X-Database-Id`` header is sent with every attempt. Inside a managed
484+
database the built-in catalog is always ``"default"``, so table
485+
references should use ``"default"."<schema>"."<table>"``.
486+
"""
487+
database_id = self.resolve_managed_database(database).id if database else None
479488
last_err: BaseException | None = None
480489
for attempt in range(3):
481490
try:
482-
return self._execute_sql_once(sql)
491+
return self._execute_sql_once(sql, database_id=database_id)
483492
except (ProtocolError, ConnectionResetError, Urllib3HTTPError) as e:
484493
last_err = e
485494
if attempt == 2:
486495
raise
487496
time.sleep(0.2 * (2**attempt))
488497
raise last_err # pragma: no cover
489498

490-
def _execute_sql_once(self, sql: str) -> QueryResult:
499+
def _execute_sql_once(self, sql: str, *, database_id: str | None = None) -> QueryResult:
491500
q = self._query_api()
492501
try:
493-
raw = q.query(QueryRequest(sql=sql))
502+
if database_id:
503+
raw = q.query(QueryRequest(sql=sql), x_database_id=database_id)
504+
else:
505+
raw = q.query(QueryRequest(sql=sql))
494506
except ApiException as e:
495507
raise RuntimeError(e.reason or str(e)) from e
496508

tests/test_client.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,67 @@ def list_results(self, *, limit: int, offset: int):
200200
assert out[0].to_dict()["created_at"] == "2026-01-01T00:00:00Z"
201201

202202

203+
def test_execute_sql_sends_no_database_id_by_default():
204+
from hotdata.models.query_response import QueryResponse as _QR
205+
206+
client = HotdataClient("k", "ws", host="https://api.hotdata.dev")
207+
208+
class FakeQueryApi:
209+
def __init__(self):
210+
self.calls: list[dict] = []
211+
212+
def query(self, request, **kwargs):
213+
self.calls.append(kwargs)
214+
return _QR(
215+
columns=["n"],
216+
rows=[[1]],
217+
row_count=1,
218+
nullable=[False],
219+
result_id="res_1",
220+
query_run_id="qrun_1",
221+
execution_time_ms=1,
222+
)
223+
224+
fake_q = FakeQueryApi()
225+
with patch.object(client, "_query_api", return_value=fake_q):
226+
client.execute_sql("SELECT 1")
227+
228+
assert fake_q.calls == [{}]
229+
230+
231+
def test_execute_sql_resolves_database_and_sends_x_database_id():
232+
from hotdata.models.query_response import QueryResponse as _QR
233+
from types import SimpleNamespace
234+
235+
client = HotdataClient("k", "ws", host="https://api.hotdata.dev")
236+
237+
class FakeQueryApi:
238+
def __init__(self):
239+
self.calls: list[dict] = []
240+
241+
def query(self, request, **kwargs):
242+
self.calls.append(kwargs)
243+
return _QR(
244+
columns=["n"],
245+
rows=[[1]],
246+
row_count=1,
247+
nullable=[False],
248+
result_id="res_1",
249+
query_run_id="qrun_1",
250+
execution_time_ms=1,
251+
)
252+
253+
fake_q = FakeQueryApi()
254+
fake_db = SimpleNamespace(id="db_abc")
255+
256+
with patch.object(client, "_query_api", return_value=fake_q), \
257+
patch.object(client, "resolve_managed_database", return_value=fake_db) as resolve:
258+
client.execute_sql('SELECT * FROM "default"."public"."orders"', database="my_db")
259+
260+
resolve.assert_called_once_with("my_db")
261+
assert fake_q.calls == [{"x_database_id": "db_abc"}]
262+
263+
203264
def test_list_run_history_returns_normalized_items():
204265
client = HotdataClient("k", "ws", host="https://api.hotdata.dev")
205266
listing = SimpleNamespace(

0 commit comments

Comments
 (0)