@@ -19,13 +19,13 @@ use std::sync::{
1919use parking_lot:: { Condvar , Mutex } ;
2020use rustc_hash:: FxHashSet ;
2121
22- use crate :: utils:: ptr_eq_arc:: PtrEqArc ;
22+ use crate :: { backend :: AnyOperation , utils:: ptr_eq_arc:: PtrEqArc } ;
2323
2424/// High bit: set while a snapshot is requested or in flight.
2525/// Low bits: count of operations currently executing (not suspended).
2626const SNAPSHOT_REQUESTED_BIT : usize = 1 << ( usize:: BITS - 1 ) ;
2727
28- /// State protected by the mutex. Kept tiny so critical sections stay short.
28+ /// State protected by the mutex.
2929struct State < O > {
3030 /// `true` between `begin_snapshot` and `SnapshotPhase::drop`.
3131 snapshot_requested : bool ,
@@ -41,12 +41,12 @@ struct State<O> {
4141/// Generic over the operation type the caller wants to suspend. The
4242/// coordinator only requires `O: Send + Sync + 'static`; it never inspects
4343/// the value, just stores it via [`PtrEqArc`].
44- pub struct SnapshotCoordinator < O > {
44+ pub struct SnapshotCoordinator < O = AnyOperation > {
4545 /// Combined count + bit. See [`SNAPSHOT_REQUESTED_BIT`].
4646 in_progress_operations : AtomicUsize ,
4747 state : Mutex < State < O > > ,
4848 /// Notified by the last operation to drain (count drops to `BIT` while
49- /// `BIT ` is set). Awaited by [`begin_snapshot`].
49+ /// `SNAPSHOT_REQUESTED_BIT ` is set). Awaited by [`begin_snapshot`].
5050 operations_drained : Condvar ,
5151 /// Notified by [`SnapshotPhase::drop`]. Awaited by operations that hit a
5252 /// suspend point or arrive while a snapshot is in flight.
@@ -137,7 +137,9 @@ impl<O> SnapshotCoordinator<O> {
137137 // and acquiring the mutex. Nothing to do.
138138 return ;
139139 }
140- state. suspended_operations . insert ( op. clone ( ) . into ( ) ) ;
140+ state
141+ . suspended_operations
142+ . insert ( PtrEqArc :: from ( op. clone ( ) ) ) ;
141143 // Decrement the count so the snapshotter can drain.
142144 let prev = this. in_progress_operations . fetch_sub ( 1 , Ordering :: AcqRel ) ;
143145 // Protocol violation if either invariant fails. Keep as a regular
@@ -155,7 +157,7 @@ impl<O> SnapshotCoordinator<O> {
155157 . wait_while ( & mut state, |s| s. snapshot_requested ) ;
156158 // Resume: re-increment and remove ourselves from the suspended set.
157159 this. in_progress_operations . fetch_add ( 1 , Ordering :: AcqRel ) ;
158- state. suspended_operations . remove ( & op . into ( ) ) ;
160+ state. suspended_operations . remove ( & PtrEqArc :: from ( op ) ) ;
159161 }
160162 suspend_point_cold ( self , suspend) ;
161163 }
@@ -305,7 +307,11 @@ impl<O> Drop for SnapshotPhase<'_, O> {
305307#[ cfg( test) ]
306308mod tests {
307309 use std:: {
308- sync:: { Arc , atomic:: AtomicUsize } ,
310+ sync:: {
311+ Arc ,
312+ atomic:: { AtomicBool , AtomicUsize } ,
313+ mpsc:: { self , RecvTimeoutError } ,
314+ } ,
309315 thread,
310316 time:: Duration ,
311317 } ;
@@ -315,6 +321,16 @@ mod tests {
315321 /// Trivial operation type for tests — just a u32 tag.
316322 type Op = u32 ;
317323
324+ /// Spin until `snapshot_pending()` returns true, yielding occasionally so
325+ /// we don't starve the snapshotter thread on single-core CI. Replaces
326+ /// fixed `thread::sleep` waits — those introduced both flakiness (too
327+ /// short) and slowness (too long).
328+ fn wait_for_snapshot_pending < O > ( coord : & SnapshotCoordinator < O > ) {
329+ while !coord. snapshot_pending ( ) {
330+ thread:: yield_now ( ) ;
331+ }
332+ }
333+
318334 #[ test]
319335 fn no_snapshot_pending_initially ( ) {
320336 let coord = SnapshotCoordinator :: < Op > :: new ( ) ;
@@ -356,8 +372,10 @@ mod tests {
356372 }
357373 } ) ;
358374
359- // Give the snapshotter time to set the bit and start waiting.
360- thread:: sleep ( Duration :: from_millis ( 50 ) ) ;
375+ // Wait for the snapshotter to set the bit. It can't make progress
376+ // past begin_snapshot while we hold `g`, so started_snapshot must
377+ // still be 0.
378+ wait_for_snapshot_pending ( & coord) ;
361379 assert_eq ! ( started_snapshot. load( Ordering :: Acquire ) , 0 ) ;
362380
363381 // Drop the operation — snapshotter should now proceed.
@@ -371,17 +389,29 @@ mod tests {
371389 let coord = Arc :: new ( SnapshotCoordinator :: < Op > :: new ( ) ) ;
372390 let phase = coord. begin_snapshot ( ) ;
373391 let started_op = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
392+ let arrived = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
374393
375394 let coord2 = coord. clone ( ) ;
376395 let op_thread = thread:: spawn ( {
377396 let started_op = started_op. clone ( ) ;
397+ let arrived = arrived. clone ( ) ;
378398 move || {
399+ arrived. store ( 1 , Ordering :: Release ) ;
379400 let _guard = coord2. begin_operation ( ) ;
380401 started_op. store ( 1 , Ordering :: Release ) ;
381402 }
382403 } ) ;
383404
384- thread:: sleep ( Duration :: from_millis ( 50 ) ) ;
405+ // Wait until the worker is alive and about to call begin_operation.
406+ // We can't directly observe it entering begin_operation (its
407+ // fetch_add is transient — it backs out and parks before we can
408+ // sample), but since we hold `phase` the worker provably cannot
409+ // set started_op=1 from anywhere inside begin_operation. So
410+ // observing started_op==0 after the worker is running and on its
411+ // way into begin_operation is a real check, not a vacuous one.
412+ while arrived. load ( Ordering :: Acquire ) == 0 {
413+ thread:: yield_now ( ) ;
414+ }
385415 assert_eq ! ( started_op. load( Ordering :: Acquire ) , 0 ) ;
386416
387417 drop ( phase) ;
@@ -409,7 +439,7 @@ mod tests {
409439 }
410440 } ) ;
411441
412- thread :: sleep ( Duration :: from_millis ( 20 ) ) ;
442+ wait_for_snapshot_pending ( & coord ) ;
413443 // Snapshotter is now waiting for our operation to drain. Calling
414444 // suspend_point should let it proceed.
415445 coord. suspend_point ( || 42u32 ) ;
@@ -420,109 +450,102 @@ mod tests {
420450 drop ( g) ;
421451 }
422452
423- /// Spawn a watchdog thread that aborts the process if the test doesn't
424- /// signal completion within the timeout. Aborting (vs. panicking) is the
425- /// only way to fail a test cleanly when its main thread is parked on a
426- /// missed-wakeup — a panic in another thread doesn't unblock a join.
427- fn spawn_watchdog (
453+ /// Run `body` on a worker thread and wait up to `timeout` for it to
454+ /// finish.
455+ fn run_with_timeout (
428456 label : & ' static str ,
429457 timeout : Duration ,
430- ) -> Arc < std:: sync:: atomic:: AtomicBool > {
431- let done = Arc :: new ( std:: sync:: atomic:: AtomicBool :: new ( false ) ) ;
432-
433- thread:: spawn ( {
434- let done_watch = done. clone ( ) ;
435- move || {
436- let deadline = std:: time:: Instant :: now ( ) + timeout;
437- while std:: time:: Instant :: now ( ) < deadline {
438- if done_watch. load ( Ordering :: Acquire ) {
439- return ;
440- }
441- thread:: sleep ( Duration :: from_millis ( 50 ) ) ;
442- }
443- eprintln ! (
444- "[watchdog] {label}: timed out after {timeout:?}, missed-wakeup race likely; \
445- aborting"
458+ body : impl FnOnce ( ) + Send + ' static ,
459+ ) {
460+ let ( tx, rx) = mpsc:: channel :: < ( ) > ( ) ;
461+ let handle = thread:: spawn ( move || {
462+ body ( ) ;
463+ let _ = tx. send ( ( ) ) ;
464+ } ) ;
465+ match rx. recv_timeout ( timeout) {
466+ // Worker either finished normally or panicked (dropping the
467+ // sender). Either way it's no longer running, so join to
468+ // propagate any panic.
469+ Ok ( ( ) ) | Err ( RecvTimeoutError :: Disconnected ) => {
470+ handle. join ( ) . unwrap ( ) ;
471+ }
472+ Err ( RecvTimeoutError :: Timeout ) => {
473+ panic ! (
474+ "[watchdog] {label}: timed out after {timeout:?}, missed-wakeup race likely"
446475 ) ;
447- std:: process:: abort ( ) ;
448476 }
449- } ) ;
450- done
477+ }
451478 }
452479
453480 /// Targeted stress test that reproduces the parking_lot notify-all
454481 /// fast-path missed-wakeup race when `OperationGuard::drop` does NOT
455- /// take the state mutex. Many tiny operations and back-to-back
456- /// snapshots maximize the chance of catching a wake during the window
457- /// where the snapshotter has called `wait_while` but parking_lot's
458- /// internal `state` is still null.
482+ /// take the state mutex.
459483 #[ test]
460484 fn stress_no_missed_wakeups ( ) {
461- let done = spawn_watchdog ( "stress_no_missed_wakeups" , Duration :: from_secs ( 60 ) ) ;
485+ run_with_timeout ( "stress_no_missed_wakeups" , Duration :: from_secs ( 60 ) , || {
486+ let coord = Arc :: new ( SnapshotCoordinator :: < Op > :: new ( ) ) ;
487+ let snapshot_lock = Arc :: new ( Mutex :: new ( ( ) ) ) ;
488+ let stop = Arc :: new ( AtomicBool :: new ( false ) ) ;
489+ let snap_count = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
490+
491+ let mut op_handles = Vec :: new ( ) ;
492+ for _ in 0 ..8 {
493+ let coord = coord. clone ( ) ;
494+ op_handles. push ( thread:: spawn ( {
495+ let stop = stop. clone ( ) ;
496+ move || {
497+ while !stop. load ( Ordering :: Relaxed ) {
498+ let _g = coord. begin_operation ( ) ;
499+ }
500+ }
501+ } ) ) ;
502+ }
503+ let mut snap_handles = Vec :: new ( ) ;
504+ for _ in 0 ..2 {
505+ snap_handles. push ( thread:: spawn ( {
506+ let coord = coord. clone ( ) ;
507+ let snapshot_lock = snapshot_lock. clone ( ) ;
508+ let snap_count = snap_count. clone ( ) ;
509+ move || {
510+ for _ in 0 ..200 {
511+ let _ser = snapshot_lock. lock ( ) ;
512+ let _phase = coord. begin_snapshot ( ) ;
513+ snap_count. fetch_add ( 1 , Ordering :: Relaxed ) ;
514+ }
515+ }
516+ } ) ) ;
517+ }
462518
463- let coord = Arc :: new ( SnapshotCoordinator :: < Op > :: new ( ) ) ;
464- let snapshot_lock = Arc :: new ( Mutex :: new ( ( ) ) ) ;
465- let stop = Arc :: new ( std:: sync:: atomic:: AtomicBool :: new ( false ) ) ;
466- let snap_count = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
519+ // Progress watchdog: print snapshot count every 5s so we can see
520+ // if the test is making progress or actually wedged.
521+ let stop_progress = Arc :: new ( AtomicBool :: new ( false ) ) ;
467522
468- let mut op_handles = Vec :: new ( ) ;
469- for _ in 0 ..8 {
470- let coord = coord. clone ( ) ;
471- op_handles. push ( thread:: spawn ( {
472- let stop = stop. clone ( ) ;
473- move || {
474- while !stop. load ( Ordering :: Relaxed ) {
475- let _g = coord. begin_operation ( ) ;
476- }
477- }
478- } ) ) ;
479- }
480- let mut snap_handles = Vec :: new ( ) ;
481- for _ in 0 ..2 {
482- snap_handles. push ( thread:: spawn ( {
483- let coord = coord. clone ( ) ;
484- let snapshot_lock = snapshot_lock. clone ( ) ;
523+ let progress = thread:: spawn ( {
524+ let stop_progress = stop_progress. clone ( ) ;
485525 let snap_count = snap_count. clone ( ) ;
486526 move || {
487- for _ in 0 ..200 {
488- let _ser = snapshot_lock. lock ( ) ;
489- let _phase = coord. begin_snapshot ( ) ;
490- snap_count. fetch_add ( 1 , Ordering :: Relaxed ) ;
527+ while !stop_progress. load ( Ordering :: Relaxed ) {
528+ thread:: sleep ( Duration :: from_secs ( 1 ) ) ;
529+ eprintln ! (
530+ "[stress] snapshots completed: {}" ,
531+ snap_count. load( Ordering :: Relaxed ) ,
532+ ) ;
491533 }
492534 }
493- } ) ) ;
494- }
495-
496- // Progress watchdog: print snapshot count every 5s so we can see
497- // if the test is making progress or actually wedged.
498- let stop_progress = Arc :: new ( std:: sync:: atomic:: AtomicBool :: new ( false ) ) ;
535+ } ) ;
499536
500- let progress = thread:: spawn ( {
501- let stop_progress = stop_progress. clone ( ) ;
502- let snap_count = snap_count. clone ( ) ;
503- move || {
504- while !stop_progress. load ( Ordering :: Relaxed ) {
505- thread:: sleep ( Duration :: from_secs ( 1 ) ) ;
506- eprintln ! (
507- "[stress] snapshots completed: {}" ,
508- snap_count. load( Ordering :: Relaxed ) ,
509- ) ;
510- }
537+ for h in snap_handles {
538+ h. join ( ) . unwrap ( ) ;
511539 }
512- } ) ;
513-
514- for h in snap_handles {
515- h. join ( ) . unwrap ( ) ;
516- }
517- stop. store ( true , Ordering :: Relaxed ) ;
518- for h in op_handles {
519- h. join ( ) . unwrap ( ) ;
520- }
521- stop_progress. store ( true , Ordering :: Relaxed ) ;
522- let _ = progress. join ( ) ;
540+ stop. store ( true , Ordering :: Relaxed ) ;
541+ for h in op_handles {
542+ h. join ( ) . unwrap ( ) ;
543+ }
544+ stop_progress. store ( true , Ordering :: Relaxed ) ;
545+ let _ = progress. join ( ) ;
523546
524- assert_eq ! ( coord. in_progress_operations. load( Ordering :: Acquire ) , 0 ) ;
525- done . store ( true , Ordering :: Release ) ;
547+ assert_eq ! ( coord. in_progress_operations. load( Ordering :: Acquire ) , 0 ) ;
548+ } ) ;
526549 }
527550
528551 #[ test]
0 commit comments