Skip to content

Commit 8c23aae

Browse files
Implement heartbeat when downloading spooled segments
1 parent 2adcec1 commit 8c23aae

1 file changed

Lines changed: 35 additions & 19 deletions

File tree

trino/client.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import abc
3838
import atexit
3939
import base64
40+
import contextlib
4041
import copy
4142
import functools
4243
import itertools
@@ -853,10 +854,10 @@ def __iter__(self):
853854
self._rows = next_rows
854855

855856

856-
class QueryHeartbeat:
857-
"""Periodically sends HEAD requests to the current nextUri to prevent the
858-
coordinator from abandoning a query while the client is downloading spooled
859-
result segments from external storage."""
857+
class _QueryHeartbeat:
858+
"""Periodically sends HEAD requests to the current nextUri to prevent the coordinator
859+
from abandoning a query if the client is silent for a longer period of time, for example
860+
while the client is downloading spooled result segments from external storage."""
860861

861862
_MAX_FAILURES = 3
862863

@@ -881,6 +882,7 @@ def _run(self) -> None:
881882
try:
882883
response = self._request.head(uri)
883884
if response.status_code in (404, 405):
885+
# 404/405 means the server does not support heartbeat calls
884886
return
885887
if not response.ok:
886888
self._failures += 1
@@ -894,6 +896,16 @@ def _run(self) -> None:
894896
return
895897

896898

899+
@contextlib.contextmanager
900+
def query_heartbeat(request: TrinoRequest, interval: float) -> Iterator[None]:
901+
heartbeat = _QueryHeartbeat(request, interval)
902+
heartbeat.start()
903+
try:
904+
yield
905+
finally:
906+
heartbeat.stop()
907+
908+
897909
class TrinoQuery:
898910
"""Represent the execution of a SQL statement by Trino."""
899911

@@ -920,7 +932,6 @@ def __init__(
920932
self._legacy_primitive_types = legacy_primitive_types
921933
self._row_mapper: Optional[RowMapper] = None
922934
self._fetch_mode = fetch_mode
923-
self._heartbeat: Optional[QueryHeartbeat] = None
924935

925936
@property
926937
def query_id(self) -> Optional[str]:
@@ -998,13 +1009,8 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
9981009
self._stats.update({"queryId": self.query_id})
9991010
self._update_state(status)
10001011
self._warnings = getattr(status, "warnings", [])
1001-
interval = self._request._client_session.heartbeat_interval
1002-
if interval is not None:
1003-
self._heartbeat = QueryHeartbeat(self._request, interval)
1004-
self._heartbeat.start()
10051012
if status.next_uri is None:
10061013
self._finished = True
1007-
self._stop_heartbeat()
10081014

10091015
rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows
10101016
self._result = TrinoResult(self, rows)
@@ -1052,7 +1058,6 @@ def fetch(self) -> Union[List[Union[List[Any], Any]], Iterator[List[Any]]]:
10521058
self._update_state(status)
10531059
if status.next_uri is None:
10541060
self._finished = True
1055-
self._stop_heartbeat()
10561061

10571062
if not self._row_mapper:
10581063
return []
@@ -1065,7 +1070,9 @@ def fetch(self) -> Union[List[Union[List[Any], Any]], Iterator[List[Any]]]:
10651070
if self._fetch_mode == "segments":
10661071
return spooled
10671072
# Return iterator directly, do NOT materialize with list()
1068-
return SegmentIterator(spooled, self._row_mapper)
1073+
return SegmentIterator(
1074+
spooled, self._row_mapper, self._request, self._request._client_session.heartbeat_interval
1075+
)
10691076
elif isinstance(status.rows, list):
10701077
return self._row_mapper.map(rows)
10711078
else:
@@ -1100,16 +1107,11 @@ def cancel(self) -> None:
11001107
raise trino.exceptions.TrinoConnectionError("failed to cancel query: {}".format(e))
11011108
if response.status_code == requests.codes.no_content:
11021109
self._cancelled = True
1103-
self._stop_heartbeat()
11041110
logger.debug("query cancelled: %s", self.query_id)
11051111
return
11061112

11071113
self._request.raise_response_error(response)
11081114

1109-
def _stop_heartbeat(self) -> None:
1110-
if self._heartbeat is not None:
1111-
self._heartbeat.stop()
1112-
11131115
def is_finished(self) -> bool:
11141116
import warnings
11151117
warnings.warn("is_finished is deprecated, use finished instead", DeprecationWarning)
@@ -1342,13 +1344,21 @@ def __repr__(self):
13421344

13431345

13441346
class SegmentIterator:
1345-
def __init__(self, segments: Union[DecodableSegment, List[DecodableSegment]], mapper: RowMapper) -> None:
1347+
def __init__(
1348+
self,
1349+
segments: Union[DecodableSegment, List[DecodableSegment]],
1350+
mapper: RowMapper,
1351+
request: TrinoRequest,
1352+
heartbeat_interval: Optional[float] = None,
1353+
) -> None:
13461354
self._segments = iter(segments if isinstance(segments, List) else [segments])
13471355
self._mapper = mapper
13481356
self._decoder = None
13491357
self._rows: Iterator[List[List[Any]]] = iter([])
13501358
self._finished = False
13511359
self._current_segment: Optional[DecodableSegment] = None
1360+
self._request = request
1361+
self._heartbeat_interval = heartbeat_interval
13521362

13531363
def __iter__(self) -> Iterator[List[Any]]:
13541364
return self
@@ -1374,7 +1384,13 @@ def _load_next_segment(self):
13741384
if self._decoder is None:
13751385
self._decoder = SegmentDecoder(CompressedQueryDataDecoderFactory(self._mapper)
13761386
.create(self._current_segment.encoding))
1377-
self._rows = iter(self._decoder.decode(self._current_segment.segment))
1387+
if isinstance(self._current_segment.segment, SpooledSegment) and self._heartbeat_interval:
1388+
# Downloading a spooled segment may take some time. In the meantime, we send heartbeat
1389+
# requests so the server doesn't think we lost interest and close the connection.
1390+
with query_heartbeat(self._request, self._heartbeat_interval):
1391+
self._rows = iter(self._decoder.decode(self._current_segment.segment))
1392+
else:
1393+
self._rows = iter(self._decoder.decode(self._current_segment.segment))
13781394
except StopIteration:
13791395
self._finished = True
13801396

0 commit comments

Comments
 (0)