@@ -213,13 +213,19 @@ def tick(self):
213213
214214class Thread :
215215 task : Task
216+ index : int
217+ context : list [int ]
216218 ready_func : Optional [Callable [[], bool ]]
217219 run_lock : threading .Lock
218220 resume_lock : threading .Lock
219221 stack : threading .Thread
220222
223+ CONTEXT_LENGTH = 2
224+
221225 def __init__ (self , task , thread_func ):
222226 self .task = task
227+ self .index = task .inst .table .add (self )
228+ self .context = [0 ] * Thread .CONTEXT_LENGTH
223229 self .ready_func = None
224230 self .run_lock = threading .Lock ()
225231 self .run_lock .acquire ()
@@ -229,6 +235,7 @@ def thread_stack_base():
229235 self .run_lock .acquire ()
230236 thread_func (self )
231237 self .task .thread_stop (self )
238+ self .task .inst .table .remove (self .index )
232239 self .resume_lock .release ()
233240 self .stack = threading .Thread (target = thread_stack_base )
234241 self .stack .start ()
@@ -252,6 +259,22 @@ def suspend_until(self, ready_func):
252259 self .resume_lock .release ()
253260 self .run_lock .acquire ()
254261
262+ async def switch_to (self , cancellable , other : Thread ):
263+ # deterministically switch to other, leave this blocked
264+ TODO
265+
266+ def yield_to (self , cancellable , other : Thread ):
267+ # deterministically switch to other, but leave this thread unblocked
268+ TODO
269+
270+ def block (self , cancellable ):
271+ # perform just the first half of switch
272+ TODO
273+
274+ def unblock (self , other : Thread ):
275+ # unblock other, but deterministically keep running here
276+ TODO
277+
255278
256279### Lifting and Lowering Context
257280
@@ -428,22 +451,6 @@ def write(self, vs):
428451 assert (all (v == () for v in vs ))
429452 self .progress += len (vs )
430453
431- #### Context-Local Storage
432-
433- class ContextLocalStorage :
434- LENGTH = 1
435- array : list [int ]
436-
437- def __init__ (self ):
438- self .array = [0 ] * ContextLocalStorage .LENGTH
439-
440- def set (self , i , v ):
441- assert (types_match_values (['i32' ], [v ]))
442- self .array [i ] = v
443-
444- def get (self , i ):
445- return self .array [i ]
446-
447454#### Waitable State
448455
449456class EventCode (IntEnum ):
@@ -454,6 +461,7 @@ class EventCode(IntEnum):
454461 FUTURE_READ = 4
455462 FUTURE_WRITE = 5
456463 TASK_CANCELLED = 6
464+ THREAD_RESUMED = 7
457465
458466EventTuple = tuple [EventCode , int , int ]
459467
@@ -526,11 +534,10 @@ class State(Enum):
526534 ft : FuncType
527535 supertask : Optional [Task ]
528536 on_resolve : OnResolve
529- thread : Optional [Thread ]
537+ threads : list [Thread ]
530538 cancellable : bool
531539 waiting_for_callback : bool
532540 num_borrows : int
533- context : ContextLocalStorage
534541
535542 def __init__ (self , opts , inst , ft , supertask , on_resolve ):
536543 self .state = Task .State .INITIAL
@@ -539,11 +546,10 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
539546 self .ft = ft
540547 self .supertask = supertask
541548 self .on_resolve = on_resolve
542- self .thread = None
549+ self .threads = []
543550 self .cancellable = False
544551 self .waiting_for_callback = False
545552 self .num_borrows = 0
546- self .context = ContextLocalStorage ()
547553
548554 def trap_if_on_the_stack (self , inst ):
549555 c = self .supertask
@@ -555,7 +561,7 @@ def needs_exclusive(self):
555561 return self .opts .sync or self .opts .callback
556562
557563 def enter (self , thread ):
558- assert (thread is self .thread and thread .task is self )
564+ assert (thread in self .threads and thread .task is self )
559565 def has_backpressure ():
560566 return self .inst .backpressure or (self .needs_exclusive () and self .inst .exclusive )
561567 if has_backpressure () or self .inst .pending_tasks > 0 :
@@ -582,11 +588,13 @@ def deliver_cancel(self) -> bool:
582588 def request_cancellation (self ):
583589 assert (self .state == Task .State .INITIAL )
584590 self .state = Task .State .PENDING_CANCEL
591+ # TODO: move cancellability to the Thread and then search
592+ # for a cancellable one here...
585593 if self .cancellable and not (self .waiting_for_callback and self .inst .exclusive ):
586- self .thread .resume ()
594+ self .threads [ 0 ] .resume ()
587595
588596 def wait_until (self , ready_func , thread , cancellable , for_callback ):
589- assert (thread is self .thread and thread .task is self )
597+ assert (thread in self .threads and thread .task is self )
590598 if cancellable and self .deliver_cancel ():
591599 return True
592600 assert (not self .cancellable )
@@ -609,13 +617,15 @@ def wait_for_event(self, thread, wset, cancellable, for_callback) -> EventTuple:
609617 wset .num_waiting += 1
610618 cancelled = self .wait_until (wset .has_pending_event , thread , cancellable , for_callback )
611619 wset .num_waiting -= 1
620+ # TODO: somehow get a THREAD_RESUME event...
612621 if cancelled :
613622 return (EventCode .TASK_CANCELLED , 0 , 0 )
614623 else :
615624 return wset .get_pending_event ()
616625
617626 def yield_ (self , thread , cancellable , for_callback ) -> EventTuple :
618627 cancelled = self .wait_until (lambda : True , thread , cancellable , for_callback )
628+ # TODO: somehow get a THREAD_RESUME event...
619629 if cancelled :
620630 return (EventCode .TASK_CANCELLED , 0 , 0 )
621631 else :
@@ -625,6 +635,7 @@ def poll_for_event(self, thread, wset, cancellable, for_callback) -> Optional[Ev
625635 wset .num_waiting += 1
626636 cancelled = self .wait_until (lambda : True , thread , cancellable , for_callback )
627637 wset .num_waiting -= 1
638+ # TODO: somehow get a THREAD_RESUME event...
628639 if cancelled :
629640 return (EventCode .TASK_CANCELLED , 0 , 0 )
630641 elif wset .has_pending_event ():
@@ -646,20 +657,21 @@ def cancel(self):
646657 self .state = Task .State .RESOLVED
647658
648659 def exit (self ):
649- assert (self .thread is not None )
660+ assert (len ( self .threads ) > 0 )
650661 if self .needs_exclusive ():
651662 assert (self .inst .exclusive )
652663 self .inst .exclusive = False
653664
654665 def thread_start (self , thread ):
655- assert (self . thread is None and thread .task is self )
656- self .thread = thread
666+ assert (thread not in self . threads and thread .task is self )
667+ self .threads . append ( thread )
657668
658669 def thread_stop (self , thread ):
659- assert (thread is self .thread and thread .task is self )
660- self .thread = None
661- trap_if (self .state != Task .State .RESOLVED )
662- assert (self .num_borrows == 0 )
670+ assert (thread in self .threads and thread .task is self )
671+ self .threads .remove (thread )
672+ if len (self .threads ) == 0 :
673+ trap_if (self .state != Task .State .RESOLVED )
674+ assert (self .num_borrows == 0 )
663675
664676#### Subtask State
665677
@@ -2075,25 +2087,77 @@ def canon_resource_rep(rt, thread, i):
20752087 trap_if (h .rt is not rt )
20762088 return [h .rep ]
20772089
2090+ ### 🧵 `canon thread.index`
2091+
2092+ def canon_thread_index (shared , thread ):
2093+ assert (not shared )
2094+ return [thread .index ]
2095+
2096+ ### 🧵 `canon thread.new`
2097+
2098+ def canon_thread_new (ft , ftbl , thread , i , c ):
2099+ task = thread .task
2100+ trap_if (not task .inst .may_leave )
2101+ f = task .inst .ftbl .get (i )
2102+ trap_if (f .type != ft )
2103+ new_thread = Thread (task , f (c ))
2104+ return [new_thread .index ]
2105+
2106+ ### 🧵 `canon thread.switch-to`
2107+
2108+ def canon_thread_switch_to (thread , cancellable , i ):
2109+ trap_if (not thread .task .inst .may_leave )
2110+ other = thread .task .inst .table .get (i )
2111+ trap_if (not isinstance (other , Thread ))
2112+ cancelled = thread .switch_to (cancellable , other )
2113+ return [ 1 if cancelled else 0 ]
2114+
2115+ ### 🧵 `canon thread.yield-to`
2116+
2117+ def canon_thread_yield_to (thread , cancellable , i ):
2118+ trap_if (not thread .task .inst .may_leave )
2119+ other = thread .task .inst .table .get (i )
2120+ trap_if (not isinstance (other , Thread ))
2121+ other .yield_to (cancellable , other )
2122+ return []
2123+
2124+ ### 🧵 `canon thread.block`
2125+
2126+ def canon_thread_block (thread , cancellable , i ):
2127+ trap_if (not thread .task .inst .may_leave )
2128+ other = thread .task .inst .table .get (i )
2129+ trap_if (not isinstance (other , Thread ))
2130+ cancelled = thread .block (cancellable )
2131+ return [ 1 if cancelled else 0 ]
2132+
2133+ ### 🧵 `canon thread.unblock`
2134+
2135+ def canon_thread_unblock (thread , i ):
2136+ trap_if (not thread .task .inst .may_leave )
2137+ other = thread .task .inst .table .get (i )
2138+ trap_if (not isinstance (other , Thread ))
2139+ thread .unblock ()
2140+ return []
2141+
20782142### 🔀 `canon context.get`
20792143
20802144def canon_context_get (t , i , thread ):
20812145 assert (t == 'i32' )
2082- assert (i < ContextLocalStorage . LENGTH )
2083- return [thread .task . context . get ( i ) ]
2146+ assert (i < Thread . CONTEXT_LENGTH )
2147+ return [thread .context [ i ] ]
20842148
20852149### 🔀 `canon context.set`
20862150
20872151def canon_context_set (t , i , thread , v ):
20882152 assert (t == 'i32' )
2089- assert (i < ContextLocalStorage . LENGTH )
2090- thread .task . context . set ( i , v )
2153+ assert (i < Thread . CONTEXT_LENGTH )
2154+ thread .context [ i ] = v
20912155 return []
20922156
20932157### 🔀 `canon backpressure.set`
20942158
20952159def canon_backpressure_set (thread , flat_args ):
2096- trap_if (thread .task .opts .sync )
2160+ # TODO: remove trap_if(thread.task.opts.sync)
20972161 assert (len (flat_args ) == 1 )
20982162 thread .task .inst .backpressure = bool (flat_args [0 ])
20992163 return []
0 commit comments