@@ -259,6 +259,7 @@ class ComponentInstance:
259259 handles : Table [ResourceHandle | Waitable | WaitableSet | ErrorContext ]
260260 threads : Table [Thread ]
261261 may_leave : bool
262+ may_block : bool
262263 backpressure : int
263264 exclusive : bool
264265 num_waiting_to_enter : int
@@ -270,6 +271,7 @@ def __init__(self, store, parent = None):
270271 self .handles = Table ()
271272 self .threads = Table ()
272273 self .may_leave = True
274+ self .may_block = True
273275 self .backpressure = 0
274276 self .exclusive = False
275277 self .num_waiting_to_enter = 0
@@ -485,7 +487,7 @@ def resume(self, cancelled):
485487 cancelled = Cancelled .FALSE
486488
487489 def suspend (self , cancellable ) -> Cancelled :
488- assert (self .running () and self .task .may_block () )
490+ assert (self .running () and self .task .inst . may_block )
489491 if self .task .deliver_pending_cancel (cancellable ):
490492 return Cancelled .TRUE
491493 self .cancellable = cancellable
@@ -495,7 +497,7 @@ def suspend(self, cancellable) -> Cancelled:
495497 return cancelled
496498
497499 def wait_until (self , ready_func , cancellable = False ) -> Cancelled :
498- assert (self .running () and self .task .may_block () )
500+ assert (self .running () and self .task .inst . may_block )
499501 if self .task .deliver_pending_cancel (cancellable ):
500502 return Cancelled .TRUE
501503 if ready_func () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
@@ -506,7 +508,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
506508
507509 def yield_until (self , ready_func , cancellable ) -> Cancelled :
508510 assert (self .running ())
509- if self .task .may_block () :
511+ if self .task .inst . may_block :
510512 return self .wait_until (ready_func , cancellable )
511513 else :
512514 assert (ready_func ())
@@ -661,12 +663,14 @@ def thread_stop(self, thread):
661663 def needs_exclusive (self ):
662664 return not self .opts .async_ or self .opts .callback
663665
664- def may_block (self ):
665- return self .ft .async_ or self .state == Task .State .RESOLVED
666-
667666 def enter (self , thread ):
668667 assert (thread in self .threads and thread .task is self )
669668 if not self .ft .async_ :
669+ # TODO: what makes this true? where is the assert or trap_if? specifically
670+ # for sibling reentrance. maybe need to add back may_enter and only clear
671+ # at cooperative yield points (based on type)
672+ assert (self .inst .may_block )
673+ self .inst .may_block = False
670674 return True
671675 def has_backpressure ():
672676 return self .inst .backpressure > 0 or (self .needs_exclusive () and self .inst .exclusive )
@@ -685,6 +689,7 @@ def has_backpressure():
685689 def exit (self ):
686690 assert (len (self .threads ) > 0 )
687691 if not self .ft .async_ :
692+ assert (self .inst .may_block )
688693 return
689694 if self .needs_exclusive ():
690695 assert (self .inst .exclusive )
@@ -709,13 +714,17 @@ def deliver_pending_cancel(self, cancellable) -> bool:
709714 def return_ (self , result ):
710715 trap_if (self .state == Task .State .RESOLVED )
711716 trap_if (self .num_borrows > 0 )
717+ if not self .ft .async_ :
718+ assert (not self .inst .may_block )
719+ self .inst .may_block = True
712720 assert (result is not None )
713721 self .on_resolve (result )
714722 self .state = Task .State .RESOLVED
715723
716724 def cancel (self ):
717725 trap_if (self .state != Task .State .CANCEL_DELIVERED )
718726 trap_if (self .num_borrows > 0 )
727+ assert (self .ft .async_ )
719728 self .on_resolve (None )
720729 self .state = Task .State .RESOLVED
721730
@@ -2042,7 +2051,7 @@ def thread_func(thread):
20422051 else :
20432052 event = (EventCode .NONE , 0 , 0 )
20442053 case CallbackCode .WAIT :
2045- trap_if (not task .may_block () )
2054+ trap_if (not inst .may_block )
20462055 wset = inst .handles .get (si )
20472056 trap_if (not isinstance (wset , WaitableSet ))
20482057 event = wset .wait_until (lambda : not inst .exclusive , thread , cancellable = True )
@@ -2058,6 +2067,7 @@ def thread_func(thread):
20582067
20592068 thread = Thread (task , thread_func )
20602069 thread .resume (Cancelled .FALSE )
2070+ assert (ft .async_ or task .state == Task .State .RESOLVED )
20612071 return task
20622072
20632073class CallbackCode (IntEnum ):
@@ -2084,7 +2094,7 @@ def call_and_trap_on_throw(callee, thread, args):
20842094
20852095def canon_lower (opts , ft , callee : FuncInst , thread , flat_args ):
20862096 trap_if (not thread .task .inst .may_leave )
2087- trap_if (not thread .task .may_block () and ft .async_ and not opts .async_ )
2097+ trap_if (not thread .task .inst . may_block and ft .async_ and not opts .async_ )
20882098
20892099 subtask = Subtask ()
20902100 cx = LiftLowerContext (opts , thread .task .inst , subtask )
@@ -2261,7 +2271,7 @@ def canon_waitable_set_new(thread):
22612271
22622272def canon_waitable_set_wait (cancellable , mem , thread , si , ptr ):
22632273 trap_if (not thread .task .inst .may_leave )
2264- trap_if (not thread .task .may_block () )
2274+ trap_if (not thread .task .inst . may_block )
22652275 wset = thread .task .inst .handles .get (si )
22662276 trap_if (not isinstance (wset , WaitableSet ))
22672277 event = wset .wait (thread , cancellable )
@@ -2312,7 +2322,7 @@ def canon_waitable_join(thread, wi, si):
23122322
23132323def canon_subtask_cancel (async_ , thread , i ):
23142324 trap_if (not thread .task .inst .may_leave )
2315- trap_if (not thread .task .may_block () and not async_ )
2325+ trap_if (not thread .task .inst . may_block and not async_ )
23162326 subtask = thread .task .inst .handles .get (i )
23172327 trap_if (not isinstance (subtask , Subtask ))
23182328 trap_if (subtask .resolve_delivered ())
@@ -2369,7 +2379,7 @@ def canon_stream_write(stream_t, opts, thread, i, ptr, n):
23692379
23702380def stream_copy (EndT , BufferT , event_code , stream_t , opts , thread , i , ptr , n ):
23712381 trap_if (not thread .task .inst .may_leave )
2372- trap_if (not thread .task .may_block () and not opts .async_ )
2382+ trap_if (not thread .task .inst . may_block and not opts .async_ )
23732383
23742384 e = thread .task .inst .handles .get (i )
23752385 trap_if (not isinstance (e , EndT ))
@@ -2423,7 +2433,7 @@ def canon_future_write(future_t, opts, thread, i, ptr):
24232433
24242434def future_copy (EndT , BufferT , event_code , future_t , opts , thread , i , ptr ):
24252435 trap_if (not thread .task .inst .may_leave )
2426- trap_if (not thread .task .may_block () and not opts .async_ )
2436+ trap_if (not thread .task .inst . may_block and not opts .async_ )
24272437
24282438 e = thread .task .inst .handles .get (i )
24292439 trap_if (not isinstance (e , EndT ))
@@ -2475,7 +2485,7 @@ def canon_future_cancel_write(future_t, async_, thread, i):
24752485
24762486def cancel_copy (EndT , event_code , stream_or_future_t , async_ , thread , i ):
24772487 trap_if (not thread .task .inst .may_leave )
2478- trap_if (not thread .task .may_block () and not async_ )
2488+ trap_if (not thread .task .inst . may_block and not async_ )
24792489 e = thread .task .inst .handles .get (i )
24802490 trap_if (not isinstance (e , EndT ))
24812491 trap_if (e .shared .t != stream_or_future_t .t )
@@ -2552,7 +2562,7 @@ def canon_thread_resume_later(thread, i):
25522562
25532563def canon_thread_suspend (cancellable , thread ):
25542564 trap_if (not thread .task .inst .may_leave )
2555- trap_if (not thread .task .may_block () )
2565+ trap_if (not thread .task .inst . may_block )
25562566 cancelled = thread .suspend (cancellable )
25572567 return [cancelled ]
25582568
0 commit comments