@@ -384,22 +384,6 @@ def write(self, vs):
384384 assert (all (v == () for v in vs ))
385385 self .progress += len (vs )
386386
387- #### Context-Local Storage
388-
389- class ContextLocalStorage :
390- LENGTH = 1
391- array : list [int ]
392-
393- def __init__ (self ):
394- self .array = [0 ] * ContextLocalStorage .LENGTH
395-
396- def set (self , i , v ):
397- assert (types_match_values (['i32' ], [v ]))
398- self .array [i ] = v
399-
400- def get (self , i ):
401- return self .array [i ]
402-
403387#### Waitable State
404388
405389class EventCode (IntEnum ):
@@ -477,6 +461,10 @@ class Thread:
477461 cancellable : bool
478462 cancelled : bool
479463 waiting_for_callback : bool
464+ index : Optional [int ]
465+ context : list [int ]
466+
467+ CONTEXT_LENGTH = 2
480468
481469 def running (self ):
482470 return self .parent_lock is not None
@@ -500,12 +488,17 @@ def __init__(self, task, thread_func):
500488 self .cancellable = False
501489 self .cancelled = False
502490 self .waiting_for_callback = False
491+ self .index = None
492+ self .context = [0 ] * Thread .CONTEXT_LENGTH
493+
503494 def fiber_func ():
504495 self .fiber_lock .acquire ()
505496 assert (self .running ())
506497 thread_func (self )
507498 assert (self .running ())
508499 self .task .thread_stop (self )
500+ if self .index is not None :
501+ self .task .inst .table .remove (self .index )
509502 self .parent_lock .release ()
510503 self .fiber = threading .Thread (target = fiber_func )
511504 self .fiber .start ()
@@ -548,6 +541,34 @@ def suspend_until(self, ready_func, cancellable = False) -> bool:
548541 self .task .inst .store .pending .append (self )
549542 return self .suspend (cancellable )
550543
544+ def switch_to (self , cancellable , other : Thread ) -> bool :
545+ assert (self .running () and other .suspended ())
546+ assert (not self .cancellable )
547+ self .cancellable = cancellable
548+ assert (self .parent_lock and not other .parent_lock )
549+ other .parent_lock = self .parent_lock
550+ self .parent_lock = None
551+ assert (self .suspended () and other .running ())
552+ other .fiber_lock .release ()
553+ self .fiber_lock .acquire ()
554+ assert (self .running ())
555+ self .cancellable = False
556+ completed = not self .cancelled
557+ self .cancelled = False
558+ return completed
559+
560+ def yield_to (self , cancellable , other : Thread ) -> bool :
561+ assert (not self .ready_func )
562+ self .ready_func = lambda : True
563+ self .task .inst .store .pending .append (self )
564+ return self .switch_to (cancellable , other )
565+
566+ def resume_later (self , other : Thread ):
567+ assert (self .running () and other .suspended ())
568+ other .ready_func = lambda : True
569+ other .task .inst .store .pending .append (other )
570+
571+
551572#### Task State
552573
553574class Task (Call , Supertask ):
@@ -564,8 +585,7 @@ class State(Enum):
564585 supertask : Optional [Task ]
565586 on_resolve : OnResolve
566587 num_borrows : int
567- thread : Optional [Thread ]
568- context : ContextLocalStorage
588+ threads : list [Thread ]
569589
570590 def __init__ (self , opts , inst , ft , supertask , on_resolve ):
571591 self .state = Task .State .INITIAL
@@ -575,8 +595,7 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
575595 self .supertask = supertask
576596 self .on_resolve = on_resolve
577597 self .num_borrows = 0
578- self .thread = None
579- self .context = ContextLocalStorage ()
598+ self .threads = []
580599
581600 def trap_if_on_the_stack (self , inst ):
582601 c = self .supertask
@@ -588,7 +607,7 @@ def needs_exclusive(self):
588607 return self .opts .sync or self .opts .callback
589608
590609 def enter (self , thread ):
591- assert (thread is self .thread and thread .task is self )
610+ assert (thread in self .threads and thread .task is self )
592611 def has_backpressure ():
593612 return self .inst .backpressure or (self .needs_exclusive () and self .inst .exclusive )
594613 if has_backpressure () or self .inst .pending_tasks > 0 :
@@ -605,28 +624,30 @@ def has_backpressure():
605624
606625 def request_cancellation (self ):
607626 assert (self .state == Task .State .INITIAL )
608- if self .thread .cancellable and not (self .thread .waiting_for_callback and self .inst .exclusive ):
609- self .state = Task .State .CANCEL_DELIVERED
610- self .thread .resume (cancel = True )
611- else :
612- self .state = Task .State .PENDING_CANCEL
627+ random .shuffle (self .threads )
628+ for thread in self .threads :
629+ if thread .cancellable and not (thread .waiting_for_callback and self .inst .exclusive ):
630+ self .state = Task .State .CANCEL_DELIVERED
631+ thread .resume (cancel = True )
632+ return
633+ self .state = Task .State .PENDING_CANCEL
613634
614635 def wait_until (self , ready_func , thread , cancellable , for_callback ) -> bool :
615- assert (thread is self .thread and thread .task is self )
636+ assert (thread in self .threads and thread .task is self )
616637 if cancellable and self .state == Task .State .PENDING_CANCEL :
617638 self .state = Task .State .CANCEL_DELIVERED
618639 return False
619640 if for_callback :
620641 assert (self .inst .exclusive )
621642 self .inst .exclusive = False
622- self . thread .waiting_for_callback = True
643+ thread .waiting_for_callback = True
623644 def ready_and_uncontended ():
624645 return ready_func () and not (for_callback and self .inst .exclusive )
625646 completed = thread .suspend_until (ready_and_uncontended , cancellable )
626647 if for_callback :
627648 assert (not self .inst .exclusive )
628649 self .inst .exclusive = True
629- self . thread .waiting_for_callback = False
650+ thread .waiting_for_callback = False
630651 return completed
631652
632653 def yield_ (self , thread , cancellable , for_callback ) -> EventTuple :
@@ -669,20 +690,21 @@ def cancel(self):
669690 self .state = Task .State .RESOLVED
670691
671692 def exit (self ):
672- assert (self .thread is not None )
693+ assert (len ( self .threads ) > 0 )
673694 if self .needs_exclusive ():
674695 assert (self .inst .exclusive )
675696 self .inst .exclusive = False
676697
677698 def thread_start (self , thread ):
678- assert (self . thread is None and thread .task is self )
679- self .thread = thread
699+ assert (thread not in self . threads and thread .task is self )
700+ self .threads . append ( thread )
680701
681702 def thread_stop (self , thread ):
682- assert (thread is self .thread and thread .task is self )
683- self .thread = None
684- trap_if (self .state != Task .State .RESOLVED )
685- assert (self .num_borrows == 0 )
703+ assert (thread in self .threads and thread .task is self )
704+ self .threads .remove (thread )
705+ if len (self .threads ) == 0 :
706+ trap_if (self .state != Task .State .RESOLVED )
707+ assert (self .num_borrows == 0 )
686708
687709#### Subtask State
688710
@@ -1918,6 +1940,9 @@ def thread_func(thread):
19181940 if not task .enter (thread ):
19191941 return
19201942
1943+ assert (thread .index is None )
1944+ thread .index = thread .task .inst .table .add (thread )
1945+
19211946 cx = LiftLowerContext (opts , inst , task )
19221947 args = on_start ()
19231948 flat_args = lower_flat_values (cx , MAX_FLAT_PARAMS , args , ft .param_types ())
@@ -2098,25 +2123,91 @@ def canon_resource_rep(rt, thread, i):
20982123 trap_if (h .rt is not rt )
20992124 return [h .rep ]
21002125
2126+ ### 🧵 `canon thread.index`
2127+
2128+ def canon_thread_index (shared , thread ):
2129+ assert (not shared )
2130+ assert (thread .index is not None )
2131+ return [thread .index ]
2132+
2133+ ### 🧵 `canon thread.new`
2134+
2135+ def canon_thread_new (ft , ftbl , thread , i , c ):
2136+ task = thread .task
2137+ trap_if (not task .inst .may_leave )
2138+ f = task .inst .ftbl .get (i )
2139+ trap_if (f .type != ft )
2140+ thread_func = partial (f , c )
2141+ new_thread = Thread (task , thread_func )
2142+ assert (new_thread .suspended ())
2143+ new_thread .index = task .inst .table .add (thread )
2144+ return [new_thread .index ]
2145+
2146+ ### 🧵 `canon thread.resume-later`
2147+
2148+ def canon_thread_resume_later (thread , i ):
2149+ trap_if (not thread .task .inst .may_leave )
2150+ other_thread = thread .task .inst .table .get (i )
2151+ trap_if (not isinstance (other_thread , Thread ))
2152+ trap_if (not other_thread .suspended ())
2153+ thread .resume_later (other_thread )
2154+ return []
2155+
2156+ ### 🧵 `canon thread.switch-to`
2157+
2158+ def canon_thread_switch_to (thread , cancellable , i ):
2159+ trap_if (not thread .task .inst .may_leave )
2160+ other_thread = thread .task .inst .table .get (i )
2161+ trap_if (not isinstance (other_thread , Thread ))
2162+ trap_if (not other_thread .suspended ())
2163+ if not thread .switch_to (cancellable , other_thread ):
2164+ assert (cancellable )
2165+ return [0 ]
2166+ else :
2167+ return [1 ]
2168+
2169+ ### 🧵 `canon thread.yield-to`
2170+
2171+ def canon_thread_yield_to (thread , cancellable , i ):
2172+ trap_if (not thread .task .inst .may_leave )
2173+ other_thread = thread .task .inst .table .get (i )
2174+ trap_if (not isinstance (other_thread , Thread ))
2175+ trap_if (not other_thread .suspended ())
2176+ if not other_thread .yield_to (cancellable , other_thread ):
2177+ assert (cancellable )
2178+ return [0 ]
2179+ else :
2180+ return [1 ]
2181+
2182+ ### 🧵 `canon thread.suspend`
2183+
2184+ def canon_thread_suspend (thread , cancellable ):
2185+ trap_if (not thread .task .inst .may_leave )
2186+ if not thread .suspend (cancellable ):
2187+ assert (cancellable )
2188+ return [0 ]
2189+ else :
2190+ return [1 ]
2191+
21012192### 🔀 `canon context.get`
21022193
21032194def canon_context_get (t , i , thread ):
21042195 assert (t == 'i32' )
2105- assert (i < ContextLocalStorage . LENGTH )
2106- return [thread .task . context . get ( i ) ]
2196+ assert (i < Thread . CONTEXT_LENGTH )
2197+ return [thread .context [ i ] ]
21072198
21082199### 🔀 `canon context.set`
21092200
21102201def canon_context_set (t , i , thread , v ):
21112202 assert (t == 'i32' )
2112- assert (i < ContextLocalStorage . LENGTH )
2113- thread .task . context . set ( i , v )
2203+ assert (i < Thread . CONTEXT_LENGTH )
2204+ thread .context [ i ] = v
21142205 return []
21152206
21162207### 🔀 `canon backpressure.set`
21172208
21182209def canon_backpressure_set (thread , flat_args ):
2119- trap_if (thread .task .opts .sync )
2210+ # TODO: remove trap_if(thread.task.opts.sync)
21202211 assert (len (flat_args ) == 1 )
21212212 thread .task .inst .backpressure = bool (flat_args [0 ])
21222213 return []
0 commit comments