@@ -215,21 +215,28 @@ def tick(self):
215215
216216class Thread :
217217 task : Task
218+ index : int
218219 future : Optional [asyncio .Future ]
219220 on_resume : Optional [asyncio .Future ]
220221 on_suspend_or_exit : Optional [asyncio .Future ]
222+ context : list [int ]
223+
224+ CONTEXT_LENGTH = 2
221225
222226 def __init__ (self , task , thread_func ):
223227 self .task = task
228+ self .index = task .inst .table .add (self )
224229 self .future = None
225230 self .on_resume = asyncio .Future ()
226231 self .on_suspend_or_exit = None
232+ self .context = [0 ] * Thread .CONTEXT_LENGTH
227233 async def thread_start ():
228234 await self .on_resume
229235 self .on_resume = None
230236 await thread_func (task , self )
231237 self .on_suspend_or_exit .set_result (None )
232238 self .task .thread = None
239+ self .task .inst .table .remove (self .index )
233240 asyncio .create_task (thread_start ())
234241
235242 async def resume (self ):
@@ -254,6 +261,29 @@ async def suspend(self, future):
254261 await self .on_resume
255262 self .on_resume = None
256263
264+ async def switch (self , other : Thread ):
265+ assert (not self .future and not other .future )
266+ assert (self .on_suspend_or_exit and not other .on_suspend_or_exit )
267+ other .on_suspend_or_exit = self .on_suspend_or_exit
268+ self .on_suspend_or_exit = None
269+ other .on_resume .set_result (Cancelled .FALSE )
270+ assert (not self .on_resume )
271+ self .on_resume = asyncio .Future ()
272+ await self .on_resume
273+ self .on_resume = None
274+
275+ def yield_to (self , other : Thread ):
276+ # deterministically switch to other, but leave this thread unblocked
277+ TODO
278+
279+ def block (self ):
280+ # perform just the first half of switch
281+ TODO
282+
283+ def unblock (self , other : Thread ):
284+ # unblock other, but deterministically keep running here
285+ TODO
286+
257287
258288### Lifting and Lowering Context
259289
@@ -431,22 +461,6 @@ def write(self, vs):
431461 assert (all (v == () for v in vs ))
432462 self .progress += len (vs )
433463
434- #### Context-Local Storage
435-
436- class ContextLocalStorage :
437- LENGTH = 1
438- array : list [int ]
439-
440- def __init__ (self ):
441- self .array = [0 ] * ContextLocalStorage .LENGTH
442-
443- def set (self , i , v ):
444- assert (types_match_values (['i32' ], [v ]))
445- self .array [i ] = v
446-
447- def get (self , i ):
448- return self .array [i ]
449-
450464#### Waitable State
451465
452466class EventCode (IntEnum ):
@@ -457,6 +471,7 @@ class EventCode(IntEnum):
457471 FUTURE_READ = 4
458472 FUTURE_WRITE = 5
459473 TASK_CANCELLED = 6
474+ THREAD_RESUMED = 7
460475
461476EventTuple = tuple [EventCode , int , int ]
462477
@@ -546,7 +561,6 @@ class State(Enum):
546561 thread : Thread
547562 cancellable : bool
548563 num_borrows : int
549- context : ContextLocalStorage
550564
551565 def __init__ (self , opts , inst , ft , supertask , on_resolve , thread_func ):
552566 self .state = Task .State .INITIAL
@@ -558,7 +572,6 @@ def __init__(self, opts, inst, ft, supertask, on_resolve, thread_func):
558572 self .thread = Thread (self , thread_func )
559573 self .cancellable = False
560574 self .num_borrows = 0
561- self .context = ContextLocalStorage ()
562575
563576 def trap_if_on_the_stack (self , inst ):
564577 c = self .supertask
@@ -638,6 +651,7 @@ async def wait_for_event(self, thread, waitable_set, cancellable, unlock) -> Eve
638651 waitable_set .num_waiting += 1
639652 e = None
640653 while not e :
654+ # TODO: somehow get a THREAD_RESUME event...
641655 maybe_event = waitable_set .maybe_has_pending_event .wait ()
642656 await self .block_on (thread , maybe_event , cancellable , unlock )
643657 if self .deliver_cancel ():
@@ -650,6 +664,7 @@ async def yield_(self, thread, cancellable, unlock) -> EventTuple:
650664 assert (self .thread is thread and self is thread .task )
651665 if cancellable and self .deliver_cancel ():
652666 return (EventCode .TASK_CANCELLED , 0 , 0 )
667+ # TODO: somehow get a THREAD_RESUME event...
653668 await self .block_on (thread , asyncio .sleep (0 ), cancellable , unlock )
654669 if cancellable and self .deliver_cancel ():
655670 return (EventCode .TASK_CANCELLED , 0 , 0 )
@@ -681,7 +696,7 @@ def cancel(self):
681696 self .state = Task .State .RESOLVED
682697
683698 def exit (self ):
684- trap_if (self .state != Task .State .RESOLVED )
699+ trap_if (self .state != Task .State .RESOLVED ) # TODO: move this to empty-threads case
685700 assert (self .num_borrows == 0 )
686701 if self .needs_lock ():
687702 self .inst .lock .release ()
@@ -2102,25 +2117,76 @@ async def canon_resource_rep(rt, thread, i):
21022117 trap_if (h .rt is not rt )
21032118 return [h .rep ]
21042119
2120+ ### 🧵 `canon thread.index`
2121+
2122+ async def canon_thread_index (shared , thread ):
2123+ assert (not shared )
2124+ return [thread .index ]
2125+
2126+ ### 🧵 `canon thread.new_indirect`
2127+
2128+ async def canon_thread_new_indirect (ft , ftbl , thread , i , c ):
2129+ trap_if (not thread .task .inst .may_leave )
2130+ f = thread .task .inst .ftbl .get (i )
2131+ trap_if (f .type != ft )
2132+ thread = Thread (thread .task , f (c ))
2133+ return [thread .index ]
2134+
2135+ ### 🧵 `canon thread.switch`
2136+
2137+ async def canon_thread_switch (thread , i ):
2138+ trap_if (not thread .task .inst .may_leave )
2139+ other = thread .task .inst .table .get (i )
2140+ trap_if (not isinstance (other , Thread ))
2141+ cancelled = await thread .switch (other )
2142+ return [ 1 if cancelled else 0 ]
2143+
2144+ ### 🧵 `canon thread.yield-to`
2145+
2146+ async def canon_thread_yield_to (thread , i ):
2147+ trap_if (not thread .task .inst .may_leave )
2148+ other = thread .task .inst .table .get (i )
2149+ trap_if (not isinstance (other , Thread ))
2150+ other .yield_to (other )
2151+ return []
2152+
2153+ ### 🧵 `canon thread.block`
2154+
2155+ async def canon_thread_block (thread , i ):
2156+ trap_if (not thread .task .inst .may_leave )
2157+ other = thread .task .inst .table .get (i )
2158+ trap_if (not isinstance (other , Thread ))
2159+ cancelled = await thread .block ()
2160+ return [ 1 if cancelled else 0 ]
2161+
2162+ ### 🧵 `canon thread.unblock`
2163+
2164+ async def canon_thread_unblock (thread , i ):
2165+ trap_if (not thread .task .inst .may_leave )
2166+ other = thread .task .inst .table .get (i )
2167+ trap_if (not isinstance (other , Thread ))
2168+ thread .unblock ()
2169+ return []
2170+
21052171### 🔀 `canon context.get`
21062172
21072173async def canon_context_get (t , i , thread ):
21082174 assert (t == 'i32' )
2109- assert (i < ContextLocalStorage . LENGTH )
2110- return [thread .task . context . get ( i ) ]
2175+ assert (i < Thread . CONTEXT_LENGTH )
2176+ return [thread .context [ i ] ]
21112177
21122178### 🔀 `canon context.set`
21132179
21142180async def canon_context_set (t , i , thread , v ):
21152181 assert (t == 'i32' )
2116- assert (i < ContextLocalStorage . LENGTH )
2117- thread .task . context . set ( i , v )
2182+ assert (i < Thread . CONTEXT_LENGTH )
2183+ thread .context [ i ] = v
21182184 return []
21192185
21202186### 🔀 `canon backpressure.set`
21212187
21222188async def canon_backpressure_set (thread , flat_args ):
2123- trap_if (thread .task .opts .sync )
2189+ # TODO: remove trap_if(thread.task.opts.sync)
21242190 assert (len (flat_args ) == 1 )
21252191 if flat_args [0 ] == 0 :
21262192 thread .task .inst .no_backpressure .set ()
0 commit comments