@@ -384,22 +384,6 @@ def write(self, vs):
384384 assert (all (v == () for v in vs ))
385385 self .progress += len (vs )
386386
387- #### Context-Local Storage
388-
389- class ContextLocalStorage :
390- LENGTH = 1
391- array : list [int ]
392-
393- def __init__ (self ):
394- self .array = [0 ] * ContextLocalStorage .LENGTH
395-
396- def set (self , i , v ):
397- assert (types_match_values (['i32' ], [v ]))
398- self .array [i ] = v
399-
400- def get (self , i ):
401- return self .array [i ]
402-
403387#### Thread State
404388
405389class Thread :
@@ -411,6 +395,10 @@ class Thread:
411395 cancellable : bool
412396 cancelled : bool
413397 waiting_for_callback : bool
398+ index : Optional [int ]
399+ context : list [int ]
400+
401+ CONTEXT_LENGTH = 2
414402
415403 def running (self ):
416404 return self .parent_lock is not None
@@ -434,12 +422,17 @@ def __init__(self, task, thread_func):
434422 self .cancellable = False
435423 self .cancelled = False
436424 self .waiting_for_callback = False
425+ self .index = None
426+ self .context = [0 ] * Thread .CONTEXT_LENGTH
427+
437428 def fiber_func ():
438429 self .fiber_lock .acquire ()
439430 assert (self .running ())
440431 thread_func (self )
441432 assert (self .running ())
442433 self .task .thread_stop (self )
434+ if self .index is not None :
435+ self .task .inst .table .remove (self .index )
443436 self .parent_lock .release ()
444437 self .fiber = threading .Thread (target = fiber_func )
445438 self .fiber .start ()
@@ -482,6 +475,33 @@ def suspend_until(self, ready_func, cancellable = False) -> bool:
482475 self .task .inst .store .pending .append (self )
483476 return self .suspend (cancellable )
484477
478+ def switch_to (self , cancellable , other : Thread ) -> bool :
479+ assert (self .running () and other .suspended ())
480+ assert (not self .cancellable )
481+ self .cancellable = cancellable
482+ assert (self .parent_lock and not other .parent_lock )
483+ other .parent_lock = self .parent_lock
484+ self .parent_lock = None
485+ assert (self .suspended () and other .running ())
486+ other .fiber_lock .release ()
487+ self .fiber_lock .acquire ()
488+ assert (self .running ())
489+ self .cancellable = False
490+ completed = not self .cancelled
491+ self .cancelled = False
492+ return completed
493+
494+ def yield_to (self , cancellable , other : Thread ) -> bool :
495+ assert (not self .ready_func )
496+ self .ready_func = lambda : True
497+ self .task .inst .store .pending .append (self )
498+ return self .switch_to (cancellable , other )
499+
500+ def resume_later (self , other : Thread ):
501+ assert (self .running () and other .suspended ())
502+ other .ready_func = lambda : True
503+ other .task .inst .store .pending .append (other )
504+
485505#### Waitable State
486506
487507class EventCode (IntEnum ):
@@ -564,8 +584,7 @@ class State(Enum):
564584 supertask : Optional [Task ]
565585 on_resolve : OnResolve
566586 num_borrows : int
567- thread : Optional [Thread ]
568- context : ContextLocalStorage
587+ threads : list [Thread ]
569588
570589 def __init__ (self , opts , inst , ft , supertask , on_resolve ):
571590 self .state = Task .State .INITIAL
@@ -575,8 +594,7 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
575594 self .supertask = supertask
576595 self .on_resolve = on_resolve
577596 self .num_borrows = 0
578- self .thread = None
579- self .context = ContextLocalStorage ()
597+ self .threads = []
580598
581599 def trap_if_on_the_stack (self , inst ):
582600 c = self .supertask
@@ -588,7 +606,7 @@ def needs_exclusive(self):
588606 return self .opts .sync or self .opts .callback
589607
590608 def enter (self , thread ):
591- assert (thread is self .thread and thread .task is self )
609+ assert (thread in self .threads and thread .task is self )
592610 def has_backpressure ():
593611 return self .inst .backpressure or (self .needs_exclusive () and self .inst .exclusive )
594612 if has_backpressure () or self .inst .pending_tasks > 0 :
@@ -605,28 +623,30 @@ def has_backpressure():
605623
606624 def request_cancellation (self ):
607625 assert (self .state == Task .State .INITIAL )
608- if self .thread .cancellable and not (self .thread .waiting_for_callback and self .inst .exclusive ):
609- self .state = Task .State .CANCEL_DELIVERED
610- self .thread .resume (cancel = True )
611- else :
612- self .state = Task .State .PENDING_CANCEL
626+ random .shuffle (self .threads )
627+ for thread in self .threads :
628+ if thread .cancellable and not (thread .waiting_for_callback and self .inst .exclusive ):
629+ self .state = Task .State .CANCEL_DELIVERED
630+ thread .resume (cancel = True )
631+ return
632+ self .state = Task .State .PENDING_CANCEL
613633
614634 def wait_until (self , ready_func , thread , cancellable , for_callback ) -> bool :
615- assert (thread is self .thread and thread .task is self )
635+ assert (thread in self .threads and thread .task is self )
616636 if cancellable and self .state == Task .State .PENDING_CANCEL :
617637 self .state = Task .State .CANCEL_DELIVERED
618638 return False
619639 if for_callback :
620640 assert (self .inst .exclusive )
621641 self .inst .exclusive = False
622- self . thread .waiting_for_callback = True
642+ thread .waiting_for_callback = True
623643 def ready_and_uncontended ():
624644 return ready_func () and not (for_callback and self .inst .exclusive )
625645 completed = thread .suspend_until (ready_and_uncontended , cancellable )
626646 if for_callback :
627647 assert (not self .inst .exclusive )
628648 self .inst .exclusive = True
629- self . thread .waiting_for_callback = False
649+ thread .waiting_for_callback = False
630650 return completed
631651
632652 def yield_ (self , thread , cancellable , for_callback ) -> EventTuple :
@@ -669,20 +689,21 @@ def cancel(self):
669689 self .state = Task .State .RESOLVED
670690
671691 def exit (self ):
672- assert (self .thread is not None )
692+ assert (len ( self .threads ) > 0 )
673693 if self .needs_exclusive ():
674694 assert (self .inst .exclusive )
675695 self .inst .exclusive = False
676696
677697 def thread_start (self , thread ):
678- assert (self . thread is None and thread .task is self )
679- self .thread = thread
698+ assert (thread not in self . threads and thread .task is self )
699+ self .threads . append ( thread )
680700
681701 def thread_stop (self , thread ):
682- assert (thread is self .thread and thread .task is self )
683- self .thread = None
684- trap_if (self .state != Task .State .RESOLVED )
685- assert (self .num_borrows == 0 )
702+ assert (thread in self .threads and thread .task is self )
703+ self .threads .remove (thread )
704+ if len (self .threads ) == 0 :
705+ trap_if (self .state != Task .State .RESOLVED )
706+ assert (self .num_borrows == 0 )
686707
687708#### Subtask State
688709
@@ -1918,6 +1939,9 @@ def thread_func(thread):
19181939 if not task .enter (thread ):
19191940 return
19201941
1942+ assert (thread .index is None )
1943+ thread .index = thread .task .inst .table .add (thread )
1944+
19211945 cx = LiftLowerContext (opts , inst , task )
19221946 args = on_start ()
19231947 flat_args = lower_flat_values (cx , MAX_FLAT_PARAMS , args , ft .param_types ())
@@ -2098,25 +2122,91 @@ def canon_resource_rep(rt, thread, i):
20982122 trap_if (h .rt is not rt )
20992123 return [h .rep ]
21002124
2125+ ### 🧵 `canon thread.index`
2126+
2127+ def canon_thread_index (shared , thread ):
2128+ assert (not shared )
2129+ assert (thread .index is not None )
2130+ return [thread .index ]
2131+
2132+ ### 🧵 `canon thread.new`
2133+
2134+ def canon_thread_new (ft , ftbl , thread , i , c ):
2135+ task = thread .task
2136+ trap_if (not task .inst .may_leave )
2137+ f = task .inst .ftbl .get (i )
2138+ trap_if (f .type != ft )
2139+ thread_func = partial (f , c )
2140+ new_thread = Thread (task , thread_func )
2141+ assert (new_thread .suspended ())
2142+ new_thread .index = task .inst .table .add (thread )
2143+ return [new_thread .index ]
2144+
2145+ ### 🧵 `canon thread.resume-later`
2146+
2147+ def canon_thread_resume_later (thread , i ):
2148+ trap_if (not thread .task .inst .may_leave )
2149+ other_thread = thread .task .inst .table .get (i )
2150+ trap_if (not isinstance (other_thread , Thread ))
2151+ trap_if (not other_thread .suspended ())
2152+ thread .resume_later (other_thread )
2153+ return []
2154+
2155+ ### 🧵 `canon thread.switch-to`
2156+
2157+ def canon_thread_switch_to (thread , cancellable , i ):
2158+ trap_if (not thread .task .inst .may_leave )
2159+ other_thread = thread .task .inst .table .get (i )
2160+ trap_if (not isinstance (other_thread , Thread ))
2161+ trap_if (not other_thread .suspended ())
2162+ if not thread .switch_to (cancellable , other_thread ):
2163+ assert (cancellable )
2164+ return [0 ]
2165+ else :
2166+ return [1 ]
2167+
2168+ ### 🧵 `canon thread.yield-to`
2169+
2170+ def canon_thread_yield_to (thread , cancellable , i ):
2171+ trap_if (not thread .task .inst .may_leave )
2172+ other_thread = thread .task .inst .table .get (i )
2173+ trap_if (not isinstance (other_thread , Thread ))
2174+ trap_if (not other_thread .suspended ())
2175+ if not other_thread .yield_to (cancellable , other_thread ):
2176+ assert (cancellable )
2177+ return [0 ]
2178+ else :
2179+ return [1 ]
2180+
2181+ ### 🧵 `canon thread.suspend`
2182+
2183+ def canon_thread_suspend (thread , cancellable ):
2184+ trap_if (not thread .task .inst .may_leave )
2185+ if not thread .suspend (cancellable ):
2186+ assert (cancellable )
2187+ return [0 ]
2188+ else :
2189+ return [1 ]
2190+
21012191### 🔀 `canon context.get`
21022192
21032193def canon_context_get (t , i , thread ):
21042194 assert (t == 'i32' )
2105- assert (i < ContextLocalStorage . LENGTH )
2106- return [thread .task . context . get ( i ) ]
2195+ assert (i < Thread . CONTEXT_LENGTH )
2196+ return [thread .context [ i ] ]
21072197
21082198### 🔀 `canon context.set`
21092199
21102200def canon_context_set (t , i , thread , v ):
21112201 assert (t == 'i32' )
2112- assert (i < ContextLocalStorage . LENGTH )
2113- thread .task . context . set ( i , v )
2202+ assert (i < Thread . CONTEXT_LENGTH )
2203+ thread .context [ i ] = v
21142204 return []
21152205
21162206### 🔀 `canon backpressure.set`
21172207
21182208def canon_backpressure_set (thread , flat_args ):
2119- trap_if (thread .task .opts .sync )
2209+ # TODO: remove trap_if(thread.task.opts.sync)
21202210 assert (len (flat_args ) == 1 )
21212211 thread .task .inst .backpressure = bool (flat_args [0 ])
21222212 return []
0 commit comments