@@ -258,7 +258,9 @@ class ComponentInstance:
258258 parent : Optional [ComponentInstance ]
259259 handles : Table [ResourceHandle | Waitable | WaitableSet | ErrorContext ]
260260 threads : Table [Thread ]
261+ may_enter : bool
261262 may_leave : bool
263+ may_block : bool
262264 backpressure : int
263265 exclusive : Optional [Task ]
264266 num_waiting_to_enter : int
@@ -269,7 +271,9 @@ def __init__(self, store, parent = None):
269271 self .parent = parent
270272 self .handles = Table ()
271273 self .threads = Table ()
274+ self .may_enter = True
272275 self .may_leave = True
276+ self .may_block = True
273277 self .backpressure = 0
274278 self .exclusive = None
275279 self .num_waiting_to_enter = 0
@@ -469,6 +473,8 @@ def resume(self, cancelled):
469473 self .ready_func = None
470474 self .task .inst .store .waiting .remove (self )
471475 assert (self .cancellable or not cancelled )
476+ assert (self .task .inst .may_enter )
477+ self .task .inst .may_enter = False
472478 thread = self
473479 while True :
474480 cont = thread .cont
@@ -481,9 +487,11 @@ def resume(self, cancelled):
481487 break
482488 thread = switch_to_thread
483489 cancelled = Cancelled .FALSE
490+ assert (not self .task .inst .may_enter )
491+ self .task .inst .may_enter = True
484492
485493 def suspend (self , cancellable ) -> Cancelled :
486- assert (self .running () and self .task .may_block () )
494+ assert (self .running () and self .task .inst . may_block )
487495 if self .task .deliver_pending_cancel (cancellable ):
488496 return Cancelled .TRUE
489497 self .cancellable = cancellable
@@ -492,7 +500,7 @@ def suspend(self, cancellable) -> Cancelled:
492500 return cancelled
493501
494502 def wait_until (self , ready_func , cancellable = False ) -> Cancelled :
495- assert (self .running () and self .task .may_block () )
503+ assert (self .running () and self .task .inst . may_block )
496504 if self .task .deliver_pending_cancel (cancellable ):
497505 return Cancelled .TRUE
498506 if ready_func () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
@@ -503,7 +511,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
503511
504512 def yield_until (self , ready_func , cancellable ) -> Cancelled :
505513 assert (self .running ())
506- if self .task .may_block () :
514+ if self .task .inst . may_block :
507515 return self .wait_until (ready_func , cancellable )
508516 else :
509517 assert (ready_func ())
@@ -658,12 +666,11 @@ def thread_stop(self, thread):
658666 def needs_exclusive (self ):
659667 return not self .opts .async_ or self .opts .callback
660668
661- def may_block (self ):
662- return self .ft .async_ or self .state == Task .State .RESOLVED
663-
664669 def enter (self , thread ):
665670 assert (thread in self .threads and thread .task is self )
666671 if not self .ft .async_ :
672+ assert (self .inst .may_block ) # TODO: why
673+ self .inst .may_block = False
667674 return True
668675 def has_backpressure ():
669676 return self .inst .backpressure > 0 or (self .needs_exclusive () and bool (self .inst .exclusive ))
@@ -714,13 +721,17 @@ def deliver_pending_cancel(self, cancellable) -> bool:
714721 def return_ (self , result ):
715722 trap_if (self .state == Task .State .RESOLVED )
716723 trap_if (self .num_borrows > 0 )
724+ if not self .ft .async_ :
725+ assert (not self .inst .may_block )
726+ self .inst .may_block = True
717727 assert (result is not None )
718728 self .on_resolve (result )
719729 self .state = Task .State .RESOLVED
720730
721731 def cancel (self ):
722732 trap_if (self .state != Task .State .CANCEL_DELIVERED )
723733 trap_if (self .num_borrows > 0 )
734+ assert (self .ft .async_ )
724735 self .on_resolve (None )
725736 self .state = Task .State .RESOLVED
726737
@@ -2002,7 +2013,9 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
20022013### `canon lift`
20032014
20042015def canon_lift (opts , inst , ft , callee , caller , on_start , on_resolve ) -> Call :
2016+ trap_if (not inst .may_enter )
20052017 trap_if (call_might_be_recursive (caller , inst ))
2018+
20062019 task = Task (opts , inst , ft , caller , on_resolve )
20072020 def thread_func (thread ):
20082021 if not task .enter (thread ):
@@ -2047,7 +2060,7 @@ def thread_func(thread):
20472060 else :
20482061 event = (EventCode .NONE , 0 , 0 )
20492062 case CallbackCode .WAIT :
2050- trap_if (not task .may_block () )
2063+ trap_if (not inst .may_block )
20512064 wset = inst .handles .get (si )
20522065 trap_if (not isinstance (wset , WaitableSet ))
20532066 event = wset .wait_until (lambda : not inst .exclusive , thread , cancellable = True )
@@ -2063,6 +2076,7 @@ def thread_func(thread):
20632076
20642077 thread = Thread (task , thread_func )
20652078 thread .resume (Cancelled .FALSE )
2079+ assert (ft .async_ or task .state == Task .State .RESOLVED )
20662080 return task
20672081
20682082class CallbackCode (IntEnum ):
@@ -2089,7 +2103,7 @@ def call_and_trap_on_throw(callee, thread, args):
20892103
20902104def canon_lower (opts , ft , callee : FuncInst , thread , flat_args ):
20912105 trap_if (not thread .task .inst .may_leave )
2092- trap_if (not thread .task .may_block () and ft .async_ and not opts .async_ )
2106+ trap_if (not thread .task .inst . may_block and ft .async_ and not opts .async_ )
20932107
20942108 subtask = Subtask ()
20952109 cx = LiftLowerContext (opts , thread .task .inst , subtask )
@@ -2264,7 +2278,7 @@ def canon_waitable_set_new(thread):
22642278
22652279def canon_waitable_set_wait (cancellable , mem , thread , si , ptr ):
22662280 trap_if (not thread .task .inst .may_leave )
2267- trap_if (not thread .task .may_block () )
2281+ trap_if (not thread .task .inst . may_block )
22682282 wset = thread .task .inst .handles .get (si )
22692283 trap_if (not isinstance (wset , WaitableSet ))
22702284 event = wset .wait (thread , cancellable )
@@ -2315,7 +2329,7 @@ def canon_waitable_join(thread, wi, si):
23152329
23162330def canon_subtask_cancel (async_ , thread , i ):
23172331 trap_if (not thread .task .inst .may_leave )
2318- trap_if (not thread .task .may_block () and not async_ )
2332+ trap_if (not thread .task .inst . may_block and not async_ )
23192333 subtask = thread .task .inst .handles .get (i )
23202334 trap_if (not isinstance (subtask , Subtask ))
23212335 trap_if (subtask .resolve_delivered ())
@@ -2372,7 +2386,7 @@ def canon_stream_write(stream_t, opts, thread, i, ptr, n):
23722386
23732387def stream_copy (EndT , BufferT , event_code , stream_t , opts , thread , i , ptr , n ):
23742388 trap_if (not thread .task .inst .may_leave )
2375- trap_if (not thread .task .may_block () and not opts .async_ )
2389+ trap_if (not thread .task .inst . may_block and not opts .async_ )
23762390
23772391 e = thread .task .inst .handles .get (i )
23782392 trap_if (not isinstance (e , EndT ))
@@ -2426,7 +2440,7 @@ def canon_future_write(future_t, opts, thread, i, ptr):
24262440
24272441def future_copy (EndT , BufferT , event_code , future_t , opts , thread , i , ptr ):
24282442 trap_if (not thread .task .inst .may_leave )
2429- trap_if (not thread .task .may_block () and not opts .async_ )
2443+ trap_if (not thread .task .inst . may_block and not opts .async_ )
24302444
24312445 e = thread .task .inst .handles .get (i )
24322446 trap_if (not isinstance (e , EndT ))
@@ -2478,7 +2492,7 @@ def canon_future_cancel_write(future_t, async_, thread, i):
24782492
24792493def cancel_copy (EndT , event_code , stream_or_future_t , async_ , thread , i ):
24802494 trap_if (not thread .task .inst .may_leave )
2481- trap_if (not thread .task .may_block () and not async_ )
2495+ trap_if (not thread .task .inst . may_block and not async_ )
24822496 e = thread .task .inst .handles .get (i )
24832497 trap_if (not isinstance (e , EndT ))
24842498 trap_if (e .shared .t != stream_or_future_t .t )
@@ -2555,7 +2569,7 @@ def canon_thread_resume_later(thread, i):
25552569
25562570def canon_thread_suspend (cancellable , thread ):
25572571 trap_if (not thread .task .inst .may_leave )
2558- trap_if (not thread .task .may_block () )
2572+ trap_if (not thread .task .inst . may_block )
25592573 cancelled = thread .suspend (cancellable )
25602574 return [cancelled ]
25612575
0 commit comments