@@ -199,7 +199,7 @@ def tick(self):
199199 random .shuffle (self .pending )
200200 for thread in self .pending :
201201 if thread .ready ():
202- thread .resume ()
202+ thread .resume (Cancelled . FALSE )
203203 return
204204
205205FuncInst : Callable [[Optional [Supertask ], OnStart , OnResolve ], Call ]
@@ -362,28 +362,74 @@ def __init__(self, impl, dtor = None, dtor_async = False, dtor_callback = None):
362362 self .dtor_async = dtor_async
363363 self .dtor_callback = dtor_callback
364364
365+ #### Stack Switching Support
366+
367+ class Continuation :
368+ lock : threading .Lock
369+ handler : Handler
370+ arg : any
371+
372+ class Handler :
373+ tls = threading .local ()
374+ lock : threading .Lock
375+ result : Optional [tuple [Continuation , any ]]
376+
377+ def cont_new (f : Callable [[], None ]) -> Continuation :
378+ cont = Continuation ()
379+ cont .lock = threading .Lock ()
380+ cont .lock .acquire ()
381+ def wrapper ():
382+ cont .lock .acquire ()
383+ Handler .tls .value = cont .handler
384+ f (cont .arg )
385+ handler = Handler .tls .value
386+ del Handler .tls .value
387+ handler .result = None
388+ handler .lock .release ()
389+ threading .Thread (target = wrapper ).start ()
390+ return cont
391+
392+ def resume (cont : Continuation , v : any ) -> Optional [tuple [Continuation , any ]]:
393+ handler = Handler ()
394+ handler .lock = threading .Lock ()
395+ handler .lock .acquire ()
396+ cont .handler = handler
397+ cont .arg = v
398+ cont .lock .release ()
399+ handler .lock .acquire ()
400+ return handler .result
401+
402+ def suspend (v : any ) -> any :
403+ handler = Handler .tls .value
404+ del Handler .tls .value
405+ cont = Continuation ()
406+ cont .lock = threading .Lock ()
407+ cont .lock .acquire ()
408+ handler .result = (cont , v )
409+ handler .lock .release ()
410+ cont .lock .acquire ()
411+ Handler .tls .value = cont .handler
412+ return cont .arg
413+
365414#### Thread State
366415
367- class SuspendResult (IntEnum ):
368- NOT_CANCELLED = 0
369- CANCELLED = 1
416+ class Cancelled (IntEnum ):
417+ FALSE = 0
418+ TRUE = 1
370419
371420class Thread :
372- task : Task
373- fiber : threading .Thread
374- fiber_lock : threading .Lock
375- parent_lock : Optional [threading .Lock ]
421+ cont : Optional [Continuation ]
376422 ready_func : Optional [Callable [[], bool ]]
377- cancellable : bool
378- suspend_result : Optional [SuspendResult ]
379- in_event_loop : bool
423+ task : Task
380424 index : Optional [int ]
381425 context : list [int ]
426+ cancellable : bool
427+ in_event_loop : bool
382428
383429 CONTEXT_LENGTH = 2
384430
385431 def running (self ):
386- return self .parent_lock is not None
432+ return self .cont is None
387433
388434 def suspended (self ):
389435 return not self .running () and self .ready_func is None
@@ -396,59 +442,49 @@ def ready(self):
396442 return self .ready_func ()
397443
398444 def __init__ (self , task , thread_func ):
399- self .task = task
400- self .fiber_lock = threading .Lock ()
401- self .fiber_lock .acquire ()
402- self .parent_lock = None
403- self .ready_func = None
404- self .cancellable = False
405- self .suspend_result = None
406- self .in_event_loop = False
407- self .index = None
408- self .context = [0 ] * Thread .CONTEXT_LENGTH
409- def fiber_func ():
410- self .fiber_lock .acquire ()
411- assert (self .running () and self .suspend_result == SuspendResult .NOT_CANCELLED )
412- self .suspend_result = None
445+ def wrapper (cancelled ):
446+ assert (self .running () and not cancelled )
413447 thread_func (self )
414- assert (self .running ())
415448 self .task .thread_stop (self )
416449 if self .index is not None :
417450 self .task .inst .threads .remove (self .index )
418- self .parent_lock .release ()
419- self .fiber = threading .Thread (target = fiber_func )
420- self .fiber .start ()
421- self .task .thread_start (self )
451+ self .cont = cont_new (wrapper )
452+ self .ready_func = None
453+ self .task = task
454+ self .index = None
455+ self .context = [0 ] * Thread .CONTEXT_LENGTH
456+ self .cancellable = False
457+ self .in_event_loop = False
422458 assert (self .suspended ())
459+ self .task .thread_start (self )
423460
424- def resume (self , suspend_result = SuspendResult . NOT_CANCELLED ):
425- assert (not self .running () and self . suspend_result is None )
461+ def resume (self , cancelled ):
462+ assert (not self .running ())
426463 if self .ready_func :
427- assert (suspend_result == SuspendResult . CANCELLED or self .ready_func ())
464+ assert (cancelled or self .ready_func ())
428465 self .ready_func = None
429466 self .task .inst .store .pending .remove (self )
430- assert (self .cancellable or suspend_result == SuspendResult .NOT_CANCELLED )
431- self .suspend_result = suspend_result
432- self .parent_lock = threading .Lock ()
433- self .parent_lock .acquire ()
434- self .fiber_lock .release ()
435- self .parent_lock .acquire ()
436- self .parent_lock = None
437- assert (not self .running ())
467+ assert (self .cancellable or not cancelled )
468+ thread = self
469+ while True :
470+ assert (not thread .running ())
471+ cont = thread .cont
472+ thread .cont = None
473+ if not (resume_result := resume (cont , cancelled )):
474+ return
475+ thread .cont ,switch_to_thread = resume_result
476+ if switch_to_thread is None :
477+ return
478+ thread = switch_to_thread
479+ cancelled = Cancelled .FALSE
438480
439- def suspend (self , cancellable ) -> SuspendResult :
440- assert (self .task .may_block ())
441- assert (self .running () and not self .cancellable and self .suspend_result is None )
481+ def suspend (self , cancellable ) -> Cancelled :
482+ assert (self .running () and self .task .may_block ())
442483 self .cancellable = cancellable
443- self .parent_lock .release ()
444- self .fiber_lock .acquire ()
484+ cancelled = suspend (None )
445485 assert (self .running ())
446- self .cancellable = False
447- suspend_result = self .suspend_result
448- self .suspend_result = None
449- assert (suspend_result is not None )
450- assert (cancellable or suspend_result == SuspendResult .NOT_CANCELLED )
451- return suspend_result
486+ assert (cancellable or not cancelled )
487+ return cancelled
452488
453489 def resume_later (self ):
454490 assert (self .suspended ())
@@ -459,29 +495,18 @@ def suspend_until(self, ready_func, cancellable = False) -> SuspendResult:
459495 assert (self .task .may_block ())
460496 assert (self .running ())
461497 if ready_func () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
462- return SuspendResult . NOT_CANCELLED
498+ return Cancelled . FALSE
463499 self .ready_func = ready_func
464500 self .task .inst .store .pending .append (self )
465501 return self .suspend (cancellable )
466502
467503 def switch_to (self , cancellable , other : Thread ) -> SuspendResult :
468- assert (self .running () and not self .cancellable and self .suspend_result is None )
469- assert (other .suspended () and other .suspend_result is None )
504+ assert (self .running ())
470505 self .cancellable = cancellable
471- other .suspend_result = SuspendResult .NOT_CANCELLED
472- assert (self .parent_lock and not other .parent_lock )
473- other .parent_lock = self .parent_lock
474- self .parent_lock = None
475- assert (not self .running () and other .running ())
476- other .fiber_lock .release ()
477- self .fiber_lock .acquire ()
506+ cancelled = suspend (other )
478507 assert (self .running ())
479- self .cancellable = False
480- suspend_result = self .suspend_result
481- self .suspend_result = None
482- assert (suspend_result is not None )
483- assert (cancellable or suspend_result == SuspendResult .NOT_CANCELLED )
484- return suspend_result
508+ assert (cancellable or not cancelled )
509+ return cancelled
485510
486511 def yield_to (self , cancellable , other : Thread ) -> SuspendResult :
487512 assert (not self .ready_func )
@@ -610,7 +635,7 @@ def has_backpressure():
610635 self .inst .num_waiting_to_enter += 1
611636 result = thread .suspend_until (lambda : not has_backpressure (), cancellable = True )
612637 self .inst .num_waiting_to_enter -= 1
613- if result == SuspendResult . CANCELLED :
638+ if result == Cancelled . TRUE :
614639 self .cancel ()
615640 return False
616641 if self .needs_exclusive ():
@@ -632,7 +657,7 @@ def request_cancellation(self):
632657 for thread in self .threads :
633658 if thread .cancellable and not (thread .in_event_loop and self .inst .exclusive ):
634659 self .state = Task .State .CANCEL_DELIVERED
635- thread .resume (SuspendResult . CANCELLED )
660+ thread .resume (Cancelled . TRUE )
636661 return
637662 self .state = Task .State .PENDING_CANCEL
638663
@@ -645,25 +670,25 @@ def deliver_pending_cancel(self, cancellable) -> bool:
645670 def suspend (self , thread , cancellable ) -> SuspendResult :
646671 assert (thread in self .threads and thread .task is self )
647672 if self .deliver_pending_cancel (cancellable ):
648- return SuspendResult . CANCELLED
673+ return Cancelled . TRUE
649674 return thread .suspend (cancellable )
650675
651676 def suspend_until (self , ready_func , thread , cancellable ) -> SuspendResult :
652677 assert (thread in self .threads and thread .task is self )
653678 if self .deliver_pending_cancel (cancellable ):
654- return SuspendResult . CANCELLED
679+ return Cancelled . TRUE
655680 return thread .suspend_until (ready_func , cancellable )
656681
657682 def switch_to (self , thread , cancellable , other_thread ) -> SuspendResult :
658683 assert (thread in self .threads and thread .task is self )
659684 if self .deliver_pending_cancel (cancellable ):
660- return SuspendResult . CANCELLED
685+ return Cancelled . TRUE
661686 return thread .switch_to (cancellable , other_thread )
662687
663688 def yield_to (self , thread , cancellable , other_thread ) -> SuspendResult :
664689 assert (thread in self .threads and thread .task is self )
665690 if self .deliver_pending_cancel (cancellable ):
666- return SuspendResult . CANCELLED
691+ return Cancelled . TRUE
667692 return thread .yield_to (cancellable , other_thread )
668693
669694 def wait_until (self , ready_func , thread , wset , cancellable ) -> EventTuple :
@@ -672,19 +697,19 @@ def wait_until(self, ready_func, thread, wset, cancellable) -> EventTuple:
672697 def ready_and_has_event ():
673698 return ready_func () and wset .has_pending_event ()
674699 match self .suspend_until (ready_and_has_event , thread , cancellable ):
675- case SuspendResult . CANCELLED :
700+ case Cancelled . TRUE :
676701 event = (EventCode .TASK_CANCELLED , 0 , 0 )
677- case SuspendResult . NOT_CANCELLED :
702+ case Cancelled . FALSE :
678703 event = wset .get_pending_event ()
679704 wset .num_waiting -= 1
680705 return event
681706
682707 def yield_until (self , ready_func , thread , cancellable ) -> EventTuple :
683708 assert (thread in self .threads and thread .task is self )
684709 match self .suspend_until (ready_func , thread , cancellable ):
685- case SuspendResult . CANCELLED :
710+ case Cancelled . TRUE :
686711 return (EventCode .TASK_CANCELLED , 0 , 0 )
687- case SuspendResult . NOT_CANCELLED :
712+ case Cancelled . FALSE :
688713 return (EventCode .NONE , 0 , 0 )
689714
690715 def return_ (self , result ):
@@ -2038,7 +2063,7 @@ def thread_func(thread):
20382063 return
20392064
20402065 thread = Thread (task , thread_func )
2041- thread .resume ()
2066+ thread .resume (Cancelled . FALSE )
20422067 return task
20432068
20442069class CallbackCode (IntEnum ):
@@ -2531,16 +2556,16 @@ def canon_thread_switch_to(cancellable, thread, i):
25312556 trap_if (not thread .task .inst .may_leave )
25322557 other_thread = thread .task .inst .threads .get (i )
25332558 trap_if (not other_thread .suspended ())
2534- suspend_result = thread .task .switch_to (thread , cancellable , other_thread )
2535- return [suspend_result ]
2559+ cancelled = thread .task .switch_to (thread , cancellable , other_thread )
2560+ return [cancelled ]
25362561
25372562### 🧵 `canon thread.suspend`
25382563
25392564def canon_thread_suspend (cancellable , thread ):
25402565 trap_if (not thread .task .inst .may_leave )
25412566 trap_if (not thread .task .may_block ())
2542- suspend_result = thread .task .suspend (thread , cancellable )
2543- return [suspend_result ]
2567+ cancelled = thread .task .suspend (thread , cancellable )
2568+ return [cancelled ]
25442569
25452570### 🧵 `canon thread.resume-later`
25462571
@@ -2557,21 +2582,21 @@ def canon_thread_yield_to(cancellable, thread, i):
25572582 trap_if (not thread .task .inst .may_leave )
25582583 other_thread = thread .task .inst .threads .get (i )
25592584 trap_if (not other_thread .suspended ())
2560- suspend_result = thread .task .yield_to (thread , cancellable , other_thread )
2561- return [suspend_result ]
2585+ cancelled = thread .task .yield_to (thread , cancellable , other_thread )
2586+ return [cancelled ]
25622587
25632588### 🧵 `canon thread.yield`
25642589
25652590def canon_thread_yield (cancellable , thread ):
25662591 trap_if (not thread .task .inst .may_leave )
25672592 if not thread .task .may_block ():
2568- return [SuspendResult . NOT_CANCELLED ]
2593+ return [Cancelled . FALSE ]
25692594 event_code ,_ ,_ = thread .task .yield_until (lambda : True , thread , cancellable )
25702595 match event_code :
25712596 case EventCode .NONE :
2572- return [SuspendResult . NOT_CANCELLED ]
2597+ return [Cancelled . FALSE ]
25732598 case EventCode .TASK_CANCELLED :
2574- return [SuspendResult . CANCELLED ]
2599+ return [Cancelled . TRUE ]
25752600
25762601### 📝 `canon error-context.new`
25772602
0 commit comments