Skip to content

Commit 175259b

Browse files
committed
[Python] Bound UnboundedSource read per bundle by records and time
A continuously busy reader stayed in process() forever: the loop only broke at EOF or when data ran out, so a hot source never checkpointed, finalized, or committed its watermark on the Fn API path. Cap the has-data path at max_records_per_bundle (10000) or max_read_time_seconds (10), then self-checkpoint via defer_remainder so the runner commits progress and resumes immediately. Both caps are configurable on ReadFromUnboundedSource.
1 parent b61da24 commit 175259b

2 files changed

Lines changed: 96 additions & 8 deletions

File tree

sdks/python/apache_beam/io/unbounded_source.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def get_checkpoint_mark_coder(self):
8686
import dataclasses
8787
import logging
8888
import threading
89+
import time
8990
from typing import Any
9091
from typing import Iterable
9192
from typing import Optional
@@ -123,6 +124,8 @@ def get_checkpoint_mark_coder(self):
123124

124125
_DEFAULT_POLL_INTERVAL_SECONDS = 1.0
125126
_DEFAULT_DESIRED_NUM_SPLITS = 20
127+
_DEFAULT_MAX_RECORDS_PER_BUNDLE = 10000
128+
_DEFAULT_MAX_READ_TIME_SECONDS = 10.0
126129

127130
# ------------------------------------------------------------------------------
128131
# Public abstract base classes.
@@ -613,8 +616,14 @@ class _ReadFromUnboundedSourceDoFn(core.DoFn):
613616
Module-level so stdlib pickle and cloudpickle can serialise the DoFn. The
614617
restriction provider is the module-level :data:`_PROVIDER` singleton.
615618
"""
616-
def __init__(self, poll_interval: float = _DEFAULT_POLL_INTERVAL_SECONDS):
619+
def __init__(
620+
self,
621+
poll_interval: float = _DEFAULT_POLL_INTERVAL_SECONDS,
622+
max_records_per_bundle: int = _DEFAULT_MAX_RECORDS_PER_BUNDLE,
623+
max_read_time_seconds: float = _DEFAULT_MAX_READ_TIME_SECONDS):
617624
self._poll_interval = poll_interval
625+
self._max_records_per_bundle = max_records_per_bundle
626+
self._max_read_time_seconds = max_read_time_seconds
618627

619628
@core.DoFn.unbounded_per_element()
620629
def process(
@@ -628,6 +637,8 @@ def process(
628637
# kwarg-injected ones (tracker, watermark estimator).
629638
assert isinstance(tracker, sdf_utils.RestrictionTrackerView)
630639
initial = tracker.current_restriction()
640+
records_emitted = 0
641+
read_deadline = time.monotonic() + self._max_read_time_seconds
631642
try:
632643
while True:
633644
holder = [None]
@@ -639,10 +650,8 @@ def process(
639650
break
640651
record = holder[0]
641652
if record is _NO_DATA:
642-
# No data now: advance the watermark and self-checkpoint with an
643-
# explicit delay. defer_remainder(None) resumes ASAP and would
644-
# busy-loop an idle source, so (unlike Java's runner-applied backoff)
645-
# we set the resume delay here.
653+
# No data now: advance the watermark and self-checkpoint with the
654+
# poll delay so an idle source backs off before resuming.
646655
_set_watermark_if_greater(
647656
watermark_estimator, tracker.current_restriction().watermark)
648657
tracker.defer_remainder(Duration(seconds=self._poll_interval))
@@ -652,6 +661,15 @@ def process(
652661
value, record_timestamp, source_watermark = record
653662
_set_watermark_if_greater(watermark_estimator, source_watermark)
654663
yield TimestampedValue(value, record_timestamp)
664+
records_emitted += 1
665+
# A busy reader never reaches the EOF or no-data branch, so bound the
666+
# bundle by record count and elapsed time. The checkpoint lets the
667+
# runner commit progress and run finalization. Resume with no delay
668+
# because data is still available.
669+
if (records_emitted >= self._max_records_per_bundle or
670+
time.monotonic() >= read_deadline):
671+
tracker.defer_remainder()
672+
break
655673
finally:
656674
current = tracker.current_restriction()
657675
try:
@@ -695,17 +713,23 @@ class ReadFromUnboundedSource(PTransform):
695713
p | beam.io.Read(MyUnboundedSource())
696714
697715
``poll_interval`` is the resume delay used when the reader has no data, which
698-
bounds how often an idle source is polled.
716+
bounds how often an idle source is polled. ``max_records_per_bundle`` and
717+
``max_read_time_seconds`` bound how much a busy reader produces before the
718+
bundle self-checkpoints.
699719
"""
700720
def __init__(
701721
self,
702722
source: UnboundedSource,
703-
poll_interval: float = _DEFAULT_POLL_INTERVAL_SECONDS):
723+
poll_interval: float = _DEFAULT_POLL_INTERVAL_SECONDS,
724+
max_records_per_bundle: int = _DEFAULT_MAX_RECORDS_PER_BUNDLE,
725+
max_read_time_seconds: float = _DEFAULT_MAX_READ_TIME_SECONDS):
704726
if not isinstance(source, UnboundedSource):
705727
raise TypeError('source must be an UnboundedSource, got %r' % (source, ))
706728
super().__init__()
707729
self._source = source
708730
self._poll_interval = poll_interval
731+
self._max_records_per_bundle = max_records_per_bundle
732+
self._max_read_time_seconds = max_read_time_seconds
709733

710734
def expand(self, pbegin):
711735
source = self._source
@@ -717,7 +741,10 @@ def expand(self, pbegin):
717741
pbegin
718742
| 'Create' >> core.Create([source])
719743
| 'ReadUnbounded' >> core.ParDo(
720-
_ReadFromUnboundedSourceDoFn(self._poll_interval)))
744+
_ReadFromUnboundedSourceDoFn(
745+
self._poll_interval,
746+
self._max_records_per_bundle,
747+
self._max_read_time_seconds)))
721748
# Surface an element type only when the global registry already maps it to
722749
# an equivalent coder. Avoid mutating ``coders.registry`` for a
723750
# parameterized coder whose instance state would be lost.

sdks/python/apache_beam/io/unbounded_source_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from apache_beam.io.unbounded_source import UnboundedReader
3636
from apache_beam.io.unbounded_source import UnboundedSource
3737
from apache_beam.io.unbounded_source import _FinalizeCheckpointOnce
38+
from apache_beam.io.unbounded_source import _ReadFromUnboundedSourceDoFn
3839
from apache_beam.io.unbounded_source import _set_watermark_if_greater
3940
from apache_beam.io.unbounded_source import _UnboundedSourceRestriction
4041
from apache_beam.io.unbounded_source import _UnboundedSourceRestrictionCoder
@@ -560,6 +561,66 @@ def test_is_bounded_false(self):
560561
self.assertFalse(_new_tracker(UnboundedCountingSource(3)).is_bounded())
561562

562563

564+
class _RecordingBundleFinalizer:
565+
def __init__(self):
566+
self.registered = []
567+
568+
def register(self, callback):
569+
self.registered.append(callback)
570+
571+
572+
class BundleCapTest(unittest.TestCase):
573+
"""A busy reader self-checkpoints once the per-bundle record or time cap is
574+
reached, so the runner can commit progress and run finalization."""
575+
def _run_process(self, dofn, source):
576+
tracker = _UnboundedSourceRestrictionTracker(
577+
_UnboundedSourceRestriction(source=source))
578+
threadsafe = sdf_utils.ThreadsafeRestrictionTracker(tracker)
579+
view = sdf_utils.RestrictionTrackerView(threadsafe)
580+
outputs = list(
581+
dofn.process(
582+
None,
583+
bundle_finalizer=_RecordingBundleFinalizer(),
584+
tracker=view,
585+
watermark_estimator=ManualWatermarkEstimator(None)))
586+
return outputs, tracker, threadsafe
587+
588+
def test_record_cap_checkpoints_busy_source(self):
589+
dofn = _ReadFromUnboundedSourceDoFn(
590+
poll_interval=0, max_records_per_bundle=5, max_read_time_seconds=1e9)
591+
# 1000 records is effectively unbounded against a cap of 5.
592+
outputs, tracker, threadsafe = self._run_process(
593+
dofn, UnboundedCountingSource(1000))
594+
595+
self.assertEqual([tv.value for tv in outputs], [0, 1, 2, 3, 4])
596+
self.assertTrue(tracker.current_restriction().is_done)
597+
self.assertTrue(tracker.check_done())
598+
residual, _delay = threadsafe.deferred_status()
599+
self.assertEqual(residual.checkpoint_mark.last_index, 4)
600+
601+
def test_time_cap_checkpoints_busy_source(self):
602+
dofn = _ReadFromUnboundedSourceDoFn(
603+
poll_interval=0, max_records_per_bundle=10**9, max_read_time_seconds=0)
604+
outputs, tracker, threadsafe = self._run_process(
605+
dofn, UnboundedCountingSource(1000))
606+
607+
# A zero time budget trips the deadline right after the first record.
608+
self.assertEqual([tv.value for tv in outputs], [0])
609+
self.assertTrue(tracker.current_restriction().is_done)
610+
residual, _delay = threadsafe.deferred_status()
611+
self.assertEqual(residual.checkpoint_mark.last_index, 0)
612+
613+
def test_eof_before_cap_finishes_without_residual(self):
614+
dofn = _ReadFromUnboundedSourceDoFn(
615+
poll_interval=0, max_records_per_bundle=100, max_read_time_seconds=1e9)
616+
outputs, tracker, threadsafe = self._run_process(
617+
dofn, UnboundedCountingSource(3))
618+
619+
self.assertEqual([tv.value for tv in outputs], [0, 1, 2])
620+
self.assertTrue(tracker.current_restriction().is_done)
621+
self.assertIsNone(threadsafe.deferred_status())
622+
623+
563624
class WatermarkTest(unittest.TestCase):
564625
def test_set_watermark_is_monotonic(self):
565626
estimator = ManualWatermarkEstimator(None)

0 commit comments

Comments
 (0)