@@ -213,39 +213,58 @@ def tick(self):
213213
214214class Thread :
215215 task : Task
216+ index : Optional [int ]
217+ context : list [int ]
216218 ready_func : Optional [Callable [[], bool ]]
217219 run_lock : threading .Lock
218220 resume_lock : Optional [threading .Lock ]
219221 stack : threading .Thread
220222 cancellable : bool
221223 cancelled : bool
224+ waiting_for_callback : bool
225+
226+ CONTEXT_LENGTH = 2
222227
223228 def __init__ (self , task , thread_func ):
224229 self .task = task
230+ self .index = None
231+ self .context = [0 ] * Thread .CONTEXT_LENGTH
225232 self .ready_func = None
226233 self .run_lock = threading .Lock ()
227234 self .run_lock .acquire ()
228235 self .resume_lock = None
229236 self .cancellable = False
230237 self .cancelled = False
238+ self .waiting_for_callback = False
231239 def thread_stack_base ():
232240 self .run_lock .acquire ()
233241 thread_func (self )
234242 self .task .thread_stop (self )
243+ if self .index is not None :
244+ self .task .inst .table .remove (self .index )
235245 self .resume_lock .release ()
236246 self .stack = threading .Thread (target = thread_stack_base )
237247 self .stack .start ()
238248 self .task .thread_start (self )
249+ assert (self .suspended ())
250+
251+ def suspended (self ):
252+ return self .ready_func is None and self .resume_lock is None
253+
254+ def pending (self ):
255+ return self .ready_func is not None and self .resume_lock is None
239256
240257 def ready (self ):
258+ assert (self .pending ())
241259 return self .ready_func ()
242260
243261 def resume (self , cancel = False ):
262+ assert (self .suspended () or self .pending ())
244263 if cancel :
245264 assert (self .cancellable and not self .cancelled )
246265 self .cancelled = True
247- if self .ready_func :
248- assert (cancel or self .ready_func ())
266+ if self .pending () :
267+ assert (cancel or self .ready ())
249268 self .ready_func = None
250269 self .task .inst .store .pending .remove (self )
251270 assert (not self .resume_lock )
@@ -255,22 +274,52 @@ def resume(self, cancel = False):
255274 self .resume_lock .acquire ()
256275 self .resume_lock = None
257276
277+ def suspend (self , cancellable ) -> bool :
278+ assert (not self .cancellable and not self .cancelled )
279+ self .cancellable = cancellable
280+ self .resume_lock .release ()
281+ self .run_lock .acquire ()
282+ assert (self .cancellable or not self .cancelled )
283+ self .cancellable = False
284+ completed = not self .cancelled
285+ self .cancelled = False
286+ return completed
287+
258288 def suspend_until (self , ready_func , cancellable = False ) -> bool :
259- assert (not self .ready_func )
289+ assert (not self .pending () )
260290 if not DETERMINISTIC_PROFILE and ready_func ():
261291 return True
262292 self .ready_func = ready_func
263293 self .task .inst .store .pending .append (self )
264- assert (not self .cancellable and not self .cancelled )
294+ return self .suspend (cancellable )
295+
296+ def switch_to (self , cancellable , other : Thread ) -> bool :
297+ assert (other .suspended ())
298+ assert (not self .cancellable )
265299 self .cancellable = cancellable
266- self .resume_lock .release ()
300+ assert (self .resume_lock and not other .resume_lock )
301+ other .resume_lock = self .resume_lock
302+ self .resume_lock = None
303+ assert (self .suspended ())
304+ other .run_lock .release ()
267305 self .run_lock .acquire ()
268- assert (self .cancellable or not self .cancelled )
269306 self .cancellable = False
270307 completed = not self .cancelled
271308 self .cancelled = False
272309 return completed
273310
311+ def yield_to (self , cancellable , other : Thread ) -> bool :
312+ assert (other .suspended ())
313+ assert (not self .ready_func )
314+ self .ready_func = lambda : True
315+ self .task .inst .store .pending .append (self )
316+ return self .switch_to (cancellable , other )
317+
318+ def resume_later (self , other : Thread ):
319+ assert (other .suspended ())
320+ other .ready_func = lambda : True
321+ other .task .inst .store .pending .append (other )
322+
274323
275324### Lifting and Lowering Context
276325
@@ -447,22 +496,6 @@ def write(self, vs):
447496 assert (all (v == () for v in vs ))
448497 self .progress += len (vs )
449498
450- #### Context-Local Storage
451-
452- class ContextLocalStorage :
453- LENGTH = 1
454- array : list [int ]
455-
456- def __init__ (self ):
457- self .array = [0 ] * ContextLocalStorage .LENGTH
458-
459- def set (self , i , v ):
460- assert (types_match_values (['i32' ], [v ]))
461- self .array [i ] = v
462-
463- def get (self , i ):
464- return self .array [i ]
465-
466499#### Waitable State
467500
468501class EventCode (IntEnum ):
@@ -545,10 +578,8 @@ class State(Enum):
545578 ft : FuncType
546579 supertask : Optional [Task ]
547580 on_resolve : OnResolve
548- thread : Optional [Thread ]
581+ threads : list [Thread ]
549582 num_borrows : int
550- waiting_for_callback : bool
551- context : ContextLocalStorage
552583
553584 def __init__ (self , opts , inst , ft , supertask , on_resolve ):
554585 self .state = Task .State .INITIAL
@@ -557,10 +588,8 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
557588 self .ft = ft
558589 self .supertask = supertask
559590 self .on_resolve = on_resolve
560- self .thread = None
591+ self .threads = []
561592 self .num_borrows = 0
562- self .waiting_for_callback = False
563- self .context = ContextLocalStorage ()
564593
565594 def trap_if_on_the_stack (self , inst ):
566595 c = self .supertask
@@ -572,7 +601,7 @@ def needs_exclusive(self):
572601 return self .opts .sync or self .opts .callback
573602
574603 def enter (self , thread ):
575- assert (thread is self .thread and thread .task is self )
604+ assert (thread in self .threads and thread .task is self )
576605 def has_backpressure ():
577606 return self .inst .backpressure or (self .needs_exclusive () and self .inst .exclusive )
578607 if has_backpressure () or self .inst .pending_tasks > 0 :
@@ -589,28 +618,31 @@ def has_backpressure():
589618
590619 def request_cancellation (self ):
591620 assert (self .state == Task .State .INITIAL )
592- if self .thread .cancellable and not (self .waiting_for_callback and self .inst .exclusive ):
593- self .state = Task .State .CANCEL_DELIVERED
594- self .thread .resume (cancel = True )
595- else :
596- self .state = Task .State .PENDING_CANCEL
621+ if not DETERMINISTIC_PROFILE :
622+ random .shuffle (self .threads )
623+ for thread in self .threads :
624+ if thread .cancellable and not (thread .waiting_for_callback and self .inst .exclusive ):
625+ self .state = Task .State .CANCEL_DELIVERED
626+ thread .resume (cancel = True )
627+ return
628+ self .state = Task .State .PENDING_CANCEL
597629
598630 def wait_until (self , ready_func , thread , cancellable , for_callback ) -> bool :
599- assert (thread is self .thread and thread .task is self )
631+ assert (thread in self .threads and thread .task is self )
600632 if cancellable and self .state == Task .State .PENDING_CANCEL :
601633 self .state = Task .State .CANCEL_DELIVERED
602634 return False
603635 if for_callback :
604636 assert (self .inst .exclusive )
605637 self .inst .exclusive = False
606- self .waiting_for_callback = True
638+ thread .waiting_for_callback = True
607639 def ready_and_uncontended ():
608640 return ready_func () and not (for_callback and self .inst .exclusive )
609641 completed = thread .suspend_until (ready_and_uncontended , cancellable )
610642 if for_callback :
611643 assert (not self .inst .exclusive )
612644 self .inst .exclusive = True
613- self .waiting_for_callback = False
645+ thread .waiting_for_callback = False
614646 return completed
615647
616648 def yield_ (self , thread , cancellable , for_callback ) -> EventTuple :
@@ -653,20 +685,21 @@ def cancel(self):
653685 self .state = Task .State .RESOLVED
654686
655687 def exit (self ):
656- assert (self .thread is not None )
688+ assert (len ( self .threads ) > 0 )
657689 if self .needs_exclusive ():
658690 assert (self .inst .exclusive )
659691 self .inst .exclusive = False
660692
661693 def thread_start (self , thread ):
662- assert (self . thread is None and thread .task is self )
663- self .thread = thread
694+ assert (thread not in self . threads and thread .task is self )
695+ self .threads . append ( thread )
664696
665697 def thread_stop (self , thread ):
666- assert (thread is self .thread and thread .task is self )
667- self .thread = None
668- trap_if (self .state != Task .State .RESOLVED )
669- assert (self .num_borrows == 0 )
698+ assert (thread in self .threads and thread .task is self )
699+ self .threads .remove (thread )
700+ if len (self .threads ) == 0 :
701+ trap_if (self .state != Task .State .RESOLVED )
702+ assert (self .num_borrows == 0 )
670703
671704#### Subtask State
672705
@@ -1902,6 +1935,9 @@ def thread_func(thread):
19021935 if not task .enter (thread ):
19031936 return
19041937
1938+ assert (thread .index is None )
1939+ thread .index = thread .task .inst .table .add (thread )
1940+
19051941 cx = LiftLowerContext (opts , inst , task )
19061942 args = on_start ()
19071943 flat_args = lower_flat_values (cx , MAX_FLAT_PARAMS , args , ft .param_types ())
@@ -2082,25 +2118,91 @@ def canon_resource_rep(rt, thread, i):
20822118 trap_if (h .rt is not rt )
20832119 return [h .rep ]
20842120
2121+ ### 🧵 `canon thread.index`
2122+
2123+ def canon_thread_index (shared , thread ):
2124+ assert (not shared )
2125+ assert (thread .index is not None )
2126+ return [thread .index ]
2127+
2128+ ### 🧵 `canon thread.new`
2129+
2130+ def canon_thread_new (ft , ftbl , thread , i , c ):
2131+ task = thread .task
2132+ trap_if (not task .inst .may_leave )
2133+ f = task .inst .ftbl .get (i )
2134+ trap_if (f .type != ft )
2135+ thread_func = partial (f , c )
2136+ new_thread = Thread (task , thread_func )
2137+ assert (new_thread .suspended ())
2138+ new_thread .index = task .inst .table .add (thread )
2139+ return [new_thread .index ]
2140+
2141+ ### 🧵 `canon thread.resume-later`
2142+
2143+ def canon_thread_resume_later (thread , i ):
2144+ trap_if (not thread .task .inst .may_leave )
2145+ other_thread = thread .task .inst .table .get (i )
2146+ trap_if (not isinstance (other_thread , Thread ))
2147+ trap_if (not other_thread .suspended ())
2148+ thread .resume_later (other_thread )
2149+ return []
2150+
2151+ ### 🧵 `canon thread.switch-to`
2152+
2153+ def canon_thread_switch_to (thread , cancellable , i ):
2154+ trap_if (not thread .task .inst .may_leave )
2155+ other_thread = thread .task .inst .table .get (i )
2156+ trap_if (not isinstance (other_thread , Thread ))
2157+ trap_if (not other_thread .suspended ())
2158+ if not thread .switch_to (cancellable , other_thread ):
2159+ assert (cancellable )
2160+ return [0 ]
2161+ else :
2162+ return [1 ]
2163+
2164+ ### 🧵 `canon thread.yield-to`
2165+
2166+ def canon_thread_yield_to (thread , cancellable , i ):
2167+ trap_if (not thread .task .inst .may_leave )
2168+ other_thread = thread .task .inst .table .get (i )
2169+ trap_if (not isinstance (other_thread , Thread ))
2170+ trap_if (not other_thread .suspended ())
2171+ if not other_thread .yield_to (cancellable , other_thread ):
2172+ assert (cancellable )
2173+ return [0 ]
2174+ else :
2175+ return [1 ]
2176+
2177+ ### 🧵 `canon thread.suspend`
2178+
2179+ def canon_thread_suspend (thread , cancellable ):
2180+ trap_if (not thread .task .inst .may_leave )
2181+ if not thread .suspend (cancellable ):
2182+ assert (cancellable )
2183+ return [0 ]
2184+ else :
2185+ return [1 ]
2186+
20852187### 🔀 `canon context.get`
20862188
20872189def canon_context_get (t , i , thread ):
20882190 assert (t == 'i32' )
2089- assert (i < ContextLocalStorage . LENGTH )
2090- return [thread .task . context . get ( i ) ]
2191+ assert (i < Thread . CONTEXT_LENGTH )
2192+ return [thread .context [ i ] ]
20912193
20922194### 🔀 `canon context.set`
20932195
20942196def canon_context_set (t , i , thread , v ):
20952197 assert (t == 'i32' )
2096- assert (i < ContextLocalStorage . LENGTH )
2097- thread .task . context . set ( i , v )
2198+ assert (i < Thread . CONTEXT_LENGTH )
2199+ thread .context [ i ] = v
20982200 return []
20992201
21002202### 🔀 `canon backpressure.set`
21012203
21022204def canon_backpressure_set (thread , flat_args ):
2103- trap_if (thread .task .opts .sync )
2205+ # TODO: remove trap_if(thread.task.opts.sync)
21042206 assert (len (flat_args ) == 1 )
21052207 thread .task .inst .backpressure = bool (flat_args [0 ])
21062208 return []
0 commit comments