@@ -199,7 +199,7 @@ def tick(self):
199199 random .shuffle (self .pending )
200200 for thread in self .pending :
201201 if thread .ready ():
202- thread .resume ()
202+ thread .resume (Cancelled . FALSE )
203203 return
204204
205205FuncInst : Callable [[Optional [Supertask ], OnStart , OnResolve ], Call ]
@@ -389,27 +389,74 @@ def __init__(self, impl, dtor = None, dtor_async = False, dtor_callback = None):
389389 self .dtor_async = dtor_async
390390 self .dtor_callback = dtor_callback
391391
392+ #### Stack Switching Support
393+
394+ class Continuation :
395+ lock : threading .Lock
396+ handler : Handler
397+ cancelled : Cancelled
398+
399+ class Handler :
400+ tls = threading .local ()
401+ lock : threading .Lock
402+ cont : Optional [Continuation ]
403+ switch_to : Optional [Thread ]
404+
405+ def cont_new (f : Callable [[Cancelled ], Optional [Thread ]]) -> Continuation :
406+ cont = Continuation ()
407+ cont .lock = threading .Lock ()
408+ cont .lock .acquire ()
409+ def wrapper ():
410+ cont .lock .acquire ()
411+ Handler .tls .value = cont .handler
412+ switch_to = f (cont .cancelled )
413+ handler = Handler .tls .value
414+ handler .cont = None
415+ handler .switch_to = switch_to
416+ handler .lock .release ()
417+ threading .Thread (target = wrapper ).start ()
418+ return cont
419+
420+ def resume (cont : Continuation , cancelled : Cancelled ) -> tuple [Optional [Continuation ], Optional [Thread ]]:
421+ handler = Handler ()
422+ handler .lock = threading .Lock ()
423+ handler .lock .acquire ()
424+ cont .handler = handler
425+ cont .cancelled = cancelled
426+ cont .lock .release ()
427+ handler .lock .acquire ()
428+ return (handler .cont , handler .switch_to )
429+
430+ def suspend (switch_to : Optional [Thread ]) -> Cancelled :
431+ cont = Continuation ()
432+ cont .lock = threading .Lock ()
433+ cont .lock .acquire ()
434+ handler = Handler .tls .value
435+ handler .cont = cont
436+ handler .switch_to = switch_to
437+ handler .lock .release ()
438+ cont .lock .acquire ()
439+ Handler .tls .value = cont .handler
440+ return cont .cancelled
441+
392442#### Thread State
393443
394- class SuspendResult (IntEnum ):
395- NOT_CANCELLED = 0
396- CANCELLED = 1
444+ class Cancelled (IntEnum ):
445+ FALSE = 0
446+ TRUE = 1
397447
398448class Thread :
399- task : Task
400- fiber : threading .Thread
401- fiber_lock : threading .Lock
402- parent_lock : Optional [threading .Lock ]
449+ cont : Optional [Continuation ]
403450 ready_func : Optional [Callable [[], bool ]]
404- cancellable : bool
405- suspend_result : Optional [SuspendResult ]
451+ task : Task
406452 index : Optional [int ]
407453 context : list [int ]
454+ cancellable : bool
408455
409456 CONTEXT_LENGTH = 2
410457
411458 def running (self ):
412- return self .parent_lock is not None
459+ return self .cont is None
413460
414461 def suspended (self ):
415462 return not self .running () and self .ready_func is None
@@ -422,94 +469,63 @@ def ready(self):
422469 return self .ready_func ()
423470
424471 def __init__ (self , task , thread_func ):
425- self .task = task
426- self .fiber_lock = threading .Lock ()
427- self .fiber_lock .acquire ()
428- self .parent_lock = None
429- self .ready_func = None
430- self .cancellable = False
431- self .suspend_result = None
432- self .index = None
433- self .context = [0 ] * Thread .CONTEXT_LENGTH
434- def fiber_func ():
435- self .fiber_lock .acquire ()
436- assert (self .running () and self .suspend_result == SuspendResult .NOT_CANCELLED )
437- self .suspend_result = None
472+ def wrapper (cancelled ):
473+ assert (self .running () and not cancelled )
438474 thread_func (self )
439- assert (self .running ())
440475 self .task .thread_stop (self )
441476 if self .index is not None :
442477 self .task .inst .threads .remove (self .index )
443- self .parent_lock .release ()
444- self .fiber = threading .Thread (target = fiber_func )
445- self .fiber .start ()
446- self .task .thread_start (self )
478+ self .cont = cont_new (wrapper )
479+ self .ready_func = None
480+ self .task = task
481+ self .index = None
482+ self .context = [0 ] * Thread .CONTEXT_LENGTH
483+ self .cancellable = False
447484 assert (self .suspended ())
485+ self .task .thread_start (self )
448486
449- def resume (self , suspend_result = SuspendResult . NOT_CANCELLED ):
450- assert (not self .running () and self .suspend_result is None )
487+ def resume (self , cancelled ):
488+ assert (not self .running () and ( self .cancellable or not cancelled ) )
451489 if self .ready_func :
452- assert (suspend_result == SuspendResult . CANCELLED or self .ready_func ())
490+ assert (cancelled or self .ready_func ())
453491 self .ready_func = None
454492 self .task .inst .store .pending .remove (self )
455- assert (self .cancellable or suspend_result == SuspendResult .NOT_CANCELLED )
456- self .suspend_result = suspend_result
457- self .parent_lock = threading .Lock ()
458- self .parent_lock .acquire ()
459- self .fiber_lock .release ()
460- self .parent_lock .acquire ()
461- self .parent_lock = None
462- assert (not self .running ())
463-
464- def suspend (self , cancellable ) -> SuspendResult :
465- assert (self .task .may_block ())
466- assert (self .running () and not self .cancellable and self .suspend_result is None )
493+ thread = self
494+ while thread is not None :
495+ cont = thread .cont
496+ thread .cont = None
497+ thread .cont , thread = resume (cont , cancelled )
498+ cancelled = Cancelled .FALSE
499+
500+ def suspend (self , cancellable ) -> Cancelled :
501+ assert (self .running () and self .task .may_block ())
467502 self .cancellable = cancellable
468- self .parent_lock .release ()
469- self .fiber_lock .acquire ()
470- assert (self .running ())
471- self .cancellable = False
472- suspend_result = self .suspend_result
473- self .suspend_result = None
474- assert (suspend_result is not None )
475- assert (cancellable or suspend_result == SuspendResult .NOT_CANCELLED )
476- return suspend_result
503+ cancelled = suspend (None )
504+ assert (self .running () and (cancellable or not cancelled ))
505+ return cancelled
477506
478507 def resume_later (self ):
479508 assert (self .suspended ())
480509 self .ready_func = lambda : True
481510 self .task .inst .store .pending .append (self )
482511
483- def suspend_until (self , ready_func , cancellable = False ) -> SuspendResult :
484- assert (self .task .may_block ())
485- assert (self .running ())
512+ def suspend_until (self , ready_func , cancellable = False ) -> Cancelled :
513+ assert (self .running () and self .task .may_block ())
486514 if ready_func () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
487- return SuspendResult . NOT_CANCELLED
515+ return Cancelled . FALSE
488516 self .ready_func = ready_func
489517 self .task .inst .store .pending .append (self )
490518 return self .suspend (cancellable )
491519
492- def switch_to (self , cancellable , other : Thread ) -> SuspendResult :
493- assert (self .running () and not self .cancellable and self .suspend_result is None )
494- assert (other .suspended () and other .suspend_result is None )
520+ def switch_to (self , cancellable , other : Thread ) -> Cancelled :
521+ assert (self .running ())
495522 self .cancellable = cancellable
496- other .suspend_result = SuspendResult .NOT_CANCELLED
497- assert (self .parent_lock and not other .parent_lock )
498- other .parent_lock = self .parent_lock
499- self .parent_lock = None
500- assert (not self .running () and other .running ())
501- other .fiber_lock .release ()
502- self .fiber_lock .acquire ()
523+ cancelled = suspend (other )
524+ assert (self .running () and (cancellable or not cancelled ))
525+ return cancelled
526+
527+ def yield_to (self , cancellable , other : Thread ) -> Cancelled :
503528 assert (self .running ())
504- self .cancellable = False
505- suspend_result = self .suspend_result
506- self .suspend_result = None
507- assert (suspend_result is not None )
508- assert (cancellable or suspend_result == SuspendResult .NOT_CANCELLED )
509- return suspend_result
510-
511- def yield_to (self , cancellable , other : Thread ) -> SuspendResult :
512- assert (not self .ready_func )
513529 self .ready_func = lambda : True
514530 self .task .inst .store .pending .append (self )
515531 return self .switch_to (cancellable , other )
@@ -637,7 +653,7 @@ def has_backpressure():
637653 self .inst .num_waiting_to_enter += 1
638654 result = thread .suspend_until (lambda : not has_backpressure (), cancellable = True )
639655 self .inst .num_waiting_to_enter -= 1
640- if result == SuspendResult . CANCELLED :
656+ if result == Cancelled . TRUE :
641657 self .cancel ()
642658 return False
643659 self .state = Task .State .UNRESOLVED
@@ -658,15 +674,15 @@ def request_cancellation(self):
658674 if self .state == Task .State .BACKPRESSURE :
659675 assert (len (self .threads ) == 1 )
660676 self .state = Task .State .CANCEL_DELIVERED
661- self .threads [0 ].resume (SuspendResult . CANCELLED )
677+ self .threads [0 ].resume (Cancelled . TRUE )
662678 return
663679 assert (self .state == Task .State .UNRESOLVED )
664680 if not self .needs_exclusive () or not self .inst .exclusive or self .inst .exclusive is self :
665681 random .shuffle (self .threads )
666682 for thread in self .threads :
667683 if thread .cancellable :
668684 self .state = Task .State .CANCEL_DELIVERED
669- thread .resume (SuspendResult . CANCELLED )
685+ thread .resume (Cancelled . TRUE )
670686 return
671687 self .state = Task .State .PENDING_CANCEL
672688
@@ -676,28 +692,28 @@ def deliver_pending_cancel(self, cancellable) -> bool:
676692 return True
677693 return False
678694
679- def suspend (self , thread , cancellable ) -> SuspendResult :
695+ def suspend (self , thread , cancellable ) -> Cancelled :
680696 assert (thread in self .threads and thread .task is self )
681697 if self .deliver_pending_cancel (cancellable ):
682- return SuspendResult . CANCELLED
698+ return Cancelled . TRUE
683699 return thread .suspend (cancellable )
684700
685- def suspend_until (self , ready_func , thread , cancellable ) -> SuspendResult :
701+ def suspend_until (self , ready_func , thread , cancellable ) -> Cancelled :
686702 assert (thread in self .threads and thread .task is self )
687703 if self .deliver_pending_cancel (cancellable ):
688- return SuspendResult . CANCELLED
704+ return Cancelled . TRUE
689705 return thread .suspend_until (ready_func , cancellable )
690706
691- def switch_to (self , thread , cancellable , other_thread ) -> SuspendResult :
707+ def switch_to (self , thread , cancellable , other_thread ) -> Cancelled :
692708 assert (thread in self .threads and thread .task is self )
693709 if self .deliver_pending_cancel (cancellable ):
694- return SuspendResult . CANCELLED
710+ return Cancelled . TRUE
695711 return thread .switch_to (cancellable , other_thread )
696712
697- def yield_to (self , thread , cancellable , other_thread ) -> SuspendResult :
713+ def yield_to (self , thread , cancellable , other_thread ) -> Cancelled :
698714 assert (thread in self .threads and thread .task is self )
699715 if self .deliver_pending_cancel (cancellable ):
700- return SuspendResult . CANCELLED
716+ return Cancelled . TRUE
701717 return thread .yield_to (cancellable , other_thread )
702718
703719 def wait_until (self , ready_func , thread , wset , cancellable ) -> EventTuple :
@@ -706,19 +722,19 @@ def wait_until(self, ready_func, thread, wset, cancellable) -> EventTuple:
706722 def ready_and_has_event ():
707723 return ready_func () and wset .has_pending_event ()
708724 match self .suspend_until (ready_and_has_event , thread , cancellable ):
709- case SuspendResult . CANCELLED :
725+ case Cancelled . TRUE :
710726 event = (EventCode .TASK_CANCELLED , 0 , 0 )
711- case SuspendResult . NOT_CANCELLED :
727+ case Cancelled . FALSE :
712728 event = wset .get_pending_event ()
713729 wset .num_waiting -= 1
714730 return event
715731
716732 def yield_until (self , ready_func , thread , cancellable ) -> EventTuple :
717733 assert (thread in self .threads and thread .task is self )
718734 match self .suspend_until (ready_func , thread , cancellable ):
719- case SuspendResult . CANCELLED :
735+ case Cancelled . TRUE :
720736 return (EventCode .TASK_CANCELLED , 0 , 0 )
721- case SuspendResult . NOT_CANCELLED :
737+ case Cancelled . FALSE :
722738 return (EventCode .NONE , 0 , 0 )
723739
724740 def return_ (self , result ):
@@ -2082,7 +2098,7 @@ def thread_func(thread):
20822098 return
20832099
20842100 thread = Thread (task , thread_func )
2085- thread .resume ()
2101+ thread .resume (Cancelled . FALSE )
20862102 return task
20872103
20882104class CallbackCode (IntEnum ):
@@ -2578,16 +2594,16 @@ def canon_thread_switch_to(cancellable, thread, i):
25782594 trap_if (not thread .task .inst .may_leave )
25792595 other_thread = thread .task .inst .threads .get (i )
25802596 trap_if (not other_thread .suspended ())
2581- suspend_result = thread .task .switch_to (thread , cancellable , other_thread )
2582- return [suspend_result ]
2597+ cancelled = thread .task .switch_to (thread , cancellable , other_thread )
2598+ return [cancelled ]
25832599
25842600### 🧵 `canon thread.suspend`
25852601
25862602def canon_thread_suspend (cancellable , thread ):
25872603 trap_if (not thread .task .inst .may_leave )
25882604 trap_if (not thread .task .may_block ())
2589- suspend_result = thread .task .suspend (thread , cancellable )
2590- return [suspend_result ]
2605+ cancelled = thread .task .suspend (thread , cancellable )
2606+ return [cancelled ]
25912607
25922608### 🧵 `canon thread.resume-later`
25932609
@@ -2604,21 +2620,21 @@ def canon_thread_yield_to(cancellable, thread, i):
26042620 trap_if (not thread .task .inst .may_leave )
26052621 other_thread = thread .task .inst .threads .get (i )
26062622 trap_if (not other_thread .suspended ())
2607- suspend_result = thread .task .yield_to (thread , cancellable , other_thread )
2608- return [suspend_result ]
2623+ cancelled = thread .task .yield_to (thread , cancellable , other_thread )
2624+ return [cancelled ]
26092625
26102626### 🧵 `canon thread.yield`
26112627
26122628def canon_thread_yield (cancellable , thread ):
26132629 trap_if (not thread .task .inst .may_leave )
26142630 if not thread .task .may_block ():
2615- return [SuspendResult . NOT_CANCELLED ]
2631+ return [Cancelled . FALSE ]
26162632 event_code ,_ ,_ = thread .task .yield_until (lambda : True , thread , cancellable )
26172633 match event_code :
26182634 case EventCode .NONE :
2619- return [SuspendResult . NOT_CANCELLED ]
2635+ return [Cancelled . FALSE ]
26202636 case EventCode .TASK_CANCELLED :
2621- return [SuspendResult . CANCELLED ]
2637+ return [Cancelled . TRUE ]
26222638
26232639### 📝 `canon error-context.new`
26242640
0 commit comments