@@ -185,22 +185,31 @@ class FutureType(ValType):
185185
186186class Store :
187187 waiting : list [Thread ]
188+ nesting_depth : int
188189
189190 def __init__ (self ):
190191 self .waiting = []
192+ self .nesting_depth = 0
191193
192194 def invoke (self , f : FuncInst , caller : Optional [Supertask ], on_start , on_resolve ) -> Call :
193195 host_caller = Supertask ()
194196 host_caller .inst = None
195197 host_caller .supertask = caller
196- return f (host_caller , on_start , on_resolve )
198+ self .nesting_depth += 1
199+ assert (self .nesting_depth == host_caller .num_host_callers ())
200+ call = f (host_caller , on_start , on_resolve )
201+ self .nesting_depth -= 1
202+ return call
197203
198204 def tick (self ):
205+ assert (self .nesting_depth == 0 )
206+ self .nesting_depth = 1
199207 random .shuffle (self .waiting )
200208 for thread in self .waiting :
201209 if thread .ready ():
202210 thread .resume (Cancelled .FALSE )
203- return
211+ break
212+ self .nesting_depth = 0
204213
205214FuncInst : Callable [[Optional [Supertask ], OnStart , OnResolve ], Call ]
206215
@@ -211,6 +220,15 @@ class Supertask:
211220 inst : Optional [ComponentInstance ]
212221 supertask : Optional [Supertask ]
213222
223+ def num_host_callers (self ):
224+ n = 0
225+ t = self
226+ while t is not None :
227+ if t .inst is None :
228+ n += 1
229+ t = t .supertask
230+ return n
231+
214232class Call :
215233 request_cancellation : Callable [[], None ]
216234
@@ -286,6 +304,7 @@ class ComponentInstance:
286304 handles : Table [ResourceHandle | Waitable | WaitableSet | ErrorContext ]
287305 threads : Table [Thread ]
288306 may_leave : bool
307+ may_block : bool
289308 backpressure : int
290309 exclusive : Optional [Task ]
291310 num_waiting_to_enter : int
@@ -297,6 +316,7 @@ def __init__(self, store, parent = None):
297316 self .handles = Table ()
298317 self .threads = Table ()
299318 self .may_leave = True
319+ self .may_block = True
300320 self .backpressure = 0
301321 self .exclusive = None
302322 self .num_waiting_to_enter = 0
@@ -504,7 +524,7 @@ def resume(self, cancelled):
504524 cancelled = Cancelled .FALSE
505525
506526 def suspend (self , cancellable ) -> Cancelled :
507- assert (self .running () and self .task .may_block () )
527+ assert (self .running () and self .task .inst . may_block )
508528 if self .task .deliver_pending_cancel (cancellable ):
509529 return Cancelled .TRUE
510530 self .cancellable = cancellable
@@ -513,7 +533,7 @@ def suspend(self, cancellable) -> Cancelled:
513533 return cancelled
514534
515535 def wait_until (self , ready_func , cancellable = False ) -> Cancelled :
516- assert (self .running () and self .task .may_block () )
536+ assert (self .running () and self .task .inst . may_block )
517537 if self .task .deliver_pending_cancel (cancellable ):
518538 return Cancelled .TRUE
519539 if ready_func () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
@@ -524,7 +544,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
524544
525545 def yield_until (self , ready_func , cancellable ) -> Cancelled :
526546 assert (self .running ())
527- if self .task .may_block () :
547+ if self .task .inst . may_block :
528548 return self .wait_until (ready_func , cancellable )
529549 else :
530550 assert (ready_func ())
@@ -669,9 +689,6 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
669689 def needs_exclusive (self ):
670690 return not self .opts .async_ or self .opts .callback
671691
672- def may_block (self ):
673- return self .ft .async_ or self .state == Task .State .RESOLVED
674-
675692 def enter (self ):
676693 thread = current_thread ()
677694 if self .ft .async_ :
@@ -689,6 +706,9 @@ def has_backpressure():
689706 if self .needs_exclusive ():
690707 assert (self .inst .exclusive is None )
691708 self .inst .exclusive = self
709+ else :
710+ assert (self .inst .may_block )
711+ self .inst .may_block = False
692712 self .register_thread (thread )
693713 return True
694714
@@ -738,13 +758,17 @@ def deliver_pending_cancel(self, cancellable) -> bool:
738758 def return_ (self , result ):
739759 trap_if (self .state == Task .State .RESOLVED )
740760 trap_if (self .num_borrows > 0 )
761+ if not self .ft .async_ :
762+ assert (not self .inst .may_block )
763+ self .inst .may_block = True
741764 assert (result is not None )
742765 self .on_resolve (result )
743766 self .state = Task .State .RESOLVED
744767
745768 def cancel (self ):
746769 trap_if (self .state != Task .State .CANCEL_DELIVERED )
747770 trap_if (self .num_borrows > 0 )
771+ assert (self .ft .async_ )
748772 self .on_resolve (None )
749773 self .state = Task .State .RESOLVED
750774
@@ -2078,7 +2102,7 @@ def thread_func():
20782102 else :
20792103 event = (EventCode .NONE , 0 , 0 )
20802104 case CallbackCode .WAIT :
2081- trap_if (not task .may_block () )
2105+ trap_if (not inst .may_block )
20822106 wset = inst .handles .get (si )
20832107 trap_if (not isinstance (wset , WaitableSet ))
20842108 event = wset .wait_until (lambda : not inst .exclusive , cancellable = True )
@@ -2094,6 +2118,7 @@ def thread_func():
20942118
20952119 thread = Thread (task , thread_func )
20962120 thread .resume (Cancelled .FALSE )
2121+ assert (ft .async_ or task .state == Task .State .RESOLVED )
20972122 return task
20982123
20992124class CallbackCode (IntEnum ):
@@ -2121,7 +2146,7 @@ def call_and_trap_on_throw(callee, args):
21212146def canon_lower (opts , ft , callee : FuncInst , flat_args ):
21222147 thread = current_thread ()
21232148 trap_if (not thread .task .inst .may_leave )
2124- trap_if (not thread .task .may_block () and ft .async_ and not opts .async_ )
2149+ trap_if (not thread .task .inst . may_block and ft .async_ and not opts .async_ )
21252150
21262151 subtask = Subtask ()
21272152 cx = LiftLowerContext (opts , thread .task .inst , subtask )
@@ -2309,7 +2334,7 @@ def canon_waitable_set_new():
23092334def canon_waitable_set_wait (cancellable , mem , si , ptr ):
23102335 thread = current_thread ()
23112336 trap_if (not thread .task .inst .may_leave )
2312- trap_if (not thread .task .may_block () )
2337+ trap_if (not thread .task .inst . may_block )
23132338 wset = thread .task .inst .handles .get (si )
23142339 trap_if (not isinstance (wset , WaitableSet ))
23152340 event = wset .wait (cancellable )
@@ -2364,7 +2389,7 @@ def canon_waitable_join(wi, si):
23642389def canon_subtask_cancel (async_ , i ):
23652390 thread = current_thread ()
23662391 trap_if (not thread .task .inst .may_leave )
2367- trap_if (not thread .task .may_block () and not async_ )
2392+ trap_if (not thread .task .inst . may_block and not async_ )
23682393 subtask = thread .task .inst .handles .get (i )
23692394 trap_if (not isinstance (subtask , Subtask ))
23702395 trap_if (subtask .resolve_delivered ())
@@ -2425,7 +2450,7 @@ def canon_stream_write(stream_t, opts, i, ptr, n):
24252450def stream_copy (EndT , BufferT , event_code , stream_t , opts , i , ptr , n ):
24262451 thread = current_thread ()
24272452 trap_if (not thread .task .inst .may_leave )
2428- trap_if (not thread .task .may_block () and not opts .async_ )
2453+ trap_if (not thread .task .inst . may_block and not opts .async_ )
24292454
24302455 e = thread .task .inst .handles .get (i )
24312456 trap_if (not isinstance (e , EndT ))
@@ -2480,7 +2505,7 @@ def canon_future_write(future_t, opts, i, ptr):
24802505def future_copy (EndT , BufferT , event_code , future_t , opts , i , ptr ):
24812506 thread = current_thread ()
24822507 trap_if (not thread .task .inst .may_leave )
2483- trap_if (not thread .task .may_block () and not opts .async_ )
2508+ trap_if (not thread .task .inst . may_block and not opts .async_ )
24842509
24852510 e = thread .task .inst .handles .get (i )
24862511 trap_if (not isinstance (e , EndT ))
@@ -2533,7 +2558,7 @@ def canon_future_cancel_write(future_t, async_, i):
25332558def cancel_copy (EndT , event_code , stream_or_future_t , async_ , i ):
25342559 thread = current_thread ()
25352560 trap_if (not thread .task .inst .may_leave )
2536- trap_if (not thread .task .may_block () and not async_ )
2561+ trap_if (not thread .task .inst . may_block and not async_ )
25372562 e = thread .task .inst .handles .get (i )
25382563 trap_if (not isinstance (e , EndT ))
25392564 trap_if (e .shared .t != stream_or_future_t .t )
@@ -2616,7 +2641,7 @@ def canon_thread_switch_to(cancellable, i):
26162641def canon_thread_suspend (cancellable ):
26172642 thread = current_thread ()
26182643 trap_if (not thread .task .inst .may_leave )
2619- trap_if (not thread .task .may_block () )
2644+ trap_if (not thread .task .inst . may_block )
26202645 cancelled = thread .suspend (cancellable )
26212646 return [cancelled ]
26222647
0 commit comments