@@ -185,21 +185,29 @@ class FutureType(ValType):
185185
186186class Store :
187187 waiting : list [Thread ]
188+ nesting_depth : int
188189
189190 def __init__ (self ):
190191 self .waiting = []
192+ self .nesting_depth = 0
191193
192194 def invoke (self , f : FuncInst , caller : Optional [Supertask ], on_start , on_resolve ) -> Call :
193195 host_caller = Supertask ()
194196 host_caller .inst = None
195197 host_caller .supertask = caller
196- return f (host_caller , on_start , on_resolve )
198+ self .nesting_depth += 1
199+ assert (self .nesting_depth == host_caller .num_host_callers ())
200+ call = f (host_caller , on_start , on_resolve )
201+ self .nesting_depth -= 1
197202
198203 def tick (self ):
204+ assert (self .nesting_depth == 0 )
199205 random .shuffle (self .waiting )
200206 for thread in self .waiting :
201207 if thread .ready ():
208+ self .nesting_depth = 1
202209 thread .resume (Cancelled .FALSE )
210+ self .nesting_depth = 0
203211 return
204212
205213FuncInst : Callable [[Optional [Supertask ], OnStart , OnResolve ], Call ]
@@ -211,6 +219,15 @@ class Supertask:
211219 inst : Optional [ComponentInstance ]
212220 supertask : Optional [Supertask ]
213221
222+ def num_host_callers (self ):
223+ n = 0
224+ t = self
225+ while t is not None :
226+ if t .inst is None :
227+ n += 1
228+ t = t .supertask
229+ return n
230+
214231class Call :
215232 request_cancellation : Callable [[], None ]
216233
@@ -286,6 +303,7 @@ class ComponentInstance:
286303 handles : Table [ResourceHandle | Waitable | WaitableSet | ErrorContext ]
287304 threads : Table [Thread ]
288305 may_leave : bool
306+ may_block : bool
289307 backpressure : int
290308 exclusive : Optional [Task ]
291309 num_waiting_to_enter : int
@@ -297,6 +315,7 @@ def __init__(self, store, parent = None):
297315 self .handles = Table ()
298316 self .threads = Table ()
299317 self .may_leave = True
318+ self .may_block = True
300319 self .backpressure = 0
301320 self .exclusive = None
302321 self .num_waiting_to_enter = 0
@@ -504,7 +523,7 @@ def resume(self, cancelled):
504523 cancelled = Cancelled .FALSE
505524
506525 def suspend (self , cancellable ) -> Cancelled :
507- assert (self .running () and self .task .may_block () )
526+ assert (self .running () and self .task .inst . may_block )
508527 if self .task .deliver_pending_cancel (cancellable ):
509528 return Cancelled .TRUE
510529 self .cancellable = cancellable
@@ -513,7 +532,7 @@ def suspend(self, cancellable) -> Cancelled:
513532 return cancelled
514533
515534 def wait_until (self , ready_func , cancellable = False ) -> Cancelled :
516- assert (self .running () and self .task .may_block () )
535+ assert (self .running () and self .task .inst . may_block )
517536 if self .task .deliver_pending_cancel (cancellable ):
518537 return Cancelled .TRUE
519538 if ready_func () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
@@ -524,7 +543,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
524543
525544 def yield_until (self , ready_func , cancellable ) -> Cancelled :
526545 assert (self .running ())
527- if self .task .may_block () :
546+ if self .task .inst . may_block :
528547 return self .wait_until (ready_func , cancellable )
529548 else :
530549 assert (ready_func ())
@@ -679,12 +698,11 @@ def thread_stop(self, thread):
679698 def needs_exclusive (self ):
680699 return not self .opts .async_ or self .opts .callback
681700
682- def may_block (self ):
683- return self .ft .async_ or self .state == Task .State .RESOLVED
684-
685701 def enter (self , thread ):
686702 assert (thread in self .threads and thread .task is self )
687703 if not self .ft .async_ :
704+ assert (self .inst .may_block )
705+ self .inst .may_block = False
688706 return True
689707 def has_backpressure ():
690708 return self .inst .backpressure > 0 or (self .needs_exclusive () and bool (self .inst .exclusive ))
@@ -735,13 +753,17 @@ def deliver_pending_cancel(self, cancellable) -> bool:
735753 def return_ (self , result ):
736754 trap_if (self .state == Task .State .RESOLVED )
737755 trap_if (self .num_borrows > 0 )
756+ if not self .ft .async_ :
757+ assert (not self .inst .may_block )
758+ self .inst .may_block = True
738759 assert (result is not None )
739760 self .on_resolve (result )
740761 self .state = Task .State .RESOLVED
741762
742763 def cancel (self ):
743764 trap_if (self .state != Task .State .CANCEL_DELIVERED )
744765 trap_if (self .num_borrows > 0 )
766+ assert (self .ft .async_ )
745767 self .on_resolve (None )
746768 self .state = Task .State .RESOLVED
747769
@@ -2078,7 +2100,7 @@ def thread_func(thread):
20782100 else :
20792101 event = (EventCode .NONE , 0 , 0 )
20802102 case CallbackCode .WAIT :
2081- trap_if (not task .may_block () )
2103+ trap_if (not inst .may_block )
20822104 wset = inst .handles .get (si )
20832105 trap_if (not isinstance (wset , WaitableSet ))
20842106 event = wset .wait_until (lambda : not inst .exclusive , thread , cancellable = True )
@@ -2094,6 +2116,7 @@ def thread_func(thread):
20942116
20952117 thread = Thread (task , thread_func )
20962118 thread .resume (Cancelled .FALSE )
2119+ assert (ft .async_ or task .state == Task .State .RESOLVED )
20972120 return task
20982121
20992122class CallbackCode (IntEnum ):
@@ -2120,7 +2143,7 @@ def call_and_trap_on_throw(callee, thread, args):
21202143
21212144def canon_lower (opts , ft , callee : FuncInst , thread , flat_args ):
21222145 trap_if (not thread .task .inst .may_leave )
2123- trap_if (not thread .task .may_block () and ft .async_ and not opts .async_ )
2146+ trap_if (not thread .task .inst . may_block and ft .async_ and not opts .async_ )
21242147
21252148 subtask = Subtask ()
21262149 cx = LiftLowerContext (opts , thread .task .inst , subtask )
@@ -2300,7 +2323,7 @@ def canon_waitable_set_new(thread):
23002323
23012324def canon_waitable_set_wait (cancellable , mem , thread , si , ptr ):
23022325 trap_if (not thread .task .inst .may_leave )
2303- trap_if (not thread .task .may_block () )
2326+ trap_if (not thread .task .inst . may_block )
23042327 wset = thread .task .inst .handles .get (si )
23052328 trap_if (not isinstance (wset , WaitableSet ))
23062329 event = wset .wait (thread , cancellable )
@@ -2351,7 +2374,7 @@ def canon_waitable_join(thread, wi, si):
23512374
23522375def canon_subtask_cancel (async_ , thread , i ):
23532376 trap_if (not thread .task .inst .may_leave )
2354- trap_if (not thread .task .may_block () and not async_ )
2377+ trap_if (not thread .task .inst . may_block and not async_ )
23552378 subtask = thread .task .inst .handles .get (i )
23562379 trap_if (not isinstance (subtask , Subtask ))
23572380 trap_if (subtask .resolve_delivered ())
@@ -2408,7 +2431,7 @@ def canon_stream_write(stream_t, opts, thread, i, ptr, n):
24082431
24092432def stream_copy (EndT , BufferT , event_code , stream_t , opts , thread , i , ptr , n ):
24102433 trap_if (not thread .task .inst .may_leave )
2411- trap_if (not thread .task .may_block () and not opts .async_ )
2434+ trap_if (not thread .task .inst . may_block and not opts .async_ )
24122435
24132436 e = thread .task .inst .handles .get (i )
24142437 trap_if (not isinstance (e , EndT ))
@@ -2462,7 +2485,7 @@ def canon_future_write(future_t, opts, thread, i, ptr):
24622485
24632486def future_copy (EndT , BufferT , event_code , future_t , opts , thread , i , ptr ):
24642487 trap_if (not thread .task .inst .may_leave )
2465- trap_if (not thread .task .may_block () and not opts .async_ )
2488+ trap_if (not thread .task .inst . may_block and not opts .async_ )
24662489
24672490 e = thread .task .inst .handles .get (i )
24682491 trap_if (not isinstance (e , EndT ))
@@ -2514,7 +2537,7 @@ def canon_future_cancel_write(future_t, async_, thread, i):
25142537
25152538def cancel_copy (EndT , event_code , stream_or_future_t , async_ , thread , i ):
25162539 trap_if (not thread .task .inst .may_leave )
2517- trap_if (not thread .task .may_block () and not async_ )
2540+ trap_if (not thread .task .inst . may_block and not async_ )
25182541 e = thread .task .inst .handles .get (i )
25192542 trap_if (not isinstance (e , EndT ))
25202543 trap_if (e .shared .t != stream_or_future_t .t )
@@ -2591,7 +2614,7 @@ def canon_thread_switch_to(cancellable, thread, i):
25912614
25922615def canon_thread_suspend (cancellable , thread ):
25932616 trap_if (not thread .task .inst .may_leave )
2594- trap_if (not thread .task .may_block () )
2617+ trap_if (not thread .task .inst . may_block )
25952618 cancelled = thread .suspend (cancellable )
25962619 return [cancelled ]
25972620
0 commit comments