@@ -26,8 +26,8 @@ namespace math {
2626 * - tid=0 runs on the caller thread
2727 * - tid=1..n-1 run on persistent worker threads
2828 *
29- * This is optimized for repeated short parallel regions (like reduce_sum/map_rect),
30- * avoiding task-queue overhead.
29+ * This is optimized for repeated short parallel regions (like
30+ * reduce_sum/map_rect), avoiding task-queue overhead.
3131 *
3232 * This version uses a single atomic wake generation counter (wake_gen_) and
3333 * removes the older "epoch" concept entirely.
@@ -42,9 +42,11 @@ class TeamThreadPool {
4242 return pool;
4343 }
4444
45- /* * Set total participants INCLUDING caller (tid=0). Call before instance(). */
45+ /* * Set total participants INCLUDING caller (tid=0). Call before instance().
46+ */
4647 static void set_num_threads (int n) {
47- if (n < 1 ) n = 1 ;
48+ if (n < 1 )
49+ n = 1 ;
4850 configured_threads_.store (n, std::memory_order_release);
4951 }
5052
@@ -54,26 +56,24 @@ class TeamThreadPool {
5456 }
5557
5658 /* * Total participants INCLUDING caller (tid=0). */
57- std::size_t team_size () const noexcept {
58- return workers_.size () + 1 ;
59- }
59+ std::size_t team_size () const noexcept { return workers_.size () + 1 ; }
6060
6161 /* * Number of worker threads (excludes caller). */
62- std::size_t worker_count () const noexcept {
63- return workers_.size ();
64- }
62+ std::size_t worker_count () const noexcept { return workers_.size (); }
6563
6664 template <typename F>
6765 void parallel_region (std::size_t n, F&& fn) {
68- if (n == 0 ) return ;
66+ if (n == 0 )
67+ return ;
6968
7069 // Clamp to actual team size
7170 const std::size_t max_team = team_size ();
7271 if (max_team <= 1 ) {
7372 fn (std::size_t {0 });
7473 return ;
7574 }
76- if (n > max_team) n = max_team;
75+ if (n > max_team)
76+ n = max_team;
7777 if (n <= 1 ) {
7878 fn (std::size_t {0 });
7979 return ;
@@ -87,10 +87,12 @@ class TeamThreadPool {
8787 try {
8888 fn (tid);
8989 } catch (...) {
90- if (!ep) ep = std::current_exception ();
90+ if (!ep)
91+ ep = std::current_exception ();
9192 }
9293 }
93- if (ep) std::rethrow_exception (ep);
94+ if (ep)
95+ std::rethrow_exception (ep);
9496 return ;
9597 }
9698
@@ -126,16 +128,16 @@ class TeamThreadPool {
126128 fn_copy (0 );
127129 } catch (...) {
128130 std::lock_guard<std::mutex> lk (exc_m_);
129- if (!eptr) eptr = std::current_exception ();
131+ if (!eptr)
132+ eptr = std::current_exception ();
130133 }
131134 in_worker_ = false ;
132135
133136 // Wait for workers 1..n-1
134137 {
135138 std::unique_lock<std::mutex> lk (done_m_);
136- done_cv_.wait (lk, [&] {
137- return remaining_.load (std::memory_order_acquire) == 0 ;
138- });
139+ done_cv_.wait (
140+ lk, [&] { return remaining_.load (std::memory_order_acquire) == 0 ; });
139141 }
140142
141143 // Hygiene: deactivate region state
@@ -149,7 +151,8 @@ class TeamThreadPool {
149151 exc_ptr_ = nullptr ;
150152 }
151153
152- if (eptr) std::rethrow_exception (eptr);
154+ if (eptr)
155+ std::rethrow_exception (eptr);
153156 }
154157
155158 static bool in_worker_thread () noexcept { return in_worker_; }
@@ -165,8 +168,10 @@ class TeamThreadPool {
165168 static std::size_t configured_cap_ (std::size_t hw) {
166169 int cfg = configured_threads_.load (std::memory_order_acquire);
167170 std::size_t cap = (cfg > 0 ) ? static_cast <std::size_t >(cfg) : hw;
168- if (cap < 1 ) cap = 1 ;
169- if (cap > hw) cap = hw; // don't exceed hardware threads by default
171+ if (cap < 1 )
172+ cap = 1 ;
173+ if (cap > hw)
174+ cap = hw; // don't exceed hardware threads by default
170175 return cap;
171176 }
172177
@@ -180,7 +185,8 @@ class TeamThreadPool {
180185 exc_ptr_(nullptr ),
181186 ready_count_(0 ) {
182187 unsigned hw_u = std::thread::hardware_concurrency ();
183- if (hw_u == 0 ) hw_u = 2 ;
188+ if (hw_u == 0 )
189+ hw_u = 2 ;
184190 const std::size_t hw = static_cast <std::size_t >(hw_u);
185191
186192 const std::size_t cap = configured_cap_ (hw);
@@ -200,7 +206,8 @@ class TeamThreadPool {
200206
201207 in_worker_ = true ;
202208
203- // Startup barrier: ensure each worker reached the wait loop at least once.
209+ // Startup barrier: ensure each worker reached the wait loop at least
210+ // once.
204211 {
205212 std::lock_guard<std::mutex> lk (wake_m_);
206213 ready_count_.fetch_add (1 , std::memory_order_release);
@@ -216,11 +223,13 @@ class TeamThreadPool {
216223 std::unique_lock<std::mutex> lk (wake_m_);
217224 wake_cv_.wait (lk, [&] {
218225 return stop_.load (std::memory_order_acquire)
219- || wake_gen_.load (std::memory_order_acquire) != seen_gen;
226+ || wake_gen_.load (std::memory_order_acquire) != seen_gen;
220227 });
221- if (stop_.load (std::memory_order_acquire)) break ;
228+ if (stop_.load (std::memory_order_acquire))
229+ break ;
222230
223- // IMPORTANT: update while holding wake_m_ so we can't miss rapid increments
231+ // IMPORTANT: update while holding wake_m_ so we can't miss rapid
232+ // increments
224233 seen_gen = wake_gen_.load (std::memory_order_acquire);
225234 }
226235
@@ -281,7 +290,8 @@ class TeamThreadPool {
281290 }
282291 wake_cv_.notify_all ();
283292 for (auto & th : workers_) {
284- if (th.joinable ()) th.join ();
293+ if (th.joinable ())
294+ th.join ();
285295 }
286296 }
287297
0 commit comments