|
20 | 20 | #include "absl/functional/any_invocable.h" |
21 | 21 | #include "absl/log/check.h" |
22 | 22 | #include "absl/status/status.h" |
| 23 | +#include "absl/strings/str_join.h" |
23 | 24 | #include "absl/synchronization/blocking_counter.h" |
24 | 25 | #include "absl/synchronization/mutex.h" |
| 26 | +#include "absl/time/clock.h" |
| 27 | +#include "absl/time/time.h" |
25 | 28 | #include "status/status_macros.h" |
| 29 | +#include "vmsdk/src/log.h" |
26 | 30 | #include "vmsdk/src/module_config.h" |
27 | 31 |
|
28 | 32 | namespace { |
@@ -135,16 +139,44 @@ void ThreadPool::JoinWorkers() { |
135 | 139 | suspend_workers_ = false; |
136 | 140 | } |
137 | 141 |
|
138 | | - threads_.ClearWithCallback( |
139 | | - [](auto thread) { pthread_join(thread->thread_id, nullptr); }); |
| 142 | + // Wait up to 5s for every worker in threads_ to flag itself joinable. |
| 143 | + const absl::Time deadline = absl::Now() + absl::Seconds(5); |
| 144 | + auto count_unjoinable = [this]() { |
| 145 | + return threads_.CountIf( |
| 146 | + [](const std::shared_ptr<Thread> &t) { return !t->IsJoinable(); }); |
| 147 | + }; |
| 148 | + while (count_unjoinable() > 0 && absl::Now() < deadline) { |
| 149 | + absl::SleepFor(absl::Milliseconds(10)); |
| 150 | + } |
| 151 | + if (count_unjoinable() > 0) { |
| 152 | + std::vector<std::string> hung; |
| 153 | + threads_.ForEach([&hung](const std::shared_ptr<Thread> &t) { |
| 154 | + if (!t->IsJoinable()) { |
| 155 | + hung.push_back(t->name); |
| 156 | + } |
| 157 | + }); |
| 158 | + VMSDK_LOG(WARNING, nullptr) |
| 159 | + << "ThreadPool shutdown timed out after 5s waiting for workers to exit;" |
| 160 | + << " hung threads: " << absl::StrJoin(hung, ", "); |
| 161 | + CHECK(false) << "ThreadPool shutdown timeout: " << hung.size() |
| 162 | + << " worker(s) did not become joinable within 5s"; |
| 163 | + } |
140 | 164 | started_ = false; |
141 | | - |
142 | 165 | JoinTerminatedWorkers(); |
| 166 | + CHECK(threads_.IsEmpty()) |
| 167 | + << "threads_ not empty after JoinTerminatedWorkers in JoinWorkers"; |
143 | 168 | } |
144 | 169 |
|
145 | 170 | void ThreadPool::JoinTerminatedWorkers() { |
146 | | - pending_join_threads_.ClearWithCallback( |
147 | | - [](auto thread) { pthread_join(thread->thread_id, nullptr); }); |
| 171 | + while (true) { |
| 172 | + auto thread = threads_.PopIf( |
| 173 | + [](const std::shared_ptr<Thread> &t) { return t->IsJoinable(); }); |
| 174 | + if (!thread.has_value()) { |
| 175 | + break; |
| 176 | + } |
| 177 | + // pthread_join intentionally runs with no ThreadSafeVector mutex held. |
| 178 | + pthread_join((*thread)->thread_id, nullptr); |
| 179 | + } |
148 | 180 | } |
149 | 181 |
|
150 | 182 | absl::Status ThreadPool::SuspendWorkers() { |
@@ -218,20 +250,15 @@ void ThreadPool::WorkerThread(std::shared_ptr<Thread> thread) { |
218 | 250 | while (!condition.Eval()) { |
219 | 251 | condition_.WaitWithTimeout(&queue_mutex_, absl::Seconds(1)); |
220 | 252 | if (thread->IsShutdown()) { |
221 | | - thread->InvokeShutdownCallback(); |
222 | | - // remove this thread from the threads list and place it in the |
223 | | - // pending join list |
224 | | - threads_.PopIf([thread](std::shared_ptr<Thread> t) { |
225 | | - return t->thread_id == thread->thread_id; |
226 | | - }); |
227 | | - pending_join_threads_.Add(thread); |
| 253 | + thread->MarkJoinable(); |
228 | 254 | return; |
229 | 255 | } |
230 | 256 | } |
231 | 257 | if (stop_mode_.has_value() && |
232 | 258 | (stop_mode_.value() == StopMode::kAbrupt || |
233 | 259 | std::all_of(priority_tasks_.begin(), priority_tasks_.end(), |
234 | 260 | [](const auto &tasks) { return tasks.empty(); }))) { |
| 261 | + thread->MarkJoinable(); |
235 | 262 | return; |
236 | 263 | } |
237 | 264 | if (suspend_workers_) { |
@@ -261,33 +288,52 @@ void ThreadPool::IncrThreadCountBy(size_t count) { |
261 | 288 | std::shared_ptr<Thread> thread_ptr = std::make_shared<Thread>(); |
262 | 289 | ThreadRunContext *context = new ThreadRunContext{this, thread_ptr}; |
263 | 290 | pthread_create(&thread_ptr->thread_id, nullptr, RunWorkerThread, context); |
264 | | -#ifndef __APPLE__ |
265 | 291 | size_t thread_num = threads_.Size(); |
266 | | - pthread_setname_np(thread_ptr->thread_id, |
267 | | - (name_prefix_ + std::to_string(thread_num)).c_str()); |
| 292 | + thread_ptr->name = name_prefix_ + std::to_string(thread_num); |
| 293 | +#ifndef __APPLE__ |
| 294 | + pthread_setname_np(thread_ptr->thread_id, thread_ptr->name.c_str()); |
268 | 295 | #endif |
269 | 296 | threads_.Add(thread_ptr); |
270 | 297 | } |
271 | 298 | } |
272 | 299 |
|
273 | 300 | void ThreadPool::DecrThreadCountBy(size_t count, bool sync) { |
274 | | - auto threads = threads_.PopBackMulti(count); |
275 | | - absl::BlockingCounter counter{static_cast<int>(threads.size())}; |
276 | | - for (const auto &thread : threads) { |
277 | | - if (sync) { |
278 | | - thread->Shutdown([&counter]() { |
279 | | - counter.DecrementCount(); |
280 | | - }); // signal the thread to exit |
281 | | - } else { |
282 | | - thread->Shutdown(); |
| 301 | + // Don't pop: leave the targeted threads in threads_ so JoinTerminatedWorkers |
| 302 | + // can reap them via the joinable_flag once they actually exit. We pick the |
| 303 | + // *last* `count` threads that are still active (not already shutdown). |
| 304 | + std::vector<std::shared_ptr<Thread>> targets; |
| 305 | + threads_.ForEach([&targets](const std::shared_ptr<Thread> &t) { |
| 306 | + if (!t->IsShutdown()) { |
| 307 | + targets.push_back(t); |
283 | 308 | } |
| 309 | + }); |
| 310 | + if (targets.size() > count) { |
| 311 | + targets.erase(targets.begin(), targets.end() - count); |
| 312 | + } |
| 313 | + for (const auto &thread : targets) { |
| 314 | + thread->Shutdown(); |
| 315 | + } |
| 316 | + // Wake idle workers so they observe shutdown_flag without waiting 1s. |
| 317 | + { |
| 318 | + absl::MutexLock lock(&queue_mutex_); |
| 319 | + condition_.SignalAll(); |
284 | 320 | } |
285 | | - |
286 | 321 | if (sync) { |
287 | | - counter.Wait(); |
| 322 | + for (const auto &thread : targets) { |
| 323 | + while (!thread->IsJoinable()) { |
| 324 | + absl::SleepFor(absl::Milliseconds(1)); |
| 325 | + } |
| 326 | + } |
| 327 | + JoinTerminatedWorkers(); |
288 | 328 | } |
289 | 329 | } |
290 | 330 |
|
| 331 | +size_t ThreadPool::Size() const { |
| 332 | + return threads_.CountIf([](const std::shared_ptr<Thread> &t) { |
| 333 | + return !t->IsShutdown() && !t->IsJoinable(); |
| 334 | + }); |
| 335 | +} |
| 336 | + |
291 | 337 | void ThreadPool::Resize(size_t count, bool wait_for_resize) { |
292 | 338 | size_t current_size = Size(); |
293 | 339 | if (count == current_size) { |
|
0 commit comments