@@ -198,7 +198,7 @@ def invoke(self, f: FuncInst, caller: Optional[Supertask], on_start, on_resolve)
198198 def tick (self ):
199199 random .shuffle (self .waiting )
200200 for thread in self .waiting :
201- if thread .ready ():
201+ if thread .ready () and thread . task . inst . may_enter :
202202 thread .resume (Cancelled .FALSE )
203203 return
204204
@@ -285,7 +285,9 @@ class ComponentInstance:
285285 parent : Optional [ComponentInstance ]
286286 handles : Table [ResourceHandle | Waitable | WaitableSet | ErrorContext ]
287287 threads : Table [Thread ]
288+ may_enter : bool
288289 may_leave : bool
290+ may_block : bool
289291 backpressure : int
290292 exclusive : Optional [Task ]
291293 num_waiting_to_enter : int
@@ -296,7 +298,9 @@ def __init__(self, store, parent = None):
296298 self .parent = parent
297299 self .handles = Table ()
298300 self .threads = Table ()
301+ self .may_enter = True
299302 self .may_leave = True
303+ self .may_block = True
300304 self .backpressure = 0
301305 self .exclusive = None
302306 self .num_waiting_to_enter = 0
@@ -489,7 +493,8 @@ def resume_later(self):
489493
490494 def resume (self , cancelled ):
491495 assert (self .cancellable or not cancelled )
492- assert (not self .running ())
496+ assert (not self .running () and self .task .inst .may_enter )
497+ self .task .inst .may_enter = False
493498 if self .waiting ():
494499 assert (cancelled or self .ready ())
495500 self .ready_func = None
@@ -506,9 +511,11 @@ def resume(self, cancelled):
506511 break
507512 thread = switch_to_thread
508513 cancelled = Cancelled .FALSE
514+ assert (not self .task .inst .may_enter )
515+ self .task .inst .may_enter = True
509516
510517 def suspend (self , cancellable ) -> Cancelled :
511- assert (self .running () and self .task .may_block () )
518+ assert (self .running () and self .task .inst . may_block )
512519 if self .task .deliver_pending_cancel (cancellable ):
513520 return Cancelled .TRUE
514521 self .cancellable = cancellable
@@ -517,7 +524,7 @@ def suspend(self, cancellable) -> Cancelled:
517524 return cancelled
518525
519526 def wait_until (self , ready_func , cancellable = False ) -> Cancelled :
520- assert (self .running () and self .task .may_block () )
527+ assert (self .running () and self .task .inst . may_block )
521528 if self .task .deliver_pending_cancel (cancellable ):
522529 return Cancelled .TRUE
523530 if ready_func () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
@@ -528,7 +535,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
528535
529536 def yield_until (self , ready_func , cancellable ) -> Cancelled :
530537 assert (self .running ())
531- if self .task .may_block () :
538+ if self .task .inst . may_block :
532539 return self .wait_until (ready_func , cancellable )
533540 else :
534541 assert (ready_func ())
@@ -683,12 +690,11 @@ def thread_stop(self, thread):
683690 def needs_exclusive (self ):
684691 return not self .opts .async_ or self .opts .callback
685692
686- def may_block (self ):
687- return self .ft .async_ or self .state == Task .State .RESOLVED
688-
689693 def enter (self , thread ):
690694 assert (thread in self .threads and thread .task is self )
691695 if not self .ft .async_ :
696+ assert (self .inst .may_block )
697+ self .inst .may_block = False
692698 return True
693699 def has_backpressure ():
694700 return self .inst .backpressure > 0 or (self .needs_exclusive () and bool (self .inst .exclusive ))
@@ -739,13 +745,17 @@ def deliver_pending_cancel(self, cancellable) -> bool:
739745 def return_ (self , result ):
740746 trap_if (self .state == Task .State .RESOLVED )
741747 trap_if (self .num_borrows > 0 )
748+ if not self .ft .async_ :
749+ assert (not self .inst .may_block )
750+ self .inst .may_block = True
742751 assert (result is not None )
743752 self .on_resolve (result )
744753 self .state = Task .State .RESOLVED
745754
746755 def cancel (self ):
747756 trap_if (self .state != Task .State .CANCEL_DELIVERED )
748757 trap_if (self .num_borrows > 0 )
758+ assert (self .ft .async_ )
749759 self .on_resolve (None )
750760 self .state = Task .State .RESOLVED
751761
@@ -2038,6 +2048,8 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
20382048
20392049def canon_lift (opts , inst , ft , callee , caller , on_start , on_resolve ) -> Call :
20402050 trap_if (call_might_be_recursive (caller , inst ))
2051+ assert (inst .may_enter )
2052+
20412053 task = Task (opts , inst , ft , caller , on_resolve )
20422054 def thread_func (thread ):
20432055 if not task .enter (thread ):
@@ -2082,7 +2094,7 @@ def thread_func(thread):
20822094 else :
20832095 event = (EventCode .NONE , 0 , 0 )
20842096 case CallbackCode .WAIT :
2085- trap_if (not task .may_block () )
2097+ trap_if (not inst .may_block )
20862098 wset = inst .handles .get (si )
20872099 trap_if (not isinstance (wset , WaitableSet ))
20882100 event = wset .wait_until (lambda : not inst .exclusive , thread , cancellable = True )
@@ -2098,6 +2110,7 @@ def thread_func(thread):
20982110
20992111 thread = Thread (task , thread_func )
21002112 thread .resume (Cancelled .FALSE )
2113+ assert (ft .async_ or task .state == Task .State .RESOLVED )
21012114 return task
21022115
21032116class CallbackCode (IntEnum ):
@@ -2124,7 +2137,7 @@ def call_and_trap_on_throw(callee, thread, args):
21242137
21252138def canon_lower (opts , ft , callee : FuncInst , thread , flat_args ):
21262139 trap_if (not thread .task .inst .may_leave )
2127- trap_if (not thread .task .may_block () and ft .async_ and not opts .async_ )
2140+ trap_if (not thread .task .inst . may_block and ft .async_ and not opts .async_ )
21282141
21292142 subtask = Subtask ()
21302143 cx = LiftLowerContext (opts , thread .task .inst , subtask )
@@ -2304,7 +2317,7 @@ def canon_waitable_set_new(thread):
23042317
23052318def canon_waitable_set_wait (cancellable , mem , thread , si , ptr ):
23062319 trap_if (not thread .task .inst .may_leave )
2307- trap_if (not thread .task .may_block () )
2320+ trap_if (not thread .task .inst . may_block )
23082321 wset = thread .task .inst .handles .get (si )
23092322 trap_if (not isinstance (wset , WaitableSet ))
23102323 event = wset .wait (thread , cancellable )
@@ -2355,7 +2368,7 @@ def canon_waitable_join(thread, wi, si):
23552368
23562369def canon_subtask_cancel (async_ , thread , i ):
23572370 trap_if (not thread .task .inst .may_leave )
2358- trap_if (not thread .task .may_block () and not async_ )
2371+ trap_if (not thread .task .inst . may_block and not async_ )
23592372 subtask = thread .task .inst .handles .get (i )
23602373 trap_if (not isinstance (subtask , Subtask ))
23612374 trap_if (subtask .resolve_delivered ())
@@ -2412,7 +2425,7 @@ def canon_stream_write(stream_t, opts, thread, i, ptr, n):
24122425
24132426def stream_copy (EndT , BufferT , event_code , stream_t , opts , thread , i , ptr , n ):
24142427 trap_if (not thread .task .inst .may_leave )
2415- trap_if (not thread .task .may_block () and not opts .async_ )
2428+ trap_if (not thread .task .inst . may_block and not opts .async_ )
24162429
24172430 e = thread .task .inst .handles .get (i )
24182431 trap_if (not isinstance (e , EndT ))
@@ -2466,7 +2479,7 @@ def canon_future_write(future_t, opts, thread, i, ptr):
24662479
24672480def future_copy (EndT , BufferT , event_code , future_t , opts , thread , i , ptr ):
24682481 trap_if (not thread .task .inst .may_leave )
2469- trap_if (not thread .task .may_block () and not opts .async_ )
2482+ trap_if (not thread .task .inst . may_block and not opts .async_ )
24702483
24712484 e = thread .task .inst .handles .get (i )
24722485 trap_if (not isinstance (e , EndT ))
@@ -2518,7 +2531,7 @@ def canon_future_cancel_write(future_t, async_, thread, i):
25182531
25192532def cancel_copy (EndT , event_code , stream_or_future_t , async_ , thread , i ):
25202533 trap_if (not thread .task .inst .may_leave )
2521- trap_if (not thread .task .may_block () and not async_ )
2534+ trap_if (not thread .task .inst . may_block and not async_ )
25222535 e = thread .task .inst .handles .get (i )
25232536 trap_if (not isinstance (e , EndT ))
25242537 trap_if (e .shared .t != stream_or_future_t .t )
@@ -2595,7 +2608,7 @@ def canon_thread_switch_to(cancellable, thread, i):
25952608
25962609def canon_thread_suspend (cancellable , thread ):
25972610 trap_if (not thread .task .inst .may_leave )
2598- trap_if (not thread .task .may_block () )
2611+ trap_if (not thread .task .inst . may_block )
25992612 cancelled = thread .suspend (cancellable )
26002613 return [cancelled ]
26012614
0 commit comments