1- use std:: fmt;
1+ use std:: cell:: { Cell , OnceCell } ;
2+ use std:: num:: NonZero ;
23use std:: ops:: Deref ;
3- use std:: sync:: Arc ;
4+ use std:: ptr;
5+ use std:: sync:: { Arc , Mutex } ;
46
5- use crate :: registry :: { Registry , WorkerThread } ;
7+ use crossbeam_utils :: CachePadded ;
68
7- #[ repr( align( 64 ) ) ]
8- #[ derive( Debug ) ]
9- struct CacheAligned < T > ( T ) ;
9+ /// A pointer to the `ComplementaryRegistryData` which uniquely identifies a registry.
10+ /// This identifier can be reused if the registry gets freed.
11+ #[ derive( Clone , Copy , PartialEq ) ]
12+ struct ComplementaryRegistryId ( * const ComplementaryRegistryData ) ;
1013
11- /// Holds worker-locals values for each thread in a thread pool.
12- /// You can only access the worker local value through the Deref impl
13- /// on the thread pool it was constructed on. It will panic otherwise
14+ impl ComplementaryRegistryId {
15+ #[ inline( always) ]
16+ /// Verifies that the current thread is associated with the registry and returns its unique
17+ /// index within the registry. This panics if the current thread is not associated with this
18+ /// registry.
19+ ///
20+ /// Note that there's a race possible where the identifier in `THREAD_DATA` could be reused
21+ /// so this can succeed from a different registry.
22+ fn verify ( self ) -> usize {
23+ let ( id, index) = THREAD_DATA . with ( |data| ( data. registry_id . get ( ) , data. index . get ( ) ) ) ;
24+
25+ if id == self { index } else { ComplementaryRegistryId :: verification_error ( ) }
26+ }
27+
28+ #[ cold]
29+ #[ inline( never) ]
30+ fn verification_error ( ) -> ! {
31+ panic ! ( "Unable to verify registry association" )
32+ }
33+ }
34+
35+ struct ComplementaryRegistryData {
36+ thread_limit : NonZero < usize > ,
37+ threads : Mutex < usize > ,
38+ }
39+
40+ /// Represents a list of threads which can access worker locals.
41+ #[ derive( Clone ) ]
42+ pub struct ComplementaryRegistry ( Arc < ComplementaryRegistryData > ) ;
43+
44+ thread_local ! {
45+ /// The registry associated with the thread.
46+ /// This allows the `WorkerLocal` type to clone the registry in its constructor.
47+ static REGISTRY : OnceCell <ComplementaryRegistry > = const { OnceCell :: new( ) } ;
48+ }
49+
50+ struct ThreadData {
51+ registry_id : Cell < ComplementaryRegistryId > ,
52+ index : Cell < usize > ,
53+ }
54+
55+ thread_local ! {
56+ /// A thread local which contains the identifier of `REGISTRY` but allows for faster access.
57+ /// It also holds the index of the current thread.
58+ static THREAD_DATA : ThreadData = const { ThreadData {
59+ registry_id: Cell :: new( ComplementaryRegistryId ( ptr:: null( ) ) ) ,
60+ index: Cell :: new( 0 ) ,
61+ } } ;
62+ }
63+
64+ impl ComplementaryRegistry {
65+ /// Creates a registry which can hold up to `thread_limit` threads.
66+ pub fn new ( thread_limit : NonZero < usize > ) -> Self {
67+ ComplementaryRegistry ( Arc :: new ( ComplementaryRegistryData {
68+ thread_limit,
69+ threads : Mutex :: new ( 0 ) ,
70+ } ) )
71+ }
72+
73+ /// Gets the registry associated with the current thread. Panics if there's no such registry.
74+ pub fn current ( ) -> Self {
75+ REGISTRY . with ( |registry| registry. get ( ) . cloned ( ) . expect ( "No associated registry" ) )
76+ }
77+
78+ /// Registers the current thread with the registry so worker locals can be used on it.
79+ /// Panics if the thread limit is hit or if the thread already has an associated registry.
80+ pub fn register ( & self ) {
81+ let mut threads = self . 0 . threads . lock ( ) . unwrap ( ) ;
82+ if * threads < self . 0 . thread_limit . get ( ) {
83+ REGISTRY . with ( |registry| {
84+ if registry. get ( ) . is_some ( ) {
85+ drop ( threads) ;
86+ panic ! ( "Thread already has a registry" ) ;
87+ }
88+ registry. set ( self . clone ( ) ) . ok ( ) ;
89+ THREAD_DATA . with ( |data| {
90+ data. registry_id . set ( self . id ( ) ) ;
91+ data. index . set ( * threads) ;
92+ } ) ;
93+ * threads += 1 ;
94+ } ) ;
95+ } else {
96+ drop ( threads) ;
97+ panic ! ( "Thread limit reached" ) ;
98+ }
99+ }
100+
101+ /// Gets the identifier of this registry.
102+ fn id ( & self ) -> ComplementaryRegistryId {
103+ ComplementaryRegistryId ( & * self . 0 )
104+ }
105+ }
106+
107+ /// Holds worker local values for each possible thread in a registry. You can only access the
108+ /// worker local value through the `Deref` impl on the registry associated with the thread it was
109+ /// created on. It will panic otherwise.
14110pub struct WorkerLocal < T > {
15- locals : Vec < CacheAligned < T > > ,
16- registry : Arc < Registry > ,
111+ locals : Box < [ CachePadded < T > ] > ,
112+ registry : ComplementaryRegistry ,
17113}
18114
19- /// We prevent concurrent access to the underlying value in the
20- /// Deref impl, thus any values safe to send across threads can
21- /// be used with WorkerLocal.
115+ // This is safe because the `deref` call will return a reference to a `T` unique to each thread
116+ // or it will panic for threads without an associated local. So there isn't a need for `T` to do
117+ // it's own synchronization. The `verify` method on `RegistryId` has an issue where the id
118+ // can be reused, but `WorkerLocal` has a reference to `Registry` which will prevent any reuse.
22119unsafe impl < T : Send > Sync for WorkerLocal < T > { }
23120
24121impl < T > WorkerLocal < T > {
25122 /// Creates a new worker local where the `initial` closure computes the
26- /// value this worker local should take for each thread in the thread pool .
123+ /// value this worker local should take for each thread in the registry .
27124 #[ inline]
125+ #[ track_caller]
28126 pub fn new < F : FnMut ( usize ) -> T > ( mut initial : F ) -> WorkerLocal < T > {
29- let registry = Registry :: current ( ) ;
127+ let registry = ComplementaryRegistry :: current ( ) ;
30128 WorkerLocal {
31- locals : ( 0 ..registry. num_threads ( ) ) . map ( |i| CacheAligned ( initial ( i) ) ) . collect ( ) ,
129+ locals : ( 0 ..registry. 0 . thread_limit . get ( ) )
130+ . map ( |i| CachePadded :: new ( initial ( i) ) )
131+ . collect ( ) ,
32132 registry,
33133 }
34134 }
35135
36- /// Returns the worker-local value for each thread
136+ /// Returns the worker-local values for each thread
37137 #[ inline]
38- pub fn into_inner ( self ) -> Vec < T > {
39- self . locals . into_iter ( ) . map ( |c| c. 0 ) . collect ( )
40- }
41-
42- fn current ( & self ) -> & T {
43- unsafe {
44- let worker_thread = WorkerThread :: current ( ) ;
45- if worker_thread. is_null ( )
46- || !std:: ptr:: eq ( & * ( * worker_thread) . registry , & * self . registry )
47- {
48- panic ! ( "WorkerLocal can only be used on the thread pool it was created on" )
49- }
50- & self . locals [ ( * worker_thread) . index ] . 0
51- }
52- }
53- }
54-
55- impl < T > WorkerLocal < Vec < T > > {
56- /// Joins the elements of all the worker locals into one Vec
57- pub fn join ( self ) -> Vec < T > {
58- self . into_inner ( ) . into_iter ( ) . flatten ( ) . collect ( )
59- }
60- }
61-
62- impl < T : fmt:: Debug > fmt:: Debug for WorkerLocal < T > {
63- fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
64- f. debug_struct ( "WorkerLocal" ) . field ( "registry" , & self . registry . id ( ) ) . finish ( )
138+ pub fn into_inner ( self ) -> impl Iterator < Item = T > {
139+ self . locals . into_vec ( ) . into_iter ( ) . map ( CachePadded :: into_inner)
65140 }
66141}
67142
@@ -70,6 +145,14 @@ impl<T> Deref for WorkerLocal<T> {
70145
71146 #[ inline( always) ]
72147 fn deref ( & self ) -> & T {
73- self . current ( )
148+ // This is safe because `verify` will only return values less than
149+ // `self.registry.thread_limit` which is the size of the `self.locals` array.
150+ unsafe { & * self . locals . get_unchecked ( self . registry . id ( ) . verify ( ) ) }
151+ }
152+ }
153+
154+ impl < T : Default > Default for WorkerLocal < T > {
155+ fn default ( ) -> Self {
156+ WorkerLocal :: new ( |_| T :: default ( ) )
74157 }
75158}
0 commit comments