@@ -39,7 +39,7 @@ def __init__(self, position):
3939
4040 def finalize_checkpoint(self):
4141 # Commit/acknowledge records up to ``position`` upstream, e.g. ack the
42- # consumed messages on a queue. Must be idempotent.
42+ # consumed messages on a queue.
4343 ...
4444
4545 class MyReader(UnboundedReader):
@@ -83,11 +83,12 @@ def get_checkpoint_mark_coder(self):
8383 p | beam.io.Read(MySource()) | beam.Map(print)
8484"""
8585
86- # pytype: skip-file
87-
8886import dataclasses
8987import logging
88+ import threading
89+ import time
9090from typing import Any
91+ from typing import Callable
9192from typing import Iterable
9293from typing import Optional
9394
@@ -124,6 +125,8 @@ def get_checkpoint_mark_coder(self):
124125
125126_DEFAULT_POLL_INTERVAL_SECONDS = 1.0
126127_DEFAULT_DESIRED_NUM_SPLITS = 20
128+ _DEFAULT_MAX_RECORDS_PER_BUNDLE = 10000
129+ _DEFAULT_MAX_READ_TIME_SECONDS = 10.0
127130
128131# ------------------------------------------------------------------------------
129132# Public abstract base classes.
@@ -143,9 +146,11 @@ def finalize_checkpoint(self) -> None:
143146 Override to acknowledge/commit upstream (for example, ack the consumed
144147 messages on a queue). The default is a no-op.
145148
146- Implementations must be idempotent: the runner may retry the callback on
147- the same mark, and each bundle produces a fresh mark covering the records
148- read so far. An exception raised here is logged but not retried.
149+ The runner calls this at most once for a committed checkpoint mark.
150+ Finalization is best effort; a mark may never be finalized. An exception
151+ raised here is logged. On bundle retry an uncommitted mark may be re-cut
152+ over an overlapping span, so this method must be idempotent (acknowledge by
153+ absolute position).
149154 """
150155 pass
151156
@@ -466,6 +471,8 @@ def _try_split_inner(self, fraction_of_remainder):
466471 if self ._reader is None or not self ._started or self ._restriction .is_done :
467472 return None
468473 checkpoint = self ._reader .get_checkpoint_mark ()
474+ # The residual watermark is advisory; the SDF watermark estimator state is
475+ # the authoritative cross-bundle watermark.
469476 watermark = self ._reader .get_watermark ()
470477 # Keep the two channels independent: the primary carries only the finalize
471478 # hook, the residual only the resume state.
@@ -588,12 +595,46 @@ def truncate(self, element, restriction):
588595_PROVIDER = _UnboundedSourceRestrictionProvider ()
589596
590597
598+ class _FinalizeCheckpointOnce (object ):
599+ def __init__ (self , checkpoint_mark : CheckpointMark ):
600+ self ._checkpoint_mark = checkpoint_mark
601+ # The lock keeps finalization idempotent if a runner ever invokes the
602+ # callback more than once.
603+ self ._lock = threading .Lock ()
604+ self ._finalized = False
605+
606+ def __call__ (self ) -> None :
607+ with self ._lock :
608+ if self ._finalized :
609+ return
610+ self ._finalized = True
611+ # Finalization is best effort: log and swallow so a failing user override
612+ # does not fail the bundle (matches CheckpointMark.finalize_checkpoint).
613+ try :
614+ self ._checkpoint_mark .finalize_checkpoint ()
615+ except Exception : # pylint: disable=broad-except
616+ _LOGGER .warning (
617+ 'Error finalizing UnboundedSource checkpoint mark.' , exc_info = True )
618+
619+
591620class _ReadFromUnboundedSourceDoFn (core .DoFn ):
592621 """SDF wrapper driving an :class:`UnboundedReader` for one restriction.
593622
594623 Module-level so stdlib pickle and cloudpickle can serialise the DoFn. The
595624 restriction provider is the module-level :data:`_PROVIDER` singleton.
596625 """
626+ def __init__ (
627+ self ,
628+ poll_interval : float = _DEFAULT_POLL_INTERVAL_SECONDS ,
629+ max_records_per_bundle : int = _DEFAULT_MAX_RECORDS_PER_BUNDLE ,
630+ max_read_time_seconds : float = _DEFAULT_MAX_READ_TIME_SECONDS ,
631+ _now : Optional [Callable [[], float ]] = None ):
632+ self ._poll_interval = poll_interval
633+ self ._max_records_per_bundle = max_records_per_bundle
634+ self ._max_read_time_seconds = max_read_time_seconds
635+ # Monotonic clock seam; tests inject a deterministic clock.
636+ self ._now = _now
637+
597638 @core .DoFn .unbounded_per_element ()
598639 def process (
599640 self ,
@@ -606,6 +647,10 @@ def process(
606647 # kwarg-injected ones (tracker, watermark estimator).
607648 assert isinstance (tracker , sdf_utils .RestrictionTrackerView )
608649 initial = tracker .current_restriction ()
650+ now = self ._now or time .monotonic
651+ records_emitted = 0
652+ # Armed on the first emitted record so reader startup is excluded.
653+ read_deadline = None # type: Optional[float]
609654 try :
610655 while True :
611656 holder = [None ]
@@ -617,25 +662,38 @@ def process(
617662 break
618663 record = holder [0 ]
619664 if record is _NO_DATA :
620- # No data is available now: advance the watermark and self-checkpoint
621- # so the runner reschedules us after a short delay .
665+ # No data now: advance the watermark and self-checkpoint with the
666+ # poll delay so an idle source backs off before resuming .
622667 _set_watermark_if_greater (
623668 watermark_estimator , tracker .current_restriction ().watermark )
624- tracker .defer_remainder (
625- Duration (seconds = _DEFAULT_POLL_INTERVAL_SECONDS ))
669+ tracker .defer_remainder (Duration (seconds = self ._poll_interval ))
626670 break
627671 # The third tuple field is the source watermark. The record timestamp
628672 # remains the output event time.
629673 value , record_timestamp , source_watermark = record
630674 _set_watermark_if_greater (watermark_estimator , source_watermark )
631675 yield TimestampedValue (value , record_timestamp )
676+ records_emitted += 1
677+ if read_deadline is None :
678+ read_deadline = now () + self ._max_read_time_seconds
679+ # A busy reader never hits the EOF or no-data branch. Bound the bundle
680+ # by record count and elapsed time so the runner commits the checkpoint
681+ # and runs finalization, then resume with no delay. The deadline is
682+ # checked between records; a reader that blocks inside advance() can
683+ # overrun it, so the record cap is the hard backstop.
684+ reached_record_cap = records_emitted >= self ._max_records_per_bundle
685+ if reached_record_cap or now () >= read_deadline :
686+ tracker .defer_remainder ()
687+ break
632688 finally :
633689 current = tracker .current_restriction ()
634690 try :
635691 # Register finalization only when a checkpoint was cut this bundle.
692+ # The SDK bundle finalizer applies no deadline, so finalization is
693+ # unbounded best effort.
636694 finalize_mark = current .finalization_checkpoint_mark
637695 if current is not initial and finalize_mark is not None :
638- bundle_finalizer .register (finalize_mark . finalize_checkpoint )
696+ bundle_finalizer .register (_FinalizeCheckpointOnce ( finalize_mark ) )
639697 finally :
640698 # Release the reader on downstream-yield errors.
641699 inner_tracker = tracker
@@ -670,12 +728,44 @@ class ReadFromUnboundedSource(PTransform):
670728 ``UnboundedSource`` here automatically::
671729
672730 p | beam.io.Read(MyUnboundedSource())
731+
732+ Args:
733+ source: the :class:`UnboundedSource` to read.
734+ poll_interval: resume delay in seconds applied when the reader has no data,
735+ which bounds how often an idle source is polled. Must be >= 0.
736+ max_records_per_bundle: a busy reader self-checkpoints after emitting this
737+ many records in one bundle. Must be >= 1. Defaults to 10000.
738+ max_read_time_seconds: a busy reader self-checkpoints after this many
739+ seconds in one bundle. Must be > 0. Defaults to 10.0. The deadline is
740+ checked between records, so a reader that blocks inside ``advance()`` may
741+ overrun it; ``max_records_per_bundle`` is the hard backstop.
742+
743+ The bundle self-checkpoints as soon as either cap is reached.
673744 """
674- def __init__ (self , source : UnboundedSource ):
745+ def __init__ (
746+ self ,
747+ source : UnboundedSource ,
748+ poll_interval : float = _DEFAULT_POLL_INTERVAL_SECONDS ,
749+ max_records_per_bundle : int = _DEFAULT_MAX_RECORDS_PER_BUNDLE ,
750+ max_read_time_seconds : float = _DEFAULT_MAX_READ_TIME_SECONDS ):
675751 if not isinstance (source , UnboundedSource ):
676752 raise TypeError ('source must be an UnboundedSource, got %r' % (source , ))
753+ if max_records_per_bundle < 1 :
754+ raise ValueError (
755+ 'max_records_per_bundle must be >= 1, got %r' %
756+ (max_records_per_bundle , ))
757+ if max_read_time_seconds <= 0 :
758+ raise ValueError (
759+ 'max_read_time_seconds must be > 0, got %r' %
760+ (max_read_time_seconds , ))
761+ if poll_interval < 0 :
762+ raise ValueError (
763+ 'poll_interval must be >= 0, got %r' % (poll_interval , ))
677764 super ().__init__ ()
678765 self ._source = source
766+ self ._poll_interval = poll_interval
767+ self ._max_records_per_bundle = max_records_per_bundle
768+ self ._max_read_time_seconds = max_read_time_seconds
679769
680770 def expand (self , pbegin ):
681771 source = self ._source
@@ -686,7 +776,11 @@ def expand(self, pbegin):
686776 output = (
687777 pbegin
688778 | 'Create' >> core .Create ([source ])
689- | 'ReadUnbounded' >> core .ParDo (_ReadFromUnboundedSourceDoFn ()))
779+ | 'ReadUnbounded' >> core .ParDo (
780+ _ReadFromUnboundedSourceDoFn (
781+ self ._poll_interval ,
782+ self ._max_records_per_bundle ,
783+ self ._max_read_time_seconds )))
690784 # Surface an element type only when the global registry already maps it to
691785 # an equivalent coder. Avoid mutating ``coders.registry`` for a
692786 # parameterized coder whose instance state would be lost.
0 commit comments