@@ -213,39 +213,57 @@ def tick(self):
213213
214214class Thread :
215215 task : Task
216+ index : Optional [int ]
217+ context : list [int ]
216218 ready_func : Optional [Callable [[], bool ]]
217219 run_lock : threading .Lock
218220 resume_lock : Optional [threading .Lock ]
219221 stack : threading .Thread
220222 cancellable : bool
221223 cancelled : bool
224+ waiting_for_callback : bool
225+
226+ CONTEXT_LENGTH = 2
222227
223228 def __init__ (self , task , thread_func ):
224229 self .task = task
230+ self .index = None
231+ self .context = [0 ] * Thread .CONTEXT_LENGTH
225232 self .ready_func = None
226233 self .run_lock = threading .Lock ()
227234 self .run_lock .acquire ()
228235 self .resume_lock = None
229236 self .cancellable = False
230237 self .cancelled = False
238+ self .waiting_for_callback = False
231239 def thread_stack_base ():
232240 self .run_lock .acquire ()
233241 thread_func (self )
234242 self .task .thread_stop (self )
243+ if self .index is not None :
244+ self .task .inst .table .remove (self .index )
235245 self .resume_lock .release ()
236246 self .stack = threading .Thread (target = thread_stack_base )
237247 self .stack .start ()
238248 self .task .thread_start (self )
249+ assert (self .suspended ())
250+
251+ def suspended (self ):
252+ return self .ready_func is None
253+
254+ def pending (self ):
255+ return self .ready_func is not None
239256
240257 def ready (self ):
258+ assert (self .pending ())
241259 return self .ready_func ()
242260
243261 def resume (self , cancel = False ):
244262 if cancel :
245263 assert (self .cancellable and not self .cancelled )
246264 self .cancelled = True
247- if self .ready_func :
248- assert (cancel or self .ready_func ())
265+ if self .pending () :
266+ assert (cancel or self .ready ())
249267 self .ready_func = None
250268 self .task .inst .store .pending .remove (self )
251269 assert (not self .resume_lock )
@@ -255,22 +273,51 @@ def resume(self, cancel = False):
255273 self .resume_lock .acquire ()
256274 self .resume_lock = None
257275
276+ def suspend (self , cancellable ) -> bool :
277+ assert (not self .cancellable and not self .cancelled )
278+ self .cancellable = cancellable
279+ self .resume_lock .release ()
280+ self .run_lock .acquire ()
281+ assert (self .cancellable or not self .cancelled )
282+ self .cancellable = False
283+ completed = not self .cancelled
284+ self .cancelled = False
285+ return completed
286+
258287 def suspend_until (self , ready_func , cancellable = False ) -> bool :
259- assert (not self .ready_func )
288+ assert (not self .pending () )
260289 if not DETERMINISTIC_PROFILE and ready_func ():
261290 return True
262291 self .ready_func = ready_func
263292 self .task .inst .store .pending .append (self )
264- assert (not self .cancellable and not self .cancelled )
293+ return self .suspend (cancellable )
294+
295+ def switch_to (self , cancellable , other : Thread ) -> bool :
296+ assert (other .suspended ())
297+ assert (not self .cancellable )
265298 self .cancellable = cancellable
266- self .resume_lock .release ()
299+ assert (self .resume_lock and not other .resume_lock )
300+ other .resume_lock = self .resume_lock
301+ self .resume_lock = None
302+ other .run_lock .release ()
267303 self .run_lock .acquire ()
268- assert (self .cancellable or not self .cancelled )
269304 self .cancellable = False
270305 completed = not self .cancelled
271306 self .cancelled = False
272307 return completed
273308
309+ def yield_to (self , cancellable , other : Thread ) -> bool :
310+ assert (other .suspended ())
311+ assert (not self .ready_func )
312+ self .ready_func = lambda : True
313+ self .task .inst .store .pending .append (self )
314+ return self .switch_to (cancellable , other )
315+
316+ def resume_later (self , other : Thread ):
317+ assert (other .suspended ())
318+ other .ready_func = lambda : True
319+ other .task .inst .store .pending .append (other )
320+
274321
275322### Lifting and Lowering Context
276323
@@ -447,22 +494,6 @@ def write(self, vs):
447494 assert (all (v == () for v in vs ))
448495 self .progress += len (vs )
449496
450- #### Context-Local Storage
451-
452- class ContextLocalStorage :
453- LENGTH = 1
454- array : list [int ]
455-
456- def __init__ (self ):
457- self .array = [0 ] * ContextLocalStorage .LENGTH
458-
459- def set (self , i , v ):
460- assert (types_match_values (['i32' ], [v ]))
461- self .array [i ] = v
462-
463- def get (self , i ):
464- return self .array [i ]
465-
466497#### Waitable State
467498
468499class EventCode (IntEnum ):
@@ -545,10 +576,8 @@ class State(Enum):
545576 ft : FuncType
546577 supertask : Optional [Task ]
547578 on_resolve : OnResolve
548- thread : Optional [Thread ]
579+ threads : list [Thread ]
549580 num_borrows : int
550- waiting_for_callback : bool
551- context : ContextLocalStorage
552581
553582 def __init__ (self , opts , inst , ft , supertask , on_resolve ):
554583 self .state = Task .State .INITIAL
@@ -557,10 +586,8 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
557586 self .ft = ft
558587 self .supertask = supertask
559588 self .on_resolve = on_resolve
560- self .thread = None
589+ self .threads = []
561590 self .num_borrows = 0
562- self .waiting_for_callback = False
563- self .context = ContextLocalStorage ()
564591
565592 def trap_if_on_the_stack (self , inst ):
566593 c = self .supertask
@@ -572,7 +599,7 @@ def needs_exclusive(self):
572599 return self .opts .sync or self .opts .callback
573600
574601 def enter (self , thread ):
575- assert (thread is self .thread and thread .task is self )
602+ assert (thread in self .threads and thread .task is self )
576603 def has_backpressure ():
577604 return self .inst .backpressure or (self .needs_exclusive () and self .inst .exclusive )
578605 if has_backpressure () or self .inst .pending_tasks > 0 :
@@ -589,28 +616,31 @@ def has_backpressure():
589616
590617 def request_cancellation (self ):
591618 assert (self .state == Task .State .INITIAL )
592- if self .thread .cancellable and not (self .waiting_for_callback and self .inst .exclusive ):
593- self .state = Task .State .CANCEL_DELIVERED
594- self .thread .resume (cancel = True )
595- else :
596- self .state = Task .State .PENDING_CANCEL
619+ if not DETERMINISTIC_PROFILE :
620+ random .shuffle (self .threads )
621+ for thread in self .threads :
622+ if thread .cancellable and not (thread .waiting_for_callback and self .inst .exclusive ):
623+ self .state = Task .State .CANCEL_DELIVERED
624+ thread .resume (cancel = True )
625+ return
626+ self .state = Task .State .PENDING_CANCEL
597627
598628 def wait_until (self , ready_func , thread , cancellable , for_callback ) -> bool :
599- assert (thread is self .thread and thread .task is self )
629+ assert (thread in self .threads and thread .task is self )
600630 if cancellable and self .state == Task .State .PENDING_CANCEL :
601631 self .state = Task .State .CANCEL_DELIVERED
602632 return False
603633 if for_callback :
604634 assert (self .inst .exclusive )
605635 self .inst .exclusive = False
606- self .waiting_for_callback = True
636+ thread .waiting_for_callback = True
607637 def ready_and_uncontended ():
608638 return ready_func () and not (for_callback and self .inst .exclusive )
609639 completed = thread .suspend_until (ready_and_uncontended , cancellable )
610640 if for_callback :
611641 assert (not self .inst .exclusive )
612642 self .inst .exclusive = True
613- self .waiting_for_callback = False
643+ thread .waiting_for_callback = False
614644 return completed
615645
616646 def yield_ (self , thread , cancellable , for_callback ) -> EventTuple :
@@ -653,20 +683,21 @@ def cancel(self):
653683 self .state = Task .State .RESOLVED
654684
655685 def exit (self ):
656- assert (self .thread is not None )
686+ assert (len ( self .threads ) > 0 )
657687 if self .needs_exclusive ():
658688 assert (self .inst .exclusive )
659689 self .inst .exclusive = False
660690
661691 def thread_start (self , thread ):
662- assert (self . thread is None and thread .task is self )
663- self .thread = thread
692+ assert (thread not in self . threads and thread .task is self )
693+ self .threads . append ( thread )
664694
665695 def thread_stop (self , thread ):
666- assert (thread is self .thread and thread .task is self )
667- self .thread = None
668- trap_if (self .state != Task .State .RESOLVED )
669- assert (self .num_borrows == 0 )
696+ assert (thread in self .threads and thread .task is self )
697+ self .threads .remove (thread )
698+ if len (self .threads ) == 0 :
699+ trap_if (self .state != Task .State .RESOLVED )
700+ assert (self .num_borrows == 0 )
670701
671702#### Subtask State
672703
@@ -1902,6 +1933,9 @@ def thread_func(thread):
19021933 if not task .enter (thread ):
19031934 return
19041935
1936+ assert (thread .index is None )
1937+ thread .index = thread .task .inst .table .add (thread )
1938+
19051939 cx = LiftLowerContext (opts , inst , task )
19061940 args = on_start ()
19071941 flat_args = lower_flat_values (cx , MAX_FLAT_PARAMS , args , ft .param_types ())
@@ -2082,25 +2116,91 @@ def canon_resource_rep(rt, thread, i):
20822116 trap_if (h .rt is not rt )
20832117 return [h .rep ]
20842118
2119+ ### 🧵 `canon thread.index`
2120+
2121+ def canon_thread_index (shared , thread ):
2122+ assert (not shared )
2123+ assert (thread .index is not None )
2124+ return [thread .index ]
2125+
2126+ ### 🧵 `canon thread.new`
2127+
2128+ def canon_thread_new (ft , ftbl , thread , i , c ):
2129+ task = thread .task
2130+ trap_if (not task .inst .may_leave )
2131+ f = task .inst .ftbl .get (i )
2132+ trap_if (f .type != ft )
2133+ thread_func = partial (f , c )
2134+ new_thread = Thread (task , thread_func )
2135+ assert (new_thread .suspended ())
2136+ new_thread .index = task .inst .table .add (thread )
2137+ return [new_thread .index ]
2138+
2139+ ### 🧵 `canon thread.resume-later`
2140+
2141+ def canon_thread_resume_later (thread , i ):
2142+ trap_if (not thread .task .inst .may_leave )
2143+ other_thread = thread .task .inst .table .get (i )
2144+ trap_if (not isinstance (other_thread , Thread ))
2145+ trap_if (not other_thread .suspended ())
2146+ thread .resume_later (other_thread )
2147+ return []
2148+
2149+ ### 🧵 `canon thread.switch-to`
2150+
2151+ def canon_thread_switch_to (thread , cancellable , i ):
2152+ trap_if (not thread .task .inst .may_leave )
2153+ other_thread = thread .task .inst .table .get (i )
2154+ trap_if (not isinstance (other_thread , Thread ))
2155+ trap_if (not other_thread .suspended ())
2156+ if not thread .switch_to (cancellable , other_thread ):
2157+ assert (cancellable )
2158+ return [0 ]
2159+ else :
2160+ return [1 ]
2161+
2162+ ### 🧵 `canon thread.yield-to`
2163+
2164+ def canon_thread_yield_to (thread , cancellable , i ):
2165+ trap_if (not thread .task .inst .may_leave )
2166+ other_thread = thread .task .inst .table .get (i )
2167+ trap_if (not isinstance (other_thread , Thread ))
2168+ trap_if (not other_thread .suspended ())
2169+ if not other_thread .yield_to (cancellable , other_thread ):
2170+ assert (cancellable )
2171+ return [0 ]
2172+ else :
2173+ return [1 ]
2174+
2175+ ### 🧵 `canon thread.suspend`
2176+
2177+ def canon_thread_suspend (thread , cancellable ):
2178+ trap_if (not thread .task .inst .may_leave )
2179+ if not thread .suspend (cancellable ):
2180+ assert (cancellable )
2181+ return [0 ]
2182+ else :
2183+ return [1 ]
2184+
20852185### 🔀 `canon context.get`
20862186
20872187def canon_context_get (t , i , thread ):
20882188 assert (t == 'i32' )
2089- assert (i < ContextLocalStorage . LENGTH )
2090- return [thread .task . context . get ( i ) ]
2189+ assert (i < Thread . CONTEXT_LENGTH )
2190+ return [thread .context [ i ] ]
20912191
20922192### 🔀 `canon context.set`
20932193
20942194def canon_context_set (t , i , thread , v ):
20952195 assert (t == 'i32' )
2096- assert (i < ContextLocalStorage . LENGTH )
2097- thread .task . context . set ( i , v )
2196+ assert (i < Thread . CONTEXT_LENGTH )
2197+ thread .context [ i ] = v
20982198 return []
20992199
21002200### 🔀 `canon backpressure.set`
21012201
21022202def canon_backpressure_set (thread , flat_args ):
2103- trap_if (thread .task .opts .sync )
2203+ # TODO: remove trap_if(thread.task.opts.sync)
21042204 assert (len (flat_args ) == 1 )
21052205 thread .task .inst .backpressure = bool (flat_args [0 ])
21062206 return []
0 commit comments