Skip to content

Commit ab1e90a

Browse files
committed
feat: add rayon-compat crate
1 parent ff22eb4 commit ab1e90a

9 files changed

Lines changed: 341 additions & 43 deletions

File tree

Cargo.lock

Lines changed: 8 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ repository = "https://github.com/NthTensor/Forte"
88

99
[workspace]
1010
resolver = "2"
11-
members = ["ci"]
12-
exclude = ["coz"]
11+
members = ["ci", "rayon-compat"]
1312

1413
[dependencies]
1514
async-task = "4.7.1"

rayon-compat/Cargo.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[package]
2+
name = "rayon-compat"
3+
version = "1.12.1"
4+
edition = "2024"
5+
6+
[dependencies]
7+
forte = { path = ".." }
8+
9+
[features]
10+
web_spin_lock = []

rayon-compat/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Rayon Compat
2+
3+
This is a way to run `rayon` on top of `forte`! The `rayon-compat` crate mocks the important bits of the api of `rayon_core` in a pretty simple and crude way, which is none-the-less enough to support most of what `rayon` needs.
4+
5+
To use this crate, apply the following cargo patch like one of these:
6+
```
7+
// If you want to clone forte and use it locally
8+
[patch.crates-io]
9+
rayon-core = { path = "path to this repo", package = "rayon-compat" }
10+
11+
// If you want to use the latest published version of forte
12+
[patch.crates-io]
13+
rayon-core = { path = "https://github.com/NthTensor/Forte", package = "rayon-compat" }
14+
```

rayon-compat/src/lib.rs

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
use std::{
2+
pin::Pin,
3+
sync::atomic::{AtomicBool, Ordering},
4+
};
5+
6+
pub static THREAD_POOL: forte::ThreadPool = const { forte::ThreadPool::new() };
7+
8+
pub static STARTED: AtomicBool = const { AtomicBool::new(false) };
9+
10+
#[inline(always)]
11+
fn ensure_started() {
12+
if !STARTED.load(Ordering::Relaxed) {
13+
if !STARTED.swap(true, Ordering::Relaxed) {
14+
THREAD_POOL.resize_to_available();
15+
}
16+
}
17+
}
18+
19+
// -----------------------------------------------------------------------------
20+
// Join
21+
22+
#[derive(Debug)]
23+
pub struct FnContext {
24+
/// True if the task was migrated.
25+
migrated: bool,
26+
}
27+
28+
impl FnContext {
29+
#[inline(always)]
30+
pub fn migrated(&self) -> bool {
31+
self.migrated
32+
}
33+
}
34+
35+
#[inline(always)]
36+
pub fn join_context<A, B, RA, RB>(oper_a: A, oper_b: B) -> (RA, RB)
37+
where
38+
A: FnOnce(FnContext) -> RA + Send,
39+
B: FnOnce(FnContext) -> RB + Send,
40+
RA: Send,
41+
RB: Send,
42+
{
43+
ensure_started();
44+
THREAD_POOL.join(
45+
|worker| {
46+
let ctx = FnContext {
47+
migrated: worker.migrated(),
48+
};
49+
oper_a(ctx)
50+
},
51+
|worker| {
52+
let ctx = FnContext {
53+
migrated: worker.migrated(),
54+
};
55+
oper_b(ctx)
56+
},
57+
)
58+
}
59+
60+
#[inline(always)]
61+
pub fn join<A, B, RA, RB>(oper_a: A, oper_b: B) -> (RA, RB)
62+
where
63+
A: FnOnce() -> RA + Send,
64+
B: FnOnce() -> RB + Send,
65+
RA: Send,
66+
RB: Send,
67+
{
68+
ensure_started();
69+
THREAD_POOL.join(|_| oper_a(), |_| oper_b())
70+
}
71+
72+
#[inline(always)]
73+
pub fn current_num_threads() -> usize {
74+
64 // Forte prefers smaller tasks, so it's better to lie to rayon about the size of the pool
75+
}
76+
77+
#[inline(always)]
78+
pub fn current_thread_index() -> Option<usize> {
79+
forte::Worker::map_current(|worker| worker.index())
80+
}
81+
82+
// -----------------------------------------------------------------------------
83+
// Scope
84+
85+
pub struct Scope<'r, 'scope: 'r> {
86+
inner: Pin<&'r forte::Scope<'scope>>,
87+
}
88+
89+
impl<'scope> Scope<'_, 'scope> {
90+
#[inline(always)]
91+
pub fn spawn<BODY>(&self, body: BODY)
92+
where
93+
BODY: FnOnce(&Scope) + Send + 'scope,
94+
{
95+
self.inner.spawn(|inner| {
96+
let scope = Scope { inner };
97+
body(&scope)
98+
});
99+
}
100+
}
101+
102+
#[inline(always)]
103+
pub fn scope<'scope, OP, R>(op: OP) -> R
104+
where
105+
OP: FnOnce(&Scope<'_, 'scope>) -> R + Send,
106+
R: Send,
107+
{
108+
THREAD_POOL.scope(|inner| {
109+
let scope = Scope { inner };
110+
op(&scope)
111+
})
112+
}
113+
114+
#[inline(always)]
115+
pub fn in_place_scope<'scope, OP, R>(op: OP) -> R
116+
where
117+
OP: FnOnce(&Scope<'_, 'scope>) -> R,
118+
{
119+
THREAD_POOL.scope(|inner| {
120+
let scope = Scope { inner };
121+
op(&scope)
122+
})
123+
}
124+
125+
// -----------------------------------------------------------------------------
126+
// Spawn
127+
128+
#[inline(always)]
129+
pub fn spawn<F>(func: F)
130+
where
131+
F: FnOnce() + Send + 'static,
132+
{
133+
THREAD_POOL.spawn(|_| func())
134+
}
135+
136+
// -----------------------------------------------------------------------------
137+
// Fake stuff that dosn't work
138+
139+
pub struct ThreadBuilder;
140+
141+
pub struct ThreadPool;
142+
143+
pub struct ThreadPoolBuildError;
144+
145+
pub struct ThreadPoolBuilder;
146+
147+
pub struct BroadcastContext;
148+
149+
pub struct ScopeFifo;
150+
151+
pub struct Yield;
152+
153+
pub fn broadcast() {
154+
unimplemented!()
155+
}
156+
157+
pub fn spawn_broadcast() {
158+
unimplemented!()
159+
}
160+
161+
pub fn max_num_threads() {
162+
unimplemented!()
163+
}
164+
165+
pub fn scope_fifo() {
166+
unimplemented!()
167+
}
168+
169+
pub fn in_place_scope_fifo() {
170+
unimplemented!()
171+
}
172+
173+
pub fn spawn_fifo() {
174+
unimplemented!()
175+
}
176+
177+
pub fn yield_local() {
178+
unimplemented!()
179+
}
180+
181+
pub fn yield_now() {
182+
unimplemented!()
183+
}

src/job.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@ use alloc::collections::VecDeque;
1616
use core::cell::UnsafeCell;
1717
use core::mem::ManuallyDrop;
1818
use core::ptr::NonNull;
19+
use std::thread::Result as ThreadResult;
1920

2021
use crate::signal::Signal;
2122
use crate::thread_pool::Worker;
23+
use crate::unwind;
2224

2325
// -----------------------------------------------------------------------------
2426
// Runnable
@@ -152,7 +154,7 @@ impl JobQueue {
152154
/// This is analogous to the chili type `JobStack` and the rayon type `StackJob`.
153155
pub struct StackJob<F, T> {
154156
f: UnsafeCell<ManuallyDrop<F>>,
155-
signal: Signal<T>,
157+
signal: Signal<ThreadResult<T>>,
156158
}
157159

158160
impl<F, T> StackJob<F, T>
@@ -210,7 +212,7 @@ where
210212
/// closure's return value is sent over this signal after the job is
211213
/// executed.
212214
#[inline(always)]
213-
pub fn signal(&self) -> &Signal<T> {
215+
pub fn signal(&self) -> &Signal<ThreadResult<T>> {
214216
&self.signal
215217
}
216218
}
@@ -235,20 +237,25 @@ where
235237
// SAFETY: The caller ensures `this` can be converted into an immutable
236238
// reference.
237239
let this = unsafe { this.cast::<Self>().as_ref() };
240+
// Create an abort guard. If the closure panics, this will convert the
241+
// panic into an abort. Doing so prevents use-after-free for other elements of the stack.
242+
let abort_guard = unwind::AbortOnDrop;
238243
// SAFETY: This memory location is accessed only in this function and in
239244
// `unwrap`. The latter cannot have been called, because it drops the
240245
// stack job, so, since this function is called only once, we can
241246
// guarantee that we have exclusive access.
242247
let f_ref = unsafe { &mut *this.f.get() };
243248
// SAFETY: The caller ensures this function is called only once.
244249
let f = unsafe { ManuallyDrop::take(f_ref) };
245-
// Run the job.
246-
let result = f(worker);
250+
// Run the job. If the job panics, we propagate the panic back to the main thread.
251+
let result = unwind::halt_unwinding(|| f(worker));
247252
// SAFETY: This is valid for the access used by `send` because
248253
// `&this.signal` is an immutable reference to a `Signal`. Because
249254
// `send` is only called in this function, and this function is never
250255
// called again, `send` is never called again.
251256
unsafe { Signal::send(&this.signal, result) }
257+
// Forget the abort guard, re-enabling panics.
258+
core::mem::forget(abort_guard);
252259
}
253260
}
254261

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ mod job;
4040
mod scope;
4141
mod signal;
4242
mod thread_pool;
43+
mod unwind;
4344

4445
// -----------------------------------------------------------------------------
4546
// Top-level exports

0 commit comments

Comments
 (0)