@@ -185,21 +185,29 @@ class FutureType(ValType):
185185
186186class Store :
187187 waiting : list [Thread ]
188+ nesting_depth : int
188189
189190 def __init__ (self ):
190191 self .waiting = []
192+ self .nesting_depth = 0
191193
192194 def invoke (self , f : FuncInst , caller : Optional [Supertask ], on_start , on_resolve ) -> Call :
193195 host_caller = Supertask ()
194196 host_caller .inst = None
195197 host_caller .supertask = caller
196- return f (host_caller , on_start , on_resolve )
198+ self .nesting_depth += 1
199+ assert (self .nesting_depth == host_caller .num_host_callers ())
200+ call = f (host_caller , on_start , on_resolve )
201+ self .nesting_depth -= 1
197202
198203 def tick (self ):
204+ assert (self .nesting_depth == 0 )
199205 random .shuffle (self .waiting )
200206 for thread in self .waiting :
201207 if thread .ready ():
208+ self .nesting_depth = 1
202209 thread .resume (Cancelled .FALSE )
210+ self .nesting_depth = 0
203211 return
204212
205213FuncInst : Callable [[Optional [Supertask ], OnStart , OnResolve ], Call ]
@@ -211,6 +219,15 @@ class Supertask:
211219 inst : Optional [ComponentInstance ]
212220 supertask : Optional [Supertask ]
213221
222+ def num_host_callers (self ):
223+ n = 0
224+ t = self
225+ while t is not None :
226+ if t .inst is None :
227+ n += 1
228+ t = t .supertask
229+ return n
230+
214231class Call :
215232 request_cancellation : Callable [[], None ]
216233
@@ -286,6 +303,7 @@ class ComponentInstance:
286303 handles : Table [ResourceHandle | Waitable | WaitableSet | ErrorContext ]
287304 threads : Table [Thread ]
288305 may_leave : bool
306+ may_block : bool
289307 backpressure : int
290308 exclusive : Optional [Task ]
291309 num_waiting_to_enter : int
@@ -297,6 +315,7 @@ def __init__(self, store, parent = None):
297315 self .handles = Table ()
298316 self .threads = Table ()
299317 self .may_leave = True
318+ self .may_block = True
300319 self .backpressure = 0
301320 self .exclusive = None
302321 self .num_waiting_to_enter = 0
@@ -508,7 +527,7 @@ def resume(self, cancelled):
508527 cancelled = Cancelled .FALSE
509528
510529 def suspend (self , cancellable ) -> Cancelled :
511- assert (self .running () and self .task .may_block () )
530+ assert (self .running () and self .task .inst . may_block )
512531 if self .task .deliver_pending_cancel (cancellable ):
513532 return Cancelled .TRUE
514533 self .cancellable = cancellable
@@ -517,7 +536,7 @@ def suspend(self, cancellable) -> Cancelled:
517536 return cancelled
518537
519538 def wait_until (self , ready_func , cancellable = False ) -> Cancelled :
520- assert (self .running () and self .task .may_block () )
539+ assert (self .running () and self .task .inst . may_block )
521540 if self .task .deliver_pending_cancel (cancellable ):
522541 return Cancelled .TRUE
523542 if ready_func () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
@@ -528,7 +547,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
528547
529548 def yield_until (self , ready_func , cancellable ) -> Cancelled :
530549 assert (self .running ())
531- if self .task .may_block () :
550+ if self .task .inst . may_block :
532551 return self .wait_until (ready_func , cancellable )
533552 else :
534553 assert (ready_func ())
@@ -683,12 +702,11 @@ def thread_stop(self, thread):
683702 def needs_exclusive (self ):
684703 return not self .opts .async_ or self .opts .callback
685704
686- def may_block (self ):
687- return self .ft .async_ or self .state == Task .State .RESOLVED
688-
689705 def enter (self , thread ):
690706 assert (thread in self .threads and thread .task is self )
691707 if not self .ft .async_ :
708+ assert (self .inst .may_block )
709+ self .inst .may_block = False
692710 return True
693711 def has_backpressure ():
694712 return self .inst .backpressure > 0 or (self .needs_exclusive () and bool (self .inst .exclusive ))
@@ -739,13 +757,17 @@ def deliver_pending_cancel(self, cancellable) -> bool:
739757 def return_ (self , result ):
740758 trap_if (self .state == Task .State .RESOLVED )
741759 trap_if (self .num_borrows > 0 )
760+ if not self .ft .async_ :
761+ assert (not self .inst .may_block )
762+ self .inst .may_block = True
742763 assert (result is not None )
743764 self .on_resolve (result )
744765 self .state = Task .State .RESOLVED
745766
746767 def cancel (self ):
747768 trap_if (self .state != Task .State .CANCEL_DELIVERED )
748769 trap_if (self .num_borrows > 0 )
770+ assert (self .ft .async_ )
749771 self .on_resolve (None )
750772 self .state = Task .State .RESOLVED
751773
@@ -2082,7 +2104,7 @@ def thread_func(thread):
20822104 else :
20832105 event = (EventCode .NONE , 0 , 0 )
20842106 case CallbackCode .WAIT :
2085- trap_if (not task .may_block () )
2107+ trap_if (not inst .may_block )
20862108 wset = inst .handles .get (si )
20872109 trap_if (not isinstance (wset , WaitableSet ))
20882110 event = wset .wait_until (lambda : not inst .exclusive , thread , cancellable = True )
@@ -2098,6 +2120,7 @@ def thread_func(thread):
20982120
20992121 thread = Thread (task , thread_func )
21002122 thread .resume (Cancelled .FALSE )
2123+ assert (ft .async_ or task .state == Task .State .RESOLVED )
21012124 return task
21022125
21032126class CallbackCode (IntEnum ):
@@ -2124,7 +2147,7 @@ def call_and_trap_on_throw(callee, thread, args):
21242147
21252148def canon_lower (opts , ft , callee : FuncInst , thread , flat_args ):
21262149 trap_if (not thread .task .inst .may_leave )
2127- trap_if (not thread .task .may_block () and ft .async_ and not opts .async_ )
2150+ trap_if (not thread .task .inst . may_block and ft .async_ and not opts .async_ )
21282151
21292152 subtask = Subtask ()
21302153 cx = LiftLowerContext (opts , thread .task .inst , subtask )
@@ -2304,7 +2327,7 @@ def canon_waitable_set_new(thread):
23042327
23052328def canon_waitable_set_wait (cancellable , mem , thread , si , ptr ):
23062329 trap_if (not thread .task .inst .may_leave )
2307- trap_if (not thread .task .may_block () )
2330+ trap_if (not thread .task .inst . may_block )
23082331 wset = thread .task .inst .handles .get (si )
23092332 trap_if (not isinstance (wset , WaitableSet ))
23102333 event = wset .wait (thread , cancellable )
@@ -2355,7 +2378,7 @@ def canon_waitable_join(thread, wi, si):
23552378
23562379def canon_subtask_cancel (async_ , thread , i ):
23572380 trap_if (not thread .task .inst .may_leave )
2358- trap_if (not thread .task .may_block () and not async_ )
2381+ trap_if (not thread .task .inst . may_block and not async_ )
23592382 subtask = thread .task .inst .handles .get (i )
23602383 trap_if (not isinstance (subtask , Subtask ))
23612384 trap_if (subtask .resolve_delivered ())
@@ -2412,7 +2435,7 @@ def canon_stream_write(stream_t, opts, thread, i, ptr, n):
24122435
24132436def stream_copy (EndT , BufferT , event_code , stream_t , opts , thread , i , ptr , n ):
24142437 trap_if (not thread .task .inst .may_leave )
2415- trap_if (not thread .task .may_block () and not opts .async_ )
2438+ trap_if (not thread .task .inst . may_block and not opts .async_ )
24162439
24172440 e = thread .task .inst .handles .get (i )
24182441 trap_if (not isinstance (e , EndT ))
@@ -2466,7 +2489,7 @@ def canon_future_write(future_t, opts, thread, i, ptr):
24662489
24672490def future_copy (EndT , BufferT , event_code , future_t , opts , thread , i , ptr ):
24682491 trap_if (not thread .task .inst .may_leave )
2469- trap_if (not thread .task .may_block () and not opts .async_ )
2492+ trap_if (not thread .task .inst . may_block and not opts .async_ )
24702493
24712494 e = thread .task .inst .handles .get (i )
24722495 trap_if (not isinstance (e , EndT ))
@@ -2518,7 +2541,7 @@ def canon_future_cancel_write(future_t, async_, thread, i):
25182541
25192542def cancel_copy (EndT , event_code , stream_or_future_t , async_ , thread , i ):
25202543 trap_if (not thread .task .inst .may_leave )
2521- trap_if (not thread .task .may_block () and not async_ )
2544+ trap_if (not thread .task .inst . may_block and not async_ )
25222545 e = thread .task .inst .handles .get (i )
25232546 trap_if (not isinstance (e , EndT ))
25242547 trap_if (e .shared .t != stream_or_future_t .t )
@@ -2595,7 +2618,7 @@ def canon_thread_switch_to(cancellable, thread, i):
25952618
25962619def canon_thread_suspend (cancellable , thread ):
25972620 trap_if (not thread .task .inst .may_leave )
2598- trap_if (not thread .task .may_block () )
2621+ trap_if (not thread .task .inst . may_block )
25992622 cancelled = thread .suspend (cancellable )
26002623 return [cancelled ]
26012624
0 commit comments