@@ -185,21 +185,29 @@ class FutureType(ValType):
185185
186186class Store :
187187 waiting : list [Thread ]
188+ nesting_depth : int
188189
189190 def __init__ (self ):
190191 self .waiting = []
192+ self .nesting_depth = 0
191193
192194 def invoke (self , f : FuncInst , caller : Optional [Supertask ], on_start , on_resolve ) -> Call :
193195 host_caller = Supertask ()
194196 host_caller .inst = None
195197 host_caller .supertask = caller
196- return f (host_caller , on_start , on_resolve )
198+ self .nesting_depth += 1
199+ assert (self .nesting_depth == host_caller .num_host_callers ())
200+ call = f (host_caller , on_start , on_resolve )
201+ self .nesting_depth -= 1
197202
198203 def tick (self ):
204+ assert (self .nesting_depth == 0 )
199205 random .shuffle (self .waiting )
200206 for thread in self .waiting :
201207 if thread .ready ():
208+ self .nesting_depth = 1
202209 thread .resume (Cancelled .FALSE )
210+ self .nesting_depth = 0
203211 return
204212
205213FuncInst : Callable [[Optional [Supertask ], OnStart , OnResolve ], Call ]
@@ -211,6 +219,15 @@ class Supertask:
211219 inst : Optional [ComponentInstance ]
212220 supertask : Optional [Supertask ]
213221
222+ def num_host_callers (self ):
223+ n = 0
224+ t = self
225+ while t is not None :
226+ if t .inst is None :
227+ n += 1
228+ t = t .supertask
229+ return n
230+
214231class Call :
215232 request_cancellation : Callable [[], None ]
216233
@@ -285,7 +302,9 @@ class ComponentInstance:
285302 parent : Optional [ComponentInstance ]
286303 handles : Table [ResourceHandle | Waitable | WaitableSet | ErrorContext ]
287304 threads : Table [Thread ]
305+ may_enter : bool
288306 may_leave : bool
307+ may_block : bool
289308 backpressure : int
290309 exclusive : Optional [Task ]
291310 num_waiting_to_enter : int
@@ -296,7 +315,9 @@ def __init__(self, store, parent = None):
296315 self .parent = parent
297316 self .handles = Table ()
298317 self .threads = Table ()
318+ self .may_enter = True
299319 self .may_leave = True
320+ self .may_block = True
300321 self .backpressure = 0
301322 self .exclusive = None
302323 self .num_waiting_to_enter = 0
@@ -489,7 +510,8 @@ def resume_later(self):
489510
490511 def resume (self , cancelled ):
491512 assert (self .cancellable or not cancelled )
492- assert (not self .running ())
513+ assert (not self .running () and self .task .inst .may_enter )
514+ self .task .inst .may_enter = False
493515 if self .waiting ():
494516 assert (cancelled or self .ready ())
495517 self .ready_func = None
@@ -506,9 +528,11 @@ def resume(self, cancelled):
506528 break
507529 thread = switch_to_thread
508530 cancelled = Cancelled .FALSE
531+ assert (not self .task .inst .may_enter )
532+ self .task .inst .may_enter = True
509533
510534 def suspend (self , cancellable ) -> Cancelled :
511- assert (self .running () and self .task .may_block () )
535+ assert (self .running () and self .task .inst . may_block )
512536 if self .task .deliver_pending_cancel (cancellable ):
513537 return Cancelled .TRUE
514538 self .cancellable = cancellable
@@ -517,7 +541,7 @@ def suspend(self, cancellable) -> Cancelled:
517541 return cancelled
518542
519543 def wait_until (self , ready_func , cancellable = False ) -> Cancelled :
520- assert (self .running () and self .task .may_block () )
544+ assert (self .running () and self .task .inst . may_block )
521545 if self .task .deliver_pending_cancel (cancellable ):
522546 return Cancelled .TRUE
523547 if ready_func () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
@@ -528,7 +552,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
528552
529553 def yield_until (self , ready_func , cancellable ) -> Cancelled :
530554 assert (self .running ())
531- if self .task .may_block () :
555+ if self .task .inst . may_block :
532556 return self .wait_until (ready_func , cancellable )
533557 else :
534558 assert (ready_func ())
@@ -683,12 +707,11 @@ def thread_stop(self, thread):
683707 def needs_exclusive (self ):
684708 return not self .opts .async_ or self .opts .callback
685709
686- def may_block (self ):
687- return self .ft .async_ or self .state == Task .State .RESOLVED
688-
689710 def enter (self , thread ):
690711 assert (thread in self .threads and thread .task is self )
691712 if not self .ft .async_ :
713+ assert (self .inst .may_block )
714+ self .inst .may_block = False
692715 return True
693716 def has_backpressure ():
694717 return self .inst .backpressure > 0 or (self .needs_exclusive () and bool (self .inst .exclusive ))
@@ -739,13 +762,17 @@ def deliver_pending_cancel(self, cancellable) -> bool:
739762 def return_ (self , result ):
740763 trap_if (self .state == Task .State .RESOLVED )
741764 trap_if (self .num_borrows > 0 )
765+ if not self .ft .async_ :
766+ assert (not self .inst .may_block )
767+ self .inst .may_block = True
742768 assert (result is not None )
743769 self .on_resolve (result )
744770 self .state = Task .State .RESOLVED
745771
746772 def cancel (self ):
747773 trap_if (self .state != Task .State .CANCEL_DELIVERED )
748774 trap_if (self .num_borrows > 0 )
775+ assert (self .ft .async_ )
749776 self .on_resolve (None )
750777 self .state = Task .State .RESOLVED
751778
@@ -2038,6 +2065,8 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
20382065
20392066def canon_lift (opts , inst , ft , callee , caller , on_start , on_resolve ) -> Call :
20402067 trap_if (call_might_be_recursive (caller , inst ))
2068+ assert (inst .may_enter ) # is it actually guaranteed by `call_might_be_recursive`?
2069+
20412070 task = Task (opts , inst , ft , caller , on_resolve )
20422071 def thread_func (thread ):
20432072 if not task .enter (thread ):
@@ -2082,7 +2111,7 @@ def thread_func(thread):
20822111 else :
20832112 event = (EventCode .NONE , 0 , 0 )
20842113 case CallbackCode .WAIT :
2085- trap_if (not task .may_block () )
2114+ trap_if (not inst .may_block )
20862115 wset = inst .handles .get (si )
20872116 trap_if (not isinstance (wset , WaitableSet ))
20882117 event = wset .wait_until (lambda : not inst .exclusive , thread , cancellable = True )
@@ -2098,6 +2127,7 @@ def thread_func(thread):
20982127
20992128 thread = Thread (task , thread_func )
21002129 thread .resume (Cancelled .FALSE )
2130+ assert (ft .async_ or task .state == Task .State .RESOLVED )
21012131 return task
21022132
21032133class CallbackCode (IntEnum ):
@@ -2124,7 +2154,7 @@ def call_and_trap_on_throw(callee, thread, args):
21242154
21252155def canon_lower (opts , ft , callee : FuncInst , thread , flat_args ):
21262156 trap_if (not thread .task .inst .may_leave )
2127- trap_if (not thread .task .may_block () and ft .async_ and not opts .async_ )
2157+ trap_if (not thread .task .inst . may_block and ft .async_ and not opts .async_ )
21282158
21292159 subtask = Subtask ()
21302160 cx = LiftLowerContext (opts , thread .task .inst , subtask )
@@ -2304,7 +2334,7 @@ def canon_waitable_set_new(thread):
23042334
23052335def canon_waitable_set_wait (cancellable , mem , thread , si , ptr ):
23062336 trap_if (not thread .task .inst .may_leave )
2307- trap_if (not thread .task .may_block () )
2337+ trap_if (not thread .task .inst . may_block )
23082338 wset = thread .task .inst .handles .get (si )
23092339 trap_if (not isinstance (wset , WaitableSet ))
23102340 event = wset .wait (thread , cancellable )
@@ -2355,7 +2385,7 @@ def canon_waitable_join(thread, wi, si):
23552385
23562386def canon_subtask_cancel (async_ , thread , i ):
23572387 trap_if (not thread .task .inst .may_leave )
2358- trap_if (not thread .task .may_block () and not async_ )
2388+ trap_if (not thread .task .inst . may_block and not async_ )
23592389 subtask = thread .task .inst .handles .get (i )
23602390 trap_if (not isinstance (subtask , Subtask ))
23612391 trap_if (subtask .resolve_delivered ())
@@ -2412,7 +2442,7 @@ def canon_stream_write(stream_t, opts, thread, i, ptr, n):
24122442
24132443def stream_copy (EndT , BufferT , event_code , stream_t , opts , thread , i , ptr , n ):
24142444 trap_if (not thread .task .inst .may_leave )
2415- trap_if (not thread .task .may_block () and not opts .async_ )
2445+ trap_if (not thread .task .inst . may_block and not opts .async_ )
24162446
24172447 e = thread .task .inst .handles .get (i )
24182448 trap_if (not isinstance (e , EndT ))
@@ -2466,7 +2496,7 @@ def canon_future_write(future_t, opts, thread, i, ptr):
24662496
24672497def future_copy (EndT , BufferT , event_code , future_t , opts , thread , i , ptr ):
24682498 trap_if (not thread .task .inst .may_leave )
2469- trap_if (not thread .task .may_block () and not opts .async_ )
2499+ trap_if (not thread .task .inst . may_block and not opts .async_ )
24702500
24712501 e = thread .task .inst .handles .get (i )
24722502 trap_if (not isinstance (e , EndT ))
@@ -2518,7 +2548,7 @@ def canon_future_cancel_write(future_t, async_, thread, i):
25182548
25192549def cancel_copy (EndT , event_code , stream_or_future_t , async_ , thread , i ):
25202550 trap_if (not thread .task .inst .may_leave )
2521- trap_if (not thread .task .may_block () and not async_ )
2551+ trap_if (not thread .task .inst . may_block and not async_ )
25222552 e = thread .task .inst .handles .get (i )
25232553 trap_if (not isinstance (e , EndT ))
25242554 trap_if (e .shared .t != stream_or_future_t .t )
@@ -2595,7 +2625,7 @@ def canon_thread_switch_to(cancellable, thread, i):
25952625
25962626def canon_thread_suspend (cancellable , thread ):
25972627 trap_if (not thread .task .inst .may_leave )
2598- trap_if (not thread .task .may_block () )
2628+ trap_if (not thread .task .inst . may_block )
25992629 cancelled = thread .suspend (cancellable )
26002630 return [cancelled ]
26012631
0 commit comments