@@ -259,6 +259,7 @@ class ComponentInstance:
259259 handles : Table [ResourceHandle | Waitable | WaitableSet | ErrorContext ]
260260 threads : Table [Thread ]
261261 may_leave : bool
262+ may_block : bool
262263 backpressure : int
263264 exclusive : bool
264265 num_waiting_to_enter : int
@@ -270,6 +271,7 @@ def __init__(self, store, parent = None):
270271 self .handles = Table ()
271272 self .threads = Table ()
272273 self .may_leave = True
274+ self .may_block = True
273275 self .backpressure = 0
274276 self .exclusive = False
275277 self .num_waiting_to_enter = 0
@@ -490,7 +492,7 @@ def resume(self, cancelled):
490492 cancelled = Cancelled .FALSE
491493
492494 def suspend (self , cancellable ) -> Cancelled :
493- assert (self .running () and self .task .may_block () )
495+ assert (self .running () and self .task .inst . may_block )
494496 if self .task .deliver_pending_cancel (cancellable ):
495497 return Cancelled .TRUE
496498 self .cancellable = cancellable
@@ -503,7 +505,7 @@ def suspend(self, cancellable) -> Cancelled:
503505 return cancelled
504506
505507 def wait_until (self , ready_func , cancellable = False ) -> Cancelled :
506- assert (self .running () and self .task .may_block () )
508+ assert (self .running () and self .task .inst . may_block )
507509 if self .task .deliver_pending_cancel (cancellable ):
508510 return Cancelled .TRUE
509511 if ready_func () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
@@ -514,7 +516,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
514516
515517 def yield_until (self , ready_func , cancellable ) -> Cancelled :
516518 assert (self .running ())
517- if self .task .may_block () :
519+ if self .task .inst . may_block :
518520 return self .wait_until (ready_func , cancellable )
519521 else :
520522 assert (ready_func ())
@@ -672,12 +674,14 @@ def thread_stop(self, thread):
672674 def needs_exclusive (self ):
673675 return not self .opts .async_ or self .opts .callback
674676
675- def may_block (self ):
676- return self .ft .async_ or self .state == Task .State .RESOLVED
677-
678677 def enter (self , thread ):
679678 assert (thread in self .threads and thread .task is self )
680679 if not self .ft .async_ :
680+ # TODO: what makes this true? where is the assert or trap_if? specifically
681+ # for sibling reentrance. maybe need to add back may_enter and only clear
682+ # at cooperative yield points (based on type)
683+ assert (self .inst .may_block )
684+ self .inst .may_block = False
681685 return True
682686 def has_backpressure ():
683687 return self .inst .backpressure > 0 or (self .needs_exclusive () and self .inst .exclusive )
@@ -696,6 +700,7 @@ def has_backpressure():
696700 def exit (self ):
697701 assert (len (self .threads ) > 0 )
698702 if not self .ft .async_ :
703+ assert (self .inst .may_block )
699704 return
700705 if self .needs_exclusive ():
701706 assert (self .inst .exclusive )
@@ -720,13 +725,17 @@ def deliver_pending_cancel(self, cancellable) -> bool:
720725 def return_ (self , result ):
721726 trap_if (self .state == Task .State .RESOLVED )
722727 trap_if (self .num_borrows > 0 )
728+ if not self .ft .async_ :
729+ assert (not self .inst .may_block )
730+ self .inst .may_block = True
723731 assert (result is not None )
724732 self .on_resolve (result )
725733 self .state = Task .State .RESOLVED
726734
727735 def cancel (self ):
728736 trap_if (self .state != Task .State .CANCEL_DELIVERED )
729737 trap_if (self .num_borrows > 0 )
738+ assert (self .ft .async_ )
730739 self .on_resolve (None )
731740 self .state = Task .State .RESOLVED
732741
@@ -2053,7 +2062,7 @@ def thread_func(thread):
20532062 else :
20542063 event = (EventCode .NONE , 0 , 0 )
20552064 case CallbackCode .WAIT :
2056- trap_if (not task .may_block () )
2065+ trap_if (not inst .may_block )
20572066 wset = inst .handles .get (si )
20582067 trap_if (not isinstance (wset , WaitableSet ))
20592068 event = wset .wait_until (lambda : not inst .exclusive , thread , cancellable = True )
@@ -2069,6 +2078,7 @@ def thread_func(thread):
20692078
20702079 thread = Thread (task , thread_func )
20712080 thread .resume (Cancelled .FALSE )
2081+ assert (ft .async_ or task .state == Task .State .RESOLVED )
20722082 return task
20732083
20742084class CallbackCode (IntEnum ):
@@ -2095,7 +2105,7 @@ def call_and_trap_on_throw(callee, thread, args):
20952105
20962106def canon_lower (opts , ft , callee : FuncInst , thread , flat_args ):
20972107 trap_if (not thread .task .inst .may_leave )
2098- trap_if (not thread .task .may_block () and ft .async_ and not opts .async_ )
2108+ trap_if (not thread .task .inst . may_block and ft .async_ and not opts .async_ )
20992109
21002110 subtask = Subtask ()
21012111 cx = LiftLowerContext (opts , thread .task .inst , subtask )
@@ -2272,7 +2282,7 @@ def canon_waitable_set_new(thread):
22722282
22732283def canon_waitable_set_wait (cancellable , mem , thread , si , ptr ):
22742284 trap_if (not thread .task .inst .may_leave )
2275- trap_if (not thread .task .may_block () )
2285+ trap_if (not thread .task .inst . may_block )
22762286 wset = thread .task .inst .handles .get (si )
22772287 trap_if (not isinstance (wset , WaitableSet ))
22782288 event = wset .wait (thread , cancellable )
@@ -2323,7 +2333,7 @@ def canon_waitable_join(thread, wi, si):
23232333
23242334def canon_subtask_cancel (async_ , thread , i ):
23252335 trap_if (not thread .task .inst .may_leave )
2326- trap_if (not thread .task .may_block () and not async_ )
2336+ trap_if (not thread .task .inst . may_block and not async_ )
23272337 subtask = thread .task .inst .handles .get (i )
23282338 trap_if (not isinstance (subtask , Subtask ))
23292339 trap_if (subtask .resolve_delivered ())
@@ -2380,7 +2390,7 @@ def canon_stream_write(stream_t, opts, thread, i, ptr, n):
23802390
23812391def stream_copy (EndT , BufferT , event_code , stream_t , opts , thread , i , ptr , n ):
23822392 trap_if (not thread .task .inst .may_leave )
2383- trap_if (not thread .task .may_block () and not opts .async_ )
2393+ trap_if (not thread .task .inst . may_block and not opts .async_ )
23842394
23852395 e = thread .task .inst .handles .get (i )
23862396 trap_if (not isinstance (e , EndT ))
@@ -2434,7 +2444,7 @@ def canon_future_write(future_t, opts, thread, i, ptr):
24342444
24352445def future_copy (EndT , BufferT , event_code , future_t , opts , thread , i , ptr ):
24362446 trap_if (not thread .task .inst .may_leave )
2437- trap_if (not thread .task .may_block () and not opts .async_ )
2447+ trap_if (not thread .task .inst . may_block and not opts .async_ )
24382448
24392449 e = thread .task .inst .handles .get (i )
24402450 trap_if (not isinstance (e , EndT ))
@@ -2486,7 +2496,7 @@ def canon_future_cancel_write(future_t, async_, thread, i):
24862496
24872497def cancel_copy (EndT , event_code , stream_or_future_t , async_ , thread , i ):
24882498 trap_if (not thread .task .inst .may_leave )
2489- trap_if (not thread .task .may_block () and not async_ )
2499+ trap_if (not thread .task .inst . may_block and not async_ )
24902500 e = thread .task .inst .handles .get (i )
24912501 trap_if (not isinstance (e , EndT ))
24922502 trap_if (e .shared .t != stream_or_future_t .t )
@@ -2563,7 +2573,7 @@ def canon_thread_resume_later(thread, i):
25632573
25642574def canon_thread_suspend (cancellable , thread ):
25652575 trap_if (not thread .task .inst .may_leave )
2566- trap_if (not thread .task .may_block () )
2576+ trap_if (not thread .task .inst . may_block )
25672577 cancelled = thread .suspend (cancellable )
25682578 return [cancelled ]
25692579
0 commit comments