@@ -259,6 +259,7 @@ class ComponentInstance:
259259 handles : Table [ResourceHandle | Waitable | WaitableSet | ErrorContext ]
260260 threads : Table [Thread ]
261261 may_leave : bool
262+ may_block : bool
262263 backpressure : int
263264 exclusive : bool
264265 num_waiting_to_enter : int
@@ -270,6 +271,7 @@ def __init__(self, store, parent = None):
270271 self .handles = Table ()
271272 self .threads = Table ()
272273 self .may_leave = True
274+ self .may_block = True
273275 self .backpressure = 0
274276 self .exclusive = False
275277 self .num_waiting_to_enter = 0
@@ -489,7 +491,7 @@ def resume(self, cancelled):
489491 cancelled = Cancelled .FALSE
490492
491493 def suspend (self , cancellable ) -> Cancelled :
492- assert (self .running () and self .task .may_block () )
494+ assert (self .running () and self .task .inst . may_block )
493495 if self .task .deliver_pending_cancel (cancellable ):
494496 return Cancelled .TRUE
495497 self .cancellable = cancellable
@@ -502,7 +504,7 @@ def suspend(self, cancellable) -> Cancelled:
502504 return cancelled
503505
504506 def wait_until (self , ready_func , cancellable = False ) -> Cancelled :
505- assert (self .running () and self .task .may_block () )
507+ assert (self .running () and self .task .inst . may_block )
506508 if self .task .deliver_pending_cancel (cancellable ):
507509 return Cancelled .TRUE
508510 if ready_func () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
@@ -513,7 +515,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
513515
514516 def yield_until (self , ready_func , cancellable ) -> Cancelled :
515517 assert (self .running ())
516- if self .task .may_block () :
518+ if self .task .inst . may_block :
517519 return self .wait_until (ready_func , cancellable )
518520 else :
519521 assert (ready_func ())
@@ -671,12 +673,14 @@ def thread_stop(self, thread):
671673 def needs_exclusive (self ):
672674 return not self .opts .async_ or self .opts .callback
673675
674- def may_block (self ):
675- return self .ft .async_ or self .state == Task .State .RESOLVED
676-
677676 def enter (self , thread ):
678677 assert (thread in self .threads and thread .task is self )
679678 if not self .ft .async_ :
679+ # TODO: what makes this true? where is the assert or trap_if? specifically
680+ # for sibling reentrance. maybe need to add back may_enter and only clear
681+ # at cooperative yield points (based on type)
682+ assert (self .inst .may_block )
683+ self .inst .may_block = False
680684 return True
681685 def has_backpressure ():
682686 return self .inst .backpressure > 0 or (self .needs_exclusive () and self .inst .exclusive )
@@ -695,6 +699,7 @@ def has_backpressure():
695699 def exit (self ):
696700 assert (len (self .threads ) > 0 )
697701 if not self .ft .async_ :
702+ assert (self .inst .may_block )
698703 return
699704 if self .needs_exclusive ():
700705 assert (self .inst .exclusive )
@@ -719,13 +724,17 @@ def deliver_pending_cancel(self, cancellable) -> bool:
719724 def return_ (self , result ):
720725 trap_if (self .state == Task .State .RESOLVED )
721726 trap_if (self .num_borrows > 0 )
727+ if not self .ft .async_ :
728+ assert (not self .inst .may_block )
729+ self .inst .may_block = True
722730 assert (result is not None )
723731 self .on_resolve (result )
724732 self .state = Task .State .RESOLVED
725733
726734 def cancel (self ):
727735 trap_if (self .state != Task .State .CANCEL_DELIVERED )
728736 trap_if (self .num_borrows > 0 )
737+ assert (self .ft .async_ )
729738 self .on_resolve (None )
730739 self .state = Task .State .RESOLVED
731740
@@ -2052,7 +2061,7 @@ def thread_func(thread):
20522061 else :
20532062 event = (EventCode .NONE , 0 , 0 )
20542063 case CallbackCode .WAIT :
2055- trap_if (not task .may_block () )
2064+ trap_if (not inst .may_block )
20562065 wset = inst .handles .get (si )
20572066 trap_if (not isinstance (wset , WaitableSet ))
20582067 event = wset .wait_until (lambda : not inst .exclusive , thread , cancellable = True )
@@ -2068,6 +2077,7 @@ def thread_func(thread):
20682077
20692078 thread = Thread (task , thread_func )
20702079 thread .resume (Cancelled .FALSE )
2080+ assert (ft .async_ or task .state == Task .State .RESOLVED )
20712081 return task
20722082
20732083class CallbackCode (IntEnum ):
@@ -2094,7 +2104,7 @@ def call_and_trap_on_throw(callee, thread, args):
20942104
20952105def canon_lower (opts , ft , callee : FuncInst , thread , flat_args ):
20962106 trap_if (not thread .task .inst .may_leave )
2097- trap_if (not thread .task .may_block () and ft .async_ and not opts .async_ )
2107+ trap_if (not thread .task .inst . may_block and ft .async_ and not opts .async_ )
20982108
20992109 subtask = Subtask ()
21002110 cx = LiftLowerContext (opts , thread .task .inst , subtask )
@@ -2271,7 +2281,7 @@ def canon_waitable_set_new(thread):
22712281
22722282def canon_waitable_set_wait (cancellable , mem , thread , si , ptr ):
22732283 trap_if (not thread .task .inst .may_leave )
2274- trap_if (not thread .task .may_block () )
2284+ trap_if (not thread .task .inst . may_block )
22752285 wset = thread .task .inst .handles .get (si )
22762286 trap_if (not isinstance (wset , WaitableSet ))
22772287 event = wset .wait (thread , cancellable )
@@ -2322,7 +2332,7 @@ def canon_waitable_join(thread, wi, si):
23222332
23232333def canon_subtask_cancel (async_ , thread , i ):
23242334 trap_if (not thread .task .inst .may_leave )
2325- trap_if (not thread .task .may_block () and not async_ )
2335+ trap_if (not thread .task .inst . may_block and not async_ )
23262336 subtask = thread .task .inst .handles .get (i )
23272337 trap_if (not isinstance (subtask , Subtask ))
23282338 trap_if (subtask .resolve_delivered ())
@@ -2379,7 +2389,7 @@ def canon_stream_write(stream_t, opts, thread, i, ptr, n):
23792389
23802390def stream_copy (EndT , BufferT , event_code , stream_t , opts , thread , i , ptr , n ):
23812391 trap_if (not thread .task .inst .may_leave )
2382- trap_if (not thread .task .may_block () and not opts .async_ )
2392+ trap_if (not thread .task .inst . may_block and not opts .async_ )
23832393
23842394 e = thread .task .inst .handles .get (i )
23852395 trap_if (not isinstance (e , EndT ))
@@ -2433,7 +2443,7 @@ def canon_future_write(future_t, opts, thread, i, ptr):
24332443
24342444def future_copy (EndT , BufferT , event_code , future_t , opts , thread , i , ptr ):
24352445 trap_if (not thread .task .inst .may_leave )
2436- trap_if (not thread .task .may_block () and not opts .async_ )
2446+ trap_if (not thread .task .inst . may_block and not opts .async_ )
24372447
24382448 e = thread .task .inst .handles .get (i )
24392449 trap_if (not isinstance (e , EndT ))
@@ -2485,7 +2495,7 @@ def canon_future_cancel_write(future_t, async_, thread, i):
24852495
24862496def cancel_copy (EndT , event_code , stream_or_future_t , async_ , thread , i ):
24872497 trap_if (not thread .task .inst .may_leave )
2488- trap_if (not thread .task .may_block () and not async_ )
2498+ trap_if (not thread .task .inst . may_block and not async_ )
24892499 e = thread .task .inst .handles .get (i )
24902500 trap_if (not isinstance (e , EndT ))
24912501 trap_if (e .shared .t != stream_or_future_t .t )
@@ -2562,7 +2572,7 @@ def canon_thread_resume_later(thread, i):
25622572
25632573def canon_thread_suspend (cancellable , thread ):
25642574 trap_if (not thread .task .inst .may_leave )
2565- trap_if (not thread .task .may_block () )
2575+ trap_if (not thread .task .inst . may_block )
25662576 cancelled = thread .suspend (cancellable )
25672577 return [cancelled ]
25682578
0 commit comments