@@ -185,21 +185,29 @@ class FutureType(ValType):
185185
186186class Store :
187187 waiting : list [Thread ]
188+ nesting_depth : int
188189
189190 def __init__ (self ):
190191 self .waiting = []
192+ self .nesting_depth = 0
191193
192194 def invoke (self , f : FuncInst , caller : Optional [Supertask ], on_start , on_resolve ) -> Call :
193195 host_caller = Supertask ()
194196 host_caller .inst = None
195197 host_caller .supertask = caller
196- return f (host_caller , on_start , on_resolve )
198+ self .nesting_depth += 1
199+ assert (self .nesting_depth == host_caller .num_host_callers ())
200+ call = f (host_caller , on_start , on_resolve )
201+ self .nesting_depth -= 1
197202
198203 def tick (self ):
204+ assert (self .nesting_depth == 0 )
199205 random .shuffle (self .waiting )
200206 for thread in self .waiting :
201207 if thread .ready ():
208+ self .nesting_depth = 1
202209 thread .resume (Cancelled .FALSE )
210+ self .nesting_depth = 0
203211 return
204212
205213FuncInst : Callable [[Optional [Supertask ], OnStart , OnResolve ], Call ]
@@ -211,6 +219,15 @@ class Supertask:
211219 inst : Optional [ComponentInstance ]
212220 supertask : Optional [Supertask ]
213221
222+ def num_host_callers (self ):
223+ n = 0
224+ t = self
225+ while t is not None :
226+ if t .inst is None :
227+ n += 1
228+ t = t .supertask
229+ return n
230+
214231class Call :
215232 request_cancellation : Callable [[], None ]
216233
@@ -286,6 +303,7 @@ class ComponentInstance:
286303 handles : Table [ResourceHandle | Waitable | WaitableSet | ErrorContext ]
287304 threads : Table [Thread ]
288305 may_leave : bool
306+ may_block : bool
289307 backpressure : int
290308 exclusive : Optional [Task ]
291309 num_waiting_to_enter : int
@@ -297,6 +315,7 @@ def __init__(self, store, parent = None):
297315 self .handles = Table ()
298316 self .threads = Table ()
299317 self .may_leave = True
318+ self .may_block = True
300319 self .backpressure = 0
301320 self .exclusive = None
302321 self .num_waiting_to_enter = 0
@@ -509,7 +528,7 @@ def resume(self, cancelled):
509528 cancelled = Cancelled .FALSE
510529
511530 def suspend (self , cancellable ) -> Cancelled :
512- assert (self .running () and self .task .may_block () )
531+ assert (self .running () and self .task .inst . may_block )
513532 if self .task .deliver_pending_cancel (cancellable ):
514533 return Cancelled .TRUE
515534 self .cancellable = cancellable
@@ -518,7 +537,7 @@ def suspend(self, cancellable) -> Cancelled:
518537 return cancelled
519538
520539 def wait_until (self , ready_func , cancellable = False ) -> Cancelled :
521- assert (self .running () and self .task .may_block () )
540+ assert (self .running () and self .task .inst . may_block )
522541 if self .task .deliver_pending_cancel (cancellable ):
523542 return Cancelled .TRUE
524543 if ready_func () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
@@ -529,7 +548,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
529548
530549 def yield_until (self , ready_func , cancellable ) -> Cancelled :
531550 assert (self .running ())
532- if self .task .may_block () :
551+ if self .task .inst . may_block :
533552 return self .wait_until (ready_func , cancellable )
534553 else :
535554 assert (ready_func ())
@@ -684,13 +703,12 @@ def thread_stop(self, thread):
684703 def needs_exclusive (self ):
685704 return not self .opts .async_ or self .opts .callback
686705
687- def may_block (self ):
688- return self .ft .async_ or self .state == Task .State .RESOLVED
689-
690706 def enter (self ):
691707 thread = current_thread ()
692708 assert (thread in self .threads and thread .task is self )
693709 if not self .ft .async_ :
710+ assert (self .inst .may_block )
711+ self .inst .may_block = False
694712 return True
695713 def has_backpressure ():
696714 return self .inst .backpressure > 0 or (self .needs_exclusive () and bool (self .inst .exclusive ))
@@ -741,13 +759,17 @@ def deliver_pending_cancel(self, cancellable) -> bool:
741759 def return_ (self , result ):
742760 trap_if (self .state == Task .State .RESOLVED )
743761 trap_if (self .num_borrows > 0 )
762+ if not self .ft .async_ :
763+ assert (not self .inst .may_block )
764+ self .inst .may_block = True
744765 assert (result is not None )
745766 self .on_resolve (result )
746767 self .state = Task .State .RESOLVED
747768
748769 def cancel (self ):
749770 trap_if (self .state != Task .State .CANCEL_DELIVERED )
750771 trap_if (self .num_borrows > 0 )
772+ assert (self .ft .async_ )
751773 self .on_resolve (None )
752774 self .state = Task .State .RESOLVED
753775
@@ -2084,7 +2106,7 @@ def thread_func():
20842106 else :
20852107 event = (EventCode .NONE , 0 , 0 )
20862108 case CallbackCode .WAIT :
2087- trap_if (not task .may_block () )
2109+ trap_if (not inst .may_block )
20882110 wset = inst .handles .get (si )
20892111 trap_if (not isinstance (wset , WaitableSet ))
20902112 event = wset .wait_until (lambda : not inst .exclusive , cancellable = True )
@@ -2100,6 +2122,7 @@ def thread_func():
21002122
21012123 thread = Thread (task , thread_func )
21022124 thread .resume (Cancelled .FALSE )
2125+ assert (ft .async_ or task .state == Task .State .RESOLVED )
21032126 return task
21042127
21052128class CallbackCode (IntEnum ):
@@ -2127,7 +2150,7 @@ def call_and_trap_on_throw(callee, args):
21272150def canon_lower (opts , ft , callee : FuncInst , flat_args ):
21282151 thread = current_thread ()
21292152 trap_if (not thread .task .inst .may_leave )
2130- trap_if (not thread .task .may_block () and ft .async_ and not opts .async_ )
2153+ trap_if (not thread .task .inst . may_block and ft .async_ and not opts .async_ )
21312154
21322155 subtask = Subtask ()
21332156 cx = LiftLowerContext (opts , thread .task .inst , subtask )
@@ -2313,7 +2336,7 @@ def canon_waitable_set_new():
23132336def canon_waitable_set_wait (cancellable , mem , si , ptr ):
23142337 thread = current_thread ()
23152338 trap_if (not thread .task .inst .may_leave )
2316- trap_if (not thread .task .may_block () )
2339+ trap_if (not thread .task .inst . may_block )
23172340 wset = thread .task .inst .handles .get (si )
23182341 trap_if (not isinstance (wset , WaitableSet ))
23192342 event = wset .wait (cancellable )
@@ -2368,7 +2391,7 @@ def canon_waitable_join(wi, si):
23682391def canon_subtask_cancel (async_ , i ):
23692392 thread = current_thread ()
23702393 trap_if (not thread .task .inst .may_leave )
2371- trap_if (not thread .task .may_block () and not async_ )
2394+ trap_if (not thread .task .inst . may_block and not async_ )
23722395 subtask = thread .task .inst .handles .get (i )
23732396 trap_if (not isinstance (subtask , Subtask ))
23742397 trap_if (subtask .resolve_delivered ())
@@ -2429,7 +2452,7 @@ def canon_stream_write(stream_t, opts, i, ptr, n):
24292452def stream_copy (EndT , BufferT , event_code , stream_t , opts , i , ptr , n ):
24302453 thread = current_thread ()
24312454 trap_if (not thread .task .inst .may_leave )
2432- trap_if (not thread .task .may_block () and not opts .async_ )
2455+ trap_if (not thread .task .inst . may_block and not opts .async_ )
24332456
24342457 e = thread .task .inst .handles .get (i )
24352458 trap_if (not isinstance (e , EndT ))
@@ -2484,7 +2507,7 @@ def canon_future_write(future_t, opts, i, ptr):
24842507def future_copy (EndT , BufferT , event_code , future_t , opts , i , ptr ):
24852508 thread = current_thread ()
24862509 trap_if (not thread .task .inst .may_leave )
2487- trap_if (not thread .task .may_block () and not opts .async_ )
2510+ trap_if (not thread .task .inst . may_block and not opts .async_ )
24882511
24892512 e = thread .task .inst .handles .get (i )
24902513 trap_if (not isinstance (e , EndT ))
@@ -2537,7 +2560,7 @@ def canon_future_cancel_write(future_t, async_, i):
25372560def cancel_copy (EndT , event_code , stream_or_future_t , async_ , i ):
25382561 thread = current_thread ()
25392562 trap_if (not thread .task .inst .may_leave )
2540- trap_if (not thread .task .may_block () and not async_ )
2563+ trap_if (not thread .task .inst . may_block and not async_ )
25412564 e = thread .task .inst .handles .get (i )
25422565 trap_if (not isinstance (e , EndT ))
25432566 trap_if (e .shared .t != stream_or_future_t .t )
@@ -2619,7 +2642,7 @@ def canon_thread_switch_to(cancellable, i):
26192642def canon_thread_suspend (cancellable ):
26202643 thread = current_thread ()
26212644 trap_if (not thread .task .inst .may_leave )
2622- trap_if (not thread .task .may_block () )
2645+ trap_if (not thread .task .inst . may_block )
26232646 cancelled = thread .suspend (cancellable )
26242647 return [cancelled ]
26252648
0 commit comments