@@ -213,7 +213,6 @@ impl Storage {
213213 if modified_count == 0 {
214214 return None ;
215215 }
216- let mut direct_snapshots: Vec < ( TaskId , Box < TaskStorage > ) > = Vec :: new ( ) ;
217216 let mut modified = Vec :: with_capacity ( modified_count as usize ) ;
218217 {
219218 let shard_guard = shard. read ( ) ;
@@ -229,44 +228,27 @@ impl Storage {
229228 // accompanied by modified flags (set_persistent_task_type calls
230229 // track_modification), so any_modified() is sufficient.
231230 if flags. any_modified ( ) {
232- debug_assert ! (
233- !key. is_transient( ) ,
234- "found a modified transient task: {:?}" ,
235- shared_value. get( ) . get_persistent_task_type( )
236- ) ;
237-
238- if flags. any_modified_during_snapshot ( ) {
239- // Task was modified during snapshot mode, so a snapshot
240- // copy must exist in the snapshots map (created by the
241- // (true, true) case in track_modification_internal).
242- // Remove the entry entirely so end_snapshot doesn't
243- // double-process this task. When iterating in `next` we will
244- // re-synchronize the task flags.
245- let ( _, snapshot) = self . snapshots . remove ( key) . expect (
246- "task with modified_during_snapshot must have a snapshots entry" ,
247- ) ;
248- let snapshot = snapshot. expect (
249- "snapshot entry for modified_during_snapshot task must contain a \
250- value",
251- ) ;
252- direct_snapshots. push ( ( * key, snapshot) ) ;
253- } else {
254- modified. push ( * key) ;
231+ if key. is_transient ( ) {
232+ if cfg ! ( debug_assertions) {
233+ unreachable ! (
234+ "found a modified transient task: {:?}" ,
235+ shared_value. get( ) . get_persistent_task_type( )
236+ ) ;
237+ }
238+ continue ;
255239 }
240+
241+ modified. push ( * key) ;
256242 }
257243 }
258244 // Safety: shard_guard must outlive the iterator.
259245 drop ( shard_guard) ;
260246 }
261247
262- // Early return for shards with no entries at all
263- if direct_snapshots. is_empty ( ) && modified. is_empty ( ) {
264- return None ;
265- }
248+ debug_assert ! ( !modified. is_empty( ) ) ;
266249
267250 Some ( SnapshotShard {
268251 shard_idx,
269- direct_snapshots,
270252 modified,
271253 storage : self ,
272254 process,
@@ -568,7 +550,6 @@ impl Drop for SnapshotGuard<'_> {
568550
569551pub struct SnapshotShard < ' l , P > {
570552 shard_idx : usize ,
571- direct_snapshots : Vec < ( TaskId , Box < TaskStorage > ) > ,
572553 modified : Vec < TaskId > ,
573554 storage : & ' l Storage ,
574555 process : & ' l P ,
@@ -606,16 +587,27 @@ where
606587 type Item = SnapshotItem ;
607588
608589 fn next ( & mut self ) -> Option < Self :: Item > {
609- // direct_snapshots: these tasks had a snapshot copy created by
610- // track_modification. We encode from the owned snapshot copy,
611- // clear the stale modified flags, and promote any _during_snapshot
612- // flags so the task stays dirty for the next cycle.
613- if let Some ( ( task_id, snapshot) ) = self . shard . direct_snapshots . pop ( ) {
614- let item = ( self . shard . process ) ( task_id, & snapshot, & mut self . buffer ) ;
615- // Clear pre-snapshot flags. Since we removed this task's entry from the
616- // snapshots map in take_snapshot, end_snapshot won't see it, so we must
617- // promote here.
590+ if let Some ( task_id) = self . shard . modified . pop ( ) {
618591 let mut inner = self . shard . storage . map . get_mut ( & task_id) . unwrap ( ) ;
592+ // If the task was re-modified during snapshot, the snapshots map may
593+ // hold a pre-modification copy we must serialize instead of the live
594+ // data. Remove the entry so end_snapshot doesn't double-promote it;
595+ // we promote manually below.
596+ let item = if inner. flags . any_modified_during_snapshot ( ) {
597+ match self . shard . storage . snapshots . remove ( & task_id) {
598+ Some ( ( _, Some ( snapshot) ) ) => {
599+ ( self . shard . process ) ( task_id, & snapshot, & mut self . buffer )
600+ }
601+ Some ( ( _, None ) ) | None => {
602+ ( self . shard . process ) ( task_id, & inner, & mut self . buffer )
603+ }
604+ }
605+ } else {
606+ ( self . shard . process ) ( task_id, & inner, & mut self . buffer )
607+ } ;
608+ // Clear the modified flags that were captured into the snapshot copy,
609+ // then promote modified_during_snapshot → modified so the task stays
610+ // dirty for the next snapshot cycle.
619611 inner. flags . set_data_modified ( false ) ;
620612 inner. flags . set_meta_modified ( false ) ;
621613 inner. flags . set_new_task ( false ) ;
@@ -624,45 +616,6 @@ where
624616 . promote_during_snapshot_flags ( & mut inner, self . shard . shard_idx ) ;
625617 return Some ( item) ;
626618 }
627- // modified tasks: acquire a write lock to encode and clear flags in one pass.
628- if let Some ( task_id) = self . shard . modified . pop ( ) {
629- let mut inner = self . shard . storage . map . get_mut ( & task_id) . unwrap ( ) ;
630- if !inner. flags . any_modified_during_snapshot ( ) {
631- let item = ( self . shard . process ) ( task_id, & inner, & mut self . buffer ) ;
632- inner. flags . set_data_modified ( false ) ;
633- inner. flags . set_meta_modified ( false ) ;
634- inner. flags . set_new_task ( false ) ;
635- return Some ( item) ;
636- } else {
637- // Task was modified again during snapshot mode. A snapshot copy was
638- // created in track_modification_internal. Remove it and encode it.
639- // end_snapshot must not also process it, so we take it out of the map.
640- // snapshots is a separate DashMap from map, so holding `inner` across
641- // the remove and encode is safe — no lock ordering issue.
642- let snapshot = self
643- . shard
644- . storage
645- . snapshots
646- . remove ( & task_id)
647- . expect ( "The snapshot bit was set, so it must be in Snapshot state" )
648- . 1
649- . expect (
650- "snapshot entry for modified_during_snapshot task must contain a value" ,
651- ) ;
652-
653- let item = ( self . shard . process ) ( task_id, & snapshot, & mut self . buffer ) ;
654- // Clear the modified flags that were captured into the snapshot copy,
655- // then promote modified_during_snapshot → modified so the task stays
656- // dirty for the next snapshot cycle.
657- inner. flags . set_data_modified ( false ) ;
658- inner. flags . set_meta_modified ( false ) ;
659- inner. flags . set_new_task ( false ) ;
660- self . shard
661- . storage
662- . promote_during_snapshot_flags ( & mut inner, self . shard . shard_idx ) ;
663- return Some ( item) ;
664- }
665- }
666619 None
667620 }
668621}
@@ -704,20 +657,22 @@ mod tests {
704657 }
705658
706659 /// Regression test: a task modified before a snapshot and then modified *again* during
707- /// snapshot iteration must not trigger `debug_assert!(!inner.flags.any_modified())` in
708- /// `SnapshotShardIter:: next` .
660+ /// snapshot iteration must serialize the pre-snapshot state and carry the during-snapshot
661+ /// modification forward to the next cycle .
709662 ///
710663 /// Sequence of events:
711664 /// 1. Task is modified (data_modified = true) → added to shard_modified_counts.
712665 /// 2. `start_snapshot` puts us in snapshot mode.
713- /// 3. `take_snapshot` scans the shard: task has `any_modified()=true` and
714- /// `any_modified_during_snapshot()=false` → task goes into the `modified` list.
715- /// 4. **Between scan and iteration**: `track_modification` is called on the task again. This is
716- /// the `(true, true)` branch: already modified AND in snapshot mode. A snapshot copy of the
717- /// pre-snapshot state is created (carrying the modified bits) and stored in `snapshots`.
718- /// 5. `SnapshotShardIter::next` processes the task from the `modified` list, finds
719- /// `any_modified_during_snapshot()=true`, clears the live modified flags (which were
720- /// captured into the snapshot), then asserts `!any_modified()` before promoting.
666+ /// 3. `take_snapshot` scans the shard: task has `any_modified()=true` → goes into the
667+ /// `modified` list.
668+ /// 4. **Between scan and iteration**: `track_modification` is called on the same category. This
669+ /// is the `(true, true)` branch: already modified AND in snapshot mode. A snapshot copy of
670+ /// the pre-second-modification state is stored in `snapshots` as `Some(copy)`, and
671+ /// `data_modified_during_snapshot` is set.
672+ /// 5. `SnapshotShardIter::next` processes the task from the `modified` list, detects
673+ /// `any_modified_during_snapshot()=true`, finds the `Some(copy)` in `snapshots`, encodes the
674+ /// pre-snapshot copy, clears the live modified flags, removes the snapshots entry, and
675+ /// promotes `data_modified_during_snapshot → data_modified` for the next cycle.
721676 // `end_snapshot` uses `parallel::for_each` which calls `block_in_place` internally,
722677 // requiring a multi-threaded Tokio runtime.
723678 #[ tokio:: test( flavor = "multi_thread" ) ]
@@ -751,8 +706,8 @@ mod tests {
751706 assert ! ( guard. flags. data_modified_during_snapshot( ) )
752707 }
753708
754- // Step 5: consume the iterator. The iterator clears the live modified flags
755- // before the assert, encodes the snapshot copy , and promotes
709+ // Step 5: consume the iterator. The iterator encodes from the pre-snapshot copy,
710+ // clears the live modified flags, removes the snapshots entry , and promotes
756711 // `data_modified_during_snapshot → data_modified` for the next cycle.
757712 let items: Vec < _ > = shards
758713 . into_iter ( )
@@ -765,7 +720,7 @@ mod tests {
765720
766721 {
767722 let guard = storage. access_mut ( task_id) ;
768- // Ending the snapshot should have promoted modified_during_snapshot → modified.
723+ // The iterator should have promoted modified_during_snapshot → modified.
769724 assert ! ( guard. flags. data_modified( ) ) ;
770725 }
771726
@@ -777,4 +732,73 @@ mod tests {
777732 "shard_modified_counts must be non-zero after promoting modified_during_snapshot"
778733 ) ;
779734 }
735+
736+ /// Regression test for the `(true, false)` during-snapshot case: a task modified in one
737+ /// category before a snapshot, then modified in a *different* category during snapshot
738+ /// iteration, must not panic and must carry both modifications forward correctly.
739+ ///
740+ /// Sequence of events:
741+ /// 1. Task meta is modified (meta_modified = true).
742+ /// 2. `start_snapshot` puts us in snapshot mode.
743+ /// 3. `take_snapshot` scans the shard: task goes into the `modified` list.
744+ /// 4. Task data is modified during snapshot → `(true, false)` branch: data was not previously
745+ /// modified, so `snapshots` gets a `None` entry and `data_modified_during_snapshot` is set.
746+ /// 5. `SnapshotShardIter::next` processes the task: finds `any_modified_during_snapshot()`,
747+ /// sees `None` in snapshots, encodes from live data (correct — live data for the
748+ /// unmodified-before-snapshot category is still the pre-snapshot state), clears pre-snapshot
749+ /// flags, and promotes `data_modified_during_snapshot → data_modified`.
750+ #[ tokio:: test( flavor = "multi_thread" ) ]
751+ async fn modify_different_category_during_snapshot ( ) {
752+ let storage = Storage :: new ( 2 , true ) ;
753+ let task_id = non_transient_task ( 1 ) ;
754+
755+ // Step 1: modify meta only, outside snapshot mode.
756+ {
757+ let mut guard = storage. access_mut ( task_id) ;
758+ guard. track_modification ( SpecificTaskDataCategory :: Meta , "test" ) ;
759+ assert ! ( guard. flags. meta_modified( ) ) ;
760+ assert ! ( !guard. flags. data_modified( ) ) ;
761+ }
762+
763+ // Step 2: enter snapshot mode.
764+ let ( snapshot_guard, has_modifications) = storage. start_snapshot ( ) ;
765+ assert ! ( has_modifications) ;
766+
767+ // Step 3: take_snapshot — task goes into modified list (meta_modified = true).
768+ let shards = storage. take_snapshot ( snapshot_guard, & dummy_process) ;
769+
770+ // Step 4: modify data during snapshot. The `(true, false)` branch fires:
771+ // data was not previously modified, so snapshots gets a None entry.
772+ {
773+ let mut guard = storage. access_mut ( task_id) ;
774+ guard. track_modification ( SpecificTaskDataCategory :: Data , "test" ) ;
775+ assert ! ( guard. flags. data_modified_during_snapshot( ) ) ;
776+ assert ! ( !guard. flags. meta_modified_during_snapshot( ) ) ;
777+ }
778+
779+ // Step 5: consume the iterator — must not panic.
780+ let items: Vec < _ > = shards
781+ . into_iter ( )
782+ . flat_map ( |shard| shard. into_iter ( ) )
783+ . collect ( ) ;
784+
785+ assert_eq ! ( items. len( ) , 1 ) ;
786+ assert_eq ! ( items[ 0 ] . task_id, task_id) ;
787+
788+ {
789+ let guard = storage. access_mut ( task_id) ;
790+ // meta_modified was cleared by the iterator (it was the pre-snapshot flag).
791+ assert ! ( !guard. flags. meta_modified( ) ) ;
792+ // data_modified_during_snapshot was promoted to data_modified.
793+ assert ! ( guard. flags. data_modified( ) ) ;
794+ assert ! ( !guard. flags. data_modified_during_snapshot( ) ) ;
795+ }
796+
797+ // Next snapshot cycle must pick up the promoted data_modified.
798+ let ( _guard2, has_modifications) = storage. start_snapshot ( ) ;
799+ assert ! (
800+ has_modifications,
801+ "shard_modified_counts must be non-zero after promoting data_modified_during_snapshot"
802+ ) ;
803+ }
780804}
0 commit comments