@@ -206,7 +206,6 @@ impl Storage {
206206 if modified_count == 0 {
207207 return None ;
208208 }
209- let mut direct_snapshots: Vec < ( TaskId , Box < TaskStorage > ) > = Vec :: new ( ) ;
210209 let mut modified = Vec :: with_capacity ( modified_count as usize ) ;
211210 {
212211 let shard_guard = shard. read ( ) ;
@@ -222,44 +221,29 @@ impl Storage {
222221 // accompanied by modified flags (set_persistent_task_type calls
223222 // track_modification), so any_modified() is sufficient.
224223 if flags. any_modified ( ) {
225- debug_assert ! (
226- !key. is_transient( ) ,
227- "found a modified transient task: {:?}" ,
228- shared_value. get( ) . get_persistent_task_type( )
229- ) ;
230-
231- if flags. any_modified_during_snapshot ( ) {
232- // Task was modified during snapshot mode, so a snapshot
233- // copy must exist in the snapshots map (created by the
234- // (true, true) case in track_modification_internal).
235- // Remove the entry entirely so end_snapshot doesn't
236- // double-process this task. When iterating in `next` we will
237- // re-synchronize the task flags.
238- let ( _, snapshot) = self . snapshots . remove ( key) . expect (
239- "task with modified_during_snapshot must have a snapshots entry" ,
240- ) ;
241- let snapshot = snapshot. expect (
242- "snapshot entry for modified_during_snapshot task must contain a \
243- value",
244- ) ;
245- direct_snapshots. push ( ( * key, snapshot) ) ;
246- } else {
247- modified. push ( * key) ;
224+ if key. is_transient ( ) {
225+ if cfg ! ( debug_assertions) {
226+ unreachable ! (
227+ "found a modified transient task: {:?}" ,
228+ shared_value. get( ) . get_persistent_task_type( )
229+ ) ;
230+ }
231+ continue ;
248232 }
233+
234+ modified. push ( * key) ;
249235 }
250236 }
251237 // Safety: shard_guard must outlive the iterator.
252238 drop ( shard_guard) ;
253239 }
254240
255241 // Early return for shards with no entries at all
256- if direct_snapshots . is_empty ( ) && modified. is_empty ( ) {
242+ if modified. is_empty ( ) {
257243 return None ;
258244 }
259245
260246 Some ( SnapshotShard {
261- shard_idx,
262- direct_snapshots,
263247 modified,
264248 storage : self ,
265249 process,
@@ -560,8 +544,6 @@ impl Drop for SnapshotGuard<'_> {
560544}
561545
562546pub struct SnapshotShard < ' l , P > {
563- shard_idx : usize ,
564- direct_snapshots : Vec < ( TaskId , Box < TaskStorage > ) > ,
565547 modified : Vec < TaskId > ,
566548 storage : & ' l Storage ,
567549 process : & ' l P ,
@@ -599,63 +581,47 @@ where
599581 type Item = SnapshotItem ;
600582
601583 fn next ( & mut self ) -> Option < Self :: Item > {
602- // direct_snapshots: these tasks had a snapshot copy created by
603- // track_modification. We encode from the owned snapshot copy,
604- // clear the stale modified flags, and promote any _during_snapshot
605- // flags so the task stays dirty for the next cycle.
606- if let Some ( ( task_id, snapshot) ) = self . shard . direct_snapshots . pop ( ) {
607- let item = ( self . shard . process ) ( task_id, & snapshot, & mut self . buffer ) ;
608- // Clear pre-snapshot flags. Since we removed this task's entry from the
609- // snapshots map in take_snapshot, end_snapshot won't see it, so we must
610- // promote here.
584+ if let Some ( task_id) = self . shard . modified . pop ( ) {
611585 let mut inner = self . shard . storage . map . get_mut ( & task_id) . unwrap ( ) ;
586+ // Check whether any category was re-modified during snapshot mode.
587+ // If so, the snapshots map may hold a pre-modification copy for the
588+ // category(s) that were already modified (the `(true, true)` branch in
589+ // track_modification_internal). We must serialize that copy, not the
590+ // live data, so we persist the state that was current when the snapshot
591+ // started. Categories modified during snapshot for the first time
592+ // (`(true, false)` branch) produce a `None` entry — in that case the
593+ // live data for those categories is still the pre-snapshot state, so
594+ // encoding from `&inner` is correct.
595+ //
596+ // We remove the entry here so end_snapshot doesn't double-promote it;
597+ // instead we promote manually below.
598+ let item = if inner. flags . any_modified_during_snapshot ( ) {
599+ match self . shard . storage . snapshots . remove ( & task_id) {
600+ Some ( ( _, Some ( snapshot) ) ) => {
601+ // `(true, true)` case: serialize the pre-snapshot copy.
602+ ( self . shard . process ) ( task_id, & snapshot, & mut self . buffer )
603+ }
604+ Some ( ( _, None ) ) | None => {
605+ // `(true, false)` case or no entry: live data is the pre-snapshot
606+ // state for these categories; serialize directly.
607+ ( self . shard . process ) ( task_id, & inner, & mut self . buffer )
608+ }
609+ }
610+ } else {
611+ ( self . shard . process ) ( task_id, & inner, & mut self . buffer )
612+ } ;
612613 inner. flags . set_data_modified ( false ) ;
613614 inner. flags . set_meta_modified ( false ) ;
614615 inner. flags . set_new_task ( false ) ;
615- self . shard
616- . storage
617- . promote_during_snapshot_flags ( & mut inner, self . shard . shard_idx ) ;
616+ // If any _during_snapshot flags are set, promote them to modified so
617+ // the next snapshot cycle picks them up. end_snapshot won't see this
618+ // task because we removed its entry from snapshots above.
619+ self . shard . storage . promote_during_snapshot_flags (
620+ & mut inner,
621+ self . shard . storage . shard_index ( & task_id) ,
622+ ) ;
618623 return Some ( item) ;
619624 }
620- // modified tasks: acquire a write lock to encode and clear flags in one pass.
621- if let Some ( task_id) = self . shard . modified . pop ( ) {
622- let mut inner = self . shard . storage . map . get_mut ( & task_id) . unwrap ( ) ;
623- if !inner. flags . any_modified_during_snapshot ( ) {
624- let item = ( self . shard . process ) ( task_id, & inner, & mut self . buffer ) ;
625- inner. flags . set_data_modified ( false ) ;
626- inner. flags . set_meta_modified ( false ) ;
627- inner. flags . set_new_task ( false ) ;
628- return Some ( item) ;
629- } else {
630- // Task was modified again during snapshot mode. A snapshot copy was
631- // created in track_modification_internal. Remove it and encode it.
632- // end_snapshot must not also process it, so we take it out of the map.
633- // snapshots is a separate DashMap from map, so holding `inner` across
634- // the remove and encode is safe — no lock ordering issue.
635- let snapshot = self
636- . shard
637- . storage
638- . snapshots
639- . remove ( & task_id)
640- . expect ( "The snapshot bit was set, so it must be in Snapshot state" )
641- . 1
642- . expect (
643- "snapshot entry for modified_during_snapshot task must contain a value" ,
644- ) ;
645-
646- let item = ( self . shard . process ) ( task_id, & snapshot, & mut self . buffer ) ;
647- // Clear the modified flags that were captured into the snapshot copy,
648- // then promote modified_during_snapshot → modified so the task stays
649- // dirty for the next snapshot cycle.
650- inner. flags . set_data_modified ( false ) ;
651- inner. flags . set_meta_modified ( false ) ;
652- inner. flags . set_new_task ( false ) ;
653- self . shard
654- . storage
655- . promote_during_snapshot_flags ( & mut inner, self . shard . shard_idx ) ;
656- return Some ( item) ;
657- }
658- }
659625 None
660626 }
661627}
@@ -697,20 +663,22 @@ mod tests {
697663 }
698664
699665 /// Regression test: a task modified before a snapshot and then modified *again* during
700- /// snapshot iteration must not trigger `debug_assert!(!inner.flags.any_modified())` in
701- /// `SnapshotShardIter:: next` .
666+ /// snapshot iteration must serialize the pre-snapshot state and carry the during-snapshot
667+ /// modification forward to the next cycle .
702668 ///
703669 /// Sequence of events:
704670 /// 1. Task is modified (data_modified = true) → added to shard_modified_counts.
705671 /// 2. `start_snapshot` puts us in snapshot mode.
706- /// 3. `take_snapshot` scans the shard: task has `any_modified()=true` and
707- /// `any_modified_during_snapshot()=false` → task goes into the `modified` list.
708- /// 4. **Between scan and iteration**: `track_modification` is called on the task again. This is
709- /// the `(true, true)` branch: already modified AND in snapshot mode. A snapshot copy of the
710- /// pre-snapshot state is created (carrying the modified bits) and stored in `snapshots`.
711- /// 5. `SnapshotShardIter::next` processes the task from the `modified` list, finds
712- /// `any_modified_during_snapshot()=true`, clears the live modified flags (which were
713- /// captured into the snapshot), then asserts `!any_modified()` before promoting.
672+ /// 3. `take_snapshot` scans the shard: task has `any_modified()=true` → goes into the
673+ /// `modified` list.
674+ /// 4. **Between scan and iteration**: `track_modification` is called on the same category. This
675+ /// is the `(true, true)` branch: already modified AND in snapshot mode. A snapshot copy of
676+ /// the pre-second-modification state is stored in `snapshots` as `Some(copy)`, and
677+ /// `data_modified_during_snapshot` is set.
678+ /// 5. `SnapshotShardIter::next` processes the task from the `modified` list, detects
679+ /// `any_modified_during_snapshot()=true`, finds the `Some(copy)` in `snapshots`, encodes the
680+ /// pre-snapshot copy, clears the live modified flags, removes the snapshots entry, and
681+ /// promotes `data_modified_during_snapshot → data_modified` for the next cycle.
714682 // `end_snapshot` uses `parallel::for_each` which calls `block_in_place` internally,
715683 // requiring a multi-threaded Tokio runtime.
716684 #[ tokio:: test( flavor = "multi_thread" ) ]
@@ -744,8 +712,8 @@ mod tests {
744712 assert ! ( guard. flags. data_modified_during_snapshot( ) )
745713 }
746714
747- // Step 5: consume the iterator. The iterator clears the live modified flags
748- // before the assert, encodes the snapshot copy , and promotes
715+ // Step 5: consume the iterator. The iterator encodes from the pre-snapshot copy,
716+ // clears the live modified flags, removes the snapshots entry , and promotes
749717 // `data_modified_during_snapshot → data_modified` for the next cycle.
750718 let items: Vec < _ > = shards
751719 . into_iter ( )
@@ -758,7 +726,7 @@ mod tests {
758726
759727 {
760728 let guard = storage. access_mut ( task_id) ;
761- // Ending the snapshot should have promoted modified_during_snapshot → modified.
729+ // The iterator should have promoted modified_during_snapshot → modified.
762730 assert ! ( guard. flags. data_modified( ) ) ;
763731 }
764732
@@ -770,4 +738,73 @@ mod tests {
770738 "shard_modified_counts must be non-zero after promoting modified_during_snapshot"
771739 ) ;
772740 }
741+
742+ /// Regression test for the `(true, false)` during-snapshot case: a task modified in one
743+ /// category before a snapshot, then modified in a *different* category during snapshot
744+ /// iteration, must not panic and must carry both modifications forward correctly.
745+ ///
746+ /// Sequence of events:
747+ /// 1. Task meta is modified (meta_modified = true).
748+ /// 2. `start_snapshot` puts us in snapshot mode.
749+ /// 3. `take_snapshot` scans the shard: task goes into the `modified` list.
750+ /// 4. Task data is modified during snapshot → `(true, false)` branch: data was not previously
751+ /// modified, so `snapshots` gets a `None` entry and `data_modified_during_snapshot` is set.
752+ /// 5. `SnapshotShardIter::next` processes the task: finds `any_modified_during_snapshot()`,
753+ /// sees `None` in snapshots, encodes from live data (correct — live data for the
754+ /// unmodified-before-snapshot category is still the pre-snapshot state), clears pre-snapshot
755+ /// flags, and promotes `data_modified_during_snapshot → data_modified`.
756+ #[ tokio:: test( flavor = "multi_thread" ) ]
757+ async fn modify_different_category_during_snapshot ( ) {
758+ let storage = Storage :: new ( 2 , true ) ;
759+ let task_id = non_transient_task ( 1 ) ;
760+
761+ // Step 1: modify meta only, outside snapshot mode.
762+ {
763+ let mut guard = storage. access_mut ( task_id) ;
764+ guard. track_modification ( SpecificTaskDataCategory :: Meta , "test" ) ;
765+ assert ! ( guard. flags. meta_modified( ) ) ;
766+ assert ! ( !guard. flags. data_modified( ) ) ;
767+ }
768+
769+ // Step 2: enter snapshot mode.
770+ let ( snapshot_guard, has_modifications) = storage. start_snapshot ( ) ;
771+ assert ! ( has_modifications) ;
772+
773+ // Step 3: take_snapshot — task goes into modified list (meta_modified = true).
774+ let shards = storage. take_snapshot ( snapshot_guard, & dummy_process) ;
775+
776+ // Step 4: modify data during snapshot. The `(true, false)` branch fires:
777+ // data was not previously modified, so snapshots gets a None entry.
778+ {
779+ let mut guard = storage. access_mut ( task_id) ;
780+ guard. track_modification ( SpecificTaskDataCategory :: Data , "test" ) ;
781+ assert ! ( guard. flags. data_modified_during_snapshot( ) ) ;
782+ assert ! ( !guard. flags. meta_modified_during_snapshot( ) ) ;
783+ }
784+
785+ // Step 5: consume the iterator — must not panic.
786+ let items: Vec < _ > = shards
787+ . into_iter ( )
788+ . flat_map ( |shard| shard. into_iter ( ) )
789+ . collect ( ) ;
790+
791+ assert_eq ! ( items. len( ) , 1 ) ;
792+ assert_eq ! ( items[ 0 ] . task_id, task_id) ;
793+
794+ {
795+ let guard = storage. access_mut ( task_id) ;
796+ // meta_modified was cleared by the iterator (it was the pre-snapshot flag).
797+ assert ! ( !guard. flags. meta_modified( ) ) ;
798+ // data_modified_during_snapshot was promoted to data_modified.
799+ assert ! ( guard. flags. data_modified( ) ) ;
800+ assert ! ( !guard. flags. data_modified_during_snapshot( ) ) ;
801+ }
802+
803+ // Next snapshot cycle must pick up the promoted data_modified.
804+ let ( _guard2, has_modifications) = storage. start_snapshot ( ) ;
805+ assert ! (
806+ has_modifications,
807+ "shard_modified_counts must be non-zero after promoting data_modified_during_snapshot"
808+ ) ;
809+ }
773810}
0 commit comments