@@ -354,22 +354,6 @@ def write(self, vs):
354354 assert (all (v == () for v in vs ))
355355 self .progress += len (vs )
356356
357- #### Context-Local Storage
358-
359- class ContextLocalStorage :
360- LENGTH = 1
361- array : list [int ]
362-
363- def __init__ (self ):
364- self .array = [0 ] * ContextLocalStorage .LENGTH
365-
366- def set (self , i , v ):
367- assert (types_match_values (['i32' ], [v ]))
368- self .array [i ] = v
369-
370- def get (self , i ):
371- return self .array [i ]
372-
373357#### Waitable State
374358
375359class EventCode (IntEnum ):
@@ -380,6 +364,7 @@ class EventCode(IntEnum):
380364 FUTURE_READ = 4
381365 FUTURE_WRITE = 5
382366 TASK_CANCELLED = 6
367+ THREAD_RESUMED = 7
383368
384369EventTuple = tuple [EventCode , int , int ]
385370
@@ -475,7 +460,6 @@ class State(Enum):
475460 on_resolve : OnResolve
476461 thread : Thread
477462 num_borrows : int
478- context : ContextLocalStorage
479463
480464 def __init__ (self , opts , inst , ft , supertask , on_resolve , thread_func ):
481465 self .state = Task .State .INITIAL
@@ -486,7 +470,6 @@ def __init__(self, opts, inst, ft, supertask, on_resolve, thread_func):
486470 self .on_resolve = on_resolve
487471 self .thread = Thread (self , thread_func )
488472 self .num_borrows = 0
489- self .context = ContextLocalStorage ()
490473
491474 def trap_if_on_the_stack (self , inst ):
492475 c = self .supertask
@@ -526,13 +509,16 @@ async def block_on(self, thread, awaitable, cancellable = False, unlock = False)
526509 return Cancelled .FALSE
527510
528511 if unlock and (self .opts .sync or self .opts .callback ):
512+ # assert(thread is thread.task.main_thread)
529513 self .inst .lock .release ()
530514
515+ # TODO: maybe pass 'cancellable' into 'suspend'
531516 cancelled = await thread .suspend (f )
532517 if cancelled and not cancellable :
533518 assert (await thread .suspend (f ) == Cancelled .FALSE )
534519
535520 if unlock and (self .opts .sync or self .opts .callback ):
521+ # assert(thread is thread.task.main_thread)
536522 acquired = asyncio .create_task (self .inst .lock .acquire ())
537523 if await thread .suspend (acquired ) == Cancelled .TRUE :
538524 assert (thread .suspend (acquired ) == Cancelled .FALSE )
@@ -559,6 +545,7 @@ async def wait_for_event(self, thread, waitable_set, cancellable, unlock) -> Eve
559545 e = None
560546 while not e :
561547 maybe_event = waitable_set .maybe_has_pending_event .wait ()
548+ # TODO: return EventCode.THREAD_RESUME
562549 if await self .block_on (thread , maybe_event , cancellable , unlock ) == Cancelled .TRUE :
563550 return (EventCode .TASK_CANCELLED , 0 , 0 )
564551 e = waitable_set .poll ()
@@ -570,6 +557,7 @@ async def yield_(self, thread, cancellable, unlock) -> EventTuple:
570557 if self .state == Task .State .PENDING_CANCEL and cancellable :
571558 self .state = Task .State .CANCEL_DELIVERED
572559 return (EventCode .TASK_CANCELLED , 0 , 0 )
560+ # TODO: return EventCode.THREAD_RESUME
573561 elif await self .block_on (thread , asyncio .sleep (0 ), cancellable , unlock ) == Cancelled .TRUE :
574562 return (EventCode .TASK_CANCELLED , 0 , 0 )
575563 else :
@@ -601,7 +589,7 @@ def cancel(self):
601589 self .state = Task .State .RESOLVED
602590
603591 def exit (self ):
604- trap_if (self .state != Task .State .RESOLVED )
592+ trap_if (self .state != Task .State .RESOLVED ) # TODO: move this to empty-threads case
605593 assert (self .num_borrows == 0 )
606594 if self .opts .sync or self .opts .callback :
607595 self .inst .lock .release ()
@@ -882,21 +870,28 @@ def drop(self):
882870
883871class Thread :
884872 task : Task
873+ index : int
885874 future : Optional [asyncio .Future ]
886875 on_resume : Optional [asyncio .Future ]
887876 on_suspend_or_exit : Optional [asyncio .Future ]
877+ context : list [int ]
878+
879+ CONTEXT_LENGTH = 1
888880
889881 def __init__ (self , task , thread_func ):
890882 self .task = task
883+ self .index = task .inst .table .add (self )
891884 self .future = None
892885 self .on_resume = asyncio .Future ()
893886 self .on_suspend_or_exit = None
887+ self .context = [0 ] * Thread .CONTEXT_LENGTH
894888 async def thread_start ():
895889 assert (await self .on_resume == Cancelled .FALSE )
896890 self .on_resume = None
897891 await thread_func (task , self )
898892 self .on_suspend_or_exit .set_result (None )
899893 self .task .thread = None
894+ self .task .inst .table .remove (self .index )
900895 asyncio .create_task (thread_start ())
901896
902897 async def resume (self , cancelled = Cancelled .FALSE ):
@@ -923,6 +918,30 @@ async def suspend(self, future) -> Cancelled:
923918 self .on_resume = None
924919 return cancelled
925920
921+ async def switch (self , other : Thread ) -> Cancelled :
922+ assert (not self .future and not other .future )
923+ assert (self .on_suspend_or_exit and not other .on_suspend_or_exit )
924+ other .on_suspend_or_exit = self .on_suspend_or_exit
925+ self .on_suspend_or_exit = None
926+ other .on_resume .set_result (Cancelled .FALSE )
927+ assert (not self .on_resume )
928+ self .on_resume = asyncio .Future ()
929+ cancelled = await self .on_resume
930+ self .on_resume = None
931+ return cancelled
932+
933+ def yield_to (self , other : Thread ) -> Cancelled :
934+ # deterministically switch to other, but leave this thread unblocked
935+ TODO
936+
937+ def block (self ) -> Cancelled :
938+ # perform just the first half of switch
939+ TODO
940+
941+ def unblock (self , other : Thread ):
942+ # unblock other, but deterministically keep running here
943+ TODO
944+
926945#### Store State / Embedding API
927946
928947class Store :
@@ -2095,25 +2114,76 @@ async def canon_resource_rep(rt, thread, i):
20952114 trap_if (h .rt is not rt )
20962115 return [h .rep ]
20972116
2117+ ### 🧵 `canon thread.index`
2118+
2119+ async def canon_thread_index (shared , thread ):
2120+ assert (not shared )
2121+ return [thread .index ]
2122+
2123+ ### 🧵 `canon thread.new_indirect`
2124+
2125+ async def canon_thread_new_indirect (ft , ftbl , thread , i , c ):
2126+ trap_if (not thread .task .inst .may_leave )
2127+ f = thread .task .inst .ftbl .get (i )
2128+ trap_if (f .type != ft )
2129+ thread = Thread (thread .task , f (c ))
2130+ return [thread .index ]
2131+
2132+ ### 🧵 `canon thread.switch`
2133+
2134+ async def canon_thread_switch (thread , i ):
2135+ trap_if (not thread .task .inst .may_leave )
2136+ other = thread .task .inst .table .get (i )
2137+ trap_if (not isinstance (other , Thread ))
2138+ cancelled = await thread .switch (other )
2139+ return [ 1 if cancelled else 0 ]
2140+
2141+ ### 🧵 `canon thread.yield-to`
2142+
2143+ async def canon_thread_yield_to (thread , i ):
2144+ trap_if (not thread .task .inst .may_leave )
2145+ other = thread .task .inst .table .get (i )
2146+ trap_if (not isinstance (other , Thread ))
2147+ other .yield_to (other )
2148+ return []
2149+
2150+ ### 🧵 `canon thread.block`
2151+
2152+ async def canon_thread_block (thread , i ):
2153+ trap_if (not thread .task .inst .may_leave )
2154+ other = thread .task .inst .table .get (i )
2155+ trap_if (not isinstance (other , Thread ))
2156+ cancelled = await thread .block ()
2157+ return [ 1 if cancelled else 0 ]
2158+
2159+ ### 🧵 `canon thread.unblock`
2160+
2161+ async def canon_thread_unblock (thread , i ):
2162+ trap_if (not thread .task .inst .may_leave )
2163+ other = thread .task .inst .table .get (i )
2164+ trap_if (not isinstance (other , Thread ))
2165+ thread .unblock ()
2166+ return []
2167+
20982168### 🔀 `canon context.get`
20992169
21002170async def canon_context_get (t , i , thread ):
21012171 assert (t == 'i32' )
2102- assert (i < ContextLocalStorage . LENGTH )
2103- return [thread .task . context . get ( i ) ]
2172+ assert (i < Thread . CONTEXT_LENGTH )
2173+ return [thread .context [ i ] ]
21042174
21052175### 🔀 `canon context.set`
21062176
21072177async def canon_context_set (t , i , thread , v ):
21082178 assert (t == 'i32' )
2109- assert (i < ContextLocalStorage . LENGTH )
2110- thread .task . context . set ( i , v )
2179+ assert (i < Thread . CONTEXT_LENGTH )
2180+ thread .context [ i ] = v
21112181 return []
21122182
21132183### 🔀 `canon backpressure.set`
21142184
21152185async def canon_backpressure_set (thread , flat_args ):
2116- trap_if (thread .task .opts .sync )
2186+ # TODO: remove trap_if(thread.task.opts.sync)
21172187 assert (len (flat_args ) == 1 )
21182188 if flat_args [0 ] == 0 :
21192189 thread .task .inst .no_backpressure .set ()
0 commit comments