Skip to content

Commit 8b2ff87

Browse files
committed
Add a WorkerLocal type which allow you to hold a value per Rayon worker thread
1 parent 7ab08c3 commit 8b2ff87

File tree

3 files changed

+78
-2
lines changed

3 files changed

+78
-2
lines changed

rayon-core/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ mod spawn;
4747
mod thread_pool;
4848
mod unwind;
4949
mod util;
50+
mod worker_local;
5051

5152
mod compile_fail;
5253
mod test;
@@ -61,6 +62,7 @@ pub use self::spawn::{spawn, spawn_fifo};
6162
pub use self::thread_pool::current_thread_has_pending_tasks;
6263
pub use self::thread_pool::current_thread_index;
6364
pub use self::thread_pool::ThreadPool;
65+
pub use worker_local::WorkerLocal;
6466

6567
use self::registry::{CustomSpawn, DefaultSpawn, ThreadSpawn};
6668

rayon-core/src/registry.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -551,12 +551,12 @@ pub(super) struct WorkerThread {
551551
/// local queue used for `spawn_fifo` indirection
552552
fifo: JobFifo,
553553

554-
index: usize,
554+
pub(crate) index: usize,
555555

556556
/// A weak random number generator.
557557
rng: XorShift64Star,
558558

559-
registry: Arc<Registry>,
559+
pub(crate) registry: Arc<Registry>,
560560
}
561561

562562
// This is a bit sketchy, but basically: the WorkerThread is

rayon-core/src/worker_local.rs

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
use crate::registry::{Registry, WorkerThread};
2+
use std::fmt;
3+
use std::ops::Deref;
4+
use std::sync::Arc;
5+
6+
#[repr(align(64))]
7+
#[derive(Debug)]
8+
struct CacheAligned<T>(T);
9+
10+
/// Holds worker-locals values for each thread in a thread pool.
11+
/// You can only access the worker local value through the Deref impl
12+
/// on the thread pool it was constructed on. It will panic otherwise
13+
pub struct WorkerLocal<T> {
14+
locals: Vec<CacheAligned<T>>,
15+
registry: Arc<Registry>,
16+
}
17+
18+
unsafe impl<T> Send for WorkerLocal<T> {}
19+
unsafe impl<T> Sync for WorkerLocal<T> {}
20+
21+
impl<T> WorkerLocal<T> {
22+
/// Creates a new worker local where the `initial` closure computes the
23+
/// value this worker local should take for each thread in the thread pool.
24+
#[inline]
25+
pub fn new<F: FnMut(usize) -> T>(mut initial: F) -> WorkerLocal<T> {
26+
let registry = Registry::current();
27+
WorkerLocal {
28+
locals: (0..registry.num_threads())
29+
.map(|i| CacheAligned(initial(i)))
30+
.collect(),
31+
registry,
32+
}
33+
}
34+
35+
/// Returns the worker-local value for each thread
36+
#[inline]
37+
pub fn into_inner(self) -> Vec<T> {
38+
self.locals.into_iter().map(|c| c.0).collect()
39+
}
40+
41+
fn current(&self) -> &T {
42+
unsafe {
43+
let worker_thread = WorkerThread::current();
44+
if worker_thread.is_null()
45+
|| &*(*worker_thread).registry as *const _ != &*self.registry as *const _
46+
{
47+
panic!("WorkerLocal can only be used on the thread pool it was created on")
48+
}
49+
&self.locals[(*worker_thread).index].0
50+
}
51+
}
52+
}
53+
54+
impl<T> WorkerLocal<Vec<T>> {
55+
/// Joins the elements of all the worker locals into one Vec
56+
pub fn join(self) -> Vec<T> {
57+
self.into_inner().into_iter().flat_map(|v| v).collect()
58+
}
59+
}
60+
61+
impl<T: fmt::Debug> fmt::Debug for WorkerLocal<T> {
62+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
63+
fmt::Debug::fmt(&self.locals, f)
64+
}
65+
}
66+
67+
impl<T> Deref for WorkerLocal<T> {
68+
type Target = T;
69+
70+
#[inline(always)]
71+
fn deref(&self) -> &T {
72+
self.current()
73+
}
74+
}

0 commit comments

Comments
 (0)