@@ -199,7 +199,7 @@ def tick(self):
199199 random .shuffle (self .pending )
200200 for thread in self .pending :
201201 if thread .ready ():
202- thread .resume ()
202+ thread .resume (Cancelled . FALSE )
203203 return
204204
205205FuncInst : Callable [[Optional [Supertask ], OnStart , OnResolve ], Call ]
@@ -362,27 +362,73 @@ def __init__(self, impl, dtor = None, dtor_async = False, dtor_callback = None):
362362 self .dtor_async = dtor_async
363363 self .dtor_callback = dtor_callback
364364
365+ #### Stack Switching Support
366+
367+ class Continuation :
368+ lock : threading .Lock
369+ handler : Handler
370+ arg : any
371+
372+ class Handler :
373+ tls = threading .local ()
374+ lock : threading .Lock
375+ result : Optional [tuple [Continuation , any ]]
376+
377+ def cont_new (f : Callable [[any ], None ]) -> Continuation :
378+ cont = Continuation ()
379+ cont .lock = threading .Lock ()
380+ cont .lock .acquire ()
381+ def wrapper ():
382+ cont .lock .acquire ()
383+ Handler .tls .value = cont .handler
384+ f (cont .arg )
385+ handler = Handler .tls .value
386+ del Handler .tls .value
387+ handler .result = None
388+ handler .lock .release ()
389+ threading .Thread (target = wrapper ).start ()
390+ return cont
391+
392+ def resume (cont : Continuation , v : any ) -> Optional [tuple [Continuation , any ]]:
393+ handler = Handler ()
394+ handler .lock = threading .Lock ()
395+ handler .lock .acquire ()
396+ cont .handler = handler
397+ cont .arg = v
398+ cont .lock .release ()
399+ handler .lock .acquire ()
400+ return handler .result
401+
402+ def suspend (v : any ) -> any :
403+ handler = Handler .tls .value
404+ del Handler .tls .value
405+ cont = Continuation ()
406+ cont .lock = threading .Lock ()
407+ cont .lock .acquire ()
408+ handler .result = (cont , v )
409+ handler .lock .release ()
410+ cont .lock .acquire ()
411+ Handler .tls .value = cont .handler
412+ return cont .arg
413+
365414#### Thread State
366415
367- class SuspendResult (IntEnum ):
368- NOT_CANCELLED = 0
369- CANCELLED = 1
416+ class Cancelled (IntEnum ):
417+ FALSE = 0
418+ TRUE = 1
370419
371420class Thread :
372- task : Task
373- fiber : threading .Thread
374- fiber_lock : threading .Lock
375- parent_lock : Optional [threading .Lock ]
421+ cont : Optional [Continuation ]
376422 ready_func : Optional [Callable [[], bool ]]
377- cancellable : bool
378- suspend_result : Optional [SuspendResult ]
423+ task : Task
379424 index : Optional [int ]
380425 context : list [int ]
426+ cancellable : bool
381427
382428 CONTEXT_LENGTH = 2
383429
384430 def running (self ):
385- return self .parent_lock is not None
431+ return self .cont is None
386432
387433 def suspended (self ):
388434 return not self .running () and self .ready_func is None
@@ -395,94 +441,70 @@ def ready(self):
395441 return self .ready_func ()
396442
397443 def __init__ (self , task , thread_func ):
398- self .task = task
399- self .fiber_lock = threading .Lock ()
400- self .fiber_lock .acquire ()
401- self .parent_lock = None
402- self .ready_func = None
403- self .cancellable = False
404- self .suspend_result = None
405- self .index = None
406- self .context = [0 ] * Thread .CONTEXT_LENGTH
407- def fiber_func ():
408- self .fiber_lock .acquire ()
409- assert (self .running () and self .suspend_result == SuspendResult .NOT_CANCELLED )
410- self .suspend_result = None
444+ def wrapper (cancelled ):
445+ assert (self .running () and not cancelled )
411446 thread_func (self )
412- assert (self .running ())
413447 self .task .thread_stop (self )
414448 if self .index is not None :
415449 self .task .inst .threads .remove (self .index )
416- self .parent_lock .release ()
417- self .fiber = threading .Thread (target = fiber_func )
418- self .fiber .start ()
419- self .task .thread_start (self )
450+ self .cont = cont_new (wrapper )
451+ self .ready_func = None
452+ self .task = task
453+ self .index = None
454+ self .context = [0 ] * Thread .CONTEXT_LENGTH
455+ self .cancellable = False
420456 assert (self .suspended ())
457+ self .task .thread_start (self )
421458
422- def resume (self , suspend_result = SuspendResult . NOT_CANCELLED ):
423- assert (not self .running () and self . suspend_result is None )
459+ def resume (self , cancelled ):
460+ assert (not self .running ())
424461 if self .ready_func :
425- assert (suspend_result == SuspendResult . CANCELLED or self .ready_func ())
462+ assert (cancelled or self .ready_func ())
426463 self .ready_func = None
427464 self .task .inst .store .pending .remove (self )
428- assert (self .cancellable or suspend_result == SuspendResult .NOT_CANCELLED )
429- self .suspend_result = suspend_result
430- self .parent_lock = threading .Lock ()
431- self .parent_lock .acquire ()
432- self .fiber_lock .release ()
433- self .parent_lock .acquire ()
434- self .parent_lock = None
435- assert (not self .running ())
436-
437- def suspend (self , cancellable ) -> SuspendResult :
438- assert (self .task .may_block ())
439- assert (self .running () and not self .cancellable and self .suspend_result is None )
465+ assert (self .cancellable or not cancelled )
466+ thread = self
467+ while True :
468+ cont = thread .cont
469+ thread .cont = None
470+ resume_result = resume (cont , cancelled )
471+ if resume_result is None :
472+ break
473+ (thread .cont , switch_to_thread ) = resume_result
474+ if switch_to_thread is None :
475+ break
476+ thread = switch_to_thread
477+ cancelled = Cancelled .FALSE
478+
479+ def suspend (self , cancellable ) -> Cancelled :
480+ assert (self .running () and self .task .may_block ())
440481 self .cancellable = cancellable
441- self .parent_lock .release ()
442- self .fiber_lock .acquire ()
443- assert (self .running ())
444- self .cancellable = False
445- suspend_result = self .suspend_result
446- self .suspend_result = None
447- assert (suspend_result is not None )
448- assert (cancellable or suspend_result == SuspendResult .NOT_CANCELLED )
449- return suspend_result
482+ cancelled = suspend (None )
483+ assert (self .running () and (cancellable or not cancelled ))
484+ return cancelled
450485
451486 def resume_later (self ):
452487 assert (self .suspended ())
453488 self .ready_func = lambda : True
454489 self .task .inst .store .pending .append (self )
455490
456- def suspend_until (self , ready_func , cancellable = False ) -> SuspendResult :
457- assert (self .task .may_block ())
458- assert (self .running ())
491+ def suspend_until (self , ready_func , cancellable = False ) -> Cancelled :
492+ assert (self .running () and self .task .may_block ())
459493 if ready_func () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
460- return SuspendResult . NOT_CANCELLED
494+ return Cancelled . FALSE
461495 self .ready_func = ready_func
462496 self .task .inst .store .pending .append (self )
463497 return self .suspend (cancellable )
464498
465- def switch_to (self , cancellable , other : Thread ) -> SuspendResult :
466- assert (self .running () and not self .cancellable and self .suspend_result is None )
467- assert (other .suspended () and other .suspend_result is None )
499+ def switch_to (self , cancellable , other : Thread ) -> Cancelled :
500+ assert (self .running ())
468501 self .cancellable = cancellable
469- other .suspend_result = SuspendResult .NOT_CANCELLED
470- assert (self .parent_lock and not other .parent_lock )
471- other .parent_lock = self .parent_lock
472- self .parent_lock = None
473- assert (not self .running () and other .running ())
474- other .fiber_lock .release ()
475- self .fiber_lock .acquire ()
502+ cancelled = suspend (other )
503+ assert (self .running () and (cancellable or not cancelled ))
504+ return cancelled
505+
506+ def yield_to (self , cancellable , other : Thread ) -> Cancelled :
476507 assert (self .running ())
477- self .cancellable = False
478- suspend_result = self .suspend_result
479- self .suspend_result = None
480- assert (suspend_result is not None )
481- assert (cancellable or suspend_result == SuspendResult .NOT_CANCELLED )
482- return suspend_result
483-
484- def yield_to (self , cancellable , other : Thread ) -> SuspendResult :
485- assert (not self .ready_func )
486508 self .ready_func = lambda : True
487509 self .task .inst .store .pending .append (self )
488510 return self .switch_to (cancellable , other )
@@ -610,7 +632,7 @@ def has_backpressure():
610632 self .inst .num_waiting_to_enter += 1
611633 result = thread .suspend_until (lambda : not has_backpressure (), cancellable = True )
612634 self .inst .num_waiting_to_enter -= 1
613- if result == SuspendResult . CANCELLED :
635+ if result == Cancelled . TRUE :
614636 self .cancel ()
615637 return False
616638 self .state = Task .State .UNRESOLVED
@@ -631,15 +653,15 @@ def request_cancellation(self):
631653 if self .state == Task .State .BACKPRESSURE :
632654 assert (len (self .threads ) == 1 )
633655 self .state = Task .State .CANCEL_DELIVERED
634- self .threads [0 ].resume (SuspendResult . CANCELLED )
656+ self .threads [0 ].resume (Cancelled . TRUE )
635657 return
636658 assert (self .state == Task .State .UNRESOLVED )
637659 if not self .needs_exclusive () or not self .inst .exclusive or self .inst .exclusive is self :
638660 random .shuffle (self .threads )
639661 for thread in self .threads :
640662 if thread .cancellable :
641663 self .state = Task .State .CANCEL_DELIVERED
642- thread .resume (SuspendResult . CANCELLED )
664+ thread .resume (Cancelled . TRUE )
643665 return
644666 self .state = Task .State .PENDING_CANCEL
645667
@@ -649,28 +671,28 @@ def deliver_pending_cancel(self, cancellable) -> bool:
649671 return True
650672 return False
651673
652- def suspend (self , thread , cancellable ) -> SuspendResult :
674+ def suspend (self , thread , cancellable ) -> Cancelled :
653675 assert (thread in self .threads and thread .task is self )
654676 if self .deliver_pending_cancel (cancellable ):
655- return SuspendResult . CANCELLED
677+ return Cancelled . TRUE
656678 return thread .suspend (cancellable )
657679
658- def suspend_until (self , ready_func , thread , cancellable ) -> SuspendResult :
680+ def suspend_until (self , ready_func , thread , cancellable ) -> Cancelled :
659681 assert (thread in self .threads and thread .task is self )
660682 if self .deliver_pending_cancel (cancellable ):
661- return SuspendResult . CANCELLED
683+ return Cancelled . TRUE
662684 return thread .suspend_until (ready_func , cancellable )
663685
664- def switch_to (self , thread , cancellable , other_thread ) -> SuspendResult :
686+ def switch_to (self , thread , cancellable , other_thread ) -> Cancelled :
665687 assert (thread in self .threads and thread .task is self )
666688 if self .deliver_pending_cancel (cancellable ):
667- return SuspendResult . CANCELLED
689+ return Cancelled . TRUE
668690 return thread .switch_to (cancellable , other_thread )
669691
670- def yield_to (self , thread , cancellable , other_thread ) -> SuspendResult :
692+ def yield_to (self , thread , cancellable , other_thread ) -> Cancelled :
671693 assert (thread in self .threads and thread .task is self )
672694 if self .deliver_pending_cancel (cancellable ):
673- return SuspendResult . CANCELLED
695+ return Cancelled . TRUE
674696 return thread .yield_to (cancellable , other_thread )
675697
676698 def wait_until (self , ready_func , thread , wset , cancellable ) -> EventTuple :
@@ -679,19 +701,19 @@ def wait_until(self, ready_func, thread, wset, cancellable) -> EventTuple:
679701 def ready_and_has_event ():
680702 return ready_func () and wset .has_pending_event ()
681703 match self .suspend_until (ready_and_has_event , thread , cancellable ):
682- case SuspendResult . CANCELLED :
704+ case Cancelled . TRUE :
683705 event = (EventCode .TASK_CANCELLED , 0 , 0 )
684- case SuspendResult . NOT_CANCELLED :
706+ case Cancelled . FALSE :
685707 event = wset .get_pending_event ()
686708 wset .num_waiting -= 1
687709 return event
688710
689711 def yield_until (self , ready_func , thread , cancellable ) -> EventTuple :
690712 assert (thread in self .threads and thread .task is self )
691713 match self .suspend_until (ready_func , thread , cancellable ):
692- case SuspendResult . CANCELLED :
714+ case Cancelled . TRUE :
693715 return (EventCode .TASK_CANCELLED , 0 , 0 )
694- case SuspendResult . NOT_CANCELLED :
716+ case Cancelled . FALSE :
695717 return (EventCode .NONE , 0 , 0 )
696718
697719 def return_ (self , result ):
@@ -2045,7 +2067,7 @@ def thread_func(thread):
20452067 return
20462068
20472069 thread = Thread (task , thread_func )
2048- thread .resume ()
2070+ thread .resume (Cancelled . FALSE )
20492071 return task
20502072
20512073class CallbackCode (IntEnum ):
@@ -2536,16 +2558,16 @@ def canon_thread_switch_to(cancellable, thread, i):
25362558 trap_if (not thread .task .inst .may_leave )
25372559 other_thread = thread .task .inst .threads .get (i )
25382560 trap_if (not other_thread .suspended ())
2539- suspend_result = thread .task .switch_to (thread , cancellable , other_thread )
2540- return [suspend_result ]
2561+ cancelled = thread .task .switch_to (thread , cancellable , other_thread )
2562+ return [cancelled ]
25412563
25422564### 🧵 `canon thread.suspend`
25432565
25442566def canon_thread_suspend (cancellable , thread ):
25452567 trap_if (not thread .task .inst .may_leave )
25462568 trap_if (not thread .task .may_block ())
2547- suspend_result = thread .task .suspend (thread , cancellable )
2548- return [suspend_result ]
2569+ cancelled = thread .task .suspend (thread , cancellable )
2570+ return [cancelled ]
25492571
25502572### 🧵 `canon thread.resume-later`
25512573
@@ -2562,21 +2584,21 @@ def canon_thread_yield_to(cancellable, thread, i):
25622584 trap_if (not thread .task .inst .may_leave )
25632585 other_thread = thread .task .inst .threads .get (i )
25642586 trap_if (not other_thread .suspended ())
2565- suspend_result = thread .task .yield_to (thread , cancellable , other_thread )
2566- return [suspend_result ]
2587+ cancelled = thread .task .yield_to (thread , cancellable , other_thread )
2588+ return [cancelled ]
25672589
25682590### 🧵 `canon thread.yield`
25692591
25702592def canon_thread_yield (cancellable , thread ):
25712593 trap_if (not thread .task .inst .may_leave )
25722594 if not thread .task .may_block ():
2573- return [SuspendResult . NOT_CANCELLED ]
2595+ return [Cancelled . FALSE ]
25742596 event_code ,_ ,_ = thread .task .yield_until (lambda : True , thread , cancellable )
25752597 match event_code :
25762598 case EventCode .NONE :
2577- return [SuspendResult . NOT_CANCELLED ]
2599+ return [Cancelled . FALSE ]
25782600 case EventCode .TASK_CANCELLED :
2579- return [SuspendResult . CANCELLED ]
2601+ return [Cancelled . TRUE ]
25802602
25812603### 📝 `canon error-context.new`
25822604
0 commit comments