@@ -419,6 +419,13 @@ def split(self, desired_num_splits, options=None):
419419
420420 self .assertEqual (list (provider .split (source , restriction )), [restriction ])
421421
422+ def test_truncate_returns_none_for_drain (self ):
423+ # On drain the SDF stops emitting; truncate yields no residual.
424+ provider = _UnboundedSourceRestrictionProvider ()
425+ source = UnboundedCountingSource (5 )
426+ restriction = _UnboundedSourceRestriction (source = source )
427+ self .assertIsNone (provider .truncate (source , restriction ))
428+
422429 def test_splittable_source_partitions_into_independent_subsources (self ):
423430 # A splittable source fans out into two sub-sources; reading each in
424431 # isolation yields the even and the odd integers, and their union is the
@@ -569,54 +576,122 @@ def register(self, callback):
569576 self .registered .append (callback )
570577
571578
579+ class _ManualClock :
580+ """A deterministic monotonic clock for the time-cap tests."""
581+ def __init__ (self , now = 0.0 ):
582+ self .now = now
583+
584+ def __call__ (self ):
585+ return self .now
586+
587+
572588class BundleCapTest (unittest .TestCase ):
573589 """A busy reader self-checkpoints once the per-bundle record or time cap is
574590 reached, so the runner can commit progress and run finalization."""
575- def _run_process (self , dofn , source ):
591+ def _bundle (self , dofn , source , checkpoint = None , estimator = None ):
592+ """Builds the SDF tracker chain and returns the process() generator plus the
593+ tracker, threadsafe tracker, finalizer, and watermark estimator."""
576594 tracker = _UnboundedSourceRestrictionTracker (
577- _UnboundedSourceRestriction (source = source ))
595+ _UnboundedSourceRestriction (source = source , checkpoint_mark = checkpoint ))
578596 threadsafe = sdf_utils .ThreadsafeRestrictionTracker (tracker )
579597 view = sdf_utils .RestrictionTrackerView (threadsafe )
580- outputs = list (
581- dofn .process (
582- None ,
583- bundle_finalizer = _RecordingBundleFinalizer (),
584- tracker = view ,
585- watermark_estimator = ManualWatermarkEstimator (None )))
586- return outputs , tracker , threadsafe
598+ finalizer = _RecordingBundleFinalizer ()
599+ estimator = estimator or ManualWatermarkEstimator (None )
600+ gen = dofn .process (
601+ None ,
602+ bundle_finalizer = finalizer ,
603+ tracker = view ,
604+ watermark_estimator = estimator )
605+ return gen , tracker , threadsafe , finalizer , estimator
587606
588607 def test_record_cap_checkpoints_busy_source (self ):
608+ finalize_log = []
589609 dofn = _ReadFromUnboundedSourceDoFn (
590610 poll_interval = 0 , max_records_per_bundle = 5 , max_read_time_seconds = 1e9 )
591611 # 1000 records is effectively unbounded against a cap of 5.
592- outputs , tracker , threadsafe = self ._run_process (
593- dofn , UnboundedCountingSource (1000 ))
612+ gen , tracker , threadsafe , finalizer , estimator = self ._bundle (
613+ dofn , UnboundedCountingSource (1000 , finalize_log = finalize_log ))
614+ outputs = list (gen )
594615
595616 self .assertEqual ([tv .value for tv in outputs ], [0 , 1 , 2 , 3 , 4 ])
596617 self .assertTrue (tracker .current_restriction ().is_done )
597618 self .assertTrue (tracker .check_done ())
598- residual , _delay = threadsafe .deferred_status ()
619+ # The estimator holds the last emitted record's source watermark.
620+ self .assertEqual (estimator .current_watermark (), _EVENT_TIME_BASE + 4 )
621+ # Residual resumes after the cut and carries no finalize hook.
622+ residual , _ = threadsafe .deferred_status ()
599623 self .assertEqual (residual .checkpoint_mark .last_index , 4 )
624+ self .assertIsNone (residual .finalization_checkpoint_mark )
625+ # Exactly one finalizer is registered; firing it commits the cut index once.
626+ self .assertEqual (len (finalizer .registered ), 1 )
627+ finalizer .registered [0 ]()
628+ finalizer .registered [0 ]()
629+ self .assertEqual (finalize_log , [4 ])
600630
601631 def test_time_cap_checkpoints_busy_source (self ):
632+ clock = _ManualClock (1000.0 )
602633 dofn = _ReadFromUnboundedSourceDoFn (
603- poll_interval = 0 , max_records_per_bundle = 10 ** 9 , max_read_time_seconds = 0 )
604- outputs , tracker , threadsafe = self ._run_process (
634+ poll_interval = 0 ,
635+ max_records_per_bundle = 10 ** 9 ,
636+ max_read_time_seconds = 5.0 ,
637+ _now = clock )
638+ gen , tracker , threadsafe , _ , _ = self ._bundle (
605639 dofn , UnboundedCountingSource (1000 ))
606640
607- # A zero time budget trips the deadline right after the first record.
608- self .assertEqual ([tv .value for tv in outputs ], [0 ])
641+ # The deadline arms at 1000 + 5 after the first record and is checked
642+ # between records, so records keep flowing until the clock passes it.
643+ self .assertEqual (next (gen ).value , 0 )
644+ self .assertEqual (next (gen ).value , 1 )
645+ self .assertEqual (next (gen ).value , 2 )
646+ clock .now = 1006.0
647+ with self .assertRaises (StopIteration ):
648+ next (gen )
649+
609650 self .assertTrue (tracker .current_restriction ().is_done )
610- residual , _delay = threadsafe .deferred_status ()
611- self .assertEqual (residual .checkpoint_mark .last_index , 0 )
651+ residual , _ = threadsafe .deferred_status ()
652+ self .assertEqual (residual .checkpoint_mark .last_index , 2 )
653+
654+ def test_cap_residual_resumes_in_next_bundle (self ):
655+ dofn = _ReadFromUnboundedSourceDoFn (
656+ poll_interval = 0 , max_records_per_bundle = 5 , max_read_time_seconds = 1e9 )
657+ source = UnboundedCountingSource (1000 )
658+ # Bundle 1 emits 0-4 and cuts a residual at index 4.
659+ gen1 , _ , threadsafe1 , _ , _ = self ._bundle (dofn , source )
660+ self .assertEqual ([tv .value for tv in gen1 ], [0 , 1 , 2 , 3 , 4 ])
661+ residual1 , _ = threadsafe1 .deferred_status ()
662+
663+ # Bundle 2 rebuilds the reader from the residual and emits 5-9.
664+ gen2 , _ , threadsafe2 , _ , _ = self ._bundle (
665+ dofn , source , checkpoint = residual1 .checkpoint_mark )
666+ self .assertEqual ([tv .value for tv in gen2 ], [5 , 6 , 7 , 8 , 9 ])
667+ residual2 , _ = threadsafe2 .deferred_status ()
668+ self .assertEqual (residual2 .checkpoint_mark .last_index , 9 )
669+
670+ def test_eof_exactly_at_cap_resumes_then_finishes (self ):
671+ dofn = _ReadFromUnboundedSourceDoFn (
672+ poll_interval = 0 , max_records_per_bundle = 5 , max_read_time_seconds = 1e9 )
673+ source = UnboundedCountingSource (5 ) # exactly cap records
674+ # Bundle 1 hits the cap on the last record before observing EOF.
675+ gen1 , t1 , threadsafe1 , _ , _ = self ._bundle (dofn , source )
676+ self .assertEqual ([tv .value for tv in gen1 ], [0 , 1 , 2 , 3 , 4 ])
677+ self .assertTrue (t1 .current_restriction ().is_done )
678+ residual1 , _ = threadsafe1 .deferred_status ()
679+ self .assertEqual (residual1 .checkpoint_mark .last_index , 4 )
680+
681+ # Bundle 2 resumes at index 5, finds EOF, and finishes with no output.
682+ gen2 , t2 , threadsafe2 , _ , _ = self ._bundle (
683+ dofn , source , checkpoint = residual1 .checkpoint_mark )
684+ self .assertEqual (list (gen2 ), [])
685+ self .assertTrue (t2 .current_restriction ().is_done )
686+ self .assertIsNone (threadsafe2 .deferred_status ())
612687
613688 def test_eof_before_cap_finishes_without_residual (self ):
614689 dofn = _ReadFromUnboundedSourceDoFn (
615690 poll_interval = 0 , max_records_per_bundle = 100 , max_read_time_seconds = 1e9 )
616- outputs , tracker , threadsafe = self ._run_process (
691+ gen , tracker , threadsafe , _ , _ = self ._bundle (
617692 dofn , UnboundedCountingSource (3 ))
618693
619- self .assertEqual ([tv .value for tv in outputs ], [0 , 1 , 2 ])
694+ self .assertEqual ([tv .value for tv in gen ], [0 , 1 , 2 ])
620695 self .assertTrue (tracker .current_restriction ().is_done )
621696 self .assertIsNone (threadsafe .deferred_status ())
622697
@@ -1022,6 +1097,15 @@ def test_non_source_argument_raises(self):
10221097 with self .assertRaises (TypeError ):
10231098 ReadFromUnboundedSource ('not-a-source' ) # type: ignore[arg-type]
10241099
1100+ def test_invalid_caps_raise (self ):
1101+ source = UnboundedCountingSource (1 )
1102+ with self .assertRaises (ValueError ):
1103+ ReadFromUnboundedSource (source , max_records_per_bundle = 0 )
1104+ with self .assertRaises (ValueError ):
1105+ ReadFromUnboundedSource (source , max_read_time_seconds = 0 )
1106+ with self .assertRaises (ValueError ):
1107+ ReadFromUnboundedSource (source , poll_interval = - 1 )
1108+
10251109
10261110if __name__ == '__main__' :
10271111 logging .getLogger ().setLevel (logging .INFO )
0 commit comments