@@ -213,20 +213,32 @@ def tick(self):
213213
214214class Thread :
215215 task : Task
216+ index : Optional [int ]
217+ context : list [int ]
216218 ready_func : Optional [Callable [[], bool ]]
217219 run_lock : threading .Lock
218220 resume_lock : Optional [threading .Lock ]
219221 stack : threading .Thread
222+ cancellable : bool
223+ waiting_for_callback : bool
224+
225+ CONTEXT_LENGTH = 2
220226
221227 def __init__ (self , task , thread_func ):
222228 self .task = task
229+ self .index = None
230+ self .context = [0 ] * Thread .CONTEXT_LENGTH
223231 self .ready_func = None
224232 self .run_lock = threading .Lock ()
225233 self .resume_lock = None
234+ self .cancellable = False
235+ self .waiting_for_callback = False
226236 def thread_stack_base ():
227237 self .run_lock .acquire ()
228238 thread_func (self )
229239 self .task .thread_stop (self )
240+ if self .index is not None :
241+ self .task .inst .table .remove (self .index )
230242 self .resume_lock .release ()
231243 self .stack = threading .Thread (target = thread_stack_base )
232244 self .run_lock .acquire ()
@@ -247,14 +259,45 @@ def resume(self):
247259 self .resume_lock .acquire ()
248260 self .resume_lock = None
249261
250- def suspend_until (self , ready_func ):
262+ def block (self , cancellable ):
263+ assert (not self .cancellable )
264+ self .cancellable = cancellable
265+ self .resume_lock .release ()
266+ self .run_lock .acquire ()
267+ self .cancellable = False
268+
269+ def suspend_until (self , ready_func , cancellable = False ):
251270 assert (not self .ready_func )
252271 if not DETERMINISTIC_PROFILE and ready_func ():
253272 return
254273 self .ready_func = ready_func
255274 self .task .inst .store .waiting .append (self )
256- self .resume_lock .release ()
275+ self .block (cancellable )
276+
277+ async def switch_to (self , cancellable , other : Thread ):
278+ assert (self .task .inst is other .task .inst )
279+ if other .ready_func :
280+ other .ready_func = None
281+ other .task .inst .store .waiting .remove (other )
282+ assert (not self .cancellable )
283+ self .cancellable = cancellable
284+ assert (self .resume_lock and not other .resume_lock )
285+ other .resume_lock = self .resume_lock
286+ self .resume_lock = None
287+ other .run_lock .release ()
257288 self .run_lock .acquire ()
289+ self .cancellable = False
290+
291+ def yield_to (self , cancellable , other : Thread ):
292+ assert (not self .ready_func )
293+ self .ready_func = lambda : True
294+ self .task .inst .store .waiting .append (self )
295+ self .switch_to (cancellable , other )
296+
297+ def unblock (self , other : Thread ):
298+ if not other .ready_func :
299+ other .task .inst .store .waiting .append (other )
300+ other .ready_func = lambda : True
258301
259302
260303### Lifting and Lowering Context
@@ -432,22 +475,6 @@ def write(self, vs):
432475 assert (all (v == () for v in vs ))
433476 self .progress += len (vs )
434477
435- #### Context-Local Storage
436-
437- class ContextLocalStorage :
438- LENGTH = 1
439- array : list [int ]
440-
441- def __init__ (self ):
442- self .array = [0 ] * ContextLocalStorage .LENGTH
443-
444- def set (self , i , v ):
445- assert (types_match_values (['i32' ], [v ]))
446- self .array [i ] = v
447-
448- def get (self , i ):
449- return self .array [i ]
450-
451478#### Waitable State
452479
453480class EventCode (IntEnum ):
@@ -458,6 +485,7 @@ class EventCode(IntEnum):
458485 FUTURE_READ = 4
459486 FUTURE_WRITE = 5
460487 TASK_CANCELLED = 6
488+ THREAD_RESUMED = 7
461489
462490EventTuple = tuple [EventCode , int , int ]
463491
@@ -530,11 +558,8 @@ class State(Enum):
530558 ft : FuncType
531559 supertask : Optional [Task ]
532560 on_resolve : OnResolve
533- thread : Optional [Thread ]
534- cancellable : bool
535- waiting_for_callback : bool
561+ threads : list [Thread ]
536562 num_borrows : int
537- context : ContextLocalStorage
538563
539564 def __init__ (self , opts , inst , ft , supertask , on_resolve ):
540565 self .state = Task .State .INITIAL
@@ -543,11 +568,8 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
543568 self .ft = ft
544569 self .supertask = supertask
545570 self .on_resolve = on_resolve
546- self .thread = None
547- self .cancellable = False
548- self .waiting_for_callback = False
571+ self .threads = []
549572 self .num_borrows = 0
550- self .context = ContextLocalStorage ()
551573
552574 def trap_if_on_the_stack (self , inst ):
553575 c = self .supertask
@@ -559,15 +581,13 @@ def needs_exclusive(self):
559581 return self .opts .sync or self .opts .callback
560582
561583 def enter (self , thread ):
562- assert (thread is self .thread and thread .task is self )
584+ assert (thread in self .threads and thread .task is self )
563585 def has_backpressure ():
564586 return self .inst .backpressure or (self .needs_exclusive () and self .inst .exclusive )
565587 if has_backpressure () or self .inst .pending_tasks > 0 :
566588 self .inst .pending_tasks += 1
567- self .cancellable = True
568- thread .suspend_until (lambda : not has_backpressure ())
589+ thread .suspend_until (lambda : not has_backpressure (), cancellable = True )
569590 self .inst .pending_tasks -= 1
570- self .cancellable = False
571591 if self .deliver_cancel ():
572592 self .cancel ()
573593 return False
@@ -586,27 +606,28 @@ def deliver_cancel(self) -> bool:
586606 def request_cancellation (self ):
587607 assert (self .state == Task .State .INITIAL )
588608 self .state = Task .State .PENDING_CANCEL
589- if self .cancellable and not (self .waiting_for_callback and self .inst .exclusive ):
590- self .thread .resume ()
609+ if not DETERMINISTIC_PROFILE :
610+ random .shuffle (self .threads )
611+ for thread in self .threads :
612+ if thread .cancellable and not (thread .waiting_for_callback and self .inst .exclusive ):
613+ thread .resume ()
614+ break
591615
592616 def wait_until (self , ready_func , thread , cancellable , for_callback ):
593- assert (thread is self .thread and thread .task is self )
617+ assert (thread in self .threads and thread .task is self )
594618 if cancellable and self .deliver_cancel ():
595619 return True
596- assert (not self .cancellable )
597- self .cancellable = cancellable
598620 if for_callback :
599621 assert (self .inst .exclusive )
600622 self .inst .exclusive = False
601- self .waiting_for_callback = True
623+ thread .waiting_for_callback = True
602624 def ready_and_allowed ():
603625 return ready_func () and not (for_callback and self .inst .exclusive )
604- thread .suspend_until (ready_and_allowed )
626+ thread .suspend_until (ready_and_allowed , cancellable )
605627 if for_callback :
606628 assert (not self .inst .exclusive )
607629 self .inst .exclusive = True
608- self .waiting_for_callback = False
609- self .cancellable = False
630+ thread .waiting_for_callback = False
610631 if cancellable and self .deliver_cancel ():
611632 return True
612633 return False
@@ -615,13 +636,15 @@ def wait_for_event(self, thread, wset, cancellable, for_callback) -> EventTuple:
615636 wset .num_waiting += 1
616637 cancelled = self .wait_until (wset .has_pending_event , thread , cancellable , for_callback )
617638 wset .num_waiting -= 1
639+ # TODO: somehow get a THREAD_RESUME event...
618640 if cancelled :
619641 return (EventCode .TASK_CANCELLED , 0 , 0 )
620642 else :
621643 return wset .get_pending_event ()
622644
623645 def yield_ (self , thread , cancellable , for_callback ) -> EventTuple :
624646 cancelled = self .wait_until (lambda : True , thread , cancellable , for_callback )
647+ # TODO: somehow get a THREAD_RESUME event...
625648 if cancelled :
626649 return (EventCode .TASK_CANCELLED , 0 , 0 )
627650 else :
@@ -631,6 +654,7 @@ def poll_for_event(self, thread, wset, cancellable, for_callback) -> Optional[Ev
631654 wset .num_waiting += 1
632655 cancelled = self .wait_until (lambda : True , thread , cancellable , for_callback )
633656 wset .num_waiting -= 1
657+ # TODO: somehow get a THREAD_RESUME event...
634658 if cancelled :
635659 return (EventCode .TASK_CANCELLED , 0 , 0 )
636660 elif wset .has_pending_event ():
@@ -652,20 +676,21 @@ def cancel(self):
652676 self .state = Task .State .RESOLVED
653677
654678 def exit (self ):
655- assert (self .thread is not None )
679+ assert (len ( self .threads ) > 0 )
656680 if self .needs_exclusive ():
657681 assert (self .inst .exclusive )
658682 self .inst .exclusive = False
659683
660684 def thread_start (self , thread ):
661- assert (self . thread is None and thread .task is self )
662- self .thread = thread
685+ assert (thread not in self . threads and thread .task is self )
686+ self .threads . append ( thread )
663687
664688 def thread_stop (self , thread ):
665- assert (thread is self .thread and thread .task is self )
666- self .thread = None
667- trap_if (self .state != Task .State .RESOLVED )
668- assert (self .num_borrows == 0 )
689+ assert (thread in self .threads and thread .task is self )
690+ self .threads .remove (thread )
691+ if len (self .threads ) == 0 :
692+ trap_if (self .state != Task .State .RESOLVED )
693+ assert (self .num_borrows == 0 )
669694
670695#### Subtask State
671696
@@ -1901,6 +1926,9 @@ def thread_func(thread):
19011926 if not task .enter (thread ):
19021927 return
19031928
1929+ assert (thread .index is None )
1930+ thread .index = thread .task .inst .table .add (thread )
1931+
19041932 cx = LiftLowerContext (opts , inst , task )
19051933 args = on_start ()
19061934 flat_args = lower_flat_values (cx , MAX_FLAT_PARAMS , args , ft .param_types ())
@@ -2081,25 +2109,82 @@ def canon_resource_rep(rt, thread, i):
20812109 trap_if (h .rt is not rt )
20822110 return [h .rep ]
20832111
2112+ ### 🧵 `canon thread.index`
2113+
2114+ def canon_thread_index (shared , thread ):
2115+ assert (not shared )
2116+ assert (thread .index is not None )
2117+ return [thread .index ]
2118+
2119+ ### 🧵 `canon thread.new`
2120+
2121+ def canon_thread_new (ft , ftbl , thread , i , c ):
2122+ task = thread .task
2123+ trap_if (not task .inst .may_leave )
2124+ f = task .inst .ftbl .get (i )
2125+ trap_if (f .type != ft )
2126+ thread_func = partial (f , c )
2127+ i = task .inst .table .add (Thread (task , thread_func ))
2128+ return [i ]
2129+
2130+ ### 🧵 `canon thread.switch-to`
2131+
2132+ def canon_thread_switch_to (thread , cancellable , i ):
2133+ trap_if (not thread .task .inst .may_leave )
2134+ other = thread .task .inst .table .get (i )
2135+ trap_if (not isinstance (other , Thread ))
2136+ trap_if (not other .cancellable ) # TODO: what about waiting_for_callback and exclusive
2137+ cancelled = thread .switch_to (cancellable , other )
2138+ return [ 1 if cancelled else 0 ]
2139+
2140+ ### 🧵 `canon thread.yield-to`
2141+
2142+ def canon_thread_yield_to (thread , cancellable , i ):
2143+ trap_if (not thread .task .inst .may_leave )
2144+ other = thread .task .inst .table .get (i )
2145+ trap_if (not isinstance (other , Thread ))
2146+ trap_if (not other .cancellable ) # TODO: what about waiting_for_callback and exclusive
2147+ other .yield_to (cancellable , other )
2148+ return []
2149+
2150+ ### 🧵 `canon thread.unblock`
2151+
2152+ def canon_thread_unblock (thread , i ):
2153+ trap_if (not thread .task .inst .may_leave )
2154+ other = thread .task .inst .table .get (i )
2155+ trap_if (not isinstance (other , Thread ))
2156+ trap_if (not other .cancellable ) # TODO: what about waiting_for_callback and exclusive
2157+ thread .unblock ()
2158+ return []
2159+
2160+ ### 🧵 `canon thread.block`
2161+
2162+ def canon_thread_block (thread , cancellable , i ):
2163+ trap_if (not thread .task .inst .may_leave )
2164+ other = thread .task .inst .table .get (i )
2165+ trap_if (not isinstance (other , Thread ))
2166+ cancelled = thread .block (cancellable )
2167+ return [ 1 if cancelled else 0 ]
2168+
20842169### 🔀 `canon context.get`
20852170
20862171def canon_context_get (t , i , thread ):
20872172 assert (t == 'i32' )
2088- assert (i < ContextLocalStorage . LENGTH )
2089- return [thread .task . context . get ( i ) ]
2173+ assert (i < Thread . CONTEXT_LENGTH )
2174+ return [thread .context [ i ] ]
20902175
20912176### 🔀 `canon context.set`
20922177
20932178def canon_context_set (t , i , thread , v ):
20942179 assert (t == 'i32' )
2095- assert (i < ContextLocalStorage . LENGTH )
2096- thread .task . context . set ( i , v )
2180+ assert (i < Thread . CONTEXT_LENGTH )
2181+ thread .context [ i ] = v
20972182 return []
20982183
20992184### 🔀 `canon backpressure.set`
21002185
21012186def canon_backpressure_set (thread , flat_args ):
2102- trap_if (thread .task .opts .sync )
2187+ # TODO: remove trap_if(thread.task.opts.sync)
21032188 assert (len (flat_args ) == 1 )
21042189 thread .task .inst .backpressure = bool (flat_args [0 ])
21052190 return []
0 commit comments