Skip to content

Commit 742502c

Browse files
committed
Move new WorkerLocal implementation to the old location
1 parent b7e97a9 commit 742502c

5 files changed

Lines changed: 133 additions & 199 deletions

File tree

compiler/rustc_data_structures/src/sync.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,13 @@ pub use self::parallel::{
4242
try_par_for_each_in,
4343
};
4444
pub use self::vec::{AppendOnlyIndexVec, AppendOnlyVec};
45-
pub use self::worker_local::{Registry, WorkerLocal};
45+
pub use rustc_thread_pool::{ComplementaryRegistry, WorkerLocal};
4646
pub use crate::marker::*;
4747

4848
mod freeze;
4949
mod lock;
5050
mod parallel;
5151
mod vec;
52-
mod worker_local;
5352

5453
/// Keep the conditional imports together in a submodule, so that import-sorting
5554
/// doesn't split them up.

compiler/rustc_data_structures/src/sync/worker_local.rs

Lines changed: 0 additions & 149 deletions
This file was deleted.

compiler/rustc_interface/src/util.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,8 @@ pub(crate) fn run_in_thread_pool_with_globals<
190190

191191
let thread_stack_size = init_stack_size(thread_builder_diag);
192192

193-
let registry = sync::Registry::new(std::num::NonZero::new(threads).unwrap());
193+
let registry =
194+
sync::ComplementaryRegistry::new(std::num::NonZero::new(threads).unwrap());
194195

195196
let Some(proof) = sync::check_dyn_thread_safe() else {
196197
return run_in_thread_with_globals(

compiler/rustc_thread_pool/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ mod tests;
9090

9191
pub mod tlv;
9292

93-
pub use worker_local::WorkerLocal;
93+
pub use worker_local::{ComplementaryRegistry, WorkerLocal};
9494

9595
pub use self::broadcast::{BroadcastContext, broadcast, spawn_broadcast};
9696
pub use self::join::{join, join_context};
Lines changed: 129 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,142 @@
1-
use std::fmt;
1+
use std::cell::{Cell, OnceCell};
2+
use std::num::NonZero;
23
use std::ops::Deref;
3-
use std::sync::Arc;
4+
use std::ptr;
5+
use std::sync::{Arc, Mutex};
46

5-
use crate::registry::{Registry, WorkerThread};
7+
use crossbeam_utils::CachePadded;
68

7-
#[repr(align(64))]
8-
#[derive(Debug)]
9-
struct CacheAligned<T>(T);
9+
/// A pointer to the `ComplementaryRegistryData` which uniquely identifies a registry.
10+
/// This identifier can be reused if the registry gets freed.
11+
#[derive(Clone, Copy, PartialEq)]
12+
struct ComplementaryRegistryId(*const ComplementaryRegistryData);
1013

11-
/// Holds worker-locals values for each thread in a thread pool.
12-
/// You can only access the worker local value through the Deref impl
13-
/// on the thread pool it was constructed on. It will panic otherwise
14+
impl ComplementaryRegistryId {
15+
#[inline(always)]
16+
/// Verifies that the current thread is associated with the registry and returns its unique
17+
/// index within the registry. This panics if the current thread is not associated with this
18+
/// registry.
19+
///
20+
/// Note that there's a race possible where the identifier in `THREAD_DATA` could be reused
21+
/// so this can succeed from a different registry.
22+
fn verify(self) -> usize {
23+
let (id, index) = THREAD_DATA.with(|data| (data.registry_id.get(), data.index.get()));
24+
25+
if id == self { index } else { ComplementaryRegistryId::verification_error() }
26+
}
27+
28+
#[cold]
29+
#[inline(never)]
30+
fn verification_error() -> ! {
31+
panic!("Unable to verify registry association")
32+
}
33+
}
34+
35+
struct ComplementaryRegistryData {
36+
thread_limit: NonZero<usize>,
37+
threads: Mutex<usize>,
38+
}
39+
40+
/// Represents a list of threads which can access worker locals.
41+
#[derive(Clone)]
42+
pub struct ComplementaryRegistry(Arc<ComplementaryRegistryData>);
43+
44+
thread_local! {
45+
/// The registry associated with the thread.
46+
/// This allows the `WorkerLocal` type to clone the registry in its constructor.
47+
static REGISTRY: OnceCell<ComplementaryRegistry> = const { OnceCell::new() };
48+
}
49+
50+
struct ThreadData {
51+
registry_id: Cell<ComplementaryRegistryId>,
52+
index: Cell<usize>,
53+
}
54+
55+
thread_local! {
56+
/// A thread local which contains the identifier of `REGISTRY` but allows for faster access.
57+
/// It also holds the index of the current thread.
58+
static THREAD_DATA: ThreadData = const { ThreadData {
59+
registry_id: Cell::new(ComplementaryRegistryId(ptr::null())),
60+
index: Cell::new(0),
61+
}};
62+
}
63+
64+
impl ComplementaryRegistry {
65+
/// Creates a registry which can hold up to `thread_limit` threads.
66+
pub fn new(thread_limit: NonZero<usize>) -> Self {
67+
ComplementaryRegistry(Arc::new(ComplementaryRegistryData {
68+
thread_limit,
69+
threads: Mutex::new(0),
70+
}))
71+
}
72+
73+
/// Gets the registry associated with the current thread. Panics if there's no such registry.
74+
pub fn current() -> Self {
75+
REGISTRY.with(|registry| registry.get().cloned().expect("No associated registry"))
76+
}
77+
78+
/// Registers the current thread with the registry so worker locals can be used on it.
79+
/// Panics if the thread limit is hit or if the thread already has an associated registry.
80+
pub fn register(&self) {
81+
let mut threads = self.0.threads.lock().unwrap();
82+
if *threads < self.0.thread_limit.get() {
83+
REGISTRY.with(|registry| {
84+
if registry.get().is_some() {
85+
drop(threads);
86+
panic!("Thread already has a registry");
87+
}
88+
registry.set(self.clone()).ok();
89+
THREAD_DATA.with(|data| {
90+
data.registry_id.set(self.id());
91+
data.index.set(*threads);
92+
});
93+
*threads += 1;
94+
});
95+
} else {
96+
drop(threads);
97+
panic!("Thread limit reached");
98+
}
99+
}
100+
101+
/// Gets the identifier of this registry.
102+
fn id(&self) -> ComplementaryRegistryId {
103+
ComplementaryRegistryId(&*self.0)
104+
}
105+
}
106+
107+
/// Holds worker local values for each possible thread in a registry. You can only access the
108+
/// worker local value through the `Deref` impl on the registry associated with the thread it was
109+
/// created on. It will panic otherwise.
14110
pub struct WorkerLocal<T> {
15-
locals: Vec<CacheAligned<T>>,
16-
registry: Arc<Registry>,
111+
locals: Box<[CachePadded<T>]>,
112+
registry: ComplementaryRegistry,
17113
}
18114

19-
/// We prevent concurrent access to the underlying value in the
20-
/// Deref impl, thus any values safe to send across threads can
21-
/// be used with WorkerLocal.
115+
// This is safe because the `deref` call will return a reference to a `T` unique to each thread
116+
// or it will panic for threads without an associated local. So there isn't a need for `T` to do
117+
// it's own synchronization. The `verify` method on `RegistryId` has an issue where the id
118+
// can be reused, but `WorkerLocal` has a reference to `Registry` which will prevent any reuse.
22119
unsafe impl<T: Send> Sync for WorkerLocal<T> {}
23120

24121
impl<T> WorkerLocal<T> {
25122
/// Creates a new worker local where the `initial` closure computes the
26-
/// value this worker local should take for each thread in the thread pool.
123+
/// value this worker local should take for each thread in the registry.
27124
#[inline]
125+
#[track_caller]
28126
pub fn new<F: FnMut(usize) -> T>(mut initial: F) -> WorkerLocal<T> {
29-
let registry = Registry::current();
127+
let registry = ComplementaryRegistry::current();
30128
WorkerLocal {
31-
locals: (0..registry.num_threads()).map(|i| CacheAligned(initial(i))).collect(),
129+
locals: (0..registry.0.thread_limit.get())
130+
.map(|i| CachePadded::new(initial(i)))
131+
.collect(),
32132
registry,
33133
}
34134
}
35135

36-
/// Returns the worker-local value for each thread
136+
/// Returns the worker-local values for each thread
37137
#[inline]
38-
pub fn into_inner(self) -> Vec<T> {
39-
self.locals.into_iter().map(|c| c.0).collect()
40-
}
41-
42-
fn current(&self) -> &T {
43-
unsafe {
44-
let worker_thread = WorkerThread::current();
45-
if worker_thread.is_null()
46-
|| !std::ptr::eq(&*(*worker_thread).registry, &*self.registry)
47-
{
48-
panic!("WorkerLocal can only be used on the thread pool it was created on")
49-
}
50-
&self.locals[(*worker_thread).index].0
51-
}
52-
}
53-
}
54-
55-
impl<T> WorkerLocal<Vec<T>> {
56-
/// Joins the elements of all the worker locals into one Vec
57-
pub fn join(self) -> Vec<T> {
58-
self.into_inner().into_iter().flatten().collect()
59-
}
60-
}
61-
62-
impl<T: fmt::Debug> fmt::Debug for WorkerLocal<T> {
63-
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64-
f.debug_struct("WorkerLocal").field("registry", &self.registry.id()).finish()
138+
pub fn into_inner(self) -> impl Iterator<Item = T> {
139+
self.locals.into_vec().into_iter().map(CachePadded::into_inner)
65140
}
66141
}
67142

@@ -70,6 +145,14 @@ impl<T> Deref for WorkerLocal<T> {
70145

71146
#[inline(always)]
72147
fn deref(&self) -> &T {
73-
self.current()
148+
// This is safe because `verify` will only return values less than
149+
// `self.registry.thread_limit` which is the size of the `self.locals` array.
150+
unsafe { &*self.locals.get_unchecked(self.registry.id().verify()) }
151+
}
152+
}
153+
154+
impl<T: Default> Default for WorkerLocal<T> {
155+
fn default() -> Self {
156+
WorkerLocal::new(|_| T::default())
74157
}
75158
}

0 commit comments

Comments
 (0)