Skip to content

Commit 99e3cd0

Browse files
committed
[Python] Bound UnboundedSource per-bundle read and address review feedback
Keep UnboundedSource on the expanded composite path: limit the deprecated READ serialization to BoundedSource, add runner API regression coverage, and clarify checkpoint finalization semantics. Bound the per-bundle read so a continuously busy reader cannot stay in process() forever. Cap the has-data path at max_records_per_bundle (10000) or max_read_time_seconds (10), then self-checkpoint via defer_remainder so the runner commits progress and resumes immediately. Validate the caps at the transform boundary (max_records_per_bundle >= 1, max_read_time_seconds > 0, poll_interval >= 0), arm the read deadline on the first emitted record so reader startup is excluded, and add an injectable clock seam for deterministic time-cap tests. Log and swallow finalize_checkpoint() failures to match the CheckpointMark best-effort contract, and lift the no-data resume delay into a configurable ReadFromUnboundedSource poll_interval. Document the between-records time-cap limitation, the finalize idempotency-on-retry contract, and the advisory residual watermark. Expand BundleCapTest with cross-bundle resume, EOF-exactly-at-cap, finalize registration, watermark, drain-truncate, and cap-validation coverage.
1 parent 0750e11 commit 99e3cd0

4 files changed

Lines changed: 293 additions & 180 deletions

File tree

sdks/python/apache_beam/io/iobase.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -993,16 +993,16 @@ def to_runner_api_parameter(
993993
timestamp_attribute=self.source.timestamp_attribute,
994994
with_attributes=self.source.with_attributes,
995995
id_attribute=self.source.id_label))
996-
# Local import to avoid a circular dependency.
997-
from apache_beam.io.unbounded_source import UnboundedSource
998-
if isinstance(self.source, (BoundedSource, UnboundedSource)):
996+
if isinstance(self.source, BoundedSource):
999997
return (
1000998
common_urns.deprecated_primitives.READ.urn,
1001999
beam_runner_api_pb2.ReadPayload(
10021000
source=self.source.to_runner_api(context),
1003-
is_bounded=beam_runner_api_pb2.IsBounded.BOUNDED
1004-
if self.source.is_bounded() else
1005-
beam_runner_api_pb2.IsBounded.UNBOUNDED))
1001+
is_bounded=beam_runner_api_pb2.IsBounded.BOUNDED))
1002+
# Local import to avoid a circular dependency.
1003+
from apache_beam.io.unbounded_source import UnboundedSource
1004+
if isinstance(self.source, UnboundedSource):
1005+
return super().to_runner_api_parameter(context)
10061006
elif isinstance(self.source, ptransform.PTransform):
10071007
return self.source.to_runner_api_parameter(context)
10081008
raise NotImplementedError(

sdks/python/apache_beam/io/iobase_test.py

Lines changed: 25 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from apache_beam.io.concat_source_test import RangeSource
3030
from apache_beam.io.iobase import SourceBundle
3131
from apache_beam.options.pipeline_options import DebugOptions
32+
from apache_beam.portability import common_urns
33+
from apache_beam.portability import python_urns
3234
from apache_beam.testing.util import assert_that
3335
from apache_beam.testing.util import equal_to
3436

@@ -232,54 +234,31 @@ def test_read_end_to_end_unbounded(self):
232234

233235
def test_read_unbounded_pcollection_is_unbounded(self):
234236
from apache_beam.io.unbounded_source_test import UnboundedCountingSource
235-
with beam.Pipeline() as p:
236-
out = p | beam.io.Read(UnboundedCountingSource(3))
237-
self.assertFalse(out.is_bounded)
238-
239-
def test_to_runner_api_emits_unbounded_read_payload(self):
240-
"""``Read.to_runner_api_parameter`` must serialize an UnboundedSource as
241-
``READ.urn`` with ``IsBounded.UNBOUNDED`` so the wire format round-trips
242-
consistently for pipeline persistence and cross-runner submission.
243-
"""
244-
from apache_beam.io.unbounded_source_test import UnboundedCountingSource
245-
from apache_beam.portability import common_urns
246-
from apache_beam.portability.api import beam_runner_api_pb2
247-
from apache_beam.runners.pipeline_context import PipelineContext
248-
249-
read = beam.io.Read(UnboundedCountingSource(5))
250-
urn, payload = read.to_runner_api_parameter(PipelineContext())
237+
p = beam.Pipeline()
238+
out = p | beam.io.Read(UnboundedCountingSource(3))
239+
self.assertFalse(out.is_bounded)
251240

252-
self.assertEqual(urn, common_urns.deprecated_primitives.READ.urn)
253-
self.assertIsInstance(payload, beam_runner_api_pb2.ReadPayload)
254-
self.assertEqual(
255-
payload.is_bounded, beam_runner_api_pb2.IsBounded.UNBOUNDED)
256-
# The source field must be populated -- a non-empty FunctionSpec proto.
257-
self.assertTrue(payload.source.urn)
258-
259-
def test_read_unbounded_round_trips_through_runner_api(self):
260-
"""Encode then decode via the runner-API protobuf. The restored
261-
transform must be a ``Read`` wrapping an equivalent UnboundedSource.
262-
"""
263-
from apache_beam.io.unbounded_source import UnboundedSource
241+
def test_read_unbounded_serializes_as_expanded_composite(self):
264242
from apache_beam.io.unbounded_source_test import UnboundedCountingSource
265-
from apache_beam.portability.api import beam_runner_api_pb2
266-
from apache_beam.runners.pipeline_context import PipelineContext
267-
268-
original = beam.io.Read(UnboundedCountingSource(7))
269-
context = PipelineContext()
270-
urn, payload = original.to_runner_api_parameter(context)
271-
272-
transform_proto = beam_runner_api_pb2.PTransform()
273-
transform_proto.spec.urn = urn
274-
restored = iobase.Read.from_runner_api_parameter(
275-
transform_proto, payload, context)
276-
277-
self.assertIsInstance(restored, iobase.Read)
278-
self.assertIsInstance(restored.source, UnboundedSource)
279-
self.assertIsInstance(restored.source, UnboundedCountingSource)
280-
self.assertFalse(restored.source.is_bounded())
281-
# Verify the source's internal state survived pickle round-trip.
282-
self.assertEqual(restored.source._count, 7)
243+
p = beam.Pipeline()
244+
p | 'ReadIt' >> beam.io.Read(UnboundedCountingSource(3))
245+
246+
proto = p.to_runner_api(use_fake_coders=True)
247+
transforms = proto.components.transforms.values()
248+
deprecated_reads = [
249+
transform.unique_name for transform in transforms
250+
if transform.spec.urn == common_urns.deprecated_primitives.READ.urn
251+
]
252+
read_transforms = [
253+
transform for transform in proto.components.transforms.values()
254+
if transform.unique_name == 'ReadIt'
255+
]
256+
257+
self.assertEqual([], deprecated_reads)
258+
self.assertEqual(1, len(read_transforms))
259+
self.assertEqual(
260+
python_urns.GENERIC_COMPOSITE_TRANSFORM, read_transforms[0].spec.urn)
261+
self.assertTrue(read_transforms[0].subtransforms)
283262

284263

285264
if __name__ == '__main__':

sdks/python/apache_beam/io/unbounded_source.py

Lines changed: 107 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, position):
3939
4040
def finalize_checkpoint(self):
4141
# Commit/acknowledge records up to ``position`` upstream, e.g. ack the
42-
# consumed messages on a queue. Must be idempotent.
42+
# consumed messages on a queue.
4343
...
4444
4545
class MyReader(UnboundedReader):
@@ -83,11 +83,12 @@ def get_checkpoint_mark_coder(self):
8383
p | beam.io.Read(MySource()) | beam.Map(print)
8484
"""
8585

86-
# pytype: skip-file
87-
8886
import dataclasses
8987
import logging
88+
import threading
89+
import time
9090
from typing import Any
91+
from typing import Callable
9192
from typing import Iterable
9293
from typing import Optional
9394

@@ -124,6 +125,8 @@ def get_checkpoint_mark_coder(self):
124125

125126
_DEFAULT_POLL_INTERVAL_SECONDS = 1.0
126127
_DEFAULT_DESIRED_NUM_SPLITS = 20
128+
_DEFAULT_MAX_RECORDS_PER_BUNDLE = 10000
129+
_DEFAULT_MAX_READ_TIME_SECONDS = 10.0
127130

128131
# ------------------------------------------------------------------------------
129132
# Public abstract base classes.
@@ -143,9 +146,11 @@ def finalize_checkpoint(self) -> None:
143146
Override to acknowledge/commit upstream (for example, ack the consumed
144147
messages on a queue). The default is a no-op.
145148
146-
Implementations must be idempotent: the runner may retry the callback on
147-
the same mark, and each bundle produces a fresh mark covering the records
148-
read so far. An exception raised here is logged but not retried.
149+
The runner calls this at most once for a committed checkpoint mark.
150+
Finalization is best effort; a mark may never be finalized. An exception
151+
raised here is logged. On bundle retry an uncommitted mark may be re-cut
152+
over an overlapping span, so this method must be idempotent (acknowledge by
153+
absolute position).
149154
"""
150155
pass
151156

@@ -466,6 +471,8 @@ def _try_split_inner(self, fraction_of_remainder):
466471
if self._reader is None or not self._started or self._restriction.is_done:
467472
return None
468473
checkpoint = self._reader.get_checkpoint_mark()
474+
# The residual watermark is advisory; the SDF watermark estimator state is
475+
# the authoritative cross-bundle watermark.
469476
watermark = self._reader.get_watermark()
470477
# Keep the two channels independent: the primary carries only the finalize
471478
# hook, the residual only the resume state.
@@ -588,12 +595,46 @@ def truncate(self, element, restriction):
588595
_PROVIDER = _UnboundedSourceRestrictionProvider()
589596

590597

598+
class _FinalizeCheckpointOnce(object):
599+
def __init__(self, checkpoint_mark: CheckpointMark):
600+
self._checkpoint_mark = checkpoint_mark
601+
# The lock keeps finalization idempotent if a runner ever invokes the
602+
# callback more than once.
603+
self._lock = threading.Lock()
604+
self._finalized = False
605+
606+
def __call__(self) -> None:
607+
with self._lock:
608+
if self._finalized:
609+
return
610+
self._finalized = True
611+
# Finalization is best effort: log and swallow so a failing user override
612+
# does not fail the bundle (matches CheckpointMark.finalize_checkpoint).
613+
try:
614+
self._checkpoint_mark.finalize_checkpoint()
615+
except Exception: # pylint: disable=broad-except
616+
_LOGGER.warning(
617+
'Error finalizing UnboundedSource checkpoint mark.', exc_info=True)
618+
619+
591620
class _ReadFromUnboundedSourceDoFn(core.DoFn):
592621
"""SDF wrapper driving an :class:`UnboundedReader` for one restriction.
593622
594623
Module-level so stdlib pickle and cloudpickle can serialise the DoFn. The
595624
restriction provider is the module-level :data:`_PROVIDER` singleton.
596625
"""
626+
def __init__(
627+
self,
628+
poll_interval: float = _DEFAULT_POLL_INTERVAL_SECONDS,
629+
max_records_per_bundle: int = _DEFAULT_MAX_RECORDS_PER_BUNDLE,
630+
max_read_time_seconds: float = _DEFAULT_MAX_READ_TIME_SECONDS,
631+
_now: Optional[Callable[[], float]] = None):
632+
self._poll_interval = poll_interval
633+
self._max_records_per_bundle = max_records_per_bundle
634+
self._max_read_time_seconds = max_read_time_seconds
635+
# Monotonic clock seam; tests inject a deterministic clock.
636+
self._now = _now
637+
597638
@core.DoFn.unbounded_per_element()
598639
def process(
599640
self,
@@ -606,6 +647,10 @@ def process(
606647
# kwarg-injected ones (tracker, watermark estimator).
607648
assert isinstance(tracker, sdf_utils.RestrictionTrackerView)
608649
initial = tracker.current_restriction()
650+
now = self._now or time.monotonic
651+
records_emitted = 0
652+
# Armed on the first emitted record so reader startup is excluded.
653+
read_deadline = None # type: Optional[float]
609654
try:
610655
while True:
611656
holder = [None]
@@ -617,25 +662,38 @@ def process(
617662
break
618663
record = holder[0]
619664
if record is _NO_DATA:
620-
# No data is available now: advance the watermark and self-checkpoint
621-
# so the runner reschedules us after a short delay.
665+
# No data now: advance the watermark and self-checkpoint with the
666+
# poll delay so an idle source backs off before resuming.
622667
_set_watermark_if_greater(
623668
watermark_estimator, tracker.current_restriction().watermark)
624-
tracker.defer_remainder(
625-
Duration(seconds=_DEFAULT_POLL_INTERVAL_SECONDS))
669+
tracker.defer_remainder(Duration(seconds=self._poll_interval))
626670
break
627671
# The third tuple field is the source watermark. The record timestamp
628672
# remains the output event time.
629673
value, record_timestamp, source_watermark = record
630674
_set_watermark_if_greater(watermark_estimator, source_watermark)
631675
yield TimestampedValue(value, record_timestamp)
676+
records_emitted += 1
677+
if read_deadline is None:
678+
read_deadline = now() + self._max_read_time_seconds
679+
# A busy reader never hits the EOF or no-data branch. Bound the bundle
680+
# by record count and elapsed time so the runner commits the checkpoint
681+
# and runs finalization, then resume with no delay. The deadline is
682+
# checked between records; a reader that blocks inside advance() can
683+
# overrun it, so the record cap is the hard backstop.
684+
reached_record_cap = records_emitted >= self._max_records_per_bundle
685+
if reached_record_cap or now() >= read_deadline:
686+
tracker.defer_remainder()
687+
break
632688
finally:
633689
current = tracker.current_restriction()
634690
try:
635691
# Register finalization only when a checkpoint was cut this bundle.
692+
# The SDK bundle finalizer applies no deadline, so finalization is
693+
# unbounded best effort.
636694
finalize_mark = current.finalization_checkpoint_mark
637695
if current is not initial and finalize_mark is not None:
638-
bundle_finalizer.register(finalize_mark.finalize_checkpoint)
696+
bundle_finalizer.register(_FinalizeCheckpointOnce(finalize_mark))
639697
finally:
640698
# Release the reader on downstream-yield errors.
641699
inner_tracker = tracker
@@ -670,12 +728,44 @@ class ReadFromUnboundedSource(PTransform):
670728
``UnboundedSource`` here automatically::
671729
672730
p | beam.io.Read(MyUnboundedSource())
731+
732+
Args:
733+
source: the :class:`UnboundedSource` to read.
734+
poll_interval: resume delay in seconds applied when the reader has no data,
735+
which bounds how often an idle source is polled. Must be >= 0.
736+
max_records_per_bundle: a busy reader self-checkpoints after emitting this
737+
many records in one bundle. Must be >= 1. Defaults to 10000.
738+
max_read_time_seconds: a busy reader self-checkpoints after this many
739+
seconds in one bundle. Must be > 0. Defaults to 10.0. The deadline is
740+
checked between records, so a reader that blocks inside ``advance()`` may
741+
overrun it; ``max_records_per_bundle`` is the hard backstop.
742+
743+
The bundle self-checkpoints as soon as either cap is reached.
673744
"""
674-
def __init__(self, source: UnboundedSource):
745+
def __init__(
746+
self,
747+
source: UnboundedSource,
748+
poll_interval: float = _DEFAULT_POLL_INTERVAL_SECONDS,
749+
max_records_per_bundle: int = _DEFAULT_MAX_RECORDS_PER_BUNDLE,
750+
max_read_time_seconds: float = _DEFAULT_MAX_READ_TIME_SECONDS):
675751
if not isinstance(source, UnboundedSource):
676752
raise TypeError('source must be an UnboundedSource, got %r' % (source, ))
753+
if max_records_per_bundle < 1:
754+
raise ValueError(
755+
'max_records_per_bundle must be >= 1, got %r' %
756+
(max_records_per_bundle, ))
757+
if max_read_time_seconds <= 0:
758+
raise ValueError(
759+
'max_read_time_seconds must be > 0, got %r' %
760+
(max_read_time_seconds, ))
761+
if poll_interval < 0:
762+
raise ValueError(
763+
'poll_interval must be >= 0, got %r' % (poll_interval, ))
677764
super().__init__()
678765
self._source = source
766+
self._poll_interval = poll_interval
767+
self._max_records_per_bundle = max_records_per_bundle
768+
self._max_read_time_seconds = max_read_time_seconds
679769

680770
def expand(self, pbegin):
681771
source = self._source
@@ -686,7 +776,11 @@ def expand(self, pbegin):
686776
output = (
687777
pbegin
688778
| 'Create' >> core.Create([source])
689-
| 'ReadUnbounded' >> core.ParDo(_ReadFromUnboundedSourceDoFn()))
779+
| 'ReadUnbounded' >> core.ParDo(
780+
_ReadFromUnboundedSourceDoFn(
781+
self._poll_interval,
782+
self._max_records_per_bundle,
783+
self._max_read_time_seconds)))
690784
# Surface an element type only when the global registry already maps it to
691785
# an equivalent coder. Avoid mutating ``coders.registry`` for a
692786
# parameterized coder whose instance state would be lost.

0 commit comments

Comments
 (0)