Skip to content

Commit 79add18

Browse files
committed
[Python] Refine UnboundedSource read-bound caps after review
Validate the caps at the transform boundary (max_records_per_bundle >= 1, max_read_time_seconds > 0, poll_interval >= 0). Arm the read deadline on the first emitted record so reader startup is excluded, and add an injectable clock seam for deterministic time-cap tests. Document the between-records time-cap limitation, the finalize idempotency-on-retry contract, and the advisory residual watermark. Expand BundleCapTest with cross-bundle resume, EOF-exactly-at-cap, finalize registration, watermark, drain-truncate, and cap-validation coverage.
1 parent 175259b commit 79add18

2 files changed

Lines changed: 153 additions & 33 deletions

File tree

sdks/python/apache_beam/io/unbounded_source.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def get_checkpoint_mark_coder(self):
8888
import threading
8989
import time
9090
from typing import Any
91+
from typing import Callable
9192
from typing import Iterable
9293
from typing import Optional
9394

@@ -147,7 +148,9 @@ def finalize_checkpoint(self) -> None:
147148
148149
The runner calls this at most once for a committed checkpoint mark.
149150
Finalization is best effort; a mark may never be finalized. An exception
150-
raised here is logged.
151+
raised here is logged. On bundle retry an uncommitted mark may be re-cut
152+
over an overlapping span, so this method must be idempotent (acknowledge by
153+
absolute position).
151154
"""
152155
pass
153156

@@ -468,6 +471,8 @@ def _try_split_inner(self, fraction_of_remainder):
468471
if self._reader is None or not self._started or self._restriction.is_done:
469472
return None
470473
checkpoint = self._reader.get_checkpoint_mark()
474+
# The residual watermark is advisory; the SDF watermark estimator state is
475+
# the authoritative cross-bundle watermark.
471476
watermark = self._reader.get_watermark()
472477
# Keep the two channels independent: the primary carries only the finalize
473478
# hook, the residual only the resume state.
@@ -593,6 +598,8 @@ def truncate(self, element, restriction):
593598
class _FinalizeCheckpointOnce(object):
594599
def __init__(self, checkpoint_mark: CheckpointMark):
595600
self._checkpoint_mark = checkpoint_mark
601+
# The lock keeps finalization idempotent if a runner ever invokes the
602+
# callback more than once.
596603
self._lock = threading.Lock()
597604
self._finalized = False
598605

@@ -620,10 +627,13 @@ def __init__(
620627
self,
621628
poll_interval: float = _DEFAULT_POLL_INTERVAL_SECONDS,
622629
max_records_per_bundle: int = _DEFAULT_MAX_RECORDS_PER_BUNDLE,
623-
max_read_time_seconds: float = _DEFAULT_MAX_READ_TIME_SECONDS):
630+
max_read_time_seconds: float = _DEFAULT_MAX_READ_TIME_SECONDS,
631+
_now: Optional[Callable[[], float]] = None):
624632
self._poll_interval = poll_interval
625633
self._max_records_per_bundle = max_records_per_bundle
626634
self._max_read_time_seconds = max_read_time_seconds
635+
# Monotonic clock seam; tests inject a deterministic clock.
636+
self._now = _now
627637

628638
@core.DoFn.unbounded_per_element()
629639
def process(
@@ -637,8 +647,10 @@ def process(
637647
# kwarg-injected ones (tracker, watermark estimator).
638648
assert isinstance(tracker, sdf_utils.RestrictionTrackerView)
639649
initial = tracker.current_restriction()
650+
now = self._now or time.monotonic
640651
records_emitted = 0
641-
read_deadline = time.monotonic() + self._max_read_time_seconds
652+
# Armed on the first emitted record so reader startup is excluded.
653+
read_deadline = None # type: Optional[float]
642654
try:
643655
while True:
644656
holder = [None]
@@ -662,18 +674,23 @@ def process(
662674
_set_watermark_if_greater(watermark_estimator, source_watermark)
663675
yield TimestampedValue(value, record_timestamp)
664676
records_emitted += 1
665-
# A busy reader never reaches the EOF or no-data branch, so bound the
666-
# bundle by record count and elapsed time. The checkpoint lets the
667-
# runner commit progress and run finalization. Resume with no delay
668-
# because data is still available.
669-
if (records_emitted >= self._max_records_per_bundle or
670-
time.monotonic() >= read_deadline):
677+
if read_deadline is None:
678+
read_deadline = now() + self._max_read_time_seconds
679+
# A busy reader never hits the EOF or no-data branch. Bound the bundle
680+
# by record count and elapsed time so the runner commits the checkpoint
681+
# and runs finalization, then resume with no delay. The deadline is
682+
# checked between records; a reader that blocks inside advance() can
683+
# overrun it, so the record cap is the hard backstop.
684+
reached_record_cap = records_emitted >= self._max_records_per_bundle
685+
if reached_record_cap or now() >= read_deadline:
671686
tracker.defer_remainder()
672687
break
673688
finally:
674689
current = tracker.current_restriction()
675690
try:
676691
# Register finalization only when a checkpoint was cut this bundle.
692+
# The SDK bundle finalizer applies no deadline, so finalization is
693+
# unbounded best effort.
677694
finalize_mark = current.finalization_checkpoint_mark
678695
if current is not initial and finalize_mark is not None:
679696
bundle_finalizer.register(_FinalizeCheckpointOnce(finalize_mark))
@@ -712,10 +729,18 @@ class ReadFromUnboundedSource(PTransform):
712729
713730
p | beam.io.Read(MyUnboundedSource())
714731
715-
``poll_interval`` is the resume delay used when the reader has no data, which
716-
bounds how often an idle source is polled. ``max_records_per_bundle`` and
717-
``max_read_time_seconds`` bound how much a busy reader produces before the
718-
bundle self-checkpoints.
732+
Args:
733+
source: the :class:`UnboundedSource` to read.
734+
poll_interval: resume delay in seconds applied when the reader has no data,
735+
which bounds how often an idle source is polled. Must be >= 0.
736+
max_records_per_bundle: a busy reader self-checkpoints after emitting this
737+
many records in one bundle. Must be >= 1. Defaults to 10000.
738+
max_read_time_seconds: a busy reader self-checkpoints after this many
739+
seconds in one bundle. Must be > 0. Defaults to 10.0. The deadline is
740+
checked between records, so a reader that blocks inside ``advance()`` may
741+
overrun it; ``max_records_per_bundle`` is the hard backstop.
742+
743+
The bundle self-checkpoints as soon as either cap is reached.
719744
"""
720745
def __init__(
721746
self,
@@ -725,6 +750,17 @@ def __init__(
725750
max_read_time_seconds: float = _DEFAULT_MAX_READ_TIME_SECONDS):
726751
if not isinstance(source, UnboundedSource):
727752
raise TypeError('source must be an UnboundedSource, got %r' % (source, ))
753+
if max_records_per_bundle < 1:
754+
raise ValueError(
755+
'max_records_per_bundle must be >= 1, got %r' %
756+
(max_records_per_bundle, ))
757+
if max_read_time_seconds <= 0:
758+
raise ValueError(
759+
'max_read_time_seconds must be > 0, got %r' %
760+
(max_read_time_seconds, ))
761+
if poll_interval < 0:
762+
raise ValueError(
763+
'poll_interval must be >= 0, got %r' % (poll_interval, ))
728764
super().__init__()
729765
self._source = source
730766
self._poll_interval = poll_interval

sdks/python/apache_beam/io/unbounded_source_test.py

Lines changed: 104 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,13 @@ def split(self, desired_num_splits, options=None):
419419

420420
self.assertEqual(list(provider.split(source, restriction)), [restriction])
421421

422+
def test_truncate_returns_none_for_drain(self):
423+
# On drain the SDF stops emitting; truncate yields no residual.
424+
provider = _UnboundedSourceRestrictionProvider()
425+
source = UnboundedCountingSource(5)
426+
restriction = _UnboundedSourceRestriction(source=source)
427+
self.assertIsNone(provider.truncate(source, restriction))
428+
422429
def test_splittable_source_partitions_into_independent_subsources(self):
423430
# A splittable source fans out into two sub-sources; reading each in
424431
# isolation yields the even and the odd integers, and their union is the
@@ -569,54 +576,122 @@ def register(self, callback):
569576
self.registered.append(callback)
570577

571578

579+
class _ManualClock:
580+
"""A deterministic monotonic clock for the time-cap tests."""
581+
def __init__(self, now=0.0):
582+
self.now = now
583+
584+
def __call__(self):
585+
return self.now
586+
587+
572588
class BundleCapTest(unittest.TestCase):
573589
"""A busy reader self-checkpoints once the per-bundle record or time cap is
574590
reached, so the runner can commit progress and run finalization."""
575-
def _run_process(self, dofn, source):
591+
def _bundle(self, dofn, source, checkpoint=None, estimator=None):
592+
"""Builds the SDF tracker chain and returns the process() generator plus the
593+
tracker, threadsafe tracker, finalizer, and watermark estimator."""
576594
tracker = _UnboundedSourceRestrictionTracker(
577-
_UnboundedSourceRestriction(source=source))
595+
_UnboundedSourceRestriction(source=source, checkpoint_mark=checkpoint))
578596
threadsafe = sdf_utils.ThreadsafeRestrictionTracker(tracker)
579597
view = sdf_utils.RestrictionTrackerView(threadsafe)
580-
outputs = list(
581-
dofn.process(
582-
None,
583-
bundle_finalizer=_RecordingBundleFinalizer(),
584-
tracker=view,
585-
watermark_estimator=ManualWatermarkEstimator(None)))
586-
return outputs, tracker, threadsafe
598+
finalizer = _RecordingBundleFinalizer()
599+
estimator = estimator or ManualWatermarkEstimator(None)
600+
gen = dofn.process(
601+
None,
602+
bundle_finalizer=finalizer,
603+
tracker=view,
604+
watermark_estimator=estimator)
605+
return gen, tracker, threadsafe, finalizer, estimator
587606

588607
def test_record_cap_checkpoints_busy_source(self):
608+
finalize_log = []
589609
dofn = _ReadFromUnboundedSourceDoFn(
590610
poll_interval=0, max_records_per_bundle=5, max_read_time_seconds=1e9)
591611
# 1000 records is effectively unbounded against a cap of 5.
592-
outputs, tracker, threadsafe = self._run_process(
593-
dofn, UnboundedCountingSource(1000))
612+
gen, tracker, threadsafe, finalizer, estimator = self._bundle(
613+
dofn, UnboundedCountingSource(1000, finalize_log=finalize_log))
614+
outputs = list(gen)
594615

595616
self.assertEqual([tv.value for tv in outputs], [0, 1, 2, 3, 4])
596617
self.assertTrue(tracker.current_restriction().is_done)
597618
self.assertTrue(tracker.check_done())
598-
residual, _delay = threadsafe.deferred_status()
619+
# The estimator holds the last emitted record's source watermark.
620+
self.assertEqual(estimator.current_watermark(), _EVENT_TIME_BASE + 4)
621+
# Residual resumes after the cut and carries no finalize hook.
622+
residual, _ = threadsafe.deferred_status()
599623
self.assertEqual(residual.checkpoint_mark.last_index, 4)
624+
self.assertIsNone(residual.finalization_checkpoint_mark)
625+
# Exactly one finalizer is registered; firing it commits the cut index once.
626+
self.assertEqual(len(finalizer.registered), 1)
627+
finalizer.registered[0]()
628+
finalizer.registered[0]()
629+
self.assertEqual(finalize_log, [4])
600630

601631
def test_time_cap_checkpoints_busy_source(self):
632+
clock = _ManualClock(1000.0)
602633
dofn = _ReadFromUnboundedSourceDoFn(
603-
poll_interval=0, max_records_per_bundle=10**9, max_read_time_seconds=0)
604-
outputs, tracker, threadsafe = self._run_process(
634+
poll_interval=0,
635+
max_records_per_bundle=10**9,
636+
max_read_time_seconds=5.0,
637+
_now=clock)
638+
gen, tracker, threadsafe, _, _ = self._bundle(
605639
dofn, UnboundedCountingSource(1000))
606640

607-
# A zero time budget trips the deadline right after the first record.
608-
self.assertEqual([tv.value for tv in outputs], [0])
641+
# The deadline arms at 1000 + 5 after the first record and is checked
642+
# between records, so records keep flowing until the clock passes it.
643+
self.assertEqual(next(gen).value, 0)
644+
self.assertEqual(next(gen).value, 1)
645+
self.assertEqual(next(gen).value, 2)
646+
clock.now = 1006.0
647+
with self.assertRaises(StopIteration):
648+
next(gen)
649+
609650
self.assertTrue(tracker.current_restriction().is_done)
610-
residual, _delay = threadsafe.deferred_status()
611-
self.assertEqual(residual.checkpoint_mark.last_index, 0)
651+
residual, _ = threadsafe.deferred_status()
652+
self.assertEqual(residual.checkpoint_mark.last_index, 2)
653+
654+
def test_cap_residual_resumes_in_next_bundle(self):
655+
dofn = _ReadFromUnboundedSourceDoFn(
656+
poll_interval=0, max_records_per_bundle=5, max_read_time_seconds=1e9)
657+
source = UnboundedCountingSource(1000)
658+
# Bundle 1 emits 0-4 and cuts a residual at index 4.
659+
gen1, _, threadsafe1, _, _ = self._bundle(dofn, source)
660+
self.assertEqual([tv.value for tv in gen1], [0, 1, 2, 3, 4])
661+
residual1, _ = threadsafe1.deferred_status()
662+
663+
# Bundle 2 rebuilds the reader from the residual and emits 5-9.
664+
gen2, _, threadsafe2, _, _ = self._bundle(
665+
dofn, source, checkpoint=residual1.checkpoint_mark)
666+
self.assertEqual([tv.value for tv in gen2], [5, 6, 7, 8, 9])
667+
residual2, _ = threadsafe2.deferred_status()
668+
self.assertEqual(residual2.checkpoint_mark.last_index, 9)
669+
670+
def test_eof_exactly_at_cap_resumes_then_finishes(self):
671+
dofn = _ReadFromUnboundedSourceDoFn(
672+
poll_interval=0, max_records_per_bundle=5, max_read_time_seconds=1e9)
673+
source = UnboundedCountingSource(5) # exactly cap records
674+
# Bundle 1 hits the cap on the last record before observing EOF.
675+
gen1, t1, threadsafe1, _, _ = self._bundle(dofn, source)
676+
self.assertEqual([tv.value for tv in gen1], [0, 1, 2, 3, 4])
677+
self.assertTrue(t1.current_restriction().is_done)
678+
residual1, _ = threadsafe1.deferred_status()
679+
self.assertEqual(residual1.checkpoint_mark.last_index, 4)
680+
681+
# Bundle 2 resumes at index 5, finds EOF, and finishes with no output.
682+
gen2, t2, threadsafe2, _, _ = self._bundle(
683+
dofn, source, checkpoint=residual1.checkpoint_mark)
684+
self.assertEqual(list(gen2), [])
685+
self.assertTrue(t2.current_restriction().is_done)
686+
self.assertIsNone(threadsafe2.deferred_status())
612687

613688
def test_eof_before_cap_finishes_without_residual(self):
614689
dofn = _ReadFromUnboundedSourceDoFn(
615690
poll_interval=0, max_records_per_bundle=100, max_read_time_seconds=1e9)
616-
outputs, tracker, threadsafe = self._run_process(
691+
gen, tracker, threadsafe, _, _ = self._bundle(
617692
dofn, UnboundedCountingSource(3))
618693

619-
self.assertEqual([tv.value for tv in outputs], [0, 1, 2])
694+
self.assertEqual([tv.value for tv in gen], [0, 1, 2])
620695
self.assertTrue(tracker.current_restriction().is_done)
621696
self.assertIsNone(threadsafe.deferred_status())
622697

@@ -1022,6 +1097,15 @@ def test_non_source_argument_raises(self):
10221097
with self.assertRaises(TypeError):
10231098
ReadFromUnboundedSource('not-a-source') # type: ignore[arg-type]
10241099

1100+
def test_invalid_caps_raise(self):
1101+
source = UnboundedCountingSource(1)
1102+
with self.assertRaises(ValueError):
1103+
ReadFromUnboundedSource(source, max_records_per_bundle=0)
1104+
with self.assertRaises(ValueError):
1105+
ReadFromUnboundedSource(source, max_read_time_seconds=0)
1106+
with self.assertRaises(ValueError):
1107+
ReadFromUnboundedSource(source, poll_interval=-1)
1108+
10251109

10261110
if __name__ == '__main__':
10271111
logging.getLogger().setLevel(logging.INFO)

0 commit comments

Comments
 (0)