Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 64 additions & 53 deletions core/src/co_pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ mod state;
/// Creator for coroutine pool.
mod creator;

/// `task_name` -> `co_name`
static RUNNING_TASKS: Lazy<DashMap<&str, &str>> = Lazy::new(DashMap::new);
/// `task_id` -> `co_id`
static RUNNING_TASKS: Lazy<DashMap<u64, u64>> = Lazy::new(DashMap::new);

static CANCEL_TASKS: Lazy<DashSet<&str>> = Lazy::new(DashSet::new);
static CANCEL_TASKS: Lazy<DashSet<u64>> = Lazy::new(DashSet::new);

/// The coroutine pool impls.
#[repr(C)]
Expand All @@ -55,10 +55,10 @@ pub struct CoroutinePool<'p> {
//阻滞器
blocker: Arc<CondvarBlocker>,
//正在等待结果的
waits: DashMap<&'p str, Arc<(Mutex<bool>, Condvar)>>,
waits: DashMap<u64, Arc<(Mutex<bool>, Condvar)>>,
//任务执行结果
results: DashMap<String, Result<Option<usize>, &'p str>>,
no_waits: DashSet<&'p str>,
results: DashMap<u64, Result<Option<usize>, &'p str>>,
no_waits: DashSet<u64>,
}

impl Drop for CoroutinePool<'_> {
Expand Down Expand Up @@ -188,7 +188,7 @@ impl<'p> CoroutinePool<'p> {

/// Returns `true` if the task queue is empty.
pub fn is_empty(&self) -> bool {
self.size() == 0
self.task_queue.is_empty()
}

/// Returns the number of tasks owned by this pool.
Expand All @@ -210,7 +210,14 @@ impl<'p> CoroutinePool<'p> {
}

fn do_stop(&mut self, dur: Duration) -> std::io::Result<()> {
_ = self.try_timed_schedule_task(dur)?;
let timeout_time = get_timeout_time(dur);
loop {
_ = self.try_timeout_schedule_task(timeout_time)?;
if self.get_running_size() == 0 || timeout_time.saturating_sub(now()) == 0 {
break;
}
std::thread::sleep(Duration::from_millis(1));
}
assert_eq!(PoolState::Stopping, self.stopped()?);
self.do_clean();
Ok(())
Expand All @@ -219,11 +226,11 @@ impl<'p> CoroutinePool<'p> {
fn do_clean(&mut self) {
// clean up remaining wait tasks
for r in &self.waits {
let task_name = *r.key();
let task_id = *r.key();
_ = self
.results
.insert(task_name.to_string(), Err("The coroutine pool has stopped"));
self.notify(task_name);
.insert(task_id, Err("The coroutine pool has stopped"));
self.notify(task_id);
}
}

Expand All @@ -237,16 +244,22 @@ impl<'p> CoroutinePool<'p> {
func: impl FnOnce(Option<usize>) -> Option<usize> + 'p,
param: Option<usize>,
priority: Option<c_longlong>,
) -> std::io::Result<String> {
) -> std::io::Result<u64> {
match self.state() {
PoolState::Running => {}
PoolState::Stopping | PoolState::Stopped => {
return Err(Error::other("The coroutine pool is stopping or stopped !"))
}
}
let name = name.unwrap_or(format!("{}@{}", self.name(), uuid::Uuid::new_v4()));
self.submit_raw_task(Task::new(name.clone(), func, param, priority));
Ok(name)
let task = Task::new(
name.unwrap_or(format!("{}@{}", self.name(), uuid::Uuid::new_v4())),
func,
param,
priority,
);
let task_id = task.id();
self.submit_raw_task(task);
Ok(task_id)
}

/// Submit new task to this pool.
Expand All @@ -258,52 +271,51 @@ impl<'p> CoroutinePool<'p> {
self.blocker.notify();
}

/// Attempt to obtain task results with the given `task_name`.
pub fn try_take_task_result(&self, task_name: &str) -> Option<Result<Option<usize>, &'p str>> {
self.results.remove(task_name).map(|(_, r)| r)
/// Attempt to obtain task results with the given `task_id`.
pub fn try_take_task_result(&self, task_id: u64) -> Option<Result<Option<usize>, &'p str>> {
self.results.remove(&task_id).map(|(_, r)| r)
}

/// clean the task result data.
pub fn clean_task_result(&self, task_name: &str) {
if self.try_take_task_result(task_name).is_some() {
pub fn clean_task_result(&self, task_id: u64) {
if self.try_take_task_result(task_id).is_some() {
return;
}
_ = self.no_waits.insert(Box::leak(Box::from(task_name)));
_ = CANCEL_TASKS.remove(task_name);
_ = self.no_waits.insert(task_id);
_ = CANCEL_TASKS.remove(&task_id);
}

/// Use the given `task_name` to obtain task results, and if no results are found,
/// Use the given `task_id` to obtain task results, and if no results are found,
/// block the current thread for `wait_time`.
///
/// # Errors
/// if timeout
pub fn wait_task_result(
&self,
task_name: &str,
task_id: u64,
wait_time: Duration,
) -> std::io::Result<Result<Option<usize>, &str>> {
let key = Box::leak(Box::from(task_name));
if let Some(r) = self.try_take_task_result(key) {
self.notify(key);
if let Some(r) = self.try_take_task_result(task_id) {
self.notify(task_id);
return Ok(r);
}
if SchedulableCoroutine::current().is_some() {
let timeout_time = get_timeout_time(wait_time);
loop {
_ = self.try_run();
if let Some(r) = self.try_take_task_result(key) {
if let Some(r) = self.try_take_task_result(task_id) {
return Ok(r);
}
if timeout_time.saturating_sub(now()) == 0 {
return Err(Error::new(ErrorKind::TimedOut, "wait timeout"));
}
}
}
let arc = if let Some(arc) = self.waits.get(key) {
let arc = if let Some(arc) = self.waits.get(&task_id) {
arc.clone()
} else {
let arc = Arc::new((Mutex::new(true), Condvar::new()));
assert!(self.waits.insert(key, arc.clone()).is_none());
assert!(self.waits.insert(task_id, arc.clone()).is_none());
arc
};
let (lock, cvar) = &*arc;
Expand All @@ -315,8 +327,8 @@ impl<'p> CoroutinePool<'p> {
)
.map_err(|e| Error::other(format!("{e}")))?,
);
if let Some(r) = self.try_take_task_result(key) {
self.notify(key);
if let Some(r) = self.try_take_task_result(task_id) {
self.notify(task_id);
return Ok(r);
}
Err(Error::new(ErrorKind::TimedOut, "wait timeout"))
Expand Down Expand Up @@ -402,32 +414,31 @@ impl<'p> CoroutinePool<'p> {

fn try_run(&self) -> Option<()> {
self.task_queue.pop().map(|task| {
let tname = task.get_name().to_string().leak();
if CANCEL_TASKS.contains(tname) {
_ = CANCEL_TASKS.remove(tname);
warn!("Cancel task:{} successfully !", tname);
let task_id = task.id();
if CANCEL_TASKS.contains(&task_id) {
_ = CANCEL_TASKS.remove(&task_id);
warn!("Cancel task:{} successfully !", task_id);
return;
}
if let Some(co) = SchedulableCoroutine::current() {
_ = RUNNING_TASKS.insert(tname, co.name());
_ = RUNNING_TASKS.insert(task_id, co.id);
}
let (task_name, result) = task.run();
_ = RUNNING_TASKS.remove(tname);
let n = task_name.clone().leak();
if self.no_waits.contains(n) {
_ = self.no_waits.remove(n);
let (_, result) = task.run();
_ = RUNNING_TASKS.remove(&task_id);
Comment thread
loongs-zhang marked this conversation as resolved.
if self.no_waits.contains(&task_id) {
_ = self.no_waits.remove(&task_id);
return;
}
assert!(
self.results.insert(task_name.clone(), result).is_none(),
self.results.insert(task_id, result).is_none(),
"The previous result was not retrieved in a timely manner"
);
self.notify(&task_name);
self.notify(task_id);
})
}

fn notify(&self, task_name: &str) {
if let Some((_, arc)) = self.waits.remove(task_name) {
fn notify(&self, task_id: u64) {
if let Some((_, arc)) = self.waits.remove(&task_id) {
let (lock, cvar) = &*arc;
let mut pending = lock.lock().expect("notify task failed");
*pending = false;
Expand All @@ -436,9 +447,9 @@ impl<'p> CoroutinePool<'p> {
}

/// Try to cancel a task.
pub fn try_cancel_task(task_name: &str) {
pub fn try_cancel_task(task_id: u64) {
// 检查正在运行的任务是否是要取消的任务
if let Some(info) = RUNNING_TASKS.get(task_name) {
if let Some(info) = RUNNING_TASKS.get(&task_id) {
let co_name = *info;
// todo windows support
#[allow(unused_variables)]
Expand All @@ -450,26 +461,26 @@ impl<'p> CoroutinePool<'p> {
{
warn!(
"Attempt to cancel task:{} running on coroutine:{} by thread:{}, cancelling...",
task_name, co_name, pthread
task_id, co_name, pthread
);
} else {
error!(
"Attempt to cancel task:{} running on coroutine:{} by thread:{} failed !",
task_name, co_name, pthread
task_id, co_name, pthread
);
}
} else {
// 添加到待取消队列
Scheduler::try_cancel_coroutine(co_name);
warn!(
"Attempt to cancel task:{} running on coroutine:{}, cancelling...",
task_name, co_name
task_id, co_name
);
}
} else {
// 添加到待取消队列
_ = CANCEL_TASKS.insert(Box::leak(Box::from(task_name)));
warn!("Attempt to cancel task:{}, cancelling...", task_name);
_ = CANCEL_TASKS.insert(task_id);
warn!("Attempt to cancel task:{}, cancelling...", task_id);
}
}

Expand Down
14 changes: 13 additions & 1 deletion core/src/co_pool/task.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::catch;
use crate::common::ordered_work_steal::Ordered;
use std::ffi::c_longlong;
use std::hash::{DefaultHasher, Hash, Hasher};

/// 做C兼容时会用到
pub type UserTaskFunc = extern "C" fn(usize) -> usize;
Expand All @@ -10,6 +11,7 @@ pub type UserTaskFunc = extern "C" fn(usize) -> usize;
#[derive(educe::Educe)]
#[educe(Debug)]
pub struct Task<'t> {
id: u64,
name: String,
#[educe(Debug(ignore))]
func: Box<dyn FnOnce(Option<usize>) -> Option<usize> + 't>,
Expand All @@ -25,7 +27,11 @@ impl<'t> Task<'t> {
param: Option<usize>,
priority: Option<c_longlong>,
) -> Self {
let mut hasher = DefaultHasher::new();
name.hash(&mut hasher);
let id = hasher.finish();
Task {
id,
name,
Comment on lines +30 to 35
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Task IDs are derived from DefaultHasher(name), which is not collision-free. Because task_id is now used as the unique key in global maps (RUNNING_TASKS, CANCEL_TASKS, results, waits), a collision would cause incorrect cancellation/result delivery. Prefer generating IDs from a monotonic counter/UUID (or at least detect and resolve collisions when inserting into maps).

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

func: Box::new(func),
param,
Expand All @@ -35,10 +41,16 @@ impl<'t> Task<'t> {

/// get the task name.
#[must_use]
pub fn get_name(&self) -> &str {
pub fn name(&self) -> &str {
&self.name
}

/// get the task id.
#[must_use]
pub fn id(&self) -> u64 {
self.id
}

/// execute the task
///
/// # Errors
Expand Down
2 changes: 1 addition & 1 deletion core/src/common/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub const DEFAULT_STACK_SIZE: usize = 128 * 1024;

/// A user data used to indicate the timeout of `io_uring_enter`.
#[cfg(all(target_os = "linux", feature = "io_uring"))]
pub const IO_URING_TIMEOUT_USERDATA: usize = usize::MAX - 1;
pub const IO_URING_TIMEOUT_USERDATA: u64 = u64::MAX - 1;

/// Coroutine global queue bean name.
pub const COROUTINE_GLOBAL_QUEUE_BEAN: &str = "coroutineGlobalQueueBean";
Expand Down
6 changes: 6 additions & 0 deletions core/src/coroutine/korosensei.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::cell::{Cell, RefCell, UnsafeCell};
use std::collections::VecDeque;
use std::ffi::c_longlong;
use std::fmt::Debug;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::io::Error;

cfg_if::cfg_if! {
Expand All @@ -25,6 +26,7 @@ cfg_if::cfg_if! {
/// Use `corosensei` as the low-level coroutine.
#[repr(C)]
pub struct Coroutine<'c, Param, Yield, Return> {
pub(crate) id: u64,
pub(crate) name: String,
inner: corosensei::Coroutine<Param, Yield, Result<Return, &'static str>, DefaultStack>,
pub(crate) state: Cell<CoroutineState<Yield, Return>>,
Expand Down Expand Up @@ -427,8 +429,12 @@ where
co_name
)
});
let mut hasher = DefaultHasher::new();
name.hash(&mut hasher);
let id = hasher.finish();
#[allow(unused_mut)]
let mut co = Coroutine {
id,
name,
Comment on lines +432 to 438
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coroutine IDs are derived from DefaultHasher(name). Since co_id is used as the unique key for cancellation, syscall tables, and resumption, any hash collision would conflate two coroutines. Prefer a collision-free ID source (e.g., atomic counter or UUID bytes) instead of a hash, or add collision detection when inserting into syscall/RUNNING_COROUTINES/CANCEL_COROUTINES.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

inner,
stack_infos,
Expand Down
6 changes: 6 additions & 0 deletions core/src/coroutine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ impl<'c, Param, Yield, Return> Coroutine<'c, Param, Yield, Return> {
&self.name
}

/// Get the id of this coroutine.
#[allow(clippy::cast_possible_truncation)]
pub fn id(&self) -> u64 {
self.id
}

/// Returns the current state of this `StateCoroutine`.
pub fn state(&self) -> CoroutineState<Yield, Return>
where
Expand Down
Loading
Loading