Skip to content

Commit c14882f

Browse files
committed
Auto merge of #107782 - Zoxc:worker-local, r=cjgillot
Move the WorkerLocal type from the rustc-rayon fork into rustc_data_structures This PR moves the definition of the `WorkerLocal` type from `rustc-rayon` into `rustc_data_structures`. This is enabled by the introduction of the `Registry` type which allows you to group up threads to be used by `WorkerLocal` which is basically just an array with an per thread index. The `Registry` type mirrors the one in Rayon and each Rayon worker thread is also registered with the new `Registry`. Safety for `WorkerLocal` is ensured by having it keep a reference to the registry and checking on each access that we're still on the group of threads associated with the registry used to construct it. Accessing a `WorkerLocal` is micro-optimized due to it being hot since it's used for most arena allocations. Performance is slightly improved for the parallel compiler: <table><tr><td rowspan="2">Benchmark</td><td colspan="1"><b>Before</b></th><td colspan="2"><b>After</b></th></tr><tr><td align="right">Time</td><td align="right">Time</td><td align="right">%</th></tr><tr><td>🟣 <b>clap</b>:check</td><td align="right">1.9992s</td><td align="right">1.9949s</td><td align="right"> -0.21%</td></tr><tr><td>🟣 <b>hyper</b>:check</td><td align="right">0.2977s</td><td align="right">0.2970s</td><td align="right"> -0.22%</td></tr><tr><td>🟣 <b>regex</b>:check</td><td align="right">1.1335s</td><td align="right">1.1315s</td><td align="right"> -0.18%</td></tr><tr><td>🟣 <b>syn</b>:check</td><td align="right">1.8235s</td><td align="right">1.8171s</td><td align="right"> -0.35%</td></tr><tr><td>🟣 <b>syntex_syntax</b>:check</td><td align="right">6.9047s</td><td align="right">6.8930s</td><td align="right"> -0.17%</td></tr><tr><td>Total</td><td align="right">12.1586s</td><td align="right">12.1336s</td><td align="right"> -0.21%</td></tr><tr><td>Summary</td><td align="right">1.0000s</td><td align="right">0.9977s</td><td align="right"> -0.23%</td></tr></table> cc `@SparrowLii`
2 parents 901fdb3 + efe7cf4 commit c14882f

File tree

5 files changed

+198
-66
lines changed

5 files changed

+198
-66
lines changed

compiler/rustc_ast/src/attr/mod.rs

+4-32
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,10 @@ use crate::tokenstream::{DelimSpan, Spacing, TokenTree};
1010
use crate::tokenstream::{LazyAttrTokenStream, TokenStream};
1111
use crate::util::comments;
1212
use crate::util::literal::escape_string_symbol;
13-
use rustc_data_structures::sync::WorkerLocal;
1413
use rustc_index::bit_set::GrowableBitSet;
1514
use rustc_span::symbol::{sym, Ident, Symbol};
1615
use rustc_span::Span;
17-
use std::cell::Cell;
1816
use std::iter;
19-
#[cfg(debug_assertions)]
20-
use std::ops::BitXor;
21-
#[cfg(debug_assertions)]
2217
use std::sync::atomic::{AtomicU32, Ordering};
2318
use thin_vec::{thin_vec, ThinVec};
2419

@@ -40,39 +35,16 @@ impl MarkedAttrs {
4035
}
4136
}
4237

43-
pub struct AttrIdGenerator(WorkerLocal<Cell<u32>>);
44-
45-
#[cfg(debug_assertions)]
46-
static MAX_ATTR_ID: AtomicU32 = AtomicU32::new(u32::MAX);
38+
pub struct AttrIdGenerator(AtomicU32);
4739

4840
impl AttrIdGenerator {
4941
pub fn new() -> Self {
50-
// We use `(index as u32).reverse_bits()` to initialize the
51-
// starting value of AttrId in each worker thread.
52-
// The `index` is the index of the worker thread.
53-
// This ensures that the AttrId generated in each thread is unique.
54-
AttrIdGenerator(WorkerLocal::new(|index| {
55-
let index: u32 = index.try_into().unwrap();
56-
57-
#[cfg(debug_assertions)]
58-
{
59-
let max_id = ((index + 1).next_power_of_two() - 1).bitxor(u32::MAX).reverse_bits();
60-
MAX_ATTR_ID.fetch_min(max_id, Ordering::Release);
61-
}
62-
63-
Cell::new(index.reverse_bits())
64-
}))
42+
AttrIdGenerator(AtomicU32::new(0))
6543
}
6644

6745
pub fn mk_attr_id(&self) -> AttrId {
68-
let id = self.0.get();
69-
70-
// Ensure the assigned attr_id does not overlap the bits
71-
// representing the number of threads.
72-
#[cfg(debug_assertions)]
73-
assert!(id <= MAX_ATTR_ID.load(Ordering::Acquire));
74-
75-
self.0.set(id + 1);
46+
let id = self.0.fetch_add(1, Ordering::Relaxed);
47+
assert!(id != u32::MAX);
7648
AttrId::from_u32(id)
7749
}
7850
}

compiler/rustc_data_structures/src/sharded.rs

+1-5
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
use crate::fx::{FxHashMap, FxHasher};
2-
use crate::sync::{Lock, LockGuard};
2+
use crate::sync::{CacheAligned, Lock, LockGuard};
33
use std::borrow::Borrow;
44
use std::collections::hash_map::RawEntryMut;
55
use std::hash::{Hash, Hasher};
66
use std::mem;
77

8-
#[derive(Default)]
9-
#[cfg_attr(parallel_compiler, repr(align(64)))]
10-
struct CacheAligned<T>(T);
11-
128
#[cfg(parallel_compiler)]
139
// 32 shards is sufficient to reduce contention on an 8-core Ryzen 7 1700,
1410
// but this should be tested on higher core count CPUs. How the `Sharded` type gets used

compiler/rustc_data_structures/src/sync.rs

+7-29
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ use std::hash::{BuildHasher, Hash};
4545
use std::ops::{Deref, DerefMut};
4646
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
4747

48+
mod worker_local;
49+
pub use worker_local::{Registry, WorkerLocal};
50+
4851
pub use std::sync::atomic::Ordering;
4952
pub use std::sync::atomic::Ordering::SeqCst;
5053

@@ -205,33 +208,6 @@ cfg_if! {
205208

206209
use std::cell::Cell;
207210

208-
#[derive(Debug)]
209-
pub struct WorkerLocal<T>(OneThread<T>);
210-
211-
impl<T> WorkerLocal<T> {
212-
/// Creates a new worker local where the `initial` closure computes the
213-
/// value this worker local should take for each thread in the thread pool.
214-
#[inline]
215-
pub fn new<F: FnMut(usize) -> T>(mut f: F) -> WorkerLocal<T> {
216-
WorkerLocal(OneThread::new(f(0)))
217-
}
218-
219-
/// Returns the worker-local value for each thread
220-
#[inline]
221-
pub fn into_inner(self) -> Vec<T> {
222-
vec![OneThread::into_inner(self.0)]
223-
}
224-
}
225-
226-
impl<T> Deref for WorkerLocal<T> {
227-
type Target = T;
228-
229-
#[inline(always)]
230-
fn deref(&self) -> &T {
231-
&self.0
232-
}
233-
}
234-
235211
pub type MTLockRef<'a, T> = &'a mut MTLock<T>;
236212

237213
#[derive(Debug, Default)]
@@ -351,8 +327,6 @@ cfg_if! {
351327
};
352328
}
353329

354-
pub use rayon_core::WorkerLocal;
355-
356330
pub use rayon::iter::ParallelIterator;
357331
use rayon::iter::IntoParallelIterator;
358332

@@ -383,6 +357,10 @@ pub fn assert_send<T: ?Sized + Send>() {}
383357
pub fn assert_send_val<T: ?Sized + Send>(_t: &T) {}
384358
pub fn assert_send_sync_val<T: ?Sized + Sync + Send>(_t: &T) {}
385359

360+
#[derive(Default)]
361+
#[cfg_attr(parallel_compiler, repr(align(64)))]
362+
pub struct CacheAligned<T>(pub T);
363+
386364
pub trait HashMapExt<K, V> {
387365
/// Same as HashMap::insert, but it may panic if there's already an
388366
/// entry for `key` with a value not equal to `value`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
use crate::sync::Lock;
2+
use std::cell::Cell;
3+
use std::cell::OnceCell;
4+
use std::ops::Deref;
5+
use std::ptr;
6+
use std::sync::Arc;
7+
8+
#[cfg(parallel_compiler)]
9+
use {crate::cold_path, crate::sync::CacheAligned};
10+
11+
/// A pointer to the `RegistryData` which uniquely identifies a registry.
12+
/// This identifier can be reused if the registry gets freed.
13+
#[derive(Clone, Copy, PartialEq)]
14+
struct RegistryId(*const RegistryData);
15+
16+
impl RegistryId {
17+
#[inline(always)]
18+
/// Verifies that the current thread is associated with the registry and returns its unique
19+
/// index within the registry. This panics if the current thread is not associated with this
20+
/// registry.
21+
///
22+
/// Note that there's a race possible where the identifer in `THREAD_DATA` could be reused
23+
/// so this can succeed from a different registry.
24+
#[cfg(parallel_compiler)]
25+
fn verify(self) -> usize {
26+
let (id, index) = THREAD_DATA.with(|data| (data.registry_id.get(), data.index.get()));
27+
28+
if id == self {
29+
index
30+
} else {
31+
cold_path(|| panic!("Unable to verify registry association"))
32+
}
33+
}
34+
}
35+
36+
struct RegistryData {
37+
thread_limit: usize,
38+
threads: Lock<usize>,
39+
}
40+
41+
/// Represents a list of threads which can access worker locals.
42+
#[derive(Clone)]
43+
pub struct Registry(Arc<RegistryData>);
44+
45+
thread_local! {
46+
/// The registry associated with the thread.
47+
/// This allows the `WorkerLocal` type to clone the registry in its constructor.
48+
static REGISTRY: OnceCell<Registry> = OnceCell::new();
49+
}
50+
51+
struct ThreadData {
52+
registry_id: Cell<RegistryId>,
53+
index: Cell<usize>,
54+
}
55+
56+
thread_local! {
57+
/// A thread local which contains the identifer of `REGISTRY` but allows for faster access.
58+
/// It also holds the index of the current thread.
59+
static THREAD_DATA: ThreadData = const { ThreadData {
60+
registry_id: Cell::new(RegistryId(ptr::null())),
61+
index: Cell::new(0),
62+
}};
63+
}
64+
65+
impl Registry {
66+
/// Creates a registry which can hold up to `thread_limit` threads.
67+
pub fn new(thread_limit: usize) -> Self {
68+
Registry(Arc::new(RegistryData { thread_limit, threads: Lock::new(0) }))
69+
}
70+
71+
/// Gets the registry associated with the current thread. Panics if there's no such registry.
72+
pub fn current() -> Self {
73+
REGISTRY.with(|registry| registry.get().cloned().expect("No assocated registry"))
74+
}
75+
76+
/// Registers the current thread with the registry so worker locals can be used on it.
77+
/// Panics if the thread limit is hit or if the thread already has an associated registry.
78+
pub fn register(&self) {
79+
let mut threads = self.0.threads.lock();
80+
if *threads < self.0.thread_limit {
81+
REGISTRY.with(|registry| {
82+
if registry.get().is_some() {
83+
drop(threads);
84+
panic!("Thread already has a registry");
85+
}
86+
registry.set(self.clone()).ok();
87+
THREAD_DATA.with(|data| {
88+
data.registry_id.set(self.id());
89+
data.index.set(*threads);
90+
});
91+
*threads += 1;
92+
});
93+
} else {
94+
drop(threads);
95+
panic!("Thread limit reached");
96+
}
97+
}
98+
99+
/// Gets the identifer of this registry.
100+
fn id(&self) -> RegistryId {
101+
RegistryId(&*self.0)
102+
}
103+
}
104+
105+
/// Holds worker local values for each possible thread in a registry. You can only access the
106+
/// worker local value through the `Deref` impl on the registry associated with the thread it was
107+
/// created on. It will panic otherwise.
108+
pub struct WorkerLocal<T> {
109+
#[cfg(not(parallel_compiler))]
110+
local: T,
111+
#[cfg(parallel_compiler)]
112+
locals: Box<[CacheAligned<T>]>,
113+
#[cfg(parallel_compiler)]
114+
registry: Registry,
115+
}
116+
117+
// This is safe because the `deref` call will return a reference to a `T` unique to each thread
118+
// or it will panic for threads without an associated local. So there isn't a need for `T` to do
119+
// it's own synchronization. The `verify` method on `RegistryId` has an issue where the the id
120+
// can be reused, but `WorkerLocal` has a reference to `Registry` which will prevent any reuse.
121+
#[cfg(parallel_compiler)]
122+
unsafe impl<T: Send> Sync for WorkerLocal<T> {}
123+
124+
impl<T> WorkerLocal<T> {
125+
/// Creates a new worker local where the `initial` closure computes the
126+
/// value this worker local should take for each thread in the registry.
127+
#[inline]
128+
pub fn new<F: FnMut(usize) -> T>(mut initial: F) -> WorkerLocal<T> {
129+
#[cfg(parallel_compiler)]
130+
{
131+
let registry = Registry::current();
132+
WorkerLocal {
133+
locals: (0..registry.0.thread_limit).map(|i| CacheAligned(initial(i))).collect(),
134+
registry,
135+
}
136+
}
137+
#[cfg(not(parallel_compiler))]
138+
{
139+
WorkerLocal { local: initial(0) }
140+
}
141+
}
142+
143+
/// Returns the worker-local values for each thread
144+
#[inline]
145+
pub fn into_inner(self) -> impl Iterator<Item = T> {
146+
#[cfg(parallel_compiler)]
147+
{
148+
self.locals.into_vec().into_iter().map(|local| local.0)
149+
}
150+
#[cfg(not(parallel_compiler))]
151+
{
152+
std::iter::once(self.local)
153+
}
154+
}
155+
}
156+
157+
impl<T> WorkerLocal<Vec<T>> {
158+
/// Joins the elements of all the worker locals into one Vec
159+
pub fn join(self) -> Vec<T> {
160+
self.into_inner().into_iter().flat_map(|v| v).collect()
161+
}
162+
}
163+
164+
impl<T> Deref for WorkerLocal<T> {
165+
type Target = T;
166+
167+
#[inline(always)]
168+
#[cfg(not(parallel_compiler))]
169+
fn deref(&self) -> &T {
170+
&self.local
171+
}
172+
173+
#[inline(always)]
174+
#[cfg(parallel_compiler)]
175+
fn deref(&self) -> &T {
176+
// This is safe because `verify` will only return values less than
177+
// `self.registry.thread_limit` which is the size of the `self.locals` array.
178+
unsafe { &self.locals.get_unchecked(self.registry.id().verify()).0 }
179+
}
180+
}

compiler/rustc_interface/src/util.rs

+6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ use libloading::Library;
44
use rustc_ast as ast;
55
use rustc_codegen_ssa::traits::CodegenBackend;
66
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
7+
#[cfg(parallel_compiler)]
8+
use rustc_data_structures::sync;
79
use rustc_errors::registry::Registry;
810
use rustc_parse::validate_attr;
911
use rustc_session as session;
@@ -170,6 +172,7 @@ pub(crate) fn run_in_thread_pool_with_globals<F: FnOnce() -> R + Send, R: Send>(
170172
use rustc_middle::ty::tls;
171173
use rustc_query_impl::{deadlock, QueryContext, QueryCtxt};
172174

175+
let registry = sync::Registry::new(threads);
173176
let mut builder = rayon::ThreadPoolBuilder::new()
174177
.thread_name(|_| "rustc".to_string())
175178
.acquire_thread_handler(jobserver::acquire_thread)
@@ -200,6 +203,9 @@ pub(crate) fn run_in_thread_pool_with_globals<F: FnOnce() -> R + Send, R: Send>(
200203
.build_scoped(
201204
// Initialize each new worker thread when created.
202205
move |thread: rayon::ThreadBuilder| {
206+
// Register the thread for use with the `WorkerLocal` type.
207+
registry.register();
208+
203209
rustc_span::set_session_globals_then(session_globals, || thread.run())
204210
},
205211
// Run `f` on the first thread in the thread pool.

0 commit comments

Comments
 (0)