@@ -86,6 +86,7 @@ def get_checkpoint_mark_coder(self):
8686import dataclasses
8787import logging
8888import threading
89+ import time
8990from typing import Any
9091from typing import Iterable
9192from typing import Optional
@@ -123,6 +124,8 @@ def get_checkpoint_mark_coder(self):
123124
124125_DEFAULT_POLL_INTERVAL_SECONDS = 1.0
125126_DEFAULT_DESIRED_NUM_SPLITS = 20
127+ _DEFAULT_MAX_RECORDS_PER_BUNDLE = 10000
128+ _DEFAULT_MAX_READ_TIME_SECONDS = 10.0
126129
127130# ------------------------------------------------------------------------------
128131# Public abstract base classes.
@@ -613,8 +616,14 @@ class _ReadFromUnboundedSourceDoFn(core.DoFn):
613616 Module-level so stdlib pickle and cloudpickle can serialise the DoFn. The
614617 restriction provider is the module-level :data:`_PROVIDER` singleton.
615618 """
616- def __init__ (self , poll_interval : float = _DEFAULT_POLL_INTERVAL_SECONDS ):
619+ def __init__ (
620+ self ,
621+ poll_interval : float = _DEFAULT_POLL_INTERVAL_SECONDS ,
622+ max_records_per_bundle : int = _DEFAULT_MAX_RECORDS_PER_BUNDLE ,
623+ max_read_time_seconds : float = _DEFAULT_MAX_READ_TIME_SECONDS ):
617624 self ._poll_interval = poll_interval
625+ self ._max_records_per_bundle = max_records_per_bundle
626+ self ._max_read_time_seconds = max_read_time_seconds
618627
619628 @core .DoFn .unbounded_per_element ()
620629 def process (
@@ -628,6 +637,8 @@ def process(
628637 # kwarg-injected ones (tracker, watermark estimator).
629638 assert isinstance (tracker , sdf_utils .RestrictionTrackerView )
630639 initial = tracker .current_restriction ()
640+ records_emitted = 0
641+ read_deadline = time .monotonic () + self ._max_read_time_seconds
631642 try :
632643 while True :
633644 holder = [None ]
@@ -639,10 +650,8 @@ def process(
639650 break
640651 record = holder [0 ]
641652 if record is _NO_DATA :
642- # No data now: advance the watermark and self-checkpoint with an
643- # explicit delay. defer_remainder(None) resumes ASAP and would
644- # busy-loop an idle source, so (unlike Java's runner-applied backoff)
645- # we set the resume delay here.
653+ # No data now: advance the watermark and self-checkpoint with the
654+ # poll delay so an idle source backs off before resuming.
646655 _set_watermark_if_greater (
647656 watermark_estimator , tracker .current_restriction ().watermark )
648657 tracker .defer_remainder (Duration (seconds = self ._poll_interval ))
@@ -652,6 +661,15 @@ def process(
652661 value , record_timestamp , source_watermark = record
653662 _set_watermark_if_greater (watermark_estimator , source_watermark )
654663 yield TimestampedValue (value , record_timestamp )
664+ records_emitted += 1
665+ # A busy reader never reaches the EOF or no-data branch, so bound the
666+ # bundle by record count and elapsed time. The checkpoint lets the
667+ # runner commit progress and run finalization. Resume with no delay
668+ # because data is still available.
669+ if (records_emitted >= self ._max_records_per_bundle or
670+ time .monotonic () >= read_deadline ):
671+ tracker .defer_remainder ()
672+ break
655673 finally :
656674 current = tracker .current_restriction ()
657675 try :
@@ -695,17 +713,23 @@ class ReadFromUnboundedSource(PTransform):
695713 p | beam.io.Read(MyUnboundedSource())
696714
697715 ``poll_interval`` is the resume delay used when the reader has no data, which
698- bounds how often an idle source is polled.
716+ bounds how often an idle source is polled. ``max_records_per_bundle`` and
717+ ``max_read_time_seconds`` bound how much a busy reader produces before the
718+ bundle self-checkpoints.
699719 """
700720 def __init__ (
701721 self ,
702722 source : UnboundedSource ,
703- poll_interval : float = _DEFAULT_POLL_INTERVAL_SECONDS ):
723+ poll_interval : float = _DEFAULT_POLL_INTERVAL_SECONDS ,
724+ max_records_per_bundle : int = _DEFAULT_MAX_RECORDS_PER_BUNDLE ,
725+ max_read_time_seconds : float = _DEFAULT_MAX_READ_TIME_SECONDS ):
704726 if not isinstance (source , UnboundedSource ):
705727 raise TypeError ('source must be an UnboundedSource, got %r' % (source , ))
706728 super ().__init__ ()
707729 self ._source = source
708730 self ._poll_interval = poll_interval
731+ self ._max_records_per_bundle = max_records_per_bundle
732+ self ._max_read_time_seconds = max_read_time_seconds
709733
710734 def expand (self , pbegin ):
711735 source = self ._source
@@ -717,7 +741,10 @@ def expand(self, pbegin):
717741 pbegin
718742 | 'Create' >> core .Create ([source ])
719743 | 'ReadUnbounded' >> core .ParDo (
720- _ReadFromUnboundedSourceDoFn (self ._poll_interval )))
744+ _ReadFromUnboundedSourceDoFn (
745+ self ._poll_interval ,
746+ self ._max_records_per_bundle ,
747+ self ._max_read_time_seconds )))
721748 # Surface an element type only when the global registry already maps it to
722749 # an equivalent coder. Avoid mutating ``coders.registry`` for a
723750 # parameterized coder whose instance state would be lost.
0 commit comments