diff --git a/libgo/pool/async_coroutine_pool.cpp b/libgo/pool/async_coroutine_pool.cpp index 04ea4c4e..6d9fc5ab 100644 --- a/libgo/pool/async_coroutine_pool.cpp +++ b/libgo/pool/async_coroutine_pool.cpp @@ -7,9 +7,10 @@ AsyncCoroutinePool * AsyncCoroutinePool::Create(size_t maxCallbackPoints) { return new AsyncCoroutinePool(maxCallbackPoints); } -void AsyncCoroutinePool::InitCoroutinePool(size_t maxCoroutineCount) +void AsyncCoroutinePool::InitCoroutinePool(size_t maxCoroutineCount, size_t stackSize) { maxCoroutineCount_ = maxCoroutineCount; + stackSize_ = stackSize; } void AsyncCoroutinePool::Start(int minThreadNumber, int maxThreadNumber) { @@ -21,11 +22,21 @@ void AsyncCoroutinePool::Start(int minThreadNumber, int maxThreadNumber) maxCoroutineCount_ = (std::max)(minThreadNumber * 128, maxThreadNumber); maxCoroutineCount_ = (std::min)(maxCoroutineCount_, 10240); } - for (size_t i = 0; i < maxCoroutineCount_; ++i) { - go co_scheduler(scheduler_) [this]{ - this->Go(); - }; + + if(stackSize_ > 0) { + for (size_t i = 0; i < maxCoroutineCount_; ++i) { + go_stack(stackSize_) co_scheduler(scheduler_) [this]{ + this->Go(); + }; + } + } else { + for (size_t i = 0; i < maxCoroutineCount_; ++i) { + go co_scheduler(scheduler_) [this]{ + this->Go(); + }; + } } + } void AsyncCoroutinePool::Go() { @@ -33,21 +44,28 @@ void AsyncCoroutinePool::Go() PoolTask task; tasks_ >> task; + taskRunningPoints++; + if (task.func_) task.func_(); - if (!task.cb_) + if (!task.cb_) { + taskRunningPoints--; continue; + } size_t pointsCount = pointsCount_; if (!pointsCount) { task.cb_(); + taskRunningPoints--; continue; } size_t idx = ++robin_ % pointsCount; points_[idx]->Post(std::move(task.cb_)); points_[idx]->Notify(); + + taskRunningPoints--; } } void AsyncCoroutinePool::Post(Func const& func, Func const& callback) @@ -55,6 +73,13 @@ void AsyncCoroutinePool::Post(Func const& func, Func const& callback) PoolTask task{func, callback}; tasks_ << std::move(task); } + +void AsyncCoroutinePool::Post(Func const& func) +{ + PoolTask task{func, NULL}; + tasks_ << std::move(task); +} + bool AsyncCoroutinePool::AddCallbackPoint(AsyncCoroutinePool::CallbackPoint * point) { size_t writeIdx = writePointsCount_++; @@ -76,6 +101,11 @@ AsyncCoroutinePool::AsyncCoroutinePool(size_t maxCallbackPoints) points_ = new CallbackPoint*[maxCallbackPoints_]; } +void AsyncCoroutinePool::WaitStop() +{ + while (!tasks_.empty() || taskRunningPoints.load() != 0); +} + size_t AsyncCoroutinePool::CallbackPoint::Run(size_t maxTrigger) { size_t i = 0; diff --git a/libgo/pool/async_coroutine_pool.h b/libgo/pool/async_coroutine_pool.h index e6321ae4..7226836e 100644 --- a/libgo/pool/async_coroutine_pool.h +++ b/libgo/pool/async_coroutine_pool.h @@ -16,13 +16,17 @@ class AsyncCoroutinePool typedef std::function Func; // 初始化协程数量 - void InitCoroutinePool(size_t maxCoroutineCount); + void InitCoroutinePool(size_t maxCoroutineCount, size_t stackSize = 0); // 启动协程池 - void Start(int minThreadNumber, int maxThreadNumber = 0); + void Start(int minThreadNumber, int maxThreadNumber); void Post(Func const& func, Func const& callback); + void Post(Func const& func); + + void WaitStop(); + template void Post(Channel const& ret, std::function const& func) { Post([=]{ ret << func(); }, NULL); @@ -75,11 +79,13 @@ class AsyncCoroutinePool private: size_t maxCoroutineCount_; + size_t stackSize_; std::atomic coroutineCount_{0}; Scheduler* scheduler_; Channel tasks_; std::atomic pointsCount_{0}; std::atomic writePointsCount_{0}; + std::atomic taskRunningPoints{0}; size_t maxCallbackPoints_; std::atomic robin_{0}; CallbackPoint ** points_; diff --git a/libgo/routine_sync/rutex.h b/libgo/routine_sync/rutex.h index 11aa9ee4..2ed1ec1b 100644 --- a/libgo/routine_sync/rutex.h +++ b/libgo/routine_sync/rutex.h @@ -19,7 +19,7 @@ struct IntValue { public: inline std::atomic* value() { return ptr_; } - inline void ref(std::atomic* ptr) { ptr_ = ptr; } + inline void ref(std::atomic* ptr) { ptr_ = {ptr}; } protected: std::atomic* ptr_ {nullptr}; @@ -32,7 +32,7 @@ struct IntValue inline std::atomic* value() { return &value_; } protected: - std::atomic value_ {0}; + std::atomic value_ = {0}; }; struct RutexBase