Skip to content

Commit fbe61ea

Browse files
authored
[Python] Python] Bound the memory used for fnapi outbound data messages and receiving messages. (#38407)
* [Python] Bound the memory used for fnapi outbound data messages. Previously an unbounded queue was used for pending data outputs to be sent over the fnapi to the runner. If outputs were being generated faster than the runner was consuming them, this would lead to memory growth and possible OOMs. This PR introduces a byte-limited queue data structure that is used instead to limit the # of bytes in the queue. This was preferred to just using a queue with max number of elements because the size of elements can vary greatly. For batch pipelines they are likely large while for stremaing pipelines there may be more small outputs. * monotonic and not shutdown restriction * change to not subclass queue.Queue and to be fair * fixups * add missing pxd file, fixup test * use 64-bit for size in pxd * address comments * add condition caching
1 parent 4d683c0 commit fbe61ea

5 files changed

Lines changed: 547 additions & 29 deletions

File tree

sdks/python/apache_beam/runners/worker/data_plane.py

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from apache_beam.portability.api import beam_fn_api_pb2_grpc
5050
from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
5151
from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
52+
from apache_beam.utils.byte_limited_queue import ByteLimitedQueue
5253

5354
if TYPE_CHECKING:
5455
import apache_beam.coders.slow_stream
@@ -455,11 +456,14 @@ class _GrpcDataChannel(DataChannel):
455456

456457
def __init__(self, data_buffer_time_limit_ms=0):
457458
# type: (int) -> None
459+
458460
self._data_buffer_time_limit_ms = data_buffer_time_limit_ms
459-
self._to_send = queue.Queue() # type: queue.Queue[DataOrTimers]
461+
self._to_send = ByteLimitedQueue(
462+
maxsize=10000,
463+
maxbytes=100 << 20) # type: ByteLimitedQueue[DataOrTimers]
460464
self._received = collections.defaultdict(
461-
lambda: queue.Queue(maxsize=5)
462-
) # type: DefaultDict[str, queue.Queue[DataOrTimers]]
465+
lambda: ByteLimitedQueue(maxsize=5, maxbytes=100 << 20)
466+
) # type: DefaultDict[str, ByteLimitedQueue[DataOrTimers]]
463467

464468
# Keep a cache of completed instructions. Data for completed instructions
465469
# must be discarded. See input_elements() and _clean_receiving_queue().
@@ -474,15 +478,15 @@ def __init__(self, data_buffer_time_limit_ms=0):
474478

475479
def close(self):
476480
# type: () -> None
477-
self._to_send.put(self._WRITES_FINISHED)
481+
self._to_send.put(self._WRITES_FINISHED, 0)
478482
self._closed = True
479483

480484
def wait(self, timeout=None):
481485
# type: (Optional[int]) -> None
482486
self._reads_finished.wait(timeout)
483487

484488
def _receiving_queue(self, instruction_id):
485-
# type: (str) -> Optional[queue.Queue[DataOrTimers]]
489+
# type: (str) -> Optional[ByteLimitedQueue[DataOrTimers]]
486490

487491
"""
488492
Gets or creates queue for a instruction_id. Or, returns None if the
@@ -585,21 +589,19 @@ def output_stream(self, instruction_id, transform_id):
585589
def add_to_send_queue(data):
586590
# type: (bytes) -> None
587591
if data:
588-
self._to_send.put(
589-
beam_fn_api_pb2.Elements.Data(
590-
instruction_id=instruction_id,
591-
transform_id=transform_id,
592-
data=data))
592+
elem = beam_fn_api_pb2.Elements.Data(
593+
instruction_id=instruction_id, transform_id=transform_id, data=data)
594+
self._to_send.put(elem, self._get_element_size_bytes(elem))
593595

594596
def close_callback(data):
595597
# type: (bytes) -> None
596598
add_to_send_queue(data)
597599
# End of stream marker.
598-
self._to_send.put(
599-
beam_fn_api_pb2.Elements.Data(
600-
instruction_id=instruction_id,
601-
transform_id=transform_id,
602-
is_last=True))
600+
elem = beam_fn_api_pb2.Elements.Data(
601+
instruction_id=instruction_id,
602+
transform_id=transform_id,
603+
is_last=True)
604+
self._to_send.put(elem, self._get_element_size_bytes(elem))
603605

604606
return ClosableOutputStream.create(
605607
close_callback, add_to_send_queue, self._data_buffer_time_limit_ms)
@@ -614,23 +616,23 @@ def output_timer_stream(
614616
def add_to_send_queue(timer):
615617
# type: (bytes) -> None
616618
if timer:
617-
self._to_send.put(
618-
beam_fn_api_pb2.Elements.Timers(
619-
instruction_id=instruction_id,
620-
transform_id=transform_id,
621-
timer_family_id=timer_family_id,
622-
timers=timer,
623-
is_last=False))
619+
elem = beam_fn_api_pb2.Elements.Timers(
620+
instruction_id=instruction_id,
621+
transform_id=transform_id,
622+
timer_family_id=timer_family_id,
623+
timers=timer,
624+
is_last=False)
625+
self._to_send.put(elem, self._get_element_size_bytes(elem))
624626

625627
def close_callback(timer):
626628
# type: (bytes) -> None
627629
add_to_send_queue(timer)
628-
self._to_send.put(
629-
beam_fn_api_pb2.Elements.Timers(
630-
instruction_id=instruction_id,
631-
transform_id=transform_id,
632-
timer_family_id=timer_family_id,
633-
is_last=True))
630+
elem = beam_fn_api_pb2.Elements.Timers(
631+
instruction_id=instruction_id,
632+
transform_id=transform_id,
633+
timer_family_id=timer_family_id,
634+
is_last=True)
635+
self._to_send.put(elem, self._get_element_size_bytes(elem))
634636

635637
return ClosableOutputStream.create(
636638
close_callback, add_to_send_queue, self._data_buffer_time_limit_ms)
@@ -665,6 +667,15 @@ def _write_outputs(self):
665667
raise ValueError('Unexpected output element type %s' % type(stream))
666668
yield beam_fn_api_pb2.Elements(data=data_stream, timers=timer_stream)
667669

670+
def _get_element_size_bytes(self, element):
671+
# type: (Union[beam_fn_api_pb2.Elements.Data, beam_fn_api_pb2.Elements.Timers]) -> int
672+
if isinstance(element, beam_fn_api_pb2.Elements.Data):
673+
return len(element.data)
674+
elif isinstance(element, beam_fn_api_pb2.Elements.Timers):
675+
return len(element.timers)
676+
else:
677+
return 0
678+
668679
def _read_inputs(self, elements_iterator):
669680
# type: (Iterable[beam_fn_api_pb2.Elements]) -> None
670681

@@ -691,7 +702,8 @@ def _put_queue(instruction_id, element):
691702
next_discard_log_time = current_time + 10
692703
return
693704
try:
694-
input_queue.put(element, timeout=1)
705+
input_queue.put(
706+
element, self._get_element_size_bytes(element), timeout=1)
695707
return
696708
except queue.Full:
697709
current_time = time.time()
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
# cython: overflowcheck=True
19+
20+
cdef class ByteLimitedQueue(object):
21+
cdef readonly Py_ssize_t max_elements
22+
cdef readonly Py_ssize_t max_bytes
23+
cdef readonly Py_ssize_t _byte_size
24+
cdef readonly object _mutex
25+
cdef readonly object _not_empty
26+
cdef readonly object _waiting_writers
27+
cdef readonly list _condition_pool
28+
cdef readonly object _queue
29+
cdef readonly Py_ssize_t _blocked_bytes
30+
31+
cpdef bint _can_fit(self, Py_ssize_t item_bytes) except -1
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""A thread-safe queue that limits capacity by total byte size."""
19+
20+
import collections
21+
import queue
22+
import threading
23+
import time
24+
import types
25+
26+
27+
class ByteLimitedQueue(object):
28+
"""A fair queue that limits by both element count and total byte size.
29+
30+
A single element is allowed to exceed the maxbytes to avoid deadlock.
31+
"""
32+
__class_getitem__ = classmethod(types.GenericAlias)
33+
34+
def __init__(
35+
self,
36+
maxsize=0, # type: int
37+
maxbytes=0, # type: int
38+
):
39+
# type: (...) -> None
40+
41+
"""Initializes a ByteLimitedQueue.
42+
43+
Args:
44+
maxsize: The maximum number of items allowed in the queue. If 0 or
45+
negative, there is no limit on the number of elements.
46+
maxbytes: The maximum accumulated bytes allowed in the queue. If 0 or
47+
negative, there is no limit on the total bytes of the elements.
48+
"""
49+
self.max_elements = maxsize
50+
self.max_bytes = maxbytes
51+
52+
self._byte_size = 0
53+
self._blocked_bytes = 0
54+
self._mutex = threading.Lock()
55+
self._not_empty = threading.Condition(self._mutex)
56+
57+
self._waiting_writers = collections.deque()
58+
self._condition_pool = []
59+
self._queue = collections.deque()
60+
61+
def put(self, item, item_bytes, *, block=True, timeout=None):
62+
"""Put an item into the queue.
63+
64+
If the queue is full, block until a free slot is available, unless `block`
65+
is false or a timeout occurs.
66+
67+
Args:
68+
item: The item to put into the queue.
69+
item_bytes: The size of the item.
70+
block: If True, block until space is available. If False, raise queue.Full
71+
immediately if the queue is full.
72+
timeout: If block is True, wait for at most `timeout` seconds. If None,
73+
block indefinitely.
74+
75+
Raises:
76+
ValueError: If timeout or item_bytes is negative.
77+
queue.Full: If the queue is full and block is False or the timeout occurs.
78+
"""
79+
if timeout is not None and timeout < 0:
80+
raise ValueError("'timeout' must be a non-negative number")
81+
if item_bytes < 0:
82+
raise ValueError("'item_bytes' must be a non-negative number")
83+
84+
with self._mutex:
85+
if not self._waiting_writers and self._can_fit(item_bytes):
86+
self._queue.append((item, item_bytes))
87+
self._byte_size += item_bytes
88+
self._not_empty.notify()
89+
return
90+
91+
if not block:
92+
raise queue.Full
93+
94+
# Reuse or create a condition
95+
my_cond = (
96+
self._condition_pool.pop()
97+
if self._condition_pool else threading.Condition(self._mutex))
98+
99+
endtime = time.monotonic() + timeout if timeout is not None else None
100+
101+
try:
102+
self._blocked_bytes += item_bytes
103+
self._waiting_writers.append(my_cond)
104+
while True:
105+
if timeout is None:
106+
my_cond.wait()
107+
else:
108+
remaining = endtime - time.monotonic()
109+
if remaining <= 0.0:
110+
raise queue.Full
111+
my_cond.wait(remaining)
112+
113+
if self._waiting_writers[0] is my_cond and self._can_fit(item_bytes):
114+
break
115+
116+
self._queue.append((item, item_bytes))
117+
self._byte_size += item_bytes
118+
self._not_empty.notify()
119+
finally:
120+
self._blocked_bytes -= item_bytes
121+
if self._waiting_writers:
122+
was_first = (self._waiting_writers[0] is my_cond)
123+
if was_first:
124+
self._waiting_writers.popleft()
125+
else:
126+
self._waiting_writers.remove(my_cond)
127+
self._condition_pool.append(my_cond)
128+
if was_first and self._waiting_writers:
129+
self._waiting_writers[0].notify()
130+
131+
def get(self, *, block=True, timeout=None):
132+
"""Remove and return an item from the queue.
133+
134+
If the queue is empty, block until an item is available, unless `block`
135+
is false or a timeout occurs.
136+
137+
Args:
138+
block: If True, block until an item is available. If False, raise
139+
queue.Empty immediately if the queue is empty.
140+
timeout: If block is True, wait for at most `timeout` seconds. If None,
141+
block indefinitely.
142+
143+
Returns:
144+
The item removed from the queue.
145+
146+
Raises:
147+
ValueError: If timeout is negative.
148+
queue.Empty: If the queue is empty and block is False or the timeout
149+
occurs.
150+
"""
151+
if timeout is not None and timeout < 0:
152+
raise ValueError("'timeout' must be a non-negative number")
153+
154+
with self._mutex:
155+
if not block:
156+
if not self._queue:
157+
raise queue.Empty
158+
elif timeout is None:
159+
while not self._queue:
160+
self._not_empty.wait()
161+
else:
162+
endtime = time.monotonic() + timeout
163+
while not self._queue:
164+
remaining = endtime - time.monotonic()
165+
if remaining <= 0.0:
166+
raise queue.Empty
167+
self._not_empty.wait(remaining)
168+
169+
item, item_bytes = self._queue.popleft()
170+
self._byte_size -= item_bytes
171+
172+
if self._waiting_writers:
173+
self._waiting_writers[0].notify()
174+
175+
return item
176+
177+
def get_nowait(self):
178+
"""Remove and return an item from the queue without blocking."""
179+
return self.get(block=False)
180+
181+
def byte_size(self):
182+
"""Return the total byte size of elements in the queue."""
183+
with self._mutex:
184+
return self._byte_size
185+
186+
def blocked_byte_size(self):
187+
"""Return the total byte size of elements in the queue that are blocked."""
188+
with self._mutex:
189+
return self._blocked_bytes
190+
191+
def qsize(self):
192+
"""Return the total number of elements in the queue."""
193+
with self._mutex:
194+
return len(self._queue)
195+
196+
def _can_fit(self, item_bytes):
197+
# Always let in a single element, regardless of size.
198+
if not self._queue:
199+
return True
200+
if self.max_elements > 0 and len(self._queue) >= self.max_elements:
201+
return False
202+
if self.max_bytes > 0 and self._byte_size + item_bytes > self.max_bytes:
203+
return False
204+
return True

0 commit comments

Comments
 (0)