@@ -213,20 +213,31 @@ def tick(self):
213213
214214class Thread :
215215 task : Task
216+ index : int
217+ context : list [int ]
216218 ready_func : Optional [Callable [[], bool ]]
217219 run_lock : threading .Lock
218220 resume_lock : Optional [threading .Lock ]
219221 stack : threading .Thread
222+ cancellable : bool
223+ waiting_for_callback : bool
224+
225+ CONTEXT_LENGTH = 2
220226
221227 def __init__ (self , task , thread_func ):
222228 self .task = task
229+ self .index = task .inst .table .add (self )
230+ self .context = [0 ] * Thread .CONTEXT_LENGTH
223231 self .ready_func = None
224232 self .run_lock = threading .Lock ()
225233 self .resume_lock = None
234+ self .cancellable = False
235+ self .waiting_for_callback = False
226236 def thread_stack_base ():
227237 self .run_lock .acquire ()
228238 thread_func (self )
229239 self .task .thread_stop (self )
240+ self .task .inst .table .remove (self .index )
230241 self .resume_lock .release ()
231242 self .stack = threading .Thread (target = thread_stack_base )
232243 self .run_lock .acquire ()
@@ -247,14 +258,45 @@ def resume(self):
247258 self .resume_lock .acquire ()
248259 self .resume_lock = None
249260
250- def suspend_until (self , ready_func ):
261+ def block (self , cancellable ):
262+ assert (not self .cancellable )
263+ self .cancellable = cancellable
264+ self .resume_lock .release ()
265+ self .run_lock .acquire ()
266+ self .cancellable = False
267+
268+ def suspend_until (self , ready_func , cancellable = False ):
251269 assert (not self .ready_func )
252270 if not DETERMINISTIC_PROFILE and ready_func ():
253271 return
254272 self .ready_func = ready_func
255273 self .task .inst .store .waiting .append (self )
256- self .resume_lock .release ()
274+ self .block (cancellable )
275+
276+ async def switch_to (self , cancellable , other : Thread ):
277+ assert (self .task .inst is other .task .inst )
278+ if other .ready_func :
279+ other .ready_func = None
280+ other .task .inst .store .waiting .remove (other )
281+ assert (not self .cancellable )
282+ self .cancellable = cancellable
283+ assert (self .resume_lock and not other .resume_lock )
284+ other .resume_lock = self .resume_lock
285+ self .resume_lock = None
286+ other .run_lock .release ()
257287 self .run_lock .acquire ()
288+ self .cancellable = False
289+
290+ def yield_to (self , cancellable , other : Thread ):
291+ assert (not self .ready_func )
292+ self .ready_func = lambda : True
293+ self .task .inst .store .waiting .append (self )
294+ self .switch_to (cancellable , other )
295+
296+ def unblock (self , other : Thread ):
297+ if not other .ready_func :
298+ other .task .inst .store .waiting .append (other )
299+ other .ready_func = lambda : True
258300
259301
260302### Lifting and Lowering Context
@@ -432,22 +474,6 @@ def write(self, vs):
432474 assert (all (v == () for v in vs ))
433475 self .progress += len (vs )
434476
435- #### Context-Local Storage
436-
437- class ContextLocalStorage :
438- LENGTH = 1
439- array : list [int ]
440-
441- def __init__ (self ):
442- self .array = [0 ] * ContextLocalStorage .LENGTH
443-
444- def set (self , i , v ):
445- assert (types_match_values (['i32' ], [v ]))
446- self .array [i ] = v
447-
448- def get (self , i ):
449- return self .array [i ]
450-
451477#### Waitable State
452478
453479class EventCode (IntEnum ):
@@ -458,6 +484,7 @@ class EventCode(IntEnum):
458484 FUTURE_READ = 4
459485 FUTURE_WRITE = 5
460486 TASK_CANCELLED = 6
487+ THREAD_RESUMED = 7
461488
462489EventTuple = tuple [EventCode , int , int ]
463490
@@ -530,11 +557,8 @@ class State(Enum):
530557 ft : FuncType
531558 supertask : Optional [Task ]
532559 on_resolve : OnResolve
533- thread : Optional [Thread ]
534- cancellable : bool
535- waiting_for_callback : bool
560+ threads : list [Thread ]
536561 num_borrows : int
537- context : ContextLocalStorage
538562
539563 def __init__ (self , opts , inst , ft , supertask , on_resolve ):
540564 self .state = Task .State .INITIAL
@@ -543,11 +567,8 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
543567 self .ft = ft
544568 self .supertask = supertask
545569 self .on_resolve = on_resolve
546- self .thread = None
547- self .cancellable = False
548- self .waiting_for_callback = False
570+ self .threads = []
549571 self .num_borrows = 0
550- self .context = ContextLocalStorage ()
551572
552573 def trap_if_on_the_stack (self , inst ):
553574 c = self .supertask
@@ -559,15 +580,13 @@ def needs_exclusive(self):
559580 return self .opts .sync or self .opts .callback
560581
561582 def enter (self , thread ):
562- assert (thread is self .thread and thread .task is self )
583+ assert (thread in self .threads and thread .task is self )
563584 def has_backpressure ():
564585 return self .inst .backpressure or (self .needs_exclusive () and self .inst .exclusive )
565586 if has_backpressure () or self .inst .pending_tasks > 0 :
566587 self .inst .pending_tasks += 1
567- self .cancellable = True
568- thread .suspend_until (lambda : not has_backpressure ())
588+ thread .suspend_until (lambda : not has_backpressure (), cancellable = True )
569589 self .inst .pending_tasks -= 1
570- self .cancellable = False
571590 if self .deliver_cancel ():
572591 self .cancel ()
573592 return False
@@ -586,27 +605,28 @@ def deliver_cancel(self) -> bool:
586605 def request_cancellation (self ):
587606 assert (self .state == Task .State .INITIAL )
588607 self .state = Task .State .PENDING_CANCEL
589- if self .cancellable and not (self .waiting_for_callback and self .inst .exclusive ):
590- self .thread .resume ()
608+ if not DETERMINISTIC_PROFILE :
609+ random .shuffle (self .threads )
610+ for thread in self .threads :
611+ if thread .cancellable and not (thread .waiting_for_callback and self .inst .exclusive ):
612+ thread .resume ()
613+ break
591614
592615 def wait_until (self , ready_func , thread , cancellable , for_callback ):
593- assert (thread is self .thread and thread .task is self )
616+ assert (thread in self .threads and thread .task is self )
594617 if cancellable and self .deliver_cancel ():
595618 return True
596- assert (not self .cancellable )
597- self .cancellable = cancellable
598619 if for_callback :
599620 assert (self .inst .exclusive )
600621 self .inst .exclusive = False
601- self .waiting_for_callback = True
622+ thread .waiting_for_callback = True
602623 def ready_and_allowed ():
603624 return ready_func () and not (for_callback and self .inst .exclusive )
604- thread .suspend_until (ready_and_allowed )
625+ thread .suspend_until (ready_and_allowed , cancellable )
605626 if for_callback :
606627 assert (not self .inst .exclusive )
607628 self .inst .exclusive = True
608- self .waiting_for_callback = False
609- self .cancellable = False
629+ thread .waiting_for_callback = False
610630 if cancellable and self .deliver_cancel ():
611631 return True
612632 return False
@@ -615,13 +635,15 @@ def wait_for_event(self, thread, wset, cancellable, for_callback) -> EventTuple:
615635 wset .num_waiting += 1
616636 cancelled = self .wait_until (wset .has_pending_event , thread , cancellable , for_callback )
617637 wset .num_waiting -= 1
638+ # TODO: somehow get a THREAD_RESUME event...
618639 if cancelled :
619640 return (EventCode .TASK_CANCELLED , 0 , 0 )
620641 else :
621642 return wset .get_pending_event ()
622643
623644 def yield_ (self , thread , cancellable , for_callback ) -> EventTuple :
624645 cancelled = self .wait_until (lambda : True , thread , cancellable , for_callback )
646+ # TODO: somehow get a THREAD_RESUME event...
625647 if cancelled :
626648 return (EventCode .TASK_CANCELLED , 0 , 0 )
627649 else :
@@ -631,6 +653,7 @@ def poll_for_event(self, thread, wset, cancellable, for_callback) -> Optional[Ev
631653 wset .num_waiting += 1
632654 cancelled = self .wait_until (lambda : True , thread , cancellable , for_callback )
633655 wset .num_waiting -= 1
656+ # TODO: somehow get a THREAD_RESUME event...
634657 if cancelled :
635658 return (EventCode .TASK_CANCELLED , 0 , 0 )
636659 elif wset .has_pending_event ():
@@ -652,20 +675,21 @@ def cancel(self):
652675 self .state = Task .State .RESOLVED
653676
654677 def exit (self ):
655- assert (self .thread is not None )
678+ assert (len ( self .threads ) > 0 )
656679 if self .needs_exclusive ():
657680 assert (self .inst .exclusive )
658681 self .inst .exclusive = False
659682
660683 def thread_start (self , thread ):
661- assert (self . thread is None and thread .task is self )
662- self .thread = thread
684+ assert (thread not in self . threads and thread .task is self )
685+ self .threads . append ( thread )
663686
664687 def thread_stop (self , thread ):
665- assert (thread is self .thread and thread .task is self )
666- self .thread = None
667- trap_if (self .state != Task .State .RESOLVED )
668- assert (self .num_borrows == 0 )
688+ assert (thread in self .threads and thread .task is self )
689+ self .threads .remove (thread )
690+ if len (self .threads ) == 0 :
691+ trap_if (self .state != Task .State .RESOLVED )
692+ assert (self .num_borrows == 0 )
669693
670694#### Subtask State
671695
@@ -2081,25 +2105,80 @@ def canon_resource_rep(rt, thread, i):
20812105 trap_if (h .rt is not rt )
20822106 return [h .rep ]
20832107
2108+ ### 🧵 `canon thread.index`
2109+
2110+ def canon_thread_index (shared , thread ):
2111+ assert (not shared )
2112+ return [thread .index ]
2113+
2114+ ### 🧵 `canon thread.new`
2115+
2116+ def canon_thread_new (ft , ftbl , thread , i , c ):
2117+ task = thread .task
2118+ trap_if (not task .inst .may_leave )
2119+ f = task .inst .ftbl .get (i )
2120+ trap_if (f .type != ft )
2121+ new_thread = Thread (task , f (c ))
2122+ return [new_thread .index ]
2123+
2124+ ### 🧵 `canon thread.switch-to`
2125+
2126+ def canon_thread_switch_to (thread , cancellable , i ):
2127+ trap_if (not thread .task .inst .may_leave )
2128+ other = thread .task .inst .table .get (i )
2129+ trap_if (not isinstance (other , Thread ))
2130+ trap_if (not other .cancellable ) # TODO: what about waiting_for_callback and exclusive
2131+ cancelled = thread .switch_to (cancellable , other )
2132+ return [ 1 if cancelled else 0 ]
2133+
2134+ ### 🧵 `canon thread.yield-to`
2135+
2136+ def canon_thread_yield_to (thread , cancellable , i ):
2137+ trap_if (not thread .task .inst .may_leave )
2138+ other = thread .task .inst .table .get (i )
2139+ trap_if (not isinstance (other , Thread ))
2140+ trap_if (not other .cancellable ) # TODO: what about waiting_for_callback and exclusive
2141+ other .yield_to (cancellable , other )
2142+ return []
2143+
2144+ ### 🧵 `canon thread.unblock`
2145+
2146+ def canon_thread_unblock (thread , i ):
2147+ trap_if (not thread .task .inst .may_leave )
2148+ other = thread .task .inst .table .get (i )
2149+ trap_if (not isinstance (other , Thread ))
2150+ trap_if (not other .cancellable ) # TODO: what about waiting_for_callback and exclusive
2151+ thread .unblock ()
2152+ return []
2153+
2154+ ### 🧵 `canon thread.block`
2155+
2156+ def canon_thread_block (thread , cancellable , i ):
2157+ trap_if (not thread .task .inst .may_leave )
2158+ other = thread .task .inst .table .get (i )
2159+ trap_if (not isinstance (other , Thread ))
2160+ cancelled = thread .block (cancellable )
2161+ return [ 1 if cancelled else 0 ]
2162+
20842163### 🔀 `canon context.get`
20852164
20862165def canon_context_get (t , i , thread ):
20872166 assert (t == 'i32' )
2088- assert (i < ContextLocalStorage . LENGTH )
2089- return [thread .task . context . get ( i ) ]
2167+ assert (i < Thread . CONTEXT_LENGTH )
2168+ return [thread .context [ i ] ]
20902169
20912170### 🔀 `canon context.set`
20922171
20932172def canon_context_set (t , i , thread , v ):
20942173 assert (t == 'i32' )
2095- assert (i < ContextLocalStorage . LENGTH )
2096- thread .task . context . set ( i , v )
2174+ assert (i < Thread . CONTEXT_LENGTH )
2175+ thread .context [ i ] = v
20972176 return []
20982177
20992178### 🔀 `canon backpressure.set`
21002179
21012180def canon_backpressure_set (thread , flat_args ):
2102- trap_if (thread .task .opts .sync )
2181+ # TODO: remove trap_if(thread.task.opts.sync)
21032182 assert (len (flat_args ) == 1 )
21042183 thread .task .inst .backpressure = bool (flat_args [0 ])
21052184 return []
0 commit comments