Skip to content

Commit 3b682bb

Browse files
committed
feat: add rayon-compat crate
1 parent ff22eb4 commit 3b682bb

6 files changed

Lines changed: 236 additions & 12 deletions

File tree

Cargo.lock

Lines changed: 8 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ repository = "https://github.com/NthTensor/Forte"
88

99
[workspace]
1010
resolver = "2"
11-
members = ["ci"]
12-
exclude = ["coz"]
11+
members = ["ci", "rayon-compat"]
1312

1413
[dependencies]
1514
async-task = "4.7.1"

rayon-compat/Cargo.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[package]
2+
name = "rayon-compat"
3+
version = "1.12.1"
4+
edition = "2024"
5+
6+
[dependencies]
7+
forte = { path = ".." }
8+
9+
[features]
10+
web_spin_lock = []

rayon-compat/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Rayon Compat
2+
3+
This is a way to run `rayon` on top of `forte`! The `rayon-compat` crate mocks the important bits of the api of `rayon_core` in a pretty simple and crude way, which is none-the-less enough to support most of what `rayon` needs.
4+
5+
To use this crate, apply the following cargo patch like one of these:
6+
```
7+
// If you want to clone forte and use it locally
8+
[patch.crates-io]
9+
rayon-core = { path = "path to this repo", package = "rayon-compat" }
10+
11+
// If you want to use the latest published version of forte
12+
[patch.crates-io]
13+
rayon-core = { path = "https://github.com/NthTensor/Forte", package = "rayon-compat" }
14+
```

rayon-compat/src/lib.rs

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
use std::{
2+
marker::PhantomData,
3+
num::NonZero,
4+
sync::atomic::{AtomicBool, Ordering},
5+
};
6+
7+
pub static THREAD_POOL: forte::ThreadPool = forte::ThreadPool::new();
8+
9+
pub static STARTED: AtomicBool = AtomicBool::new(false);
10+
11+
#[inline]
12+
fn ensure_started() {
13+
if !STARTED.load(Ordering::Relaxed) {
14+
if !STARTED.swap(true, Ordering::Relaxed) {
15+
THREAD_POOL.resize_to_available();
16+
}
17+
}
18+
}
19+
20+
// -----------------------------------------------------------------------------
21+
// Join
22+
23+
#[derive(Debug)]
24+
pub struct FnContext {
25+
/// True if the task was migrated.
26+
migrated: bool,
27+
}
28+
29+
impl FnContext {
30+
#[inline]
31+
pub fn migrated(&self) -> bool {
32+
self.migrated
33+
}
34+
}
35+
36+
pub fn join_context<A, B, RA, RB>(oper_a: A, oper_b: B) -> (RA, RB)
37+
where
38+
A: FnOnce(FnContext) -> RA + Send,
39+
B: FnOnce(FnContext) -> RB + Send,
40+
RA: Send,
41+
RB: Send,
42+
{
43+
ensure_started();
44+
THREAD_POOL.join(
45+
|worker| {
46+
let ctx = FnContext {
47+
migrated: worker.migrated(),
48+
};
49+
oper_a(ctx)
50+
},
51+
|worker| {
52+
let ctx = FnContext {
53+
migrated: worker.migrated(),
54+
};
55+
oper_b(ctx)
56+
},
57+
)
58+
}
59+
60+
pub fn join<A, B, RA, RB>(oper_a: A, oper_b: B) -> (RA, RB)
61+
where
62+
A: FnOnce() -> RA + Send,
63+
B: FnOnce() -> RB + Send,
64+
RA: Send,
65+
RB: Send,
66+
{
67+
ensure_started();
68+
THREAD_POOL.join(|_| oper_a(), |_| oper_b())
69+
}
70+
71+
pub fn current_num_threads() -> usize {
72+
std::thread::available_parallelism()
73+
.map(NonZero::get)
74+
.unwrap_or(1)
75+
}
76+
77+
pub fn current_thread_index() -> Option<usize> {
78+
forte::Worker::map_current(|worker| worker.index())
79+
}
80+
81+
// -----------------------------------------------------------------------------
82+
// Scope
83+
84+
pub struct Scope<'scope> {
85+
phantom: PhantomData<&'scope ()>,
86+
}
87+
88+
impl<'scope> Scope<'scope> {
89+
pub fn spawn<BODY>(&self, body: BODY)
90+
where
91+
BODY: FnOnce(&Scope) + Send + 'scope,
92+
{
93+
unimplemented!();
94+
}
95+
}
96+
97+
pub fn scope<'scope, OP, R>(op: OP) -> R
98+
where
99+
OP: FnOnce(&Scope<'scope>) -> R + Send,
100+
R: Send,
101+
{
102+
unimplemented!();
103+
}
104+
105+
pub fn in_place_scope<'scope, OP, R>(op: OP) -> R
106+
where
107+
OP: FnOnce(&Scope<'scope>) -> R,
108+
{
109+
unimplemented!();
110+
}
111+
112+
pub fn in_place_scope_fifo() {
113+
unimplemented!()
114+
}
115+
116+
pub fn scope_fifo() {
117+
unimplemented!()
118+
}
119+
120+
// -----------------------------------------------------------------------------
121+
// Spawn
122+
123+
pub fn spawn<F>(func: F)
124+
where
125+
F: FnOnce() + Send + 'static,
126+
{
127+
unimplemented!();
128+
}
129+
130+
// -----------------------------------------------------------------------------
131+
// Fake stuff that dosn't work
132+
133+
pub struct ThreadBuilder;
134+
135+
pub struct ThreadPool;
136+
137+
pub struct ThreadPoolBuildError;
138+
139+
pub struct ThreadPoolBuilder;
140+
141+
pub struct BroadcastContext;
142+
143+
pub struct ScopeFifo;
144+
145+
pub struct Yield;
146+
147+
pub fn broadcast() {
148+
unimplemented!()
149+
}
150+
151+
pub fn spawn_broadcast() {
152+
unimplemented!()
153+
}
154+
155+
pub fn max_num_threads() {
156+
unimplemented!()
157+
}
158+
159+
pub fn spawn_fifo() {
160+
unimplemented!()
161+
}
162+
163+
pub fn yield_local() {
164+
unimplemented!()
165+
}
166+
167+
pub fn yield_now() {
168+
unimplemented!()
169+
}

src/thread_pool.rs

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,7 @@ thread_local! {
647647
/// Workers have one core memory-safety guarantee: Any jobs added to the worker
648648
/// will eventually be executed.
649649
pub struct Worker {
650+
pub(crate) migrated: Cell<bool>,
650651
pub(crate) lease: Lease,
651652
pub(crate) queue: JobQueue,
652653
}
@@ -736,6 +737,7 @@ impl Worker {
736737
// problem that the same thread can occupy multiple workers on the same
737738
// thread. We many eventually need to design something to prevent this.
738739
let worker = Worker {
740+
migrated: Cell::new(false),
739741
lease,
740742
queue: JobQueue::new(),
741743
};
@@ -749,7 +751,7 @@ impl Worker {
749751

750752
// Execute the work queue until it's empty
751753
while let Some(job_ref) = worker.queue.pop_front() {
752-
job_ref.execute(&worker);
754+
worker.execute(job_ref, false);
753755
}
754756

755757
// Swap back to pointing to the previous value (possibly null).
@@ -762,6 +764,10 @@ impl Worker {
762764
result
763765
}
764766

767+
pub fn index(&self) -> usize {
768+
self.lease.index
769+
}
770+
765771
/// Tries to promote the oldest job in the local stack to a shared job. If
766772
/// the local job queue is empty, or if the shared queue is full, this does
767773
/// nothing. If the promotion is successful, it tries to wake another
@@ -814,12 +820,15 @@ impl Worker {
814820

815821
/// Tries to find a job to execute, either in the local queue or shared on
816822
/// the threadpool.
823+
///
824+
/// The second value is true if the job was shared, or false if it was spawned locally.
817825
#[inline]
818-
pub fn find_work(&self) -> Option<JobRef> {
826+
pub fn find_work(&self) -> Option<(JobRef, bool)> {
819827
// We give preference first to things in our local deque, then in other
820828
// workers deques, and finally to injected jobs from the outside. The
821829
// idea is to finish what we started before we take on something new.
822-
self.queue.pop_back().or_else(|| self.claim_shared_job())
830+
self.queue.pop_back().map(|job| (job, false))
831+
.or_else(|| self.claim_shared_job().map(|job| (job, true)))
823832
}
824833

825834
/// Claims a shared job from the thread pool.
@@ -842,7 +851,7 @@ impl Worker {
842851
pub fn yield_local(&self) -> Yield {
843852
match self.queue.pop_back() {
844853
Some(job_ref) => {
845-
job_ref.execute(self);
854+
self.execute(job_ref, false);
846855
Yield::Executed
847856
}
848857
None => Yield::Idle,
@@ -860,13 +869,29 @@ impl Worker {
860869
#[inline]
861870
pub fn yield_now(&self) -> Yield {
862871
match self.find_work() {
863-
Some(job_ref) => {
864-
job_ref.execute(self);
872+
Some((job_ref, migrated)) => {
873+
self.execute(job_ref, migrated);
865874
Yield::Executed
866875
}
867876
None => Yield::Idle,
868877
}
869878
}
879+
880+
/// Returns `true` if the current job is executing on a different thread
881+
/// from the one on which it was created. Returns `false` if not executing a
882+
/// job, or if the current job was created on the current thread.
883+
#[inline]
884+
pub fn migrated(&self) -> bool {
885+
self.migrated.get()
886+
}
887+
888+
/// Executes a job on a worker
889+
#[inline]
890+
pub fn execute(&self, job_ref: JobRef, migrated: bool) {
891+
let migrated = self.migrated.replace(migrated);
892+
job_ref.execute(self);
893+
self.migrated.set(migrated);
894+
}
870895
}
871896

872897
// -----------------------------------------------------------------------------
@@ -1063,7 +1088,7 @@ impl Worker {
10631088

10641089
// Even if it's not the droid we were looking for, we must still
10651090
// execute the job.
1066-
job.execute(self);
1091+
self.execute(job, false);
10671092
}
10681093

10691094
// Wait for the job to complete, then return the result.
@@ -1217,7 +1242,7 @@ fn managed_worker(lease: Lease, halt: Arc<AtomicBool>, barrier: Arc<Barrier>) {
12171242
Worker::occupy(lease, |worker| {
12181243
while !halt.load(Ordering::Relaxed) {
12191244
if let Some(job) = worker.queue.pop_back() {
1220-
job.execute(worker);
1245+
worker.execute(job, false);
12211246
continue;
12221247
}
12231248

@@ -1226,7 +1251,7 @@ fn managed_worker(lease: Lease, halt: Arc<AtomicBool>, barrier: Arc<Barrier>) {
12261251
while !halt.load(Ordering::Relaxed) {
12271252
if let Some(job) = state.claim_shared_job() {
12281253
drop(state);
1229-
job.execute(worker);
1254+
worker.execute(job, true);
12301255
break;
12311256
}
12321257

0 commit comments

Comments
 (0)