Skip to content

Commit 0aa81e7

Browse files
committed
Lazily load spooled result set segments
When using the spooling protocol, TrinoQuery.fetch() materializes all segments into a single list, causing out-of-memory errors on large result sets. Return SegmentIterator directly so rows are decoded on demand, keeping memory usage constant regardless of result set size.
1 parent ede0893 commit 0aa81e7

3 files changed

Lines changed: 98 additions & 6 deletions

File tree

tests/development_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def start_development_server(port=None, trino_version=TRINO_VERSION):
5151
network = Network().create()
5252
supports_spooling_protocol = TRINO_VERSION == "latest" or int(TRINO_VERSION) >= 466
5353
if supports_spooling_protocol:
54-
localstack = LocalStackContainer(image="localstack/localstack:latest", region_name="us-east-1") \
54+
localstack = LocalStackContainer(image="localstack/localstack:4.14.0", region_name="us-east-1") \
5555
.with_name("localstack") \
5656
.with_network(network) \
5757
.with_bind_ports(4566, 4566) \

tests/integration/test_dbapi_integration.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1900,6 +1900,79 @@ def test_segments_cursor(trino_connection):
19001900
assert total == 300875, f"Expected total rows 300875, got {total}"
19011901

19021902

1903+
@pytest.mark.skipif(
1904+
trino_version() <= 466,
1905+
reason="spooling protocol was introduced in version 466"
1906+
)
1907+
def test_spooled_segments_lazy_fetchone(trino_connection):
1908+
"""Verify that spooled results can be consumed row-by-row via fetchone()
1909+
without materializing the entire result set in memory."""
1910+
if trino_connection._client_session.encoding is None:
1911+
pytest.skip("spooling requires an encoding")
1912+
1913+
cur = trino_connection.cursor()
1914+
cur.execute("""SELECT l.*
1915+
FROM tpch.tiny.lineitem l, TABLE(sequence(
1916+
start => 1,
1917+
stop => 5,
1918+
step => 1)) n""")
1919+
1920+
# The underlying result rows should be an iterator, not a list
1921+
result_rows = cur._query._result._rows
1922+
assert not isinstance(result_rows, list), (
1923+
f"Expected lazy iterator for spooled results, got {type(result_rows)}"
1924+
)
1925+
1926+
# Consume rows one by one and count them
1927+
count = 0
1928+
while cur.fetchone() is not None:
1929+
count += 1
1930+
assert count == 300875, f"Expected 300875 rows, got {count}"
1931+
1932+
1933+
@pytest.mark.skipif(
1934+
trino_version() <= 466,
1935+
reason="spooling protocol was introduced in version 466"
1936+
)
1937+
def test_spooled_segments_fetchmany(trino_connection):
1938+
"""Verify that fetchmany() works correctly with lazily loaded spooled segments."""
1939+
if trino_connection._client_session.encoding is None:
1940+
pytest.skip("spooling requires an encoding")
1941+
1942+
cur = trino_connection.cursor()
1943+
cur.execute("SELECT * FROM tpch.tiny.lineitem")
1944+
1945+
batch = cur.fetchmany(100)
1946+
assert len(batch) == 100
1947+
1948+
total = len(batch)
1949+
while True:
1950+
batch = cur.fetchmany(1000)
1951+
if not batch:
1952+
break
1953+
total += len(batch)
1954+
assert total == 60175, f"Expected 60175 rows, got {total}"
1955+
1956+
1957+
@pytest.mark.skipif(
1958+
trino_version() <= 466,
1959+
reason="spooling protocol was introduced in version 466"
1960+
)
1961+
def test_spooled_segments_iterator_protocol(trino_connection):
1962+
"""Verify that cursor iteration works correctly with spooled segments."""
1963+
if trino_connection._client_session.encoding is None:
1964+
pytest.skip("spooling requires an encoding")
1965+
1966+
cur = trino_connection.cursor()
1967+
cur.execute("SELECT * FROM tpch.tiny.lineitem")
1968+
1969+
count = 0
1970+
for row in cur:
1971+
count += 1
1972+
assert isinstance(row, list)
1973+
assert count == 60175, f"Expected 60175 rows, got {count}"
1974+
1975+
19031976
def get_cursor(legacy_prepared_statements, run_trino):
19041977
host, port = run_trino
19051978

trino/client.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import base64
4040
import copy
4141
import functools
42+
import itertools
4243
import os
4344
import random
4445
import re
@@ -904,9 +905,26 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
904905
rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows
905906
self._result = TrinoResult(self, rows)
906907

907-
# Execute should block until at least one row is received or query is finished or cancelled
908-
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
909-
self._result.rows += self.fetch()
908+
# Block until rows are available, the query finishes, or it is canceled.
909+
# Rows start as an empty list. Early responses often contain only stats,
910+
# so we keep fetching until actual data arrives.
911+
#
912+
# Two protocols produce rows differently:
913+
# - Direct: fetch() returns a list - accumulate into the existing list.
914+
# - Spooling: fetch() returns a lazy iterator - replace rows and stop,
915+
# because we cannot cheaply check iterator length.
916+
while not self.finished and not self.cancelled and self._result.rows == []:
917+
new_rows = self.fetch()
918+
if isinstance(new_rows, list):
919+
self._result.rows += new_rows
920+
else:
921+
try:
922+
first_row = next(new_rows)
923+
self._result.rows = itertools.chain([first_row], new_rows)
924+
break
925+
except StopIteration:
926+
self._result.rows = []
927+
910928
return self._result
911929

912930
def _update_state(self, status):
@@ -920,7 +938,7 @@ def _update_state(self, status):
920938
if status.columns:
921939
self._columns = status.columns
922940

923-
def fetch(self) -> List[Union[List[Any]], Any]:
941+
def fetch(self) -> Union[List[Union[List[Any], Any]], Iterator[List[Any]]]:
924942
"""Continue fetching data for the current query_id"""
925943
try:
926944
response = self._request.get(self._request.next_uri)
@@ -941,7 +959,8 @@ def fetch(self) -> List[Union[List[Any]], Any]:
941959
spooled = self._to_segments(rows)
942960
if self._fetch_mode == "segments":
943961
return spooled
944-
return list(SegmentIterator(spooled, self._row_mapper))
962+
# Return iterator directly, do NOT materialize with list()
963+
return SegmentIterator(spooled, self._row_mapper)
945964
elif isinstance(status.rows, list):
946965
return self._row_mapper.map(rows)
947966
else:

0 commit comments

Comments
 (0)