@@ -362,6 +362,52 @@ def __init__(self, impl, dtor = None, dtor_async = False, dtor_callback = None):
362362 self .dtor_async = dtor_async
363363 self .dtor_callback = dtor_callback
364364
365+ #### Stack Switching Support
366+
367+ class Continuation :
368+ lock : threading .Lock
369+ handler : Handler
370+
371+ class Handler :
372+ tls = threading .local ()
373+ lock : threading .Lock
374+ result : Optional [tuple [Continuation , any ]]
375+
376+ def cont_new (f : Callable [[], None ]) -> Continuation :
377+ cont = Continuation ()
378+ cont .lock = threading .Lock ()
379+ cont .lock .acquire ()
380+ def wrapper ():
381+ cont .lock .acquire ()
382+ Handler .tls .value = cont .handler
383+ f ()
384+ handler = Handler .tls .value
385+ del Handler .tls .value
386+ handler .result = None
387+ handler .lock .release ()
388+ threading .Thread (target = wrapper ).start ()
389+ return cont
390+
391+ def resume (cont : Continuation ) -> Optional [tuple [Continuation , any ]]:
392+ handler = Handler ()
393+ handler .lock = threading .Lock ()
394+ handler .lock .acquire ()
395+ cont .handler = handler
396+ cont .lock .release ()
397+ handler .lock .acquire ()
398+ return handler .result
399+
400+ def suspend (v : any ):
401+ handler = Handler .tls .value
402+ del Handler .tls .value
403+ cont = Continuation ()
404+ cont .lock = threading .Lock ()
405+ cont .lock .acquire ()
406+ handler .result = (cont , v )
407+ handler .lock .release ()
408+ cont .lock .acquire ()
409+ Handler .tls .value = cont .handler
410+
365411#### Thread State
366412
367413class SuspendResult (IntEnum ):
@@ -370,9 +416,7 @@ class SuspendResult(IntEnum):
370416
371417class Thread :
372418 task : Task
373- fiber : threading .Thread
374- fiber_lock : threading .Lock
375- parent_lock : Optional [threading .Lock ]
419+ cont : Optional [Continuation ]
376420 ready_func : Optional [Callable [[], bool ]]
377421 cancellable : bool
378422 suspend_result : Optional [SuspendResult ]
@@ -383,7 +427,7 @@ class Thread:
383427 CONTEXT_LENGTH = 2
384428
385429 def running (self ):
386- return self .parent_lock is not None
430+ return self .cont is None
387431
388432 def suspended (self ):
389433 return not self .running () and self .ready_func is None
@@ -397,51 +441,53 @@ def ready(self):
397441
398442 def __init__ (self , task , thread_func ):
399443 self .task = task
400- self .fiber_lock = threading .Lock ()
401- self .fiber_lock .acquire ()
402- self .parent_lock = None
444+ self .cont = None
403445 self .ready_func = None
404446 self .cancellable = False
405447 self .suspend_result = None
406448 self .in_event_loop = False
407449 self .index = None
408450 self .context = [0 ] * Thread .CONTEXT_LENGTH
409- def fiber_func ():
410- self .fiber_lock .acquire ()
451+ def wrapper ():
411452 assert (self .running () and self .suspend_result == SuspendResult .NOT_CANCELLED )
412453 self .suspend_result = None
413454 thread_func (self )
414455 assert (self .running ())
415456 self .task .thread_stop (self )
416457 if self .index is not None :
417458 self .task .inst .threads .remove (self .index )
418- self .parent_lock .release ()
419- self .fiber = threading .Thread (target = fiber_func )
420- self .fiber .start ()
459+ self .index = None
460+ self .cont = cont_new (wrapper )
421461 self .task .thread_start (self )
422462 assert (self .suspended ())
423463
424464 def resume (self , suspend_result = SuspendResult .NOT_CANCELLED ):
425- assert (not self .running () and self . suspend_result is None )
465+ assert (not self .running ())
426466 if self .ready_func :
427467 assert (suspend_result == SuspendResult .CANCELLED or self .ready_func ())
428468 self .ready_func = None
429469 self .task .inst .store .pending .remove (self )
430470 assert (self .cancellable or suspend_result == SuspendResult .NOT_CANCELLED )
431- self .suspend_result = suspend_result
432- self .parent_lock = threading .Lock ()
433- self .parent_lock .acquire ()
434- self .fiber_lock .release ()
435- self .parent_lock .acquire ()
436- self .parent_lock = None
437- assert (not self .running ())
471+ thread = self
472+ while True :
473+ assert (not thread .running () and thread .suspend_result is None )
474+ cont = thread .cont
475+ thread .cont = None
476+ thread .suspend_result = suspend_result
477+ if not (resume_result := resume (cont )):
478+ assert (thread .index is None and thread not in self .task .threads )
479+ return
480+ thread .cont , switch_to_thread = resume_result
481+ if switch_to_thread is None :
482+ return
483+ thread = switch_to_thread
484+ suspend_result = SuspendResult .NOT_CANCELLED
438485
439486 def suspend (self , cancellable ) -> SuspendResult :
440487 assert (self .task .may_block ())
441488 assert (self .running () and not self .cancellable and self .suspend_result is None )
442489 self .cancellable = cancellable
443- self .parent_lock .release ()
444- self .fiber_lock .acquire ()
490+ suspend (None )
445491 assert (self .running ())
446492 self .cancellable = False
447493 suspend_result = self .suspend_result
@@ -465,16 +511,9 @@ def suspend_until(self, ready_func, cancellable = False) -> SuspendResult:
465511 return self .suspend (cancellable )
466512
467513 def switch_to (self , cancellable , other : Thread ) -> SuspendResult :
468- assert (self .running () and not self .cancellable and self .suspend_result is None )
469- assert (other .suspended () and other .suspend_result is None )
514+ assert (self .running () and not self .cancellable )
470515 self .cancellable = cancellable
471- other .suspend_result = SuspendResult .NOT_CANCELLED
472- assert (self .parent_lock and not other .parent_lock )
473- other .parent_lock = self .parent_lock
474- self .parent_lock = None
475- assert (not self .running () and other .running ())
476- other .fiber_lock .release ()
477- self .fiber_lock .acquire ()
516+ suspend (other )
478517 assert (self .running ())
479518 self .cancellable = False
480519 suspend_result = self .suspend_result
0 commit comments