Skip to content

Commit 83ca2f2

Browse files
authored
fix(managed_client): wait for result ready on sync query path (#28)
* fix(managed_client): wait for result ready on sync query path fetch_table fetched the persisted result as Arrow using the result_id from a synchronous QueryResponse without waiting for it to reach 'ready'. Against the live async backend the result is often still 'processing', so Arrow fetches failed on every read-modify-write (merge/append) and state read. The async path already waited; the sync path now does too. Adds a regression test driving the sync path with a 'processing' -> 'ready' result, asserting Arrow is fetched only after readiness. * refactor(managed_client): dedupe run/result polling, group constants Extract a single typed _poll() helper shared by the query-run and result-ready waits (removing two near-identical loops and the magic 0.5/0.3 sleep intervals), flatten _query_database_scoped via _await_query_run, and hoist the class constants to the top. No behavior change; verified by the unit suite and a full prod e2e re-run. --------- Co-authored-by: Eddie A Tejeda <669988+eddietejeda@users.noreply.github.com>
1 parent 5ec8803 commit 83ca2f2

2 files changed

Lines changed: 150 additions & 32 deletions

File tree

hotdata_framework/managed_client.py

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import time
1111
from collections.abc import Callable
12-
from typing import Any, TypeVar
12+
from typing import Any, Protocol, TypeVar
1313

1414
import pyarrow as pa
1515
from hotdata.api.query_api import QueryApi
@@ -30,6 +30,16 @@
3030
T = TypeVar("T")
3131

3232

33+
class _StatusResponse(Protocol):
34+
"""Async resources (query runs, results) expose a status and error message."""
35+
36+
status: str
37+
error_message: str | None
38+
39+
40+
S = TypeVar("S", bound=_StatusResponse)
41+
42+
3343
class ManagedDatabaseClient:
3444
"""Managed-database client with bounded retries over hotdata-framework.
3545
@@ -39,6 +49,10 @@ class ManagedDatabaseClient:
3949
database lifecycle.
4050
"""
4151

52+
_QUERY_TIMEOUT_SECONDS = 300.0
53+
_POLL_INTERVAL_SECONDS = 0.4
54+
_MAX_BACKOFF_SECONDS = 30.0
55+
4256
def __init__(
4357
self,
4458
*,
@@ -100,49 +114,62 @@ def operation() -> pa.Table | None:
100114

101115
return self._request_with_retry(operation)
102116

103-
_QUERY_TIMEOUT_SECONDS = 300.0
117+
def _poll(
118+
self,
119+
fetch: Callable[[], S],
120+
*,
121+
is_ready: Callable[[S], bool],
122+
describe: str,
123+
) -> S:
124+
"""Poll ``fetch`` until ``is_ready`` is satisfied, or raise on failure/timeout.
125+
126+
``failed``/``cancelled`` statuses raise ``RuntimeError``; exceeding
127+
:attr:`_QUERY_TIMEOUT_SECONDS` raises ``TimeoutError``.
128+
"""
129+
deadline = time.monotonic() + self._QUERY_TIMEOUT_SECONDS
130+
while time.monotonic() < deadline:
131+
obj = fetch()
132+
if obj.status in ("failed", "cancelled"):
133+
raise RuntimeError(obj.error_message or f"{describe} {obj.status}")
134+
if is_ready(obj):
135+
return obj
136+
time.sleep(self._POLL_INTERVAL_SECONDS)
137+
raise TimeoutError(f"{describe} timed out after {self._QUERY_TIMEOUT_SECONDS}s")
104138

105139
def _query_database_scoped(self, sql: str, *, database_id: str) -> str | None:
106140
raw = QueryApi(self._runtime.api).query(
107141
QueryRequest(sql=sql),
108142
x_database_id=database_id,
109143
)
110144
if isinstance(raw, QueryResponse):
111-
return raw.result_id
112-
145+
# A synchronous response still persists its full result out-of-band
146+
# under ``result_id``; that result may be ``processing`` when the
147+
# inline preview returns, so wait for ``ready`` before the caller
148+
# fetches it as Arrow.
149+
return self._wait_result_ready(raw.result_id)
113150
if isinstance(raw, AsyncQueryResponse):
114-
runs = QueryRunsApi(self._runtime.api)
115-
deadline = time.monotonic() + self._QUERY_TIMEOUT_SECONDS
116-
result_id: str | None = None
117-
while time.monotonic() < deadline:
118-
run = runs.get_query_run(raw.query_run_id)
119-
if run.status == "succeeded":
120-
result_id = run.result_id
121-
break
122-
if run.status in ("failed", "cancelled"):
123-
raise RuntimeError(run.error_message or f"Query {run.status}")
124-
time.sleep(0.5)
125-
else:
126-
raise TimeoutError(
127-
f"Managed database query timed out after {self._QUERY_TIMEOUT_SECONDS}s"
128-
)
129-
return self._wait_result_ready(result_id)
130-
151+
return self._wait_result_ready(self._await_query_run(raw.query_run_id))
131152
return None
132153

154+
def _await_query_run(self, query_run_id: str) -> str | None:
155+
runs = QueryRunsApi(self._runtime.api)
156+
run = self._poll(
157+
lambda: runs.get_query_run(query_run_id),
158+
is_ready=lambda r: r.status == "succeeded",
159+
describe="Query",
160+
)
161+
return run.result_id
162+
133163
def _wait_result_ready(self, result_id: str | None) -> str | None:
134164
if result_id is None:
135165
return None
136166
results = ResultsApi(self._runtime.api)
137-
deadline = time.monotonic() + self._QUERY_TIMEOUT_SECONDS
138-
while time.monotonic() < deadline:
139-
r = results.get_result(result_id)
140-
if r.status == "ready":
141-
return result_id
142-
if r.status in ("failed", "cancelled"):
143-
raise RuntimeError(r.error_message or f"Result {r.status}")
144-
time.sleep(0.3)
145-
raise TimeoutError(f"Result {result_id} not ready after {self._QUERY_TIMEOUT_SECONDS}s")
167+
self._poll(
168+
lambda: results.get_result(result_id),
169+
is_ready=lambda r: r.status == "ready",
170+
describe=f"Result {result_id}",
171+
)
172+
return result_id
146173

147174
def fetch_table_rows(self, *, database: str, schema: str, table: str) -> list[dict[str, Any]]:
148175
result = self.fetch_table(database=database, schema=schema, table=table)
@@ -168,8 +195,6 @@ def load_managed_table(
168195
)
169196
)
170197

171-
_MAX_BACKOFF_SECONDS = 30.0
172-
173198
def _request_with_retry(self, operation: Callable[[], T]) -> T:
174199
for attempt in range(1, self._max_retries + 1):
175200
try:

tests/test_managed_client.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""Regression tests for ManagedDatabaseClient result handling."""
2+
3+
from __future__ import annotations
4+
5+
from types import SimpleNamespace
6+
from typing import Any
7+
8+
import pyarrow as pa
9+
import pytest
10+
from hotdata.models.query_response import QueryResponse
11+
12+
import hotdata_framework.managed_client as mc
13+
14+
15+
def _query_response(result_id: str) -> QueryResponse:
16+
return QueryResponse(
17+
columns=[],
18+
rows=[],
19+
row_count=0,
20+
preview_row_count=0,
21+
truncated=False,
22+
nullable=[],
23+
result_id=result_id,
24+
query_run_id="qr",
25+
execution_time_ms=1,
26+
)
27+
28+
29+
def test_fetch_table_waits_for_ready_before_arrow(monkeypatch: pytest.MonkeyPatch) -> None:
30+
"""A synchronous ``QueryResponse`` persists its full result out-of-band, and
31+
that result can still be ``processing`` when the inline preview returns.
32+
33+
``fetch_table`` must poll the result to ``ready`` before fetching it as
34+
Arrow. The earlier bug returned the ``result_id`` immediately on the sync
35+
path, so Arrow was fetched against a ``processing`` result and failed.
36+
"""
37+
calls: list[str] = []
38+
39+
class FakeQueryApi:
40+
def __init__(self, api: object) -> None:
41+
pass
42+
43+
def query(self, request: object, *, x_database_id: str) -> QueryResponse:
44+
calls.append("query")
45+
return _query_response("rslt1")
46+
47+
statuses = iter(["processing", "processing", "ready"])
48+
49+
class FakeResultsApi:
50+
def __init__(self, api: object) -> None:
51+
pass
52+
53+
def get_result(self, result_id: str) -> Any:
54+
status = next(statuses)
55+
calls.append(f"get_result:{status}")
56+
return SimpleNamespace(status=status, result_id=result_id, error_message=None)
57+
58+
class FakeArrowResultsApi:
59+
def __init__(self, api: object) -> None:
60+
pass
61+
62+
def get_result_arrow(self, result_id: str) -> pa.Table:
63+
calls.append("arrow")
64+
return pa.table({"id": [1, 2]})
65+
66+
monkeypatch.setattr(mc, "QueryApi", FakeQueryApi)
67+
monkeypatch.setattr(mc, "ResultsApi", FakeResultsApi)
68+
monkeypatch.setattr(mc, "ArrowResultsApi", FakeArrowResultsApi)
69+
monkeypatch.setattr(mc.time, "sleep", lambda _seconds: None)
70+
71+
client = mc.ManagedDatabaseClient(
72+
api_key="k",
73+
workspace_id="w",
74+
api_base_url="https://example.test",
75+
max_retries=1,
76+
retry_backoff_seconds=0.0,
77+
)
78+
client._runtime = SimpleNamespace( # type: ignore[assignment]
79+
api=object(),
80+
resolve_managed_database=lambda name: SimpleNamespace(id="db1", default_connection_id="c"),
81+
list_managed_tables=lambda database, schema=None: [
82+
SimpleNamespace(table="orders", synced=True)
83+
],
84+
)
85+
86+
table = client.fetch_table(database="mydb", schema="public", table="orders")
87+
88+
assert table is not None
89+
assert table.num_rows == 2
90+
# The result was polled to readiness, and Arrow was fetched only afterwards.
91+
assert "get_result:processing" in calls
92+
assert "get_result:ready" in calls
93+
assert calls.index("arrow") > calls.index("get_result:ready")

0 commit comments

Comments
 (0)