Skip to content

Commit f178dc3

Browse files
Fix OOM on large spooled result sets
Previously, TrinoQuery.fetch() eagerness caused all segments to load into memory at once when using fault-tolerant execution. This led to OOM errors on large datasets. Changes: - Enable lazy loading by returning SegmentIterator directly in fetch(). - Update execute() to handle result rows as iterators instead of requiring lists. - Add unit test to verify lazy fetching implementation.
1 parent ede0893 commit f178dc3

2 files changed

Lines changed: 120 additions & 4 deletions

File tree

tests/unit/test_spooling.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
2+
import unittest
3+
from unittest.mock import MagicMock, patch
4+
from trino.client import TrinoQuery, TrinoRequest, ClientSession, TrinoResult
5+
from trino.client import SegmentIterator
6+
7+
class TestTrinoQueryLazy(unittest.TestCase):
8+
def setUp(self):
9+
self.mock_request = MagicMock(spec=TrinoRequest)
10+
self.client_session = ClientSession("user")
11+
self.mock_request.client_session = self.client_session
12+
13+
def test_fetch_returns_iterator_for_spooled_segments(self):
14+
# Mock the initial POST response
15+
post_response = MagicMock()
16+
post_response.id = "query_1"
17+
post_response.stats = {}
18+
post_response.info_uri = "info"
19+
post_response.next_uri = "next_1"
20+
post_response.rows = [] # No rows initially
21+
22+
self.mock_request.process.return_value = post_response
23+
self.mock_request.post.return_value = MagicMock()
24+
25+
query = TrinoQuery(self.mock_request, "SELECT 1")
26+
27+
# Execute should return empty result initially but try to fetch
28+
# We need to mock fetch behavior too since execute calls it if rows are empty
29+
30+
# Mock the GET response for fetch()
31+
get_response_status = MagicMock()
32+
get_response_status.next_uri = None # Finished
33+
get_response_status.stats = {}
34+
# Status rows as dict indicates spooling protocol
35+
get_response_status.rows = {
36+
"encoding": "json",
37+
"segments": [
38+
{"type": "spooled", "uri": "u1", "ackUri": "a1", "metadata": {"segmentSize": "10", "uncompressedSize": "10"}}
39+
],
40+
"metadata": {}
41+
}
42+
43+
# When execute calls fetch(), it calls request.get -> process -> returns get_response_status
44+
self.mock_request.process.side_effect = [post_response, get_response_status]
45+
self.mock_request.get.return_value = MagicMock()
46+
47+
# Mock _to_segments to return a list of decodable segments
48+
# We can just verify that fetch returns a SegmentIterator
49+
# But _to_segments is internal.
50+
51+
# We need to patch SegmentIterator or check the return type
52+
53+
result = query.execute()
54+
55+
# Verify result.rows is a SegmentIterator, NOT a list
56+
self.assertIsInstance(result.rows, SegmentIterator)
57+
self.assertNotIsInstance(result.rows, list)
58+
59+
def test_fetch_returns_list_for_normal_segments(self):
60+
# Mock the initial POST response
61+
post_response = MagicMock()
62+
post_response.id = "query_1"
63+
post_response.stats = {}
64+
post_response.info_uri = "info"
65+
post_response.next_uri = "next_1"
66+
post_response.rows = []
67+
68+
# Mock the GET response for fetch()
69+
get_response_status = MagicMock()
70+
get_response_status.next_uri = None
71+
get_response_status.stats = {}
72+
get_response_status.rows = [[1], [2]] # Normal list rows
73+
74+
self.mock_request.process.side_effect = [post_response, get_response_status]
75+
76+
query = TrinoQuery(self.mock_request, "SELECT 1")
77+
result = query.execute()
78+
79+
# Verify result.rows is a list (appended)
80+
self.assertIsInstance(result.rows, list)
81+
self.assertEqual(result.rows, [[1], [2]])
82+
83+
if __name__ == '__main__':
84+
unittest.main()

trino/client.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -905,8 +905,39 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
905905
self._result = TrinoResult(self, rows)
906906

907907
# Execute should block until at least one row is received or query is finished or cancelled
908-
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
909-
self._result.rows += self.fetch()
908+
# If rows is a list (standard execution), we can check len.
909+
# If rows is an iterator (spooled), we can't check len easily without peeking.
910+
# However, for standard execution, the first response usually contains no rows (just stats),
911+
# so we need to fetch.
912+
913+
# If it's a list and empty, or if we haven't finished and haven't cancelled, try to fetch more.
914+
# The issue with spooled execution is that fetch() returns an iterator.
915+
916+
# For standard execution: rows starts as list. fetch() returns list.
917+
# For spooled execution: rows starts as list (empty). fetch() returns Iterator.
918+
919+
# We need to detect if we have data.
920+
# For spooled, we might get an iterator that *yields* nothing if segments are empty?
921+
# But we want to block until we have *something* to return or finished.
922+
923+
# Modified logic:
924+
# If _result.rows is empty list, we fetch.
925+
# If _result.rows key becomes an iterator, we stop blocking and let the user consume it.
926+
927+
while not self.finished and not self.cancelled:
928+
if isinstance(self._result.rows, list) and len(self._result.rows) == 0:
929+
new_rows = self.fetch()
930+
if isinstance(new_rows, list):
931+
self._result.rows += new_rows
932+
else:
933+
# It's an iterator (spooled segments), replace rows with it
934+
self._result.rows = new_rows
935+
# We have an iterator now, so we can return result to user
936+
break
937+
else:
938+
# We have data (list with items or an iterator), so return
939+
break
940+
910941
return self._result
911942

912943
def _update_state(self, status):
@@ -920,7 +951,7 @@ def _update_state(self, status):
920951
if status.columns:
921952
self._columns = status.columns
922953

923-
def fetch(self) -> List[Union[List[Any]], Any]:
954+
def fetch(self) -> Union[List[Union[List[Any], Any]], Iterator[List[Any]]]:
924955
"""Continue fetching data for the current query_id"""
925956
try:
926957
response = self._request.get(self._request.next_uri)
@@ -941,7 +972,8 @@ def fetch(self) -> List[Union[List[Any]], Any]:
941972
spooled = self._to_segments(rows)
942973
if self._fetch_mode == "segments":
943974
return spooled
944-
return list(SegmentIterator(spooled, self._row_mapper))
975+
# Return iterator directly, do NOT materialize with list()
976+
return SegmentIterator(spooled, self._row_mapper)
945977
elif isinstance(status.rows, list):
946978
return self._row_mapper.map(rows)
947979
else:

0 commit comments

Comments
 (0)