Skip to content

Commit a913889

Browse files
committed
Lazily load spooled result set segments
When using the spooling protocol, TrinoQuery.fetch() materializes all segments into a single list, causing out-of-memory errors on large result sets. Return SegmentIterator directly so rows are decoded on demand, keeping memory usage constant regardless of result set size.
1 parent ede0893 commit a913889

3 files changed

Lines changed: 101 additions & 6 deletions

File tree

tests/development_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def start_development_server(port=None, trino_version=TRINO_VERSION):
5151
network = Network().create()
5252
supports_spooling_protocol = TRINO_VERSION == "latest" or int(TRINO_VERSION) >= 466
5353
if supports_spooling_protocol:
54-
localstack = LocalStackContainer(image="localstack/localstack:latest", region_name="us-east-1") \
54+
localstack = LocalStackContainer(image="localstack/localstack:4.14.0", region_name="us-east-1") \
5555
.with_name("localstack") \
5656
.with_network(network) \
5757
.with_bind_ports(4566, 4566) \

tests/integration/test_dbapi_integration.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1900,6 +1900,79 @@ def test_segments_cursor(trino_connection):
19001900
assert total == 300875, f"Expected total rows 300875, got {total}"
19011901

19021902

1903+
@pytest.mark.skipif(
1904+
trino_version() <= 466,
1905+
reason="spooling protocol was introduced in version 466"
1906+
)
1907+
def test_spooled_segments_lazy_fetchone(trino_connection):
1908+
"""Verify that spooled results can be consumed row-by-row via fetchone()
1909+
without materializing the entire result set in memory."""
1910+
if trino_connection._client_session.encoding is None:
1911+
pytest.skip("spooling requires an encoding")
1912+
1913+
cur = trino_connection.cursor()
1914+
cur.execute("""SELECT l.*
1915+
FROM tpch.tiny.lineitem l, TABLE(sequence(
1916+
start => 1,
1917+
stop => 5,
1918+
step => 1)) n""")
1919+
1920+
# The underlying result rows should be an iterator, not a list
1921+
result_rows = cur._query._result._rows
1922+
assert not isinstance(result_rows, list), (
1923+
f"Expected lazy iterator for spooled results, got {type(result_rows)}"
1924+
)
1925+
1926+
# Consume rows one by one and count them
1927+
count = 0
1928+
while cur.fetchone() is not None:
1929+
count += 1
1930+
assert count == 300875, f"Expected 300875 rows, got {count}"
1931+
1932+
1933+
@pytest.mark.skipif(
1934+
trino_version() <= 466,
1935+
reason="spooling protocol was introduced in version 466"
1936+
)
1937+
def test_spooled_segments_fetchmany(trino_connection):
1938+
"""Verify that fetchmany() works correctly with lazily loaded spooled segments."""
1939+
if trino_connection._client_session.encoding is None:
1940+
pytest.skip("spooling requires an encoding")
1941+
1942+
cur = trino_connection.cursor()
1943+
cur.execute("SELECT * FROM tpch.tiny.lineitem")
1944+
1945+
batch = cur.fetchmany(100)
1946+
assert len(batch) == 100
1947+
1948+
total = len(batch)
1949+
while True:
1950+
batch = cur.fetchmany(1000)
1951+
if not batch:
1952+
break
1953+
total += len(batch)
1954+
assert total == 60175, f"Expected 60175 rows, got {total}"
1955+
1956+
1957+
@pytest.mark.skipif(
1958+
trino_version() <= 466,
1959+
reason="spooling protocol was introduced in version 466"
1960+
)
1961+
def test_spooled_segments_iterator_protocol(trino_connection):
1962+
"""Verify that cursor iteration works correctly with spooled segments."""
1963+
if trino_connection._client_session.encoding is None:
1964+
pytest.skip("spooling requires an encoding")
1965+
1966+
cur = trino_connection.cursor()
1967+
cur.execute("SELECT * FROM tpch.tiny.lineitem")
1968+
1969+
count = 0
1970+
for row in cur:
1971+
count += 1
1972+
assert isinstance(row, list)
1973+
assert count == 60175, f"Expected 60175 rows, got {count}"
1974+
1975+
19031976
def get_cursor(legacy_prepared_statements, run_trino):
19041977
host, port = run_trino
19051978

trino/client.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -904,9 +904,30 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
904904
rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows
905905
self._result = TrinoResult(self, rows)
906906

907-
# Execute should block until at least one row is received or query is finished or cancelled
908-
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
909-
self._result.rows += self.fetch()
907+
"""
908+
Execute should block until at least one row is received or query is finished or cancelled.
909+
910+
In the direct protocol, rows is a list that we can check for length. Usually the first response contains just
911+
stats but not rows so we need to continue fetching next uri until we get some rows or query is finished or cancelled.
912+
913+
In the spooling protocol, rows start as empty list and eventually fetch returns the rows as lazy iterator,
914+
we can't check length of an iterator easily without peeking.
915+
916+
So, if we get rows as non empty list or iterator, we stop blocking and return it to the caller to consume it.
917+
"""
918+
while not self.finished and not self.cancelled:
919+
if isinstance(self._result.rows, list) and len(self._result.rows) == 0:
920+
new_rows = self.fetch()
921+
if isinstance(new_rows, list):
922+
# Direct protocol - append rows to a list
923+
self._result.rows += new_rows
924+
else:
925+
# Spooling protocol - replace rows with a lazy iterator
926+
self._result.rows = new_rows
927+
break
928+
else:
929+
break
930+
910931
return self._result
911932

912933
def _update_state(self, status):
@@ -920,7 +941,7 @@ def _update_state(self, status):
920941
if status.columns:
921942
self._columns = status.columns
922943

923-
def fetch(self) -> List[Union[List[Any]], Any]:
944+
def fetch(self) -> Union[List[Union[List[Any], Any]], Iterator[List[Any]]]:
924945
"""Continue fetching data for the current query_id"""
925946
try:
926947
response = self._request.get(self._request.next_uri)
@@ -941,7 +962,8 @@ def fetch(self) -> List[Union[List[Any]], Any]:
941962
spooled = self._to_segments(rows)
942963
if self._fetch_mode == "segments":
943964
return spooled
944-
return list(SegmentIterator(spooled, self._row_mapper))
965+
# Return iterator directly, do NOT materialize with list()
966+
return SegmentIterator(spooled, self._row_mapper)
945967
elif isinstance(status.rows, list):
946968
return self._row_mapper.map(rows)
947969
else:

0 commit comments

Comments
 (0)