@@ -215,21 +215,28 @@ def tick(self):
215215
216216class Thread :
217217 task : Task
218+ index : int
218219 future : Optional [asyncio .Future ]
219220 on_resume : Optional [asyncio .Future ]
220221 on_suspend_or_exit : Optional [asyncio .Future ]
222+ context : list [int ]
223+
224+ CONTEXT_LENGTH = 2
221225
222226 def __init__ (self , task , thread_func ):
223227 self .task = task
228+ self .index = task .inst .table .add (self )
224229 self .future = None
225230 self .on_resume = asyncio .Future ()
226231 self .on_suspend_or_exit = None
232+ self .context = [0 ] * Thread .CONTEXT_LENGTH
227233 async def thread_start ():
228234 await self .on_resume
229235 self .on_resume = None
230236 await thread_func (task , self )
231237 self .on_suspend_or_exit .set_result (None )
232- self .task .thread = None
238+ self .task .inst .table .remove (self .index )
239+ self .task .thread_return (self )
233240 asyncio .create_task (thread_start ())
234241
235242 async def resume (self ):
@@ -254,6 +261,29 @@ async def suspend(self, future):
254261 await self .on_resume
255262 self .on_resume = None
256263
264+ async def switch (self , other : Thread ):
265+ assert (not self .future and not other .future )
266+ assert (self .on_suspend_or_exit and not other .on_suspend_or_exit )
267+ other .on_suspend_or_exit = self .on_suspend_or_exit
268+ self .on_suspend_or_exit = None
269+ other .on_resume .set_result (Cancelled .FALSE )
270+ assert (not self .on_resume )
271+ self .on_resume = asyncio .Future ()
272+ await self .on_resume
273+ self .on_resume = None
274+
275+ def yield_to (self , other : Thread ):
276+ # deterministically switch to other, but leave this thread unblocked
277+ TODO
278+
279+ def block (self ):
280+ # perform just the first half of switch
281+ TODO
282+
283+ def unblock (self , other : Thread ):
284+ # unblock other, but deterministically keep running here
285+ TODO
286+
257287
258288### Lifting and Lowering Context
259289
@@ -431,22 +461,6 @@ def write(self, vs):
431461 assert (all (v == () for v in vs ))
432462 self .progress += len (vs )
433463
434- #### Context-Local Storage
435-
436- class ContextLocalStorage :
437- LENGTH = 1
438- array : list [int ]
439-
440- def __init__ (self ):
441- self .array = [0 ] * ContextLocalStorage .LENGTH
442-
443- def set (self , i , v ):
444- assert (types_match_values (['i32' ], [v ]))
445- self .array [i ] = v
446-
447- def get (self , i ):
448- return self .array [i ]
449-
450464#### Waitable State
451465
452466class EventCode (IntEnum ):
@@ -457,6 +471,7 @@ class EventCode(IntEnum):
457471 FUTURE_READ = 4
458472 FUTURE_WRITE = 5
459473 TASK_CANCELLED = 6
474+ THREAD_RESUMED = 7
460475
461476EventTuple = tuple [EventCode , int , int ]
462477
@@ -543,10 +558,9 @@ class State(Enum):
543558 ft : FuncType
544559 supertask : Optional [Task ]
545560 on_resolve : OnResolve
546- thread : Thread
561+ threads : list [ Thread ]
547562 cancellable : bool
548563 num_borrows : int
549- context : ContextLocalStorage
550564
551565 def __init__ (self , opts , inst , ft , supertask , on_resolve , thread_func ):
552566 self .state = Task .State .INITIAL
@@ -555,10 +569,13 @@ def __init__(self, opts, inst, ft, supertask, on_resolve, thread_func):
555569 self .ft = ft
556570 self .supertask = supertask
557571 self .on_resolve = on_resolve
558- self .thread = Thread (self , thread_func )
572+ self .threads = [ Thread (self , thread_func )]
559573 self .cancellable = False
560574 self .num_borrows = 0
561- self .context = ContextLocalStorage ()
575+
576+ async def start (self ):
577+ assert (len (self .threads ) == 1 )
578+ await self .threads [0 ].resume ()
562579
563580 def trap_if_on_the_stack (self , inst ):
564581 c = self .supertask
@@ -569,8 +586,10 @@ def trap_if_on_the_stack(self, inst):
569586 async def request_cancellation (self ):
570587 assert (self .state == Task .State .INITIAL )
571588 self .state = Task .State .PENDING_CANCEL
589+ # TODO: move cancellability to the Thread and then search
590+ # for a cancellable one here...
572591 if self .cancellable :
573- await self .thread .resume ()
592+ await self .threads [ 0 ] .resume ()
574593
575594 def deliver_cancel (self ) -> bool :
576595 if self .state == Task .State .PENDING_CANCEL :
@@ -583,7 +602,7 @@ def needs_lock(self):
583602 return self .opts .sync or self .opts .callback
584603
585604 async def enter (self , thread ):
586- assert (thread is self .thread and thread .task is self )
605+ assert (thread in self .threads and thread .task is self )
587606 if (self .inst .no_backpressure .is_set () and
588607 self .inst .num_pending == 0 and
589608 (not self .needs_lock () or not self .inst .lock .locked ())):
@@ -599,6 +618,7 @@ async def enter(self, thread):
599618 self .inst .num_pending -= 1
600619 if self .deliver_cancel ():
601620 self .on_resolve (None )
621+ self .state = Task .State .RESOLVED
602622 return False
603623 if not self .inst .no_backpressure .is_set ():
604624 continue
@@ -611,6 +631,7 @@ async def enter(self, thread):
611631 else :
612632 acquired .cancel ()
613633 self .on_resolve (None )
634+ self .state = Task .State .RESOLVED
614635 return False
615636 if not self .inst .no_backpressure .is_set ():
616637 self .inst .lock .release ()
@@ -619,7 +640,7 @@ async def enter(self, thread):
619640 return True
620641
621642 async def block_on (self , thread , awaitable , cancellable = False , unlock = False ) -> Cancelled :
622- assert (thread is self .thread and thread .task is self )
643+ assert (thread in self .threads and thread .task is self )
623644 f = asyncio .ensure_future (awaitable )
624645 if f .done () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
625646 return
@@ -633,12 +654,13 @@ async def block_on(self, thread, awaitable, cancellable = False, unlock = False)
633654 await thread .suspend (asyncio .create_task (self .inst .lock .acquire ()))
634655
635656 async def wait_for_event (self , thread , waitable_set , cancellable , unlock ) -> EventTuple :
636- assert (thread is self .thread and thread .task is self )
657+ assert (thread in self .threads and thread .task is self )
637658 if cancellable and self .deliver_cancel ():
638659 return (EventCode .TASK_CANCELLED , 0 , 0 )
639660 waitable_set .num_waiting += 1
640661 e = None
641662 while not e :
663+ # TODO: somehow get a THREAD_RESUME event...
642664 maybe_event = waitable_set .maybe_has_pending_event .wait ()
643665 await self .block_on (thread , maybe_event , cancellable , unlock )
644666 if self .deliver_cancel ():
@@ -648,16 +670,17 @@ async def wait_for_event(self, thread, waitable_set, cancellable, unlock) -> Eve
648670 return e
649671
650672 async def yield_ (self , thread , cancellable , unlock ) -> EventTuple :
651- assert (thread is self .thread and thread .task is self )
673+ assert (thread in self .threads and thread .task is self )
652674 if cancellable and self .deliver_cancel ():
653675 return (EventCode .TASK_CANCELLED , 0 , 0 )
676+ # TODO: somehow get a THREAD_RESUME event...
654677 await self .block_on (thread , asyncio .sleep (0 ), cancellable , unlock )
655678 if cancellable and self .deliver_cancel ():
656679 return (EventCode .TASK_CANCELLED , 0 , 0 )
657680 return (EventCode .NONE , 0 , 0 )
658681
659682 async def poll_for_event (self , thread , waitable_set , cancellable , unlock ) -> Optional [EventTuple ]:
660- assert (thread is self .thread and thread .task is self )
683+ assert (thread in self .threads and thread .task is self )
661684 waitable_set .num_waiting += 1
662685 event_code ,_ ,_ = e = await self .yield_ (thread , cancellable , unlock )
663686 waitable_set .num_waiting -= 1
@@ -682,11 +705,16 @@ def cancel(self):
682705 self .state = Task .State .RESOLVED
683706
684707 def exit (self ):
685- trap_if (self .state != Task .State .RESOLVED )
686- assert (self .num_borrows == 0 )
687708 if self .needs_lock ():
688709 self .inst .lock .release ()
689710
711+ def thread_return (self , thread ):
712+ assert (thread in self .threads and thread .task is self )
713+ self .threads .remove (thread )
714+ if len (self .threads ) == 0 :
715+ trap_if (self .state != Task .State .RESOLVED )
716+ assert (self .num_borrows == 0 )
717+
690718#### Subtask State
691719
692720class Subtask (Waitable ):
@@ -1965,7 +1993,7 @@ async def thread_func(task, thread):
19651993 [packed ] = await call_and_trap_on_throw (opts .callback , thread , [event_code , p1 , p2 ])
19661994
19671995 task = Task (opts , inst , ft , caller , on_resolve , thread_func )
1968- await task .thread . resume ()
1996+ await task .start ()
19691997 return task
19701998
19711999class CallbackCode (IntEnum ):
@@ -2103,25 +2131,76 @@ async def canon_resource_rep(rt, thread, i):
21032131 trap_if (h .rt is not rt )
21042132 return [h .rep ]
21052133
2134+ ### 🧵 `canon thread.index`
2135+
2136+ async def canon_thread_index (shared , thread ):
2137+ assert (not shared )
2138+ return [thread .index ]
2139+
2140+ ### 🧵 `canon thread.new_indirect`
2141+
2142+ async def canon_thread_new_indirect (ft , ftbl , thread , i , c ):
2143+ trap_if (not thread .task .inst .may_leave )
2144+ f = thread .task .inst .ftbl .get (i )
2145+ trap_if (f .type != ft )
2146+ thread = Thread (thread .task , f (c ))
2147+ return [thread .index ]
2148+
2149+ ### 🧵 `canon thread.switch`
2150+
2151+ async def canon_thread_switch (thread , i ):
2152+ trap_if (not thread .task .inst .may_leave )
2153+ other = thread .task .inst .table .get (i )
2154+ trap_if (not isinstance (other , Thread ))
2155+ cancelled = await thread .switch (other )
2156+ return [ 1 if cancelled else 0 ]
2157+
2158+ ### 🧵 `canon thread.yield-to`
2159+
2160+ async def canon_thread_yield_to (thread , i ):
2161+ trap_if (not thread .task .inst .may_leave )
2162+ other = thread .task .inst .table .get (i )
2163+ trap_if (not isinstance (other , Thread ))
2164+ other .yield_to (other )
2165+ return []
2166+
2167+ ### 🧵 `canon thread.block`
2168+
2169+ async def canon_thread_block (thread , i ):
2170+ trap_if (not thread .task .inst .may_leave )
2171+ other = thread .task .inst .table .get (i )
2172+ trap_if (not isinstance (other , Thread ))
2173+ cancelled = await thread .block ()
2174+ return [ 1 if cancelled else 0 ]
2175+
2176+ ### 🧵 `canon thread.unblock`
2177+
2178+ async def canon_thread_unblock (thread , i ):
2179+ trap_if (not thread .task .inst .may_leave )
2180+ other = thread .task .inst .table .get (i )
2181+ trap_if (not isinstance (other , Thread ))
2182+ thread .unblock ()
2183+ return []
2184+
21062185### 🔀 `canon context.get`
21072186
21082187async def canon_context_get (t , i , thread ):
21092188 assert (t == 'i32' )
2110- assert (i < ContextLocalStorage . LENGTH )
2111- return [thread .task . context . get ( i ) ]
2189+ assert (i < Thread . CONTEXT_LENGTH )
2190+ return [thread .context [ i ] ]
21122191
21132192### 🔀 `canon context.set`
21142193
21152194async def canon_context_set (t , i , thread , v ):
21162195 assert (t == 'i32' )
2117- assert (i < ContextLocalStorage . LENGTH )
2118- thread .task . context . set ( i , v )
2196+ assert (i < Thread . CONTEXT_LENGTH )
2197+ thread .context [ i ] = v
21192198 return []
21202199
21212200### 🔀 `canon backpressure.set`
21222201
21232202async def canon_backpressure_set (thread , flat_args ):
2124- trap_if (thread .task .opts .sync )
2203+ # TODO: remove trap_if(thread.task.opts.sync)
21252204 assert (len (flat_args ) == 1 )
21262205 if flat_args [0 ] == 0 :
21272206 thread .task .inst .no_backpressure .set ()
0 commit comments