Skip to content

Commit 0753268

Browse files
committed
[Python] Apply Gemini review on UnboundedSource SDF wrapper
* Make the restriction coder stateless and source-dynamic: encode/decode reads the source's checkpoint_mark_coder from the restriction itself rather than from the coder's constructor. This removes the source- specific dependency that forced the provider and DoFn to be defined inside ``ReadFromUnboundedSource.expand``. * Move ``_UnboundedSourceRestrictionProvider`` and ``_ReadFromUnboundedSourceDoFn`` to module level, backed by a stateless ``_PROVIDER`` singleton. Closure-defined DoFns serialise only via cloudpickle; lifting both to module level lets stdlib pickle and any runner that does not use cloudpickle handle the DoFn too. (PipelineOptions forwarded to ``UnboundedSource.split`` becomes W2 work; today the provider passes ``None``.) * Register the source-declared output coder against the pipeline-specific ``pbegin.pipeline.coder_registry`` instead of the process-global ``coders.registry`` so the registration does not leak across pipelines running in the same process. * Use ``RestrictionProgress(completed=, remaining=)`` instead of ``fraction=`` so ``completed_work`` / ``remaining_work`` resolve directly. * Make the DoFn's ``finally`` tracker-unwrap chain hasattr-driven so a future ``RestrictionTrackerView`` refactor degrades gracefully instead of skipping reader close silently. * Apply yapf + isort across the four files (CI ``beam_PreCommit_PythonFormatter`` was failing on ``iobase_test.py``). Tests: 42/42 ``unbounded_source_test.py``, 16/16 ``iobase_test.py``. Tracking #19137.
1 parent 278ceb0 commit 0753268

3 files changed

Lines changed: 171 additions & 135 deletions

File tree

sdks/python/apache_beam/io/iobase_test.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,12 @@
2121

2222
import unittest
2323

24-
import mock
25-
2624
import apache_beam as beam
27-
from apache_beam.io.concat_source import ConcatSource
28-
from apache_beam.io.concat_source_test import RangeSource
25+
import mock
2926
from apache_beam.io import iobase
3027
from apache_beam.io import range_trackers
28+
from apache_beam.io.concat_source import ConcatSource
29+
from apache_beam.io.concat_source_test import RangeSource
3130
from apache_beam.io.iobase import SourceBundle
3231
from apache_beam.options.pipeline_options import DebugOptions
3332
from apache_beam.testing.util import assert_that
@@ -225,7 +224,6 @@ class UseSdfUnboundedSourcesTests(unittest.TestCase):
225224
iobase.Read.expand(). Uses CountingSource from unbounded_source_test as the
226225
fake finite UnboundedSource (avoids dragging the network in).
227226
"""
228-
229227
def test_read_dispatches_to_read_from_unbounded_source(self):
230228
from apache_beam.io.unbounded_source_test import CountingSource
231229
with mock.patch(

sdks/python/apache_beam/io/unbounded_source.py

Lines changed: 160 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -288,37 +288,54 @@ class _UnboundedSourceRestriction(object):
288288
class _UnboundedSourceRestrictionCoder(Coder):
289289
"""Encodes :class:`_UnboundedSourceRestriction` as a fixed 5-tuple.
290290
291-
Shape: pickled source + nullable resume checkpoint (encoded with the
292-
source's own checkpoint coder if provided, else pickle) + watermark +
293-
done flag + nullable finalization checkpoint (same coder as resume).
291+
Stateless: at encode time the source's own
292+
:meth:`UnboundedSource.get_checkpoint_mark_coder` is looked up from the
293+
restriction; at decode time the source is decoded FIRST and its coder
294+
drives the checkpoint-mark decoding. This avoids passing source-specific
295+
coder state into the coder's constructor, which in turn lets
296+
:class:`_UnboundedSourceRestrictionProvider` and
297+
:class:`_ReadFromUnboundedSourceDoFn` be module-level classes (avoiding
298+
stdlib-pickle gotchas for closure-defined DoFns on some runners).
299+
300+
Wire shape: source_bytes / checkpoint_bytes / watermark / done /
301+
finalization_checkpoint_bytes -- the checkpoint and finalization bytes
302+
are independently encoded with the (source-declared) checkpoint coder
303+
wrapped in :class:`NullableCoder`.
294304
"""
295-
def __init__(self, checkpoint_mark_coder: Optional[Coder] = None):
296-
nullable_checkpoint = NullableCoder(
297-
checkpoint_mark_coder or _MemoizingPickleCoder())
305+
def __init__(self):
306+
self._source_coder = _MemoizingPickleCoder()
307+
self._bytes_coder = coders.BytesCoder()
298308
self._tuple_coder = TupleCoder((
299-
_MemoizingPickleCoder(), # source
300-
nullable_checkpoint, # checkpoint_mark (RESUME state, may be None)
309+
self._bytes_coder, # source (pickled bytes)
310+
self._bytes_coder, # checkpoint_mark (nullable-encoded bytes)
301311
TimestampCoder(), # watermark
302312
BooleanCoder(), # is_done
303-
nullable_checkpoint)) # finalization_checkpoint_mark (commit hook)
313+
self._bytes_coder)) # finalization_checkpoint_mark (nullable-encoded)
314+
315+
def _checkpoint_coder(self, source: UnboundedSource) -> Coder:
316+
return NullableCoder(source.get_checkpoint_mark_coder())
304317

305318
def encode(self, restriction: '_UnboundedSourceRestriction') -> bytes:
319+
source_bytes = self._source_coder.encode(restriction.source)
320+
cp_coder = self._checkpoint_coder(restriction.source)
306321
return self._tuple_coder.encode((
307-
restriction.source,
308-
restriction.checkpoint_mark,
322+
source_bytes,
323+
cp_coder.encode(restriction.checkpoint_mark),
309324
restriction.watermark,
310325
restriction.is_done,
311-
restriction.finalization_checkpoint_mark))
326+
cp_coder.encode(restriction.finalization_checkpoint_mark)))
312327

313328
def decode(self, encoded: bytes) -> '_UnboundedSourceRestriction':
314-
(source, checkpoint_mark, watermark, is_done,
315-
finalization_checkpoint_mark) = self._tuple_coder.decode(encoded)
329+
(source_bytes, checkpoint_bytes, watermark, is_done,
330+
finalization_bytes) = self._tuple_coder.decode(encoded)
331+
source = self._source_coder.decode(source_bytes)
332+
cp_coder = self._checkpoint_coder(source)
316333
return _UnboundedSourceRestriction(
317334
source=source,
318-
checkpoint_mark=checkpoint_mark,
335+
checkpoint_mark=cp_coder.decode(checkpoint_bytes),
319336
watermark=watermark,
320337
is_done=is_done,
321-
finalization_checkpoint_mark=finalization_checkpoint_mark)
338+
finalization_checkpoint_mark=cp_coder.decode(finalization_bytes))
322339

323340
def is_deterministic(self) -> bool:
324341
# The source and checkpoint are pickled, which is not guaranteed
@@ -524,23 +541,31 @@ def check_done(self) -> bool:
524541

525542
def current_progress(self) -> 'iobase.RestrictionProgress':
526543
# Backlog-based progress is out of scope; report a coarse done/not-done
527-
# fraction so the runner has a (recommended) signal.
528-
return iobase.RestrictionProgress(
529-
fraction=1.0 if self._restriction.is_done else 0.0)
544+
# signal so the runner has a (recommended) signal. Use ``completed`` /
545+
# ``remaining`` rather than ``fraction`` so downstream consumers that
546+
# read ``RestrictionProgress.completed_work`` / ``remaining_work`` see
547+
# the expected values directly.
548+
if self._restriction.is_done:
549+
return iobase.RestrictionProgress(completed=1.0, remaining=0.0)
550+
return iobase.RestrictionProgress(completed=0.0, remaining=1.0)
530551

531552
def is_bounded(self) -> bool:
532553
return False
533554

534555

535556
class _UnboundedSourceRestrictionProvider(core.RestrictionProvider):
536-
"""Wraps an :class:`UnboundedSource` element as an SDF restriction."""
537-
def __init__(
538-
self,
539-
checkpoint_mark_coder: Optional[Coder] = None,
540-
options: Optional[Any] = None):
541-
self._restriction_coder = _UnboundedSourceRestrictionCoder(
542-
checkpoint_mark_coder)
543-
self._options = options
557+
"""Wraps an :class:`UnboundedSource` element as an SDF restriction.
558+
559+
Stateless module-level singleton (see :data:`_PROVIDER`): all
560+
source-specific state (e.g. the source's checkpoint coder) is derived
561+
per-call from the restriction's ``source`` field, which lets
562+
:class:`_ReadFromUnboundedSourceDoFn` live at module level too --
563+
avoiding stdlib-pickle gotchas for closure-defined DoFns. PipelineOptions
564+
forwarded to ``UnboundedSource.split`` are W2 work; today the provider
565+
always passes ``None``.
566+
"""
567+
def __init__(self):
568+
self._restriction_coder = _UnboundedSourceRestrictionCoder()
544569

545570
def initial_restriction(
546571
self, element: UnboundedSource) -> _UnboundedSourceRestriction:
@@ -553,8 +578,7 @@ def initial_restriction(
553578
def create_tracker(
554579
self, restriction: _UnboundedSourceRestriction
555580
) -> _UnboundedSourceRestrictionTracker:
556-
return _UnboundedSourceRestrictionTracker(
557-
restriction, options=self._options)
581+
return _UnboundedSourceRestrictionTracker(restriction)
558582

559583
def split(self, element,
560584
restriction) -> Iterable[_UnboundedSourceRestriction]:
@@ -568,7 +592,7 @@ def split(self, element,
568592
# ``BoundedSourceAsSDF`` behaviour.
569593
try:
570594
split_sources = list(
571-
restriction.source.split(_DEFAULT_DESIRED_NUM_SPLITS, self._options))
595+
restriction.source.split(_DEFAULT_DESIRED_NUM_SPLITS, None))
572596
except Exception: # pylint: disable=broad-except
573597
_LOGGER.warning(
574598
'Exception while splitting UnboundedSource. Source not split.',
@@ -612,6 +636,98 @@ def truncate(self, element, restriction):
612636
return None
613637

614638

639+
# Module-level singleton -- the SDF framework captures this via
640+
# ``RestrictionParam`` at class-def time of :class:`_ReadFromUnboundedSourceDoFn`.
641+
# Stateless by design (see provider docstring).
642+
_PROVIDER = _UnboundedSourceRestrictionProvider()
643+
644+
645+
class _ReadFromUnboundedSourceDoFn(core.DoFn):
646+
"""SDF wrapper driving an :class:`UnboundedReader` for one restriction.
647+
648+
Module-level (not nested inside ``ReadFromUnboundedSource.expand``) so
649+
stdlib ``pickle`` -- not just cloudpickle -- can serialise the DoFn. The
650+
per-pipeline ``poll_interval_seconds`` is passed via ``__init__``; the
651+
restriction provider is the module-level :data:`_PROVIDER` singleton.
652+
"""
653+
def __init__(self, poll_interval_seconds: float):
654+
self._poll_interval_seconds = poll_interval_seconds
655+
656+
@core.DoFn.unbounded_per_element()
657+
def process(
658+
self,
659+
unused_element,
660+
bundle_finalizer=core.DoFn.BundleFinalizerParam,
661+
tracker=core.DoFn.RestrictionParam(_PROVIDER),
662+
watermark_estimator=core.DoFn.WatermarkEstimatorParam(
663+
ManualWatermarkEstimator.default_provider())):
664+
# Parameter order matters: positionally-injected params (the element and
665+
# the bundle finalizer) must precede the kwarg-injected ones (the
666+
# restriction tracker and watermark estimator), which the SDF invoker
667+
# passes by name (runners/common.py _get_arg_placeholders).
668+
assert isinstance(tracker, sdf_utils.RestrictionTrackerView)
669+
initial = tracker.current_restriction()
670+
try:
671+
while True:
672+
holder = [None]
673+
if not tracker.try_claim(holder):
674+
# EOF (restriction is_done==True, watermark already set to MAX in
675+
# the tracker). Mirrors Java Read.java:625 -- advance the
676+
# watermark estimator unconditionally on the terminal path so
677+
# downstream event-time windows can close, otherwise the
678+
# estimator would stay at the last reported watermark.
679+
_set_watermark_if_greater(
680+
watermark_estimator, tracker.current_restriction().watermark)
681+
break
682+
record = holder[0]
683+
if record is _NO_DATA:
684+
# No data right now: advance the watermark and self-checkpoint so
685+
# the runner reschedules us later. Resume via defer_remainder() +
686+
# break -- NOT yield ProcessContinuation (the portable SDF path).
687+
_set_watermark_if_greater(
688+
watermark_estimator, tracker.current_restriction().watermark)
689+
tracker.defer_remainder(Duration(seconds=self._poll_interval_seconds))
690+
break
691+
# Data path: advance the estimator with the SOURCE's reported
692+
# watermark (third tuple slot), NOT the record's event time.
693+
# Mirrors Java Read.java:594. The record's event time is used
694+
# only as the TimestampedValue label so the downstream sees the
695+
# real per-record timestamp.
696+
value, record_timestamp, source_watermark = record
697+
_set_watermark_if_greater(watermark_estimator, source_watermark)
698+
yield TimestampedValue(value, record_timestamp)
699+
finally:
700+
current = tracker.current_restriction()
701+
# Register finalization only when a real checkpoint was cut this
702+
# bundle. Restriction identity (``current is not initial``) mirrors
703+
# Java's reference-equality gate in Read.java. We read the explicit
704+
# finalization channel, NOT ``checkpoint_mark`` (which is the RESUME
705+
# state and may belong to the residual after a split).
706+
finalize_mark = current.finalization_checkpoint_mark
707+
if current is not initial and finalize_mark is not None:
708+
bundle_finalizer.register(finalize_mark.finalize_checkpoint)
709+
# Release the underlying reader on every exit path, including the
710+
# exception path where a downstream yield raised between two
711+
# try_claim calls (reader-method failures are already closed inside
712+
# the tracker). ``RestrictionTrackerView`` does not expose the inner
713+
# tracker, so we unwrap dynamically: each layer is optional, so a
714+
# future wrapper-chain change degrades gracefully rather than
715+
# crashing the bundle.
716+
inner_tracker = tracker
717+
if hasattr(inner_tracker, '_threadsafe_restriction_tracker'):
718+
inner_tracker = inner_tracker._threadsafe_restriction_tracker
719+
if hasattr(inner_tracker, '_restriction_tracker'):
720+
inner_tracker = inner_tracker._restriction_tracker
721+
if isinstance(inner_tracker, _UnboundedSourceRestrictionTracker):
722+
inner_tracker._close_reader_if_open()
723+
elif inner_tracker is not tracker:
724+
_LOGGER.warning(
725+
'UnboundedSource DoFn could not reach an inner tracker of type '
726+
'_UnboundedSourceRestrictionTracker via the SDF wrapper chain; '
727+
'reader close on exception path skipped, relying on GC. Beam '
728+
'SDF wrapper internals may have changed -- file an issue.')
729+
730+
615731
def _set_watermark_if_greater(
616732
watermark_estimator, new_watermark: Timestamp) -> None:
617733
# ManualWatermarkEstimator.set_watermark raises if the watermark regresses, so
@@ -654,117 +770,36 @@ def __init__(
654770

655771
def expand(self, pbegin):
656772
source = self._source
657-
poll_interval_seconds = self._poll_interval_seconds
658773
output_coder = source.default_output_coder()
659-
provider = _UnboundedSourceRestrictionProvider(
660-
checkpoint_mark_coder=source.get_checkpoint_mark_coder())
661-
662-
# The DoFn is defined inside ``expand`` so it can close over the
663-
# source-specific ``provider`` (which holds the source's checkpoint coder)
664-
# and the user-tuned ``poll_interval_seconds``. Lifting it to module level
665-
# would require a stateless provider (losing per-source checkpoint coder
666-
# selection), so this is a deliberate trade-off. Cloudpickle, Beam's
667-
# default, handles closure-defined classes; stdlib ``pickle`` does not.
668-
class _ReadFromUnboundedSourceDoFn(core.DoFn):
669-
"""SDF wrapper driving an :class:`UnboundedReader` for one restriction."""
670-
@core.DoFn.unbounded_per_element()
671-
def process(
672-
self,
673-
unused_element,
674-
bundle_finalizer=core.DoFn.BundleFinalizerParam,
675-
tracker=core.DoFn.RestrictionParam(provider),
676-
watermark_estimator=core.DoFn.WatermarkEstimatorParam(
677-
ManualWatermarkEstimator.default_provider())):
678-
# Parameter order matters: positionally-injected params (the element and
679-
# the bundle finalizer) must precede the kwarg-injected ones (the
680-
# restriction tracker and watermark estimator), which the SDF invoker
681-
# passes by name (runners/common.py _get_arg_placeholders).
682-
assert isinstance(tracker, sdf_utils.RestrictionTrackerView)
683-
initial = tracker.current_restriction()
684-
try:
685-
while True:
686-
holder = [None]
687-
if not tracker.try_claim(holder):
688-
# EOF (restriction is_done==True, watermark already set to MAX in
689-
# the tracker). Mirrors Java Read.java:625 -- advance the
690-
# watermark estimator unconditionally on the terminal path so
691-
# downstream event-time windows can close, otherwise the
692-
# estimator would stay at the last reported watermark.
693-
_set_watermark_if_greater(
694-
watermark_estimator, tracker.current_restriction().watermark)
695-
break
696-
record = holder[0]
697-
if record is _NO_DATA:
698-
# No data right now: advance the watermark and self-checkpoint so
699-
# the runner reschedules us later. Resume via defer_remainder() +
700-
# break -- NOT yield ProcessContinuation (the portable SDF path).
701-
_set_watermark_if_greater(
702-
watermark_estimator, tracker.current_restriction().watermark)
703-
tracker.defer_remainder(Duration(seconds=poll_interval_seconds))
704-
break
705-
# Data path: advance the estimator with the SOURCE's reported
706-
# watermark (third tuple slot), NOT the record's event time.
707-
# Mirrors Java Read.java:594. The record's event time is used
708-
# only as the TimestampedValue label so the downstream sees the
709-
# real per-record timestamp.
710-
value, record_timestamp, source_watermark = record
711-
_set_watermark_if_greater(watermark_estimator, source_watermark)
712-
yield TimestampedValue(value, record_timestamp)
713-
finally:
714-
current = tracker.current_restriction()
715-
# Register finalization only when a real checkpoint was cut this
716-
# bundle. Restriction identity (`current is not initial`) mirrors
717-
# Java's reference-equality gate in Read.java. We read the explicit
718-
# finalization channel, NOT ``checkpoint_mark`` (which is the
719-
# RESUME state and may belong to the residual after a split).
720-
finalize_mark = current.finalization_checkpoint_mark
721-
if current is not initial and finalize_mark is not None:
722-
bundle_finalizer.register(finalize_mark.finalize_checkpoint)
723-
# Release the underlying reader on every exit path, including the
724-
# exception path where a downstream yield raised between two
725-
# try_claim calls (reader-method failures are already closed inside
726-
# the tracker). ``RestrictionTrackerView`` does not expose the inner
727-
# tracker, so traverse the (stable-but-private) wrapper chain. If
728-
# the chain changes in a future Beam version we log a warning and
729-
# let GC eventually close; never call ``close`` on an unrelated
730-
# tracker subclass.
731-
threadsafe = getattr(tracker, '_threadsafe_restriction_tracker', None)
732-
inner_tracker = getattr(threadsafe, '_restriction_tracker', None)
733-
if isinstance(inner_tracker, _UnboundedSourceRestrictionTracker):
734-
inner_tracker._close_reader_if_open()
735-
elif inner_tracker is not None or threadsafe is not None:
736-
_LOGGER.warning(
737-
'UnboundedSource DoFn could not reach the inner tracker via '
738-
'_threadsafe_restriction_tracker._restriction_tracker; reader '
739-
'close on exception path skipped, relying on GC. Beam SDF '
740-
'wrapper internals may have changed -- file an issue.')
741-
742774
output = (
743775
pbegin
744776
| 'Impulse' >> Impulse()
745777
| 'EmitSource' >> core.Map(lambda _: source)
746-
| 'ReadUnbounded' >> core.ParDo(_ReadFromUnboundedSourceDoFn()))
778+
| 'ReadUnbounded' >> core.ParDo(
779+
_ReadFromUnboundedSourceDoFn(self._poll_interval_seconds)))
747780
# Wire the source's declared output coder onto the output PCollection.
748781
# Setting ``element_type`` alone is not enough: the runner derives the
749-
# PCollection's coder via ``coders.registry.get_coder(element_type)``,
782+
# PCollection's coder via ``coder_registry.get_coder(element_type)``,
750783
# which may resolve to a registry default that does NOT match the
751784
# source's declared coder (silently downgrading custom coders to pickle).
752-
# Register the source-declared coder against the element type so the
753-
# registry lookup returns it.
785+
# Register against the pipeline-specific ``coder_registry`` rather than
786+
# the global ``coders.registry`` so the registration does not leak
787+
# across pipelines running in the same process.
754788
try:
755789
type_hint = output_coder.to_type_hint()
756790
except NotImplementedError:
757791
type_hint = None
758792
if type_hint is not None:
759793
try:
760-
coders.registry.register_coder(type_hint, type(output_coder))
794+
pbegin.pipeline.coder_registry.register_coder(
795+
type_hint, type(output_coder))
761796
except Exception: # pylint: disable=broad-except
762-
# Some Beam versions / coder classes refuse class-only registration
763-
# (e.g. coders parameterised by non-default constructor args). The
764-
# element_type below still flows through the registry's standard
765-
# lookup; users with parameterised coders must register their coder
766-
# explicitly via ``coders.registry.register_coder`` before pipeline
767-
# construction. Logged so the gap is observable.
797+
# Some coder classes refuse class-only registration (e.g. coders
798+
# parameterised by non-default constructor args). The element_type
799+
# below still flows through the registry's standard lookup; users
800+
# with parameterised coders must register their coder explicitly
801+
# via ``pipeline.coder_registry.register_coder`` before pipeline
802+
# construction.
768803
_LOGGER.warning(
769804
'Could not register %s for element type %s; users must register '
770805
'their coder explicitly for non-default coders.',

0 commit comments

Comments
 (0)