Skip to content

Commit 1082a9f

Browse files
committed
feat: add rayon-compat crate
1 parent ff22eb4 commit 1082a9f

9 files changed

Lines changed: 339 additions & 43 deletions

File tree

Cargo.lock

Lines changed: 8 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ repository = "https://github.com/NthTensor/Forte"
88

99
[workspace]
1010
resolver = "2"
11-
members = ["ci"]
12-
exclude = ["coz"]
11+
members = ["ci", "rayon-compat"]
1312

1413
[dependencies]
1514
async-task = "4.7.1"

rayon-compat/Cargo.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[package]
2+
name = "rayon-compat"
3+
version = "1.12.1"
4+
edition = "2024"
5+
6+
[dependencies]
7+
forte = { path = ".." }
8+
9+
[features]
10+
web_spin_lock = []

rayon-compat/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Rayon Compat
2+
3+
This is a way to run `rayon` on top of `forte`! The `rayon-compat` crate mocks the important bits of the api of `rayon_core` in a pretty simple and crude way, which is none-the-less enough to support most of what `rayon` needs.
4+
5+
To use this crate, apply the following cargo patch like one of these:
6+
```
7+
// If you want to clone forte and use it locally
8+
[patch.crates-io]
9+
rayon-core = { path = "path to this repo", package = "rayon-compat" }
10+
11+
// If you want to use the latest published version of forte
12+
[patch.crates-io]
13+
rayon-core = { path = "https://github.com/NthTensor/Forte", package = "rayon-compat" }
14+
```

rayon-compat/src/lib.rs

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
use std::{
2+
pin::Pin,
3+
sync::atomic::{AtomicBool, Ordering},
4+
};
5+
6+
pub static THREAD_POOL: forte::ThreadPool = const { forte::ThreadPool::new() };
7+
8+
pub static STARTED: AtomicBool = const { AtomicBool::new(false) };
9+
10+
#[inline(always)]
11+
fn ensure_started() {
12+
if !STARTED.load(Ordering::Relaxed) && !STARTED.swap(true, Ordering::Relaxed) {
13+
THREAD_POOL.resize_to_available();
14+
}
15+
}
16+
17+
// -----------------------------------------------------------------------------
18+
// Join
19+
20+
#[derive(Debug)]
21+
pub struct FnContext {
22+
/// True if the task was migrated.
23+
migrated: bool,
24+
}
25+
26+
impl FnContext {
27+
#[inline(always)]
28+
pub fn migrated(&self) -> bool {
29+
self.migrated
30+
}
31+
}
32+
33+
#[inline(always)]
34+
pub fn join_context<A, B, RA, RB>(oper_a: A, oper_b: B) -> (RA, RB)
35+
where
36+
A: FnOnce(FnContext) -> RA + Send,
37+
B: FnOnce(FnContext) -> RB + Send,
38+
RA: Send,
39+
RB: Send,
40+
{
41+
ensure_started();
42+
THREAD_POOL.join(
43+
|worker| {
44+
let ctx = FnContext {
45+
migrated: worker.migrated(),
46+
};
47+
oper_a(ctx)
48+
},
49+
|worker| {
50+
let ctx = FnContext {
51+
migrated: worker.migrated(),
52+
};
53+
oper_b(ctx)
54+
},
55+
)
56+
}
57+
58+
#[inline(always)]
59+
pub fn join<A, B, RA, RB>(oper_a: A, oper_b: B) -> (RA, RB)
60+
where
61+
A: FnOnce() -> RA + Send,
62+
B: FnOnce() -> RB + Send,
63+
RA: Send,
64+
RB: Send,
65+
{
66+
ensure_started();
67+
THREAD_POOL.join(|_| oper_a(), |_| oper_b())
68+
}
69+
70+
#[inline(always)]
71+
pub fn current_num_threads() -> usize {
72+
64 // Forte prefers smaller tasks, so it's better to lie to rayon about the size of the pool
73+
}
74+
75+
#[inline(always)]
76+
pub fn current_thread_index() -> Option<usize> {
77+
forte::Worker::map_current(|worker| worker.index())
78+
}
79+
80+
// -----------------------------------------------------------------------------
81+
// Scope
82+
83+
pub struct Scope<'r, 'scope: 'r> {
84+
inner: Pin<&'r forte::Scope<'scope>>,
85+
}
86+
87+
impl<'scope> Scope<'_, 'scope> {
88+
#[inline(always)]
89+
pub fn spawn<BODY>(&self, body: BODY)
90+
where
91+
BODY: FnOnce(&Scope) + Send + 'scope,
92+
{
93+
self.inner.spawn(|inner| {
94+
let scope = Scope { inner };
95+
body(&scope)
96+
});
97+
}
98+
}
99+
100+
#[inline(always)]
101+
pub fn scope<'scope, OP, R>(op: OP) -> R
102+
where
103+
OP: FnOnce(&Scope<'_, 'scope>) -> R + Send,
104+
R: Send,
105+
{
106+
THREAD_POOL.scope(|inner| {
107+
let scope = Scope { inner };
108+
op(&scope)
109+
})
110+
}
111+
112+
#[inline(always)]
113+
pub fn in_place_scope<'scope, OP, R>(op: OP) -> R
114+
where
115+
OP: FnOnce(&Scope<'_, 'scope>) -> R,
116+
{
117+
THREAD_POOL.scope(|inner| {
118+
let scope = Scope { inner };
119+
op(&scope)
120+
})
121+
}
122+
123+
// -----------------------------------------------------------------------------
124+
// Spawn
125+
126+
#[inline(always)]
127+
pub fn spawn<F>(func: F)
128+
where
129+
F: FnOnce() + Send + 'static,
130+
{
131+
THREAD_POOL.spawn(|_| func())
132+
}
133+
134+
// -----------------------------------------------------------------------------
135+
// Fake stuff that dosn't work
136+
137+
pub struct ThreadBuilder;
138+
139+
pub struct ThreadPool;
140+
141+
pub struct ThreadPoolBuildError;
142+
143+
pub struct ThreadPoolBuilder;
144+
145+
pub struct BroadcastContext;
146+
147+
pub struct ScopeFifo;
148+
149+
pub struct Yield;
150+
151+
pub fn broadcast() {
152+
unimplemented!()
153+
}
154+
155+
pub fn spawn_broadcast() {
156+
unimplemented!()
157+
}
158+
159+
pub fn max_num_threads() {
160+
unimplemented!()
161+
}
162+
163+
pub fn scope_fifo() {
164+
unimplemented!()
165+
}
166+
167+
pub fn in_place_scope_fifo() {
168+
unimplemented!()
169+
}
170+
171+
pub fn spawn_fifo() {
172+
unimplemented!()
173+
}
174+
175+
pub fn yield_local() {
176+
unimplemented!()
177+
}
178+
179+
pub fn yield_now() {
180+
unimplemented!()
181+
}

src/job.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@ use alloc::collections::VecDeque;
1616
use core::cell::UnsafeCell;
1717
use core::mem::ManuallyDrop;
1818
use core::ptr::NonNull;
19+
use std::thread::Result as ThreadResult;
1920

2021
use crate::signal::Signal;
2122
use crate::thread_pool::Worker;
23+
use crate::unwind;
2224

2325
// -----------------------------------------------------------------------------
2426
// Runnable
@@ -152,7 +154,7 @@ impl JobQueue {
152154
/// This is analogous to the chili type `JobStack` and the rayon type `StackJob`.
153155
pub struct StackJob<F, T> {
154156
f: UnsafeCell<ManuallyDrop<F>>,
155-
signal: Signal<T>,
157+
signal: Signal<ThreadResult<T>>,
156158
}
157159

158160
impl<F, T> StackJob<F, T>
@@ -210,7 +212,7 @@ where
210212
/// closure's return value is sent over this signal after the job is
211213
/// executed.
212214
#[inline(always)]
213-
pub fn signal(&self) -> &Signal<T> {
215+
pub fn signal(&self) -> &Signal<ThreadResult<T>> {
214216
&self.signal
215217
}
216218
}
@@ -235,20 +237,25 @@ where
235237
// SAFETY: The caller ensures `this` can be converted into an immutable
236238
// reference.
237239
let this = unsafe { this.cast::<Self>().as_ref() };
240+
// Create an abort guard. If the closure panics, this will convert the
241+
// panic into an abort. Doing so prevents use-after-free for other elements of the stack.
242+
let abort_guard = unwind::AbortOnDrop;
238243
// SAFETY: This memory location is accessed only in this function and in
239244
// `unwrap`. The latter cannot have been called, because it drops the
240245
// stack job, so, since this function is called only once, we can
241246
// guarantee that we have exclusive access.
242247
let f_ref = unsafe { &mut *this.f.get() };
243248
// SAFETY: The caller ensures this function is called only once.
244249
let f = unsafe { ManuallyDrop::take(f_ref) };
245-
// Run the job.
246-
let result = f(worker);
250+
// Run the job. If the job panics, we propagate the panic back to the main thread.
251+
let result = unwind::halt_unwinding(|| f(worker));
247252
// SAFETY: This is valid for the access used by `send` because
248253
// `&this.signal` is an immutable reference to a `Signal`. Because
249254
// `send` is only called in this function, and this function is never
250255
// called again, `send` is never called again.
251256
unsafe { Signal::send(&this.signal, result) }
257+
// Forget the abort guard, re-enabling panics.
258+
core::mem::forget(abort_guard);
252259
}
253260
}
254261

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ mod job;
4040
mod scope;
4141
mod signal;
4242
mod thread_pool;
43+
mod unwind;
4344

4445
// -----------------------------------------------------------------------------
4546
// Top-level exports

0 commit comments

Comments
 (0)