3737import abc
3838import atexit
3939import base64
40+ import contextlib
4041import copy
4142import functools
4243import itertools
@@ -853,10 +854,10 @@ def __iter__(self):
853854 self ._rows = next_rows
854855
855856
856- class QueryHeartbeat :
857- """Periodically sends HEAD requests to the current nextUri to prevent the
858- coordinator from abandoning a query while the client is downloading spooled
859- result segments from external storage."""
857+ class _QueryHeartbeat :
858+ """Periodically sends HEAD requests to the current nextUri to prevent the coordinator
859+ from abandoning a query if the client is silent for a longer period of time, for example
860+ while the client is downloading spooled result segments from external storage."""
860861
861862 _MAX_FAILURES = 3
862863
@@ -881,6 +882,7 @@ def _run(self) -> None:
881882 try :
882883 response = self ._request .head (uri )
883884 if response .status_code in (404 , 405 ):
885+ # 404/405 means the server does not support heartbeat calls
884886 return
885887 if not response .ok :
886888 self ._failures += 1
@@ -894,6 +896,16 @@ def _run(self) -> None:
894896 return
895897
896898
899+ @contextlib .contextmanager
900+ def query_heartbeat (request : TrinoRequest , interval : float ) -> Iterator [None ]:
901+ heartbeat = _QueryHeartbeat (request , interval )
902+ heartbeat .start ()
903+ try :
904+ yield
905+ finally :
906+ heartbeat .stop ()
907+
908+
897909class TrinoQuery :
898910 """Represent the execution of a SQL statement by Trino."""
899911
@@ -920,7 +932,6 @@ def __init__(
920932 self ._legacy_primitive_types = legacy_primitive_types
921933 self ._row_mapper : Optional [RowMapper ] = None
922934 self ._fetch_mode = fetch_mode
923- self ._heartbeat : Optional [QueryHeartbeat ] = None
924935
925936 @property
926937 def query_id (self ) -> Optional [str ]:
@@ -998,13 +1009,8 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
9981009 self ._stats .update ({"queryId" : self .query_id })
9991010 self ._update_state (status )
10001011 self ._warnings = getattr (status , "warnings" , [])
1001- interval = self ._request ._client_session .heartbeat_interval
1002- if interval is not None :
1003- self ._heartbeat = QueryHeartbeat (self ._request , interval )
1004- self ._heartbeat .start ()
10051012 if status .next_uri is None :
10061013 self ._finished = True
1007- self ._stop_heartbeat ()
10081014
10091015 rows = self ._row_mapper .map (status .rows ) if self ._row_mapper else status .rows
10101016 self ._result = TrinoResult (self , rows )
@@ -1052,7 +1058,6 @@ def fetch(self) -> Union[List[Union[List[Any], Any]], Iterator[List[Any]]]:
10521058 self ._update_state (status )
10531059 if status .next_uri is None :
10541060 self ._finished = True
1055- self ._stop_heartbeat ()
10561061
10571062 if not self ._row_mapper :
10581063 return []
@@ -1065,7 +1070,9 @@ def fetch(self) -> Union[List[Union[List[Any], Any]], Iterator[List[Any]]]:
10651070 if self ._fetch_mode == "segments" :
10661071 return spooled
10671072 # Return iterator directly, do NOT materialize with list()
1068- return SegmentIterator (spooled , self ._row_mapper )
1073+ return SegmentIterator (
1074+ spooled , self ._row_mapper , self ._request , self ._request ._client_session .heartbeat_interval
1075+ )
10691076 elif isinstance (status .rows , list ):
10701077 return self ._row_mapper .map (rows )
10711078 else :
@@ -1100,16 +1107,11 @@ def cancel(self) -> None:
11001107 raise trino .exceptions .TrinoConnectionError ("failed to cancel query: {}" .format (e ))
11011108 if response .status_code == requests .codes .no_content :
11021109 self ._cancelled = True
1103- self ._stop_heartbeat ()
11041110 logger .debug ("query cancelled: %s" , self .query_id )
11051111 return
11061112
11071113 self ._request .raise_response_error (response )
11081114
1109- def _stop_heartbeat (self ) -> None :
1110- if self ._heartbeat is not None :
1111- self ._heartbeat .stop ()
1112-
11131115 def is_finished (self ) -> bool :
11141116 import warnings
11151117 warnings .warn ("is_finished is deprecated, use finished instead" , DeprecationWarning )
@@ -1342,13 +1344,21 @@ def __repr__(self):
13421344
13431345
13441346class SegmentIterator :
1345- def __init__ (self , segments : Union [DecodableSegment , List [DecodableSegment ]], mapper : RowMapper ) -> None :
1347+ def __init__ (
1348+ self ,
1349+ segments : Union [DecodableSegment , List [DecodableSegment ]],
1350+ mapper : RowMapper ,
1351+ request : TrinoRequest ,
1352+ heartbeat_interval : Optional [float ] = None ,
1353+ ) -> None :
13461354 self ._segments = iter (segments if isinstance (segments , List ) else [segments ])
13471355 self ._mapper = mapper
13481356 self ._decoder = None
13491357 self ._rows : Iterator [List [List [Any ]]] = iter ([])
13501358 self ._finished = False
13511359 self ._current_segment : Optional [DecodableSegment ] = None
1360+ self ._request = request
1361+ self ._heartbeat_interval = heartbeat_interval
13521362
13531363 def __iter__ (self ) -> Iterator [List [Any ]]:
13541364 return self
@@ -1374,7 +1384,13 @@ def _load_next_segment(self):
13741384 if self ._decoder is None :
13751385 self ._decoder = SegmentDecoder (CompressedQueryDataDecoderFactory (self ._mapper )
13761386 .create (self ._current_segment .encoding ))
1377- self ._rows = iter (self ._decoder .decode (self ._current_segment .segment ))
1387+ if isinstance (self ._current_segment .segment , SpooledSegment ) and self ._heartbeat_interval :
1388+ # Downloading a spooled segment may take some time. In the meantime, we send heartbeat
1389+ # requests so the server doesn't think we lost interest and close the connection.
1390+ with query_heartbeat (self ._request , self ._heartbeat_interval ):
1391+ self ._rows = iter (self ._decoder .decode (self ._current_segment .segment ))
1392+ else :
1393+ self ._rows = iter (self ._decoder .decode (self ._current_segment .segment ))
13781394 except StopIteration :
13791395 self ._finished = True
13801396
0 commit comments