@@ -221,22 +221,17 @@ class CanonicalOptions(LiftLowerOptions):
221221class ComponentInstance :
222222 table : Table
223223 may_leave : bool
224- backpressure : bool
225- calling_sync_export : bool
226- calling_sync_import : bool
227- pending_tasks : list [tuple [Task , asyncio .Future ]]
228- starting_pending_task : bool
229- async_waiting_tasks : asyncio .Condition
224+ no_backpressure : asyncio .Event
225+ num_backpressure_waiters : int
226+ lock : asyncio .Lock
230227
231228 def __init__ (self ):
232229 self .table = Table ()
233230 self .may_leave = True
234- self .backpressure = False
235- self .calling_sync_export = False
236- self .calling_sync_import = False
237- self .pending_tasks = []
238- self .starting_pending_task = False
239- self .async_waiting_tasks = asyncio .Condition (scheduler )
231+ self .no_backpressure = asyncio .Event ()
232+ self .no_backpressure .set ()
233+ self .num_backpressure_waiters = 0
234+ self .lock = asyncio .Lock ()
240235
241236#### Table State
242237
@@ -464,7 +459,7 @@ class Cancelled(IntEnum):
464459
465460OnStart = Callable [[], list [any ]]
466461OnResolve = Callable [[Optional [list [any ]]], None ]
467- OnBlock = Callable [[Awaitable ], Awaitable [Cancelled ]]
462+ OnBlock = Callable [[asyncio . Future ], Awaitable [Cancelled ]]
468463
469464class Task :
470465 class State (Enum ):
@@ -497,67 +492,64 @@ def __init__(self, opts, inst, ft, supertask, on_resolve, on_block):
497492 async def enter (self ):
498493 assert (scheduler .locked ())
499494 self .trap_if_on_the_stack (self .inst )
500- if not self .may_enter (self ) or self .inst .pending_tasks :
501- f = asyncio .Future ()
502- self .inst .pending_tasks .append ((self , f ))
503- if await self .on_block (f ) == Cancelled .TRUE :
504- [i ] = [i for i ,(t ,_ ) in enumerate (self .inst .pending_tasks ) if t == self ]
505- self .inst .pending_tasks .pop (i )
506- self .on_resolve (None )
507- return Cancelled .FALSE
508- assert (self .may_enter (self ) and self .inst .starting_pending_task )
509- self .inst .starting_pending_task = False
510- if self .opts .sync :
511- self .inst .calling_sync_export = True
512- return True
495+ if self .opts .sync or self .opts .callback :
496+ if self .inst .lock .locked ():
497+ acquired = asyncio .create_task (self .inst .lock .acquire ())
498+ cancelled = await self .wait_on (acquired , cancellable = True , for_callback = False )
499+ if cancelled :
500+ if acquired .done ():
501+ self .inst .lock .release ()
502+ else :
503+ acquired .cancel ()
504+ return Cancelled .TRUE
505+ else :
506+ await self .inst .lock .acquire ()
507+ if not self .inst .no_backpressure .is_set () or self .inst .num_backpressure_waiters > 0 :
508+ while True :
509+ self .inst .num_backpressure_waiters += 1
510+ maybe_go = self .inst .no_backpressure .wait ()
511+ cancelled = await self .wait_on (maybe_go , cancellable = True , for_callback = False )
512+ self .inst .num_backpressure_waiters -= 1
513+ if cancelled :
514+ return Cancelled .TRUE
515+ if self .inst .no_backpressure .is_set ():
516+ break
517+ return Cancelled .FALSE
513518
514519 def trap_if_on_the_stack (self , inst ):
515520 c = self .supertask
516521 while c is not None :
517522 trap_if (c .inst is inst )
518523 c = c .supertask
519524
520- def may_enter (self , pending_task ):
521- return not self .inst .backpressure and \
522- not self .inst .calling_sync_import and \
523- not (self .inst .calling_sync_export and pending_task .opts .sync )
524-
525- def maybe_start_pending_task (self ):
526- if self .inst .starting_pending_task :
527- return
528- for i ,(pending_task ,pending_future ) in enumerate (self .inst .pending_tasks ):
529- if self .may_enter (pending_task ):
530- self .inst .pending_tasks .pop (i )
531- self .inst .starting_pending_task = True
532- pending_future .set_result (None )
533- return
525+ async def wait_on (self , awaitable , cancellable = False , for_callback = False ) -> Cancelled :
526+ f = asyncio .ensure_future (awaitable )
527+ if f .done () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
528+ return Cancelled .FALSE
534529
535- async def wait_on (self , awaitable , sync , cancellable = False ) -> bool :
536- if sync :
537- assert (not self .inst .calling_sync_import )
538- self .inst .calling_sync_import = True
539- else :
540- self .maybe_start_pending_task ()
530+ if for_callback :
531+ self .inst .lock .release ()
541532
542- awaitable = asyncio .ensure_future (awaitable )
543- if awaitable .done () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
544- cancelled = Cancelled .FALSE
545- else :
546- cancelled = await self .on_block (awaitable )
547- if cancelled and not cancellable :
548- assert (self .state == Task .State .INITIAL )
549- self .state = Task .State .PENDING_CANCEL
550- cancelled = await self .on_block (awaitable )
551- assert (not cancelled )
533+ cancelled = await self .on_block (f )
534+ if cancelled and not cancellable :
535+ assert (await self .on_block (f ) == Cancelled .FALSE )
552536
553- if sync :
554- self .inst .calling_sync_import = False
555- self .inst .async_waiting_tasks .notify_all ()
556- else :
557- while self .inst .calling_sync_import :
558- await self .inst .async_waiting_tasks .wait ()
537+ if for_callback :
538+ acquired = asyncio .create_task (self .inst .lock .acquire ())
539+ cancelled |= await self .on_block (acquired )
540+ if cancelled :
541+ assert (self .on_block (acquired ) == Cancelled .FALSE )
559542
560- return cancelled
543+ if cancelled :
544+ assert (self .state == Task .State .INITIAL )
545+ if not cancellable :
546+ self .state = Task .State .PENDING_CANCEL
547+ return Cancelled .FALSE
548+ else :
549+ self .state = Task .State .CANCEL_DELIVERED
550+ return Cancelled .TRUE
551+ else :
552+ return Cancelled .FALSE
561553
562554 async def call_sync (self , callee , on_start , on_return ):
563555 async def sync_on_block (awaitable ):
@@ -567,42 +559,36 @@ async def sync_on_block(awaitable):
567559 assert (await self .on_block (awaitable ) == Cancelled .FALSE )
568560 return Cancelled .FALSE
569561
570- assert (not self .inst .calling_sync_import )
571- self .inst .calling_sync_import = True
572562 await callee (self , on_start , on_return , sync_on_block )
573- self .inst .calling_sync_import = False
574- self .inst .async_waiting_tasks .notify_all ()
575563
576- async def wait_for_event (self , waitable_set , sync ) -> EventTuple :
577- if self .state == Task .State .PENDING_CANCEL :
564+ async def wait_for_event (self , waitable_set , cancellable , for_callback ) -> EventTuple :
565+ if self .state == Task .State .PENDING_CANCEL and cancellable :
578566 self .state = Task .State .CANCEL_DELIVERED
579567 return (EventCode .TASK_CANCELLED , 0 , 0 )
580568 else :
581569 waitable_set .num_waiting += 1
582570 e = None
583571 while not e :
584572 maybe_event = waitable_set .maybe_has_pending_event .wait ()
585- if await self .wait_on (maybe_event , sync , cancellable = True ):
586- assert (self .state == Task .State .INITIAL )
587- self .state = Task .State .CANCEL_DELIVERED
573+ if await self .wait_on (maybe_event , cancellable , for_callback ) == Cancelled .TRUE :
588574 return (EventCode .TASK_CANCELLED , 0 , 0 )
589575 e = waitable_set .poll ()
590576 waitable_set .num_waiting -= 1
591577 return e
592578
593- async def yield_ (self , sync ) -> EventTuple :
594- if self .state == Task .State .PENDING_CANCEL :
579+ async def yield_ (self , cancellable , for_callback ) -> EventTuple :
580+ if self .state == Task .State .PENDING_CANCEL and cancellable :
595581 self .state = Task .State .CANCEL_DELIVERED
596582 return (EventCode .TASK_CANCELLED , 0 , 0 )
597- elif await self .wait_on (asyncio .sleep (0 ), sync , cancellable = True ):
598- assert (self .state == Task .State .INITIAL )
599- self .state = Task .State .CANCEL_DELIVERED
583+ elif await self .wait_on (asyncio .sleep (0 ), cancellable , for_callback ) == Cancelled .TRUE :
600584 return (EventCode .TASK_CANCELLED , 0 , 0 )
601585 else :
602586 return (EventCode .NONE , 0 , 0 )
603587
604- async def poll_for_event (self , waitable_set , sync ) -> Optional [EventTuple ]:
605- event_code ,_ ,_ = e = await self .yield_ (sync )
588+ async def poll_for_event (self , waitable_set , cancellable , for_callback ) -> Optional [EventTuple ]:
589+ waitable_set .num_waiting += 1
590+ event_code ,_ ,_ = e = await self .yield_ (cancellable , for_callback )
591+ waitable_set .num_waiting -= 1
606592 if event_code == EventCode .TASK_CANCELLED :
607593 return e
608594 elif (e := waitable_set .poll ()):
@@ -624,13 +610,10 @@ def cancel(self):
624610 self .state = Task .State .RESOLVED
625611
626612 def exit (self ):
627- assert (scheduler .locked ())
628613 trap_if (self .state != Task .State .RESOLVED )
629614 assert (self .num_borrows == 0 )
630- if self .opts .sync :
631- assert (self .inst .calling_sync_export )
632- self .inst .calling_sync_export = False
633- self .maybe_start_pending_task ()
615+ if self .opts .sync or self .opts .callback :
616+ self .inst .lock .release ()
634617
635618#### Subtask State
636619
@@ -1932,7 +1915,9 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
19321915
19331916async def canon_lift (opts , inst , ft , callee , caller , on_start , on_resolve , on_block ):
19341917 task = Task (opts , inst , ft , caller , on_resolve , on_block )
1935- if not await task .enter ():
1918+ if await task .enter () == Cancelled .TRUE :
1919+ task .cancel ()
1920+ task .exit ()
19361921 return
19371922
19381923 cx = LiftLowerContext (opts , inst , task )
@@ -1967,15 +1952,15 @@ async def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve, on_bl
19671952 task .exit ()
19681953 return
19691954 case CallbackCode .YIELD :
1970- e = await task .yield_ (sync = False )
1955+ e = await task .yield_ (cancellable = True , for_callback = True )
19711956 case CallbackCode .WAIT :
19721957 s = task .inst .table .get (si )
19731958 trap_if (not isinstance (s , WaitableSet ))
1974- e = await task .wait_for_event (s , sync = False )
1959+ e = await task .wait_for_event (s , cancellable = True , for_callback = True )
19751960 case CallbackCode .POLL :
19761961 s = task .inst .table .get (si )
19771962 trap_if (not isinstance (s , WaitableSet ))
1978- e = await task .poll_for_event (s , sync = False )
1963+ e = await task .poll_for_event (s , cancellable = True , for_callback = True )
19791964 event_code , p1 , p2 = e
19801965 [packed ] = await call_and_trap_on_throw (opts .callback , task , [event_code , p1 , p2 ])
19811966
@@ -2114,8 +2099,11 @@ async def canon_context_set(t, i, task, v):
21142099### 🔀 `canon backpressure.set`
21152100
21162101async def canon_backpressure_set (task , flat_args ):
2117- trap_if (task .opts .sync )
2118- task .inst .backpressure = bool (flat_args [0 ])
2102+ assert (len (flat_args ) == 1 )
2103+ if flat_args [0 ] == 0 :
2104+ task .inst .no_backpressure .set ()
2105+ else :
2106+ task .inst .no_backpressure .clear ()
21192107 return []
21202108
21212109### 🔀 `canon task.return`
@@ -2140,9 +2128,9 @@ async def canon_task_cancel(task):
21402128
21412129### 🔀 `canon yield`
21422130
2143- async def canon_yield (sync , task ):
2131+ async def canon_yield (cancellable , task ):
21442132 trap_if (not task .inst .may_leave )
2145- event_code ,_ ,_ = await task .yield_ (sync )
2133+ event_code ,_ ,_ = await task .yield_ (cancellable , for_callback = False )
21462134 match event_code :
21472135 case EventCode .NONE :
21482136 return [0 ]
@@ -2157,11 +2145,11 @@ async def canon_waitable_set_new(task):
21572145
21582146### 🔀 `canon waitable-set.wait`
21592147
2160- async def canon_waitable_set_wait (sync , mem , task , si , ptr ):
2148+ async def canon_waitable_set_wait (cancellable , mem , task , si , ptr ):
21612149 trap_if (not task .inst .may_leave )
21622150 s = task .inst .table .get (si )
21632151 trap_if (not isinstance (s , WaitableSet ))
2164- e = await task .wait_for_event (s , sync )
2152+ e = await task .wait_for_event (s , cancellable , for_callback = False )
21652153 return unpack_event (mem , task , ptr , e )
21662154
21672155def unpack_event (mem , task , ptr , e : EventTuple ):
@@ -2173,11 +2161,11 @@ def unpack_event(mem, task, ptr, e: EventTuple):
21732161
21742162### 🔀 `canon waitable-set.poll`
21752163
2176- async def canon_waitable_set_poll (sync , mem , task , si , ptr ):
2164+ async def canon_waitable_set_poll (cancellable , mem , task , si , ptr ):
21772165 trap_if (not task .inst .may_leave )
21782166 s = task .inst .table .get (si )
21792167 trap_if (not isinstance (s , WaitableSet ))
2180- e = await task .poll_for_event (s , sync )
2168+ e = await task .poll_for_event (s , cancellable , for_callback = False )
21812169 return unpack_event (mem , task , ptr , e )
21822170
21832171### 🔀 `canon waitable-set.drop`
@@ -2220,7 +2208,7 @@ async def canon_subtask_cancel(sync, task, i):
22202208 while not subtask .resolved ():
22212209 if subtask .has_pending_event ():
22222210 _ = subtask .get_event ()
2223- await task .wait_on (subtask .wait_for_pending_event (), sync = True )
2211+ await task .wait_on (subtask .wait_for_pending_event ())
22242212 else :
22252213 if not subtask .resolved ():
22262214 return [BLOCKED ]
@@ -2296,7 +2284,7 @@ def on_copy_done(result):
22962284 e .copy (task .inst , buffer , on_copy , on_copy_done )
22972285
22982286 if opts .sync and not e .has_pending_event ():
2299- await task .wait_on (e .wait_for_pending_event (), sync = True )
2287+ await task .wait_on (e .wait_for_pending_event ())
23002288
23012289 if e .has_pending_event ():
23022290 code ,index ,payload = e .get_event ()
@@ -2342,7 +2330,7 @@ def on_copy_done(result):
23422330 e .copy (task .inst , buffer , on_copy_done )
23432331
23442332 if opts .sync and not e .has_pending_event ():
2345- await task .wait_on (e .wait_for_pending_event (), sync = True )
2333+ await task .wait_on (e .wait_for_pending_event ())
23462334
23472335 if e .has_pending_event ():
23482336 code ,index ,payload = e .get_event ()
@@ -2375,7 +2363,7 @@ async def cancel_copy(EndT, event_code, stream_or_future_t, sync, task, i):
23752363 e .shared .cancel ()
23762364 if not e .has_pending_event ():
23772365 if sync :
2378- await task .wait_on (e .wait_for_pending_event (), sync = True )
2366+ await task .wait_on (e .wait_for_pending_event ())
23792367 else :
23802368 return [BLOCKED ]
23812369 code ,index ,payload = e .get_event ()
0 commit comments