@@ -199,7 +199,7 @@ def tick(self):
199199 random .shuffle (self .pending )
200200 for thread in self .pending :
201201 if thread .ready ():
202- thread .resume ()
202+ thread .resume (Cancelled . FALSE )
203203 return
204204
205205FuncInst : Callable [[Optional [Supertask ], OnStart , OnResolve ], Call ]
@@ -389,27 +389,71 @@ def __init__(self, impl, dtor = None, dtor_async = False, dtor_callback = None):
389389 self .dtor_async = dtor_async
390390 self .dtor_callback = dtor_callback
391391
392+ #### Stack Switching Support
393+
394+ class Continuation :
395+ lock : threading .Lock
396+ handler : Handler
397+ arg : any
398+
399+ class Handler :
400+ tls = threading .local ()
401+ lock : threading .Lock
402+ result : Optional [tuple [Continuation , any ]]
403+
404+ def cont_new (f : Callable [[any ], None ]) -> Continuation :
405+ cont = Continuation ()
406+ cont .lock = threading .Lock ()
407+ cont .lock .acquire ()
408+ def wrapper ():
409+ cont .lock .acquire ()
410+ Handler .tls .value = cont .handler
411+ f (cont .arg )
412+ handler = Handler .tls .value
413+ handler .result = None
414+ handler .lock .release ()
415+ threading .Thread (target = wrapper ).start ()
416+ return cont
417+
418+ def resume (cont : Continuation , v : any ) -> Optional [tuple [Continuation , any ]]:
419+ handler = Handler ()
420+ handler .lock = threading .Lock ()
421+ handler .lock .acquire ()
422+ cont .handler = handler
423+ cont .arg = v
424+ cont .lock .release ()
425+ handler .lock .acquire ()
426+ return handler .result
427+
428+ def suspend (v : any ) -> any :
429+ handler = Handler .tls .value
430+ cont = Continuation ()
431+ cont .lock = threading .Lock ()
432+ cont .lock .acquire ()
433+ handler .result = (cont , v )
434+ handler .lock .release ()
435+ cont .lock .acquire ()
436+ Handler .tls .value = cont .handler
437+ return cont .arg
438+
392439#### Thread State
393440
394- class SuspendResult (IntEnum ):
395- NOT_CANCELLED = 0
396- CANCELLED = 1
441+ class Cancelled (IntEnum ):
442+ FALSE = 0
443+ TRUE = 1
397444
398445class Thread :
399- task : Task
400- fiber : threading .Thread
401- fiber_lock : threading .Lock
402- parent_lock : Optional [threading .Lock ]
446+ cont : Optional [Continuation ]
403447 ready_func : Optional [Callable [[], bool ]]
404- cancellable : bool
405- suspend_result : Optional [SuspendResult ]
448+ task : Task
406449 index : Optional [int ]
407450 context : list [int ]
451+ cancellable : bool
408452
409453 CONTEXT_LENGTH = 2
410454
411455 def running (self ):
412- return self .parent_lock is not None
456+ return self .cont is None
413457
414458 def suspended (self ):
415459 return not self .running () and self .ready_func is None
@@ -422,94 +466,70 @@ def ready(self):
422466 return self .ready_func ()
423467
424468 def __init__ (self , task , thread_func ):
425- self .task = task
426- self .fiber_lock = threading .Lock ()
427- self .fiber_lock .acquire ()
428- self .parent_lock = None
429- self .ready_func = None
430- self .cancellable = False
431- self .suspend_result = None
432- self .index = None
433- self .context = [0 ] * Thread .CONTEXT_LENGTH
434- def fiber_func ():
435- self .fiber_lock .acquire ()
436- assert (self .running () and self .suspend_result == SuspendResult .NOT_CANCELLED )
437- self .suspend_result = None
469+ def wrapper (cancelled ):
470+ assert (self .running () and not cancelled )
438471 thread_func (self )
439- assert (self .running ())
440472 self .task .thread_stop (self )
441473 if self .index is not None :
442474 self .task .inst .threads .remove (self .index )
443- self .parent_lock .release ()
444- self .fiber = threading .Thread (target = fiber_func )
445- self .fiber .start ()
446- self .task .thread_start (self )
475+ self .cont = cont_new (wrapper )
476+ self .ready_func = None
477+ self .task = task
478+ self .index = None
479+ self .context = [0 ] * Thread .CONTEXT_LENGTH
480+ self .cancellable = False
447481 assert (self .suspended ())
482+ self .task .thread_start (self )
448483
449- def resume (self , suspend_result = SuspendResult .NOT_CANCELLED ):
450- assert (not self .running () and self .suspend_result is None )
484+ def resume (self , cancelled ):
485+ assert (self .cancellable or not cancelled )
486+ assert (not self .running ())
451487 if self .ready_func :
452- assert (suspend_result == SuspendResult . CANCELLED or self .ready_func ())
488+ assert (cancelled or self .ready_func ())
453489 self .ready_func = None
454490 self .task .inst .store .pending .remove (self )
455- assert (self .cancellable or suspend_result == SuspendResult .NOT_CANCELLED )
456- self .suspend_result = suspend_result
457- self .parent_lock = threading .Lock ()
458- self .parent_lock .acquire ()
459- self .fiber_lock .release ()
460- self .parent_lock .acquire ()
461- self .parent_lock = None
462- assert (not self .running ())
463-
464- def suspend (self , cancellable ) -> SuspendResult :
465- assert (self .task .may_block ())
466- assert (self .running () and not self .cancellable and self .suspend_result is None )
491+ thread = self
492+ while True :
493+ cont = thread .cont
494+ thread .cont = None
495+ resume_result = resume (cont , cancelled )
496+ if resume_result is None :
497+ break
498+ (thread .cont , switch_to_thread ) = resume_result
499+ if switch_to_thread is None :
500+ break
501+ thread = switch_to_thread
502+ cancelled = Cancelled .FALSE
503+
504+ def suspend (self , cancellable ) -> Cancelled :
505+ assert (self .running () and self .task .may_block ())
467506 self .cancellable = cancellable
468- self .parent_lock .release ()
469- self .fiber_lock .acquire ()
470- assert (self .running ())
471- self .cancellable = False
472- suspend_result = self .suspend_result
473- self .suspend_result = None
474- assert (suspend_result is not None )
475- assert (cancellable or suspend_result == SuspendResult .NOT_CANCELLED )
476- return suspend_result
507+ cancelled = suspend (None )
508+ assert (self .running () and (cancellable or not cancelled ))
509+ return cancelled
477510
478511 def resume_later (self ):
479512 assert (self .suspended ())
480513 self .ready_func = lambda : True
481514 self .task .inst .store .pending .append (self )
482515
483- def suspend_until (self , ready_func , cancellable = False ) -> SuspendResult :
484- assert (self .task .may_block ())
485- assert (self .running ())
516+ def suspend_until (self , ready_func , cancellable = False ) -> Cancelled :
517+ assert (self .running () and self .task .may_block ())
486518 if ready_func () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
487- return SuspendResult . NOT_CANCELLED
519+ return Cancelled . FALSE
488520 self .ready_func = ready_func
489521 self .task .inst .store .pending .append (self )
490522 return self .suspend (cancellable )
491523
492- def switch_to (self , cancellable , other : Thread ) -> SuspendResult :
493- assert (self .running () and not self .cancellable and self .suspend_result is None )
494- assert (other .suspended () and other .suspend_result is None )
524+ def switch_to (self , cancellable , other : Thread ) -> Cancelled :
525+ assert (self .running ())
495526 self .cancellable = cancellable
496- other .suspend_result = SuspendResult .NOT_CANCELLED
497- assert (self .parent_lock and not other .parent_lock )
498- other .parent_lock = self .parent_lock
499- self .parent_lock = None
500- assert (not self .running () and other .running ())
501- other .fiber_lock .release ()
502- self .fiber_lock .acquire ()
527+ cancelled = suspend (other )
528+ assert (self .running () and (cancellable or not cancelled ))
529+ return cancelled
530+
531+ def yield_to (self , cancellable , other : Thread ) -> Cancelled :
503532 assert (self .running ())
504- self .cancellable = False
505- suspend_result = self .suspend_result
506- self .suspend_result = None
507- assert (suspend_result is not None )
508- assert (cancellable or suspend_result == SuspendResult .NOT_CANCELLED )
509- return suspend_result
510-
511- def yield_to (self , cancellable , other : Thread ) -> SuspendResult :
512- assert (not self .ready_func )
513533 self .ready_func = lambda : True
514534 self .task .inst .store .pending .append (self )
515535 return self .switch_to (cancellable , other )
@@ -637,7 +657,7 @@ def has_backpressure():
637657 self .inst .num_waiting_to_enter += 1
638658 result = thread .suspend_until (lambda : not has_backpressure (), cancellable = True )
639659 self .inst .num_waiting_to_enter -= 1
640- if result == SuspendResult . CANCELLED :
660+ if result == Cancelled . TRUE :
641661 self .cancel ()
642662 return False
643663 self .state = Task .State .UNRESOLVED
@@ -658,15 +678,15 @@ def request_cancellation(self):
658678 if self .state == Task .State .BACKPRESSURE :
659679 assert (len (self .threads ) == 1 )
660680 self .state = Task .State .CANCEL_DELIVERED
661- self .threads [0 ].resume (SuspendResult . CANCELLED )
681+ self .threads [0 ].resume (Cancelled . TRUE )
662682 return
663683 assert (self .state == Task .State .UNRESOLVED )
664684 if not self .needs_exclusive () or not self .inst .exclusive or self .inst .exclusive is self :
665685 random .shuffle (self .threads )
666686 for thread in self .threads :
667687 if thread .cancellable :
668688 self .state = Task .State .CANCEL_DELIVERED
669- thread .resume (SuspendResult . CANCELLED )
689+ thread .resume (Cancelled . TRUE )
670690 return
671691 self .state = Task .State .PENDING_CANCEL
672692
@@ -676,28 +696,28 @@ def deliver_pending_cancel(self, cancellable) -> bool:
676696 return True
677697 return False
678698
679- def suspend (self , thread , cancellable ) -> SuspendResult :
699+ def suspend (self , thread , cancellable ) -> Cancelled :
680700 assert (thread in self .threads and thread .task is self )
681701 if self .deliver_pending_cancel (cancellable ):
682- return SuspendResult . CANCELLED
702+ return Cancelled . TRUE
683703 return thread .suspend (cancellable )
684704
685- def suspend_until (self , ready_func , thread , cancellable ) -> SuspendResult :
705+ def suspend_until (self , ready_func , thread , cancellable ) -> Cancelled :
686706 assert (thread in self .threads and thread .task is self )
687707 if self .deliver_pending_cancel (cancellable ):
688- return SuspendResult . CANCELLED
708+ return Cancelled . TRUE
689709 return thread .suspend_until (ready_func , cancellable )
690710
691- def switch_to (self , thread , cancellable , other_thread ) -> SuspendResult :
711+ def switch_to (self , thread , cancellable , other_thread ) -> Cancelled :
692712 assert (thread in self .threads and thread .task is self )
693713 if self .deliver_pending_cancel (cancellable ):
694- return SuspendResult . CANCELLED
714+ return Cancelled . TRUE
695715 return thread .switch_to (cancellable , other_thread )
696716
697- def yield_to (self , thread , cancellable , other_thread ) -> SuspendResult :
717+ def yield_to (self , thread , cancellable , other_thread ) -> Cancelled :
698718 assert (thread in self .threads and thread .task is self )
699719 if self .deliver_pending_cancel (cancellable ):
700- return SuspendResult . CANCELLED
720+ return Cancelled . TRUE
701721 return thread .yield_to (cancellable , other_thread )
702722
703723 def wait_until (self , ready_func , thread , wset , cancellable ) -> EventTuple :
@@ -706,19 +726,19 @@ def wait_until(self, ready_func, thread, wset, cancellable) -> EventTuple:
706726 def ready_and_has_event ():
707727 return ready_func () and wset .has_pending_event ()
708728 match self .suspend_until (ready_and_has_event , thread , cancellable ):
709- case SuspendResult . CANCELLED :
729+ case Cancelled . TRUE :
710730 event = (EventCode .TASK_CANCELLED , 0 , 0 )
711- case SuspendResult . NOT_CANCELLED :
731+ case Cancelled . FALSE :
712732 event = wset .get_pending_event ()
713733 wset .num_waiting -= 1
714734 return event
715735
716736 def yield_until (self , ready_func , thread , cancellable ) -> EventTuple :
717737 assert (thread in self .threads and thread .task is self )
718738 match self .suspend_until (ready_func , thread , cancellable ):
719- case SuspendResult . CANCELLED :
739+ case Cancelled . TRUE :
720740 return (EventCode .TASK_CANCELLED , 0 , 0 )
721- case SuspendResult . NOT_CANCELLED :
741+ case Cancelled . FALSE :
722742 return (EventCode .NONE , 0 , 0 )
723743
724744 def return_ (self , result ):
@@ -2082,7 +2102,7 @@ def thread_func(thread):
20822102 return
20832103
20842104 thread = Thread (task , thread_func )
2085- thread .resume ()
2105+ thread .resume (Cancelled . FALSE )
20862106 return task
20872107
20882108class CallbackCode (IntEnum ):
@@ -2578,16 +2598,16 @@ def canon_thread_switch_to(cancellable, thread, i):
25782598 trap_if (not thread .task .inst .may_leave )
25792599 other_thread = thread .task .inst .threads .get (i )
25802600 trap_if (not other_thread .suspended ())
2581- suspend_result = thread .task .switch_to (thread , cancellable , other_thread )
2582- return [suspend_result ]
2601+ cancelled = thread .task .switch_to (thread , cancellable , other_thread )
2602+ return [cancelled ]
25832603
25842604### 🧵 `canon thread.suspend`
25852605
25862606def canon_thread_suspend (cancellable , thread ):
25872607 trap_if (not thread .task .inst .may_leave )
25882608 trap_if (not thread .task .may_block ())
2589- suspend_result = thread .task .suspend (thread , cancellable )
2590- return [suspend_result ]
2609+ cancelled = thread .task .suspend (thread , cancellable )
2610+ return [cancelled ]
25912611
25922612### 🧵 `canon thread.resume-later`
25932613
@@ -2604,21 +2624,21 @@ def canon_thread_yield_to(cancellable, thread, i):
26042624 trap_if (not thread .task .inst .may_leave )
26052625 other_thread = thread .task .inst .threads .get (i )
26062626 trap_if (not other_thread .suspended ())
2607- suspend_result = thread .task .yield_to (thread , cancellable , other_thread )
2608- return [suspend_result ]
2627+ cancelled = thread .task .yield_to (thread , cancellable , other_thread )
2628+ return [cancelled ]
26092629
26102630### 🧵 `canon thread.yield`
26112631
26122632def canon_thread_yield (cancellable , thread ):
26132633 trap_if (not thread .task .inst .may_leave )
26142634 if not thread .task .may_block ():
2615- return [SuspendResult . NOT_CANCELLED ]
2635+ return [Cancelled . FALSE ]
26162636 event_code ,_ ,_ = thread .task .yield_until (lambda : True , thread , cancellable )
26172637 match event_code :
26182638 case EventCode .NONE :
2619- return [SuspendResult . NOT_CANCELLED ]
2639+ return [Cancelled . FALSE ]
26202640 case EventCode .TASK_CANCELLED :
2621- return [SuspendResult . CANCELLED ]
2641+ return [Cancelled . TRUE ]
26222642
26232643### 📝 `canon error-context.new`
26242644
0 commit comments