@@ -188,6 +188,7 @@ class ComponentInstance:
188188 parent : Optional [ComponentInstance ]
189189 handles : Table [ResourceHandle | Waitable | WaitableSet | ErrorContext ]
190190 threads : Table [Thread ]
191+ may_enter : bool
191192 may_leave : bool
192193 backpressure : int
193194 num_waiting_to_enter : int
@@ -199,41 +200,36 @@ def __init__(self, store, parent = None):
199200 self .parent = parent
200201 self .handles = Table ()
201202 self .threads = Table ()
203+ self .may_enter = True
202204 self .may_leave = True
203205 self .backpressure = 0
204206 self .num_waiting_to_enter = 0
205207 self .exclusive = None
206208
207- def reflexive_ancestors (self ) -> set [ComponentInstance ]:
209+ def enter_from (self , caller : Optional [ComponentInstance ]):
210+ for inst in self .entering (caller ):
211+ trap_if (not inst .may_enter )
212+ inst .may_enter = False
213+
214+ def leave_to (self , caller : Optional [ComponentInstance ]):
215+ for inst in self .entering (caller ):
216+ assert (not inst .may_enter )
217+ inst .may_enter = True
218+
219+ def entering (self , caller : Optional [ComponentInstance ]):
220+ if caller :
221+ return self .self_and_ancestors () - caller .self_and_ancestors ()
222+ else :
223+ return self .self_and_ancestors ()
224+
225+ def self_and_ancestors (self ) -> set [ComponentInstance ]:
208226 s = set ()
209227 inst = self
210228 while inst is not None :
211229 s .add (inst )
212230 inst = inst .parent
213231 return s
214232
215- def is_reflexive_ancestor_of (self , other ):
216- while other is not None :
217- if self is other :
218- return True
219- other = other .parent
220- return False
221-
222- class Supertask :
223- inst : Optional [ComponentInstance ]
224- supertask : Optional [Supertask ]
225-
226- def call_might_be_recursive (caller : Supertask , callee_inst : ComponentInstance ):
227- if caller .inst is None :
228- while caller is not None :
229- if caller .inst and caller .inst .reflexive_ancestors () & callee_inst .reflexive_ancestors ():
230- return True
231- caller = caller .supertask
232- return False
233- else :
234- return (caller .inst .is_reflexive_ancestor_of (callee_inst ) or
235- callee_inst .is_reflexive_ancestor_of (caller .inst ))
236-
237233## Concurrency
238234
239235### Stack Switching
@@ -415,9 +411,9 @@ def yield_to(self, cancellable, other: Thread) -> Cancelled:
415411OnStart = Callable [[], list [any ]]
416412OnResolve = Callable [[Optional [list [any ]]], None ]
417413OnCancel = Callable [[], None ]
418- FuncInst = Callable [[Supertask , OnStart , OnResolve ], OnCancel ]
414+ FuncInst = Callable [[OnStart , OnResolve , Optional [ ComponentInstance ] ], OnCancel ]
419415
420- class Task ( Supertask ) :
416+ class Task :
421417 class State (Enum ):
422418 INITIAL = 1
423419 STARTED = 2
@@ -428,19 +424,17 @@ class State(Enum):
428424 ft : FuncType
429425 opts : CanonicalOptions
430426 inst : ComponentInstance
431- supertask : Supertask
432427 on_start : OnStart
433428 on_resolve : OnResolve
434429 state : State
435430 num_borrows : int
436431 waiting_to_enter : Optional [Thread ]
437432 threads : list [Thread ]
438433
439- def __init__ (self , ft , opts , inst , supertask , on_start , on_resolve ):
434+ def __init__ (self , ft , opts , inst , on_start , on_resolve ):
440435 self .ft = ft
441436 self .opts = opts
442437 self .inst = inst
443- self .supertask = supertask
444438 self .on_start = on_start
445439 self .on_resolve = on_resolve
446440 self .state = Task .State .INITIAL
@@ -541,34 +535,42 @@ def cancel(self):
541535
542536class Store :
543537 waiting : list [Thread ]
538+ nesting_depth : int
544539
545540 def __init__ (self ):
546541 self .waiting = []
542+ self .nesting_depth = 0
547543
548- def invoke (self , f : FuncInst , caller : Optional [ Supertask ], on_start , on_resolve ) -> OnCancel :
549- host_caller = Supertask ()
550- host_caller . inst = None
551- host_caller . supertask = caller
552- return f ( host_caller , on_start , on_resolve )
544+ def invoke (self , f : FuncInst , on_start : OnStart , on_resolve : OnResolve ) -> OnCancel :
545+ self . nesting_depth += 1
546+ on_cancel = f ( on_start , on_resolve , caller = None )
547+ self . nesting_depth -= 1
548+ return on_cancel
553549
554550 def lift (self , f : CoreFuncInst , ft : FuncType , opts : CanonicalOptions , inst : ComponentInstance ) -> FuncInst :
555- def func_inst (caller : Supertask , on_start : OnStart , on_resolve : OnResolve ) -> OnCancel :
556- trap_if (call_might_be_recursive (caller , inst ))
557- return canon_lift (f , ft , opts , inst , caller , on_start , on_resolve )
551+ def func_inst (on_start : OnStart , on_resolve : OnResolve , caller : Optional [ComponentInstance ]) -> OnCancel :
552+ inst .enter_from (caller )
553+ on_cancel = canon_lift (f , ft , opts , inst , on_start , on_resolve )
554+ inst .leave_to (caller )
555+ return on_cancel
558556 return func_inst
559557
560558 def lower (self , f : FuncInst , ft : FuncType , opts : CanonicalOptions , inst : ComponentInstance ) -> CoreFuncInst :
561559 def core_func_inst (args : list [CoreValType ]) -> list [CoreValType ]:
562- assert (current_instance () is inst )
560+ assert (current_instance () is inst and self . nesting_depth > 0 )
563561 return canon_lower (f , ft , opts , args )
564562 return core_func_inst
565563
566564 def tick (self ):
567- random .shuffle (self .waiting )
568- for thread in self .waiting :
569- if thread .ready ():
570- thread .resume ()
571- return
565+ assert (self .nesting_depth == 0 )
566+ candidates = { t for t in self .waiting if t .ready () }
567+ if candidates :
568+ thread = random .choice (list (candidates ))
569+ self .nesting_depth += 1
570+ thread .task .inst .enter_from (None )
571+ thread .resume ()
572+ thread .task .inst .leave_to (None )
573+ self .nesting_depth -= 1
572574
573575## Lifting and Lowering Context
574576
@@ -2072,7 +2074,7 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
20722074
20732075### `canon lift`
20742076
2075- def canon_lift (callee , ft , opts , inst , caller , on_start , on_resolve ) -> OnCancel :
2077+ def canon_lift (callee , ft , opts , inst , on_start , on_resolve ) -> OnCancel :
20762078 def thread_func ():
20772079 if not task .enter_implicit_thread ():
20782080 return
@@ -2128,7 +2130,7 @@ def thread_func():
21282130 task .exit_implicit_thread ()
21292131 return
21302132
2131- task = Task (ft , opts , inst , caller , on_start , on_resolve )
2133+ task = Task (ft , opts , inst , on_start , on_resolve )
21322134 thread = Thread (task , thread_func )
21332135 thread .resume ()
21342136 return task .request_cancellation
@@ -2198,7 +2200,7 @@ def on_resolve(result):
21982200 nonlocal flat_results
21992201 flat_results = lower_flat_values (cx , max_flat_results , result , ft .result_type (), flat_args )
22002202
2201- subtask .on_cancel = callee (thread . task , on_start , on_resolve )
2203+ subtask .on_cancel = callee (on_start , on_resolve , caller = thread . task . inst )
22022204 assert (ft .async_ or subtask .state == Subtask .State .RETURNED )
22032205
22042206 if not opts .async_ :
@@ -2244,17 +2246,13 @@ def canon_resource_drop(rt, i):
22442246 trap_if (h .num_lends != 0 )
22452247 if h .own :
22462248 assert (h .borrow_scope is None )
2247- if inst is rt .impl :
2248- if rt .dtor :
2249- rt .dtor (h .rep )
2250- else :
2251- caller_opts = CanonicalOptions (async_ = False )
2252- callee_opts = CanonicalOptions (async_ = rt .dtor_async , callback = rt .dtor_callback )
2253- ft = FuncType ([U32Type ()],[], async_ = False )
2254- dtor = rt .dtor or (lambda rep : [])
2255- callee = inst .store .lift (dtor , ft , callee_opts , rt .impl )
2256- caller = inst .store .lower (callee , ft , caller_opts , inst )
2257- caller ([h .rep ])
2249+ caller_opts = CanonicalOptions (async_ = False )
2250+ callee_opts = CanonicalOptions (async_ = rt .dtor_async , callback = rt .dtor_callback )
2251+ ft = FuncType ([U32Type ()], [], async_ = False )
2252+ dtor = rt .dtor or (lambda rep : [])
2253+ callee = inst .store .lift (dtor , ft , callee_opts , rt .impl )
2254+ caller = inst .store .lower (callee , ft , caller_opts , inst )
2255+ caller ([h .rep ])
22582256 else :
22592257 h .borrow_scope .num_borrows -= 1
22602258 return []
0 commit comments