@@ -119,15 +119,20 @@ final_suspend_continue(Context *context, frame_type<Context> *parent) noexcept -
119119 // the stack so we must prepare it for release now.
120120 auto release_key = context->allocator ().prepare_release ();
121121
122+ // TODO: we could add an `if (owner)` around acquire below, then we could
123+ // define that acquire is always called with null or not-self.
124+
122125 // Register with parent we have completed this child task.
123126 if (parent->atomic_joins ().fetch_sub (1 , std::memory_order_release) == 1 ) {
124127 // Parent has reached join and we are the last child task to complete. We
125128 // are the exclusive owner of the parent and therefore, we must continue
126129 // parent. As we won the race, acquire all writes before resuming.
127130 std::atomic_thread_fence (std::memory_order_acquire);
128131
129- // In case of scenario (2) we must acquire the parent's stack.
130- context->allocator ().acquire (std::as_const (parent->stack_ckpt ));
132+ if (!owner) {
133+ // In case of scenario (2) we must acquire the parent's stack.
134+ context->allocator ().acquire (std::as_const (parent->stack_ckpt ));
135+ }
131136
132137 // Must reset parent's control block before resuming parent.
133138 parent->reset_counters ();
@@ -192,11 +197,11 @@ constexpr auto final_suspend(frame_type<Context> *frame) noexcept -> coro<> {
192197 return parent->handle ();
193198 case category::fork:
194199
195- Context *context = not_null (thread_context <Context>);
200+ Context *context = get_context <Context>( );
196201
197202 if (frame_handle last_pushed = context->pop ()) {
198203 // No-one stole continuation, we are the exclusive owner of parent -> just keep ripping!
199- LF_ASSUME (last_pushed == frame_handle{key, parent});
204+ LF_ASSUME (last_pushed == frame_handle{key () , parent});
200205 // This is not a join point so no state (i.e. counters) is guaranteed.
201206 return parent->handle ();
202207 }
@@ -243,6 +248,8 @@ constexpr void stash_current_exception(frame_type<Context> *frame) noexcept {
243248template <category Cat, worker_context Context>
244249struct awaitable : std::suspend_always {
245250
251+ static_assert (Cat == category::call || Cat == category::fork, " Invalid category for awaitable" );
252+
246253 frame_type<Context> *child;
247254
248255 /* *
@@ -274,16 +281,21 @@ struct awaitable : std::suspend_always {
274281 // Propagate parent->child relationships
275282 self.child ->parent .frame = &parent.promise ().frame ;
276283 self.child ->cancel = parent.promise ().frame .cancel ;
277- self.child ->stack_ckpt = not_null (thread_context<Context>)->allocator ().checkpoint ();
278- self.child ->kind = Cat;
284+
285+ if constexpr (Cat == category::call) {
286+ // Should be the default
287+ LF_ASSUME (self.child ->kind == category::call);
288+ } else {
289+ self.child ->kind = Cat;
290+ }
279291
280292 if constexpr (Cat == category::fork) {
281293 // It is critical to pass self by-value here, after the call to push()
282294 // the object `*this` may be destroyed, if passing by ref it would be
283295 // use-after-free to then access self in the following line to fetch the
284296 // handle.
285297 LF_TRY {
286- not_null (thread_context <Context>)->push (frame_handle{key, &parent.promise ().frame });
298+ get_context <Context>( )->push (frame_handle{key () , &parent.promise ().frame });
287299 } LF_CATCH_ALL {
288300 return self.stash_and_resume (parent), parent;
289301 }
@@ -303,7 +315,7 @@ struct join_awaitable {
303315 frame_type<Context> *frame;
304316
305317 constexpr auto take_stack_and_reset (this join_awaitable self) noexcept -> void {
306- Context *context = not_null (thread_context <Context>);
318+ Context *context = get_context <Context>( );
307319 LF_ASSUME (self.frame ->stack_ckpt != context->allocator ().checkpoint ());
308320 context->allocator ().acquire (std::as_const (self.frame ->stack_ckpt ));
309321 self.frame ->reset_counters ();
@@ -437,14 +449,12 @@ struct mixin_frame {
437449 // --- Allocation
438450
439451 static auto operator new(std::size_t sz) -> void * {
440- void *ptr = not_null (thread_context <Ctx>)-> allocator ().push (sz);
452+ void *ptr = get_allocator <Ctx>().push (sz);
441453 LF_ASSUME (is_aligned<k_new_align>(ptr));
442454 return std::assume_aligned<k_new_align>(ptr);
443455 }
444456
445- static auto operator delete (void *p, std::size_t sz) noexcept -> void {
446- not_null (thread_context<Ctx>)->allocator ().pop (p, sz);
447- }
457+ static auto operator delete (void *p, std::size_t sz) noexcept -> void { get_allocator<Ctx>().pop (p, sz); }
448458
449459 // --- Await transformations
450460
@@ -506,7 +516,10 @@ struct mixin_frame {
506516template <worker_context Context>
507517struct promise_type <void , Context> : mixin_frame<Context> {
508518
509- frame_type<Context> frame;
519+ // Putting init here allows:
520+ // 1. Frame not no need to know about the checkpoint type
521+ // 2. Compiler merge double read of thread local here and in allocator
522+ frame_type<Context> frame{get_allocator<Context>().checkpoint ()};
510523
511524 constexpr auto get_return_object () noexcept -> task<void> { return access::task (this ); }
512525
@@ -518,7 +531,10 @@ struct promise_type<void, Context> : mixin_frame<Context> {
518531template <returnable T, worker_context Context>
519532struct promise_type : mixin_frame<Context> {
520533
521- frame_type<Context> frame;
534+ // Putting init here allows:
535+ // 1. Frame not no need to know about the checkpoint type
536+ // 2. Compiler merge double read of thread local here and in allocator
537+ frame_type<Context> frame{get_allocator<Context>().checkpoint ()};
522538 T *return_address;
523539
524540 constexpr auto get_return_object () noexcept -> task<T> { return access::task (this ); }
0 commit comments