@@ -213,13 +213,19 @@ def tick(self):
213213
214214class Thread :
215215 task : Task
216+ index : int
217+ context : list [int ]
216218 ready_func : Optional [Callable [[], bool ]]
217219 run_lock : threading .Lock
218220 resume_lock : threading .Lock
219221 stack : threading .Thread
220222
223+ CONTEXT_LENGTH = 2
224+
221225 def __init__ (self , task , thread_func ):
222226 self .task = task
227+ self .index = task .inst .table .add (self )
228+ self .context = [0 ] * Thread .CONTEXT_LENGTH
223229 self .ready_func = None
224230 self .run_lock = threading .Lock ()
225231 self .run_lock .acquire ()
@@ -230,6 +236,7 @@ def thread_stack_base():
230236 assert (self .resume_lock .locked ())
231237 thread_func (self )
232238 self .task .thread_stop (self )
239+ self .task .inst .table .remove (self .index )
233240 self .resume_lock .release ()
234241 self .stack = threading .Thread (target = thread_stack_base )
235242 self .stack .start ()
@@ -255,6 +262,22 @@ def suspend_until(self, ready_func):
255262 self .resume_lock .release ()
256263 self .run_lock .acquire ()
257264
265+ async def switch_to (self , cancellable , other : Thread ):
266+ # deterministically switch to other, leave this blocked
267+ TODO
268+
269+ def yield_to (self , cancellable , other : Thread ):
270+ # deterministically switch to other, but leave this thread unblocked
271+ TODO
272+
273+ def block (self , cancellable ):
274+ # perform just the first half of switch
275+ TODO
276+
277+ def unblock (self , other : Thread ):
278+ # unblock other, but deterministically keep running here
279+ TODO
280+
258281
259282### Lifting and Lowering Context
260283
@@ -431,22 +454,6 @@ def write(self, vs):
431454 assert (all (v == () for v in vs ))
432455 self .progress += len (vs )
433456
434- #### Context-Local Storage
435-
436- class ContextLocalStorage :
437- LENGTH = 1
438- array : list [int ]
439-
440- def __init__ (self ):
441- self .array = [0 ] * ContextLocalStorage .LENGTH
442-
443- def set (self , i , v ):
444- assert (types_match_values (['i32' ], [v ]))
445- self .array [i ] = v
446-
447- def get (self , i ):
448- return self .array [i ]
449-
450457#### Waitable State
451458
452459class EventCode (IntEnum ):
@@ -457,6 +464,7 @@ class EventCode(IntEnum):
457464 FUTURE_READ = 4
458465 FUTURE_WRITE = 5
459466 TASK_CANCELLED = 6
467+ THREAD_RESUMED = 7
460468
461469EventTuple = tuple [EventCode , int , int ]
462470
@@ -529,11 +537,10 @@ class State(Enum):
529537 ft : FuncType
530538 supertask : Optional [Task ]
531539 on_resolve : OnResolve
532- thread : Optional [Thread ]
540+ threads : list [Thread ]
533541 cancellable : bool
534542 waiting_for_callback : bool
535543 num_borrows : int
536- context : ContextLocalStorage
537544
538545 def __init__ (self , opts , inst , ft , supertask , on_resolve ):
539546 self .state = Task .State .INITIAL
@@ -542,11 +549,10 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
542549 self .ft = ft
543550 self .supertask = supertask
544551 self .on_resolve = on_resolve
545- self .thread = None
552+ self .threads = []
546553 self .cancellable = False
547554 self .waiting_for_callback = False
548555 self .num_borrows = 0
549- self .context = ContextLocalStorage ()
550556
551557 def trap_if_on_the_stack (self , inst ):
552558 c = self .supertask
@@ -558,7 +564,7 @@ def needs_exclusive(self):
558564 return self .opts .sync or self .opts .callback
559565
560566 def enter (self , thread ):
561- assert (thread is self .thread and thread .task is self )
567+ assert (thread in self .threads and thread .task is self )
562568 def has_backpressure ():
563569 return self .inst .backpressure or (self .needs_exclusive () and self .inst .exclusive )
564570 if has_backpressure () or self .inst .pending_tasks > 0 :
@@ -585,11 +591,13 @@ def deliver_cancel(self) -> bool:
585591 def request_cancellation (self ):
586592 assert (self .state == Task .State .INITIAL )
587593 self .state = Task .State .PENDING_CANCEL
594+ # TODO: move cancellability to the Thread and then search
595+ # for a cancellable one here...
588596 if self .cancellable and not (self .waiting_for_callback and self .inst .exclusive ):
589- self .thread .resume ()
597+ self .threads [ 0 ] .resume ()
590598
591599 def wait_until (self , ready_func , thread , cancellable , for_callback ):
592- assert (thread is self .thread and thread .task is self )
600+ assert (thread in self .threads and thread .task is self )
593601 if cancellable and self .deliver_cancel ():
594602 return True
595603 assert (not self .cancellable )
@@ -612,13 +620,15 @@ def wait_for_event(self, thread, wset, cancellable, for_callback) -> EventTuple:
612620 wset .num_waiting += 1
613621 cancelled = self .wait_until (wset .has_pending_event , thread , cancellable , for_callback )
614622 wset .num_waiting -= 1
623+ # TODO: somehow get a THREAD_RESUME event...
615624 if cancelled :
616625 return (EventCode .TASK_CANCELLED , 0 , 0 )
617626 else :
618627 return wset .get_pending_event ()
619628
620629 def yield_ (self , thread , cancellable , for_callback ) -> EventTuple :
621630 cancelled = self .wait_until (lambda : True , thread , cancellable , for_callback )
631+ # TODO: somehow get a THREAD_RESUME event...
622632 if cancelled :
623633 return (EventCode .TASK_CANCELLED , 0 , 0 )
624634 else :
@@ -628,6 +638,7 @@ def poll_for_event(self, thread, wset, cancellable, for_callback) -> Optional[Ev
628638 wset .num_waiting += 1
629639 cancelled = self .wait_until (lambda : True , thread , cancellable , for_callback )
630640 wset .num_waiting -= 1
641+ # TODO: somehow get a THREAD_RESUME event...
631642 if cancelled :
632643 return (EventCode .TASK_CANCELLED , 0 , 0 )
633644 elif wset .has_pending_event ():
@@ -649,20 +660,21 @@ def cancel(self):
649660 self .state = Task .State .RESOLVED
650661
651662 def exit (self ):
652- assert (self .thread is not None )
663+ assert (len ( self .threads ) > 0 )
653664 if self .needs_exclusive ():
654665 assert (self .inst .exclusive )
655666 self .inst .exclusive = False
656667
657668 def thread_start (self , thread ):
658- assert (self . thread is None and thread .task is self )
659- self .thread = thread
669+ assert (thread not in self . threads and thread .task is self )
670+ self .threads . append ( thread )
660671
661672 def thread_stop (self , thread ):
662- assert (thread is self .thread and thread .task is self )
663- self .thread = None
664- trap_if (self .state != Task .State .RESOLVED )
665- assert (self .num_borrows == 0 )
673+ assert (thread in self .threads and thread .task is self )
674+ self .threads .remove (thread )
675+ if len (self .threads ) == 0 :
676+ trap_if (self .state != Task .State .RESOLVED )
677+ assert (self .num_borrows == 0 )
666678
667679#### Subtask State
668680
@@ -2078,25 +2090,77 @@ def canon_resource_rep(rt, thread, i):
20782090 trap_if (h .rt is not rt )
20792091 return [h .rep ]
20802092
2093+ ### 🧵 `canon thread.index`
2094+
2095+ def canon_thread_index (shared , thread ):
2096+ assert (not shared )
2097+ return [thread .index ]
2098+
2099+ ### 🧵 `canon thread.new`
2100+
2101+ def canon_thread_new (ft , ftbl , thread , i , c ):
2102+ task = thread .task
2103+ trap_if (not task .inst .may_leave )
2104+ f = task .inst .ftbl .get (i )
2105+ trap_if (f .type != ft )
2106+ new_thread = Thread (task , f (c ))
2107+ return [new_thread .index ]
2108+
2109+ ### 🧵 `canon thread.switch-to`
2110+
2111+ def canon_thread_switch_to (thread , cancellable , i ):
2112+ trap_if (not thread .task .inst .may_leave )
2113+ other = thread .task .inst .table .get (i )
2114+ trap_if (not isinstance (other , Thread ))
2115+ cancelled = thread .switch_to (cancellable , other )
2116+ return [ 1 if cancelled else 0 ]
2117+
2118+ ### 🧵 `canon thread.yield-to`
2119+
2120+ def canon_thread_yield_to (thread , cancellable , i ):
2121+ trap_if (not thread .task .inst .may_leave )
2122+ other = thread .task .inst .table .get (i )
2123+ trap_if (not isinstance (other , Thread ))
2124+ other .yield_to (cancellable , other )
2125+ return []
2126+
2127+ ### 🧵 `canon thread.block`
2128+
2129+ def canon_thread_block (thread , cancellable , i ):
2130+ trap_if (not thread .task .inst .may_leave )
2131+ other = thread .task .inst .table .get (i )
2132+ trap_if (not isinstance (other , Thread ))
2133+ cancelled = thread .block (cancellable )
2134+ return [ 1 if cancelled else 0 ]
2135+
2136+ ### 🧵 `canon thread.unblock`
2137+
2138+ def canon_thread_unblock (thread , i ):
2139+ trap_if (not thread .task .inst .may_leave )
2140+ other = thread .task .inst .table .get (i )
2141+ trap_if (not isinstance (other , Thread ))
2142+ thread .unblock ()
2143+ return []
2144+
20812145### 🔀 `canon context.get`
20822146
20832147def canon_context_get (t , i , thread ):
20842148 assert (t == 'i32' )
2085- assert (i < ContextLocalStorage . LENGTH )
2086- return [thread .task . context . get ( i ) ]
2149+ assert (i < Thread . CONTEXT_LENGTH )
2150+ return [thread .context [ i ] ]
20872151
20882152### 🔀 `canon context.set`
20892153
20902154def canon_context_set (t , i , thread , v ):
20912155 assert (t == 'i32' )
2092- assert (i < ContextLocalStorage . LENGTH )
2093- thread .task . context . set ( i , v )
2156+ assert (i < Thread . CONTEXT_LENGTH )
2157+ thread .context [ i ] = v
20942158 return []
20952159
20962160### 🔀 `canon backpressure.set`
20972161
20982162def canon_backpressure_set (thread , flat_args ):
2099- trap_if (thread .task .opts .sync )
2163+ # TODO: remove trap_if(thread.task.opts.sync)
21002164 assert (len (flat_args ) == 1 )
21012165 thread .task .inst .backpressure = bool (flat_args [0 ])
21022166 return []
0 commit comments