Skip to content

Commit 20c817b

Browse files
gopinathnelluriwendigo
authored andcommitted
Fix OOM on large spooled result sets
Previously, TrinoQuery.fetch() eagerness caused all segments to load into memory at once when using fault-tolerant execution. This led to OOM errors on large datasets. Changes: - Enable lazy loading by returning SegmentIterator directly in fetch(). - Update execute() to handle result rows as iterators instead of requiring lists. - Add unit test to verify lazy fetching implementation.
1 parent ede0893 commit 20c817b

4 files changed

Lines changed: 187 additions & 6 deletions

File tree

tests/development_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def start_development_server(port=None, trino_version=TRINO_VERSION):
5151
network = Network().create()
5252
supports_spooling_protocol = TRINO_VERSION == "latest" or int(TRINO_VERSION) >= 466
5353
if supports_spooling_protocol:
54-
localstack = LocalStackContainer(image="localstack/localstack:latest", region_name="us-east-1") \
54+
localstack = LocalStackContainer(image="localstack/localstack:4.14.0", region_name="us-east-1") \
5555
.with_name("localstack") \
5656
.with_network(network) \
5757
.with_bind_ports(4566, 4566) \

tests/integration/test_dbapi_integration.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1900,6 +1900,79 @@ def test_segments_cursor(trino_connection):
19001900
assert total == 300875, f"Expected total rows 300875, got {total}"
19011901

19021902

1903+
@pytest.mark.skipif(
1904+
trino_version() <= 466,
1905+
reason="spooling protocol was introduced in version 466"
1906+
)
1907+
def test_spooled_segments_lazy_fetchone(trino_connection):
1908+
"""Verify that spooled results can be consumed row-by-row via fetchone()
1909+
without materializing the entire result set in memory."""
1910+
if trino_connection._client_session.encoding is None:
1911+
pytest.skip("spooling requires an encoding")
1912+
1913+
cur = trino_connection.cursor()
1914+
cur.execute("""SELECT l.*
1915+
FROM tpch.tiny.lineitem l, TABLE(sequence(
1916+
start => 1,
1917+
stop => 5,
1918+
step => 1)) n""")
1919+
1920+
# The underlying result rows should be an iterator, not a list
1921+
result_rows = cur._query._result._rows
1922+
assert not isinstance(result_rows, list), (
1923+
f"Expected lazy iterator for spooled results, got {type(result_rows)}"
1924+
)
1925+
1926+
# Consume rows one by one and count them
1927+
count = 0
1928+
while cur.fetchone() is not None:
1929+
count += 1
1930+
assert count == 300875, f"Expected 300875 rows, got {count}"
1931+
1932+
1933+
@pytest.mark.skipif(
1934+
trino_version() <= 466,
1935+
reason="spooling protocol was introduced in version 466"
1936+
)
1937+
def test_spooled_segments_fetchmany(trino_connection):
1938+
"""Verify that fetchmany() works correctly with lazily loaded spooled segments."""
1939+
if trino_connection._client_session.encoding is None:
1940+
pytest.skip("spooling requires an encoding")
1941+
1942+
cur = trino_connection.cursor()
1943+
cur.execute("SELECT * FROM tpch.tiny.lineitem")
1944+
1945+
batch = cur.fetchmany(100)
1946+
assert len(batch) == 100
1947+
1948+
total = len(batch)
1949+
while True:
1950+
batch = cur.fetchmany(1000)
1951+
if not batch:
1952+
break
1953+
total += len(batch)
1954+
assert total == 60175, f"Expected 60175 rows, got {total}"
1955+
1956+
1957+
@pytest.mark.skipif(
1958+
trino_version() <= 466,
1959+
reason="spooling protocol was introduced in version 466"
1960+
)
1961+
def test_spooled_segments_iterator_protocol(trino_connection):
1962+
"""Verify that cursor iteration works correctly with spooled segments."""
1963+
if trino_connection._client_session.encoding is None:
1964+
pytest.skip("spooling requires an encoding")
1965+
1966+
cur = trino_connection.cursor()
1967+
cur.execute("SELECT * FROM tpch.tiny.lineitem")
1968+
1969+
count = 0
1970+
for row in cur:
1971+
count += 1
1972+
assert isinstance(row, list)
1973+
assert count == 60175, f"Expected 60175 rows, got {count}"
1974+
1975+
19031976
def get_cursor(legacy_prepared_statements, run_trino):
19041977
host, port = run_trino
19051978

tests/unit/test_spooling.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
2+
import unittest
3+
from unittest.mock import MagicMock, patch
4+
from trino.client import TrinoQuery, TrinoRequest, ClientSession, TrinoResult
5+
from trino.client import SegmentIterator
6+
7+
class TestTrinoQueryLazy(unittest.TestCase):
8+
def setUp(self):
9+
self.mock_request = MagicMock(spec=TrinoRequest)
10+
self.client_session = ClientSession("user")
11+
self.mock_request.client_session = self.client_session
12+
13+
def test_fetch_returns_iterator_for_spooled_segments(self):
14+
# Mock the initial POST response
15+
post_response = MagicMock()
16+
post_response.id = "query_1"
17+
post_response.stats = {}
18+
post_response.info_uri = "info"
19+
post_response.next_uri = "next_1"
20+
post_response.rows = [] # No rows initially
21+
22+
self.mock_request.process.return_value = post_response
23+
self.mock_request.post.return_value = MagicMock()
24+
25+
query = TrinoQuery(self.mock_request, "SELECT 1")
26+
27+
# Execute should return empty result initially but try to fetch
28+
# We need to mock fetch behavior too since execute calls it if rows are empty
29+
30+
# Mock the GET response for fetch()
31+
get_response_status = MagicMock()
32+
get_response_status.next_uri = None # Finished
33+
get_response_status.stats = {}
34+
# Status rows as dict indicates spooling protocol
35+
get_response_status.rows = {
36+
"encoding": "json",
37+
"segments": [
38+
{"type": "spooled", "uri": "u1", "ackUri": "a1", "metadata": {"segmentSize": "10", "uncompressedSize": "10"}}
39+
],
40+
"metadata": {}
41+
}
42+
43+
# When execute calls fetch(), it calls request.get -> process -> returns get_response_status
44+
self.mock_request.process.side_effect = [post_response, get_response_status]
45+
self.mock_request.get.return_value = MagicMock()
46+
47+
# Mock _to_segments to return a list of decodable segments
48+
# We can just verify that fetch returns a SegmentIterator
49+
# But _to_segments is internal.
50+
51+
# We need to patch SegmentIterator or check the return type
52+
53+
result = query.execute()
54+
55+
# Verify result.rows is a SegmentIterator, NOT a list
56+
self.assertIsInstance(result.rows, SegmentIterator)
57+
self.assertNotIsInstance(result.rows, list)
58+
59+
def test_fetch_returns_list_for_normal_segments(self):
60+
# Mock the initial POST response
61+
post_response = MagicMock()
62+
post_response.id = "query_1"
63+
post_response.stats = {}
64+
post_response.info_uri = "info"
65+
post_response.next_uri = "next_1"
66+
post_response.rows = []
67+
68+
# Mock the GET response for fetch()
69+
get_response_status = MagicMock()
70+
get_response_status.next_uri = None
71+
get_response_status.stats = {}
72+
get_response_status.rows = [[1], [2]] # Normal list rows
73+
74+
self.mock_request.process.side_effect = [post_response, get_response_status]
75+
76+
query = TrinoQuery(self.mock_request, "SELECT 1")
77+
result = query.execute()
78+
79+
# Verify result.rows is a list (appended)
80+
self.assertIsInstance(result.rows, list)
81+
self.assertEqual(result.rows, [[1], [2]])
82+
83+
if __name__ == '__main__':
84+
unittest.main()

trino/client.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -904,9 +904,32 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
904904
rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows
905905
self._result = TrinoResult(self, rows)
906906

907-
# Execute should block until at least one row is received or query is finished or cancelled
908-
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
909-
self._result.rows += self.fetch()
907+
"""
908+
Execute should block until at least one row is received or query is finished or cancelled
909+
910+
For Standard Execution, rows is a list, we can check len. the first response usually contains no rows (just stats),
911+
so we need to continue fetching until we get some rows or query is finished or cancelled.
912+
913+
For Spooled Execution, rows start as empty list and eventually fetch returns the rows as iterator,
914+
we can't check len of an iterator easily without peeking.
915+
916+
So, if we get rows as non empty list or iterator, we stop blocking and return it to the caller to consume it.
917+
"""
918+
919+
while not self.finished and not self.cancelled:
920+
if isinstance(self._result.rows, list) and len(self._result.rows) == 0:
921+
new_rows = self.fetch()
922+
if isinstance(new_rows, list):
923+
self._result.rows += new_rows
924+
else:
925+
# It's an iterator (spooled segments), replace rows with it
926+
self._result.rows = new_rows
927+
# We have an iterator now, so we can return result to user
928+
break
929+
else:
930+
# We have data (list with items or an iterator), so return
931+
break
932+
910933
return self._result
911934

912935
def _update_state(self, status):
@@ -920,7 +943,7 @@ def _update_state(self, status):
920943
if status.columns:
921944
self._columns = status.columns
922945

923-
def fetch(self) -> List[Union[List[Any]], Any]:
946+
def fetch(self) -> Union[List[Union[List[Any], Any]], Iterator[List[Any]]]:
924947
"""Continue fetching data for the current query_id"""
925948
try:
926949
response = self._request.get(self._request.next_uri)
@@ -941,7 +964,8 @@ def fetch(self) -> List[Union[List[Any]], Any]:
941964
spooled = self._to_segments(rows)
942965
if self._fetch_mode == "segments":
943966
return spooled
944-
return list(SegmentIterator(spooled, self._row_mapper))
967+
# Return iterator directly, do NOT materialize with list()
968+
return SegmentIterator(spooled, self._row_mapper)
945969
elif isinstance(status.rows, list):
946970
return self._row_mapper.map(rows)
947971
else:

0 commit comments

Comments
 (0)