Skip to content

Commit a89d27a

Browse files
committed
Rebase CABI onto explicit stack-switching interface (no behavior change)
1 parent 4068dbc commit a89d27a

File tree

2 files changed

+136
-116
lines changed

2 files changed

+136
-116
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 121 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def tick(self):
199199
random.shuffle(self.pending)
200200
for thread in self.pending:
201201
if thread.ready():
202-
thread.resume()
202+
thread.resume(Cancelled.FALSE)
203203
return
204204

205205
FuncInst: Callable[[Optional[Supertask], OnStart, OnResolve], Call]
@@ -389,27 +389,71 @@ def __init__(self, impl, dtor = None, dtor_async = False, dtor_callback = None):
389389
self.dtor_async = dtor_async
390390
self.dtor_callback = dtor_callback
391391

392+
#### Stack Switching Support
393+
394+
class Continuation:
395+
lock: threading.Lock
396+
handler: Handler
397+
arg: any
398+
399+
class Handler:
400+
tls = threading.local()
401+
lock: threading.Lock
402+
result: Optional[tuple[Continuation, any]]
403+
404+
def cont_new(f: Callable[[any], None]) -> Continuation:
405+
cont = Continuation()
406+
cont.lock = threading.Lock()
407+
cont.lock.acquire()
408+
def wrapper():
409+
cont.lock.acquire()
410+
Handler.tls.value = cont.handler
411+
f(cont.arg)
412+
handler = Handler.tls.value
413+
handler.result = None
414+
handler.lock.release()
415+
threading.Thread(target = wrapper).start()
416+
return cont
417+
418+
def resume(cont: Continuation, v: any) -> Optional[tuple[Continuation, any]]:
419+
handler = Handler()
420+
handler.lock = threading.Lock()
421+
handler.lock.acquire()
422+
cont.handler = handler
423+
cont.arg = v
424+
cont.lock.release()
425+
handler.lock.acquire()
426+
return handler.result
427+
428+
def suspend(v: any) -> any:
429+
handler = Handler.tls.value
430+
cont = Continuation()
431+
cont.lock = threading.Lock()
432+
cont.lock.acquire()
433+
handler.result = (cont, v)
434+
handler.lock.release()
435+
cont.lock.acquire()
436+
Handler.tls.value = cont.handler
437+
return cont.arg
438+
392439
#### Thread State
393440

394-
class SuspendResult(IntEnum):
395-
NOT_CANCELLED = 0
396-
CANCELLED = 1
441+
class Cancelled(IntEnum):
442+
FALSE = 0
443+
TRUE = 1
397444

398445
class Thread:
399-
task: Task
400-
fiber: threading.Thread
401-
fiber_lock: threading.Lock
402-
parent_lock: Optional[threading.Lock]
446+
cont: Optional[Continuation]
403447
ready_func: Optional[Callable[[], bool]]
404-
cancellable: bool
405-
suspend_result: Optional[SuspendResult]
448+
task: Task
406449
index: Optional[int]
407450
context: list[int]
451+
cancellable: bool
408452

409453
CONTEXT_LENGTH = 2
410454

411455
def running(self):
412-
return self.parent_lock is not None
456+
return self.cont is None
413457

414458
def suspended(self):
415459
return not self.running() and self.ready_func is None
@@ -422,94 +466,70 @@ def ready(self):
422466
return self.ready_func()
423467

424468
def __init__(self, task, thread_func):
425-
self.task = task
426-
self.fiber_lock = threading.Lock()
427-
self.fiber_lock.acquire()
428-
self.parent_lock = None
429-
self.ready_func = None
430-
self.cancellable = False
431-
self.suspend_result = None
432-
self.index = None
433-
self.context = [0] * Thread.CONTEXT_LENGTH
434-
def fiber_func():
435-
self.fiber_lock.acquire()
436-
assert(self.running() and self.suspend_result == SuspendResult.NOT_CANCELLED)
437-
self.suspend_result = None
469+
def wrapper(cancelled):
470+
assert(self.running() and not cancelled)
438471
thread_func(self)
439-
assert(self.running())
440472
self.task.thread_stop(self)
441473
if self.index is not None:
442474
self.task.inst.threads.remove(self.index)
443-
self.parent_lock.release()
444-
self.fiber = threading.Thread(target = fiber_func)
445-
self.fiber.start()
446-
self.task.thread_start(self)
475+
self.cont = cont_new(wrapper)
476+
self.ready_func = None
477+
self.task = task
478+
self.index = None
479+
self.context = [0] * Thread.CONTEXT_LENGTH
480+
self.cancellable = False
447481
assert(self.suspended())
482+
self.task.thread_start(self)
448483

449-
def resume(self, suspend_result = SuspendResult.NOT_CANCELLED):
450-
assert(not self.running() and self.suspend_result is None)
484+
def resume(self, cancelled):
485+
assert(self.cancellable or not cancelled)
486+
assert(not self.running())
451487
if self.ready_func:
452-
assert(suspend_result == SuspendResult.CANCELLED or self.ready_func())
488+
assert(cancelled or self.ready_func())
453489
self.ready_func = None
454490
self.task.inst.store.pending.remove(self)
455-
assert(self.cancellable or suspend_result == SuspendResult.NOT_CANCELLED)
456-
self.suspend_result = suspend_result
457-
self.parent_lock = threading.Lock()
458-
self.parent_lock.acquire()
459-
self.fiber_lock.release()
460-
self.parent_lock.acquire()
461-
self.parent_lock = None
462-
assert(not self.running())
463-
464-
def suspend(self, cancellable) -> SuspendResult:
465-
assert(self.task.may_block())
466-
assert(self.running() and not self.cancellable and self.suspend_result is None)
491+
thread = self
492+
while True:
493+
cont = thread.cont
494+
thread.cont = None
495+
resume_result = resume(cont, cancelled)
496+
if resume_result is None:
497+
break
498+
(thread.cont, switch_to_thread) = resume_result
499+
if switch_to_thread is None:
500+
break
501+
thread = switch_to_thread
502+
cancelled = Cancelled.FALSE
503+
504+
def suspend(self, cancellable) -> Cancelled:
505+
assert(self.running() and self.task.may_block())
467506
self.cancellable = cancellable
468-
self.parent_lock.release()
469-
self.fiber_lock.acquire()
470-
assert(self.running())
471-
self.cancellable = False
472-
suspend_result = self.suspend_result
473-
self.suspend_result = None
474-
assert(suspend_result is not None)
475-
assert(cancellable or suspend_result == SuspendResult.NOT_CANCELLED)
476-
return suspend_result
507+
cancelled = suspend(None)
508+
assert(self.running() and (cancellable or not cancelled))
509+
return cancelled
477510

478511
def resume_later(self):
479512
assert(self.suspended())
480513
self.ready_func = lambda: True
481514
self.task.inst.store.pending.append(self)
482515

483-
def suspend_until(self, ready_func, cancellable = False) -> SuspendResult:
484-
assert(self.task.may_block())
485-
assert(self.running())
516+
def suspend_until(self, ready_func, cancellable = False) -> Cancelled:
517+
assert(self.running() and self.task.may_block())
486518
if ready_func() and not DETERMINISTIC_PROFILE and random.randint(0,1):
487-
return SuspendResult.NOT_CANCELLED
519+
return Cancelled.FALSE
488520
self.ready_func = ready_func
489521
self.task.inst.store.pending.append(self)
490522
return self.suspend(cancellable)
491523

492-
def switch_to(self, cancellable, other: Thread) -> SuspendResult:
493-
assert(self.running() and not self.cancellable and self.suspend_result is None)
494-
assert(other.suspended() and other.suspend_result is None)
524+
def switch_to(self, cancellable, other: Thread) -> Cancelled:
525+
assert(self.running())
495526
self.cancellable = cancellable
496-
other.suspend_result = SuspendResult.NOT_CANCELLED
497-
assert(self.parent_lock and not other.parent_lock)
498-
other.parent_lock = self.parent_lock
499-
self.parent_lock = None
500-
assert(not self.running() and other.running())
501-
other.fiber_lock.release()
502-
self.fiber_lock.acquire()
527+
cancelled = suspend(other)
528+
assert(self.running() and (cancellable or not cancelled))
529+
return cancelled
530+
531+
def yield_to(self, cancellable, other: Thread) -> Cancelled:
503532
assert(self.running())
504-
self.cancellable = False
505-
suspend_result = self.suspend_result
506-
self.suspend_result = None
507-
assert(suspend_result is not None)
508-
assert(cancellable or suspend_result == SuspendResult.NOT_CANCELLED)
509-
return suspend_result
510-
511-
def yield_to(self, cancellable, other: Thread) -> SuspendResult:
512-
assert(not self.ready_func)
513533
self.ready_func = lambda: True
514534
self.task.inst.store.pending.append(self)
515535
return self.switch_to(cancellable, other)
@@ -637,7 +657,7 @@ def has_backpressure():
637657
self.inst.num_waiting_to_enter += 1
638658
result = thread.suspend_until(lambda: not has_backpressure(), cancellable = True)
639659
self.inst.num_waiting_to_enter -= 1
640-
if result == SuspendResult.CANCELLED:
660+
if result == Cancelled.TRUE:
641661
self.cancel()
642662
return False
643663
self.state = Task.State.UNRESOLVED
@@ -658,15 +678,15 @@ def request_cancellation(self):
658678
if self.state == Task.State.BACKPRESSURE:
659679
assert(len(self.threads) == 1)
660680
self.state = Task.State.CANCEL_DELIVERED
661-
self.threads[0].resume(SuspendResult.CANCELLED)
681+
self.threads[0].resume(Cancelled.TRUE)
662682
return
663683
assert(self.state == Task.State.UNRESOLVED)
664684
if not self.needs_exclusive() or not self.inst.exclusive or self.inst.exclusive is self:
665685
random.shuffle(self.threads)
666686
for thread in self.threads:
667687
if thread.cancellable:
668688
self.state = Task.State.CANCEL_DELIVERED
669-
thread.resume(SuspendResult.CANCELLED)
689+
thread.resume(Cancelled.TRUE)
670690
return
671691
self.state = Task.State.PENDING_CANCEL
672692

@@ -676,28 +696,28 @@ def deliver_pending_cancel(self, cancellable) -> bool:
676696
return True
677697
return False
678698

679-
def suspend(self, thread, cancellable) -> SuspendResult:
699+
def suspend(self, thread, cancellable) -> Cancelled:
680700
assert(thread in self.threads and thread.task is self)
681701
if self.deliver_pending_cancel(cancellable):
682-
return SuspendResult.CANCELLED
702+
return Cancelled.TRUE
683703
return thread.suspend(cancellable)
684704

685-
def suspend_until(self, ready_func, thread, cancellable) -> SuspendResult:
705+
def suspend_until(self, ready_func, thread, cancellable) -> Cancelled:
686706
assert(thread in self.threads and thread.task is self)
687707
if self.deliver_pending_cancel(cancellable):
688-
return SuspendResult.CANCELLED
708+
return Cancelled.TRUE
689709
return thread.suspend_until(ready_func, cancellable)
690710

691-
def switch_to(self, thread, cancellable, other_thread) -> SuspendResult:
711+
def switch_to(self, thread, cancellable, other_thread) -> Cancelled:
692712
assert(thread in self.threads and thread.task is self)
693713
if self.deliver_pending_cancel(cancellable):
694-
return SuspendResult.CANCELLED
714+
return Cancelled.TRUE
695715
return thread.switch_to(cancellable, other_thread)
696716

697-
def yield_to(self, thread, cancellable, other_thread) -> SuspendResult:
717+
def yield_to(self, thread, cancellable, other_thread) -> Cancelled:
698718
assert(thread in self.threads and thread.task is self)
699719
if self.deliver_pending_cancel(cancellable):
700-
return SuspendResult.CANCELLED
720+
return Cancelled.TRUE
701721
return thread.yield_to(cancellable, other_thread)
702722

703723
def wait_until(self, ready_func, thread, wset, cancellable) -> EventTuple:
@@ -706,19 +726,19 @@ def wait_until(self, ready_func, thread, wset, cancellable) -> EventTuple:
706726
def ready_and_has_event():
707727
return ready_func() and wset.has_pending_event()
708728
match self.suspend_until(ready_and_has_event, thread, cancellable):
709-
case SuspendResult.CANCELLED:
729+
case Cancelled.TRUE:
710730
event = (EventCode.TASK_CANCELLED, 0, 0)
711-
case SuspendResult.NOT_CANCELLED:
731+
case Cancelled.FALSE:
712732
event = wset.get_pending_event()
713733
wset.num_waiting -= 1
714734
return event
715735

716736
def yield_until(self, ready_func, thread, cancellable) -> EventTuple:
717737
assert(thread in self.threads and thread.task is self)
718738
match self.suspend_until(ready_func, thread, cancellable):
719-
case SuspendResult.CANCELLED:
739+
case Cancelled.TRUE:
720740
return (EventCode.TASK_CANCELLED, 0, 0)
721-
case SuspendResult.NOT_CANCELLED:
741+
case Cancelled.FALSE:
722742
return (EventCode.NONE, 0, 0)
723743

724744
def return_(self, result):
@@ -2082,7 +2102,7 @@ def thread_func(thread):
20822102
return
20832103

20842104
thread = Thread(task, thread_func)
2085-
thread.resume()
2105+
thread.resume(Cancelled.FALSE)
20862106
return task
20872107

20882108
class CallbackCode(IntEnum):
@@ -2578,16 +2598,16 @@ def canon_thread_switch_to(cancellable, thread, i):
25782598
trap_if(not thread.task.inst.may_leave)
25792599
other_thread = thread.task.inst.threads.get(i)
25802600
trap_if(not other_thread.suspended())
2581-
suspend_result = thread.task.switch_to(thread, cancellable, other_thread)
2582-
return [suspend_result]
2601+
cancelled = thread.task.switch_to(thread, cancellable, other_thread)
2602+
return [cancelled]
25832603

25842604
### 🧵 `canon thread.suspend`
25852605

25862606
def canon_thread_suspend(cancellable, thread):
25872607
trap_if(not thread.task.inst.may_leave)
25882608
trap_if(not thread.task.may_block())
2589-
suspend_result = thread.task.suspend(thread, cancellable)
2590-
return [suspend_result]
2609+
cancelled = thread.task.suspend(thread, cancellable)
2610+
return [cancelled]
25912611

25922612
### 🧵 `canon thread.resume-later`
25932613

@@ -2604,21 +2624,21 @@ def canon_thread_yield_to(cancellable, thread, i):
26042624
trap_if(not thread.task.inst.may_leave)
26052625
other_thread = thread.task.inst.threads.get(i)
26062626
trap_if(not other_thread.suspended())
2607-
suspend_result = thread.task.yield_to(thread, cancellable, other_thread)
2608-
return [suspend_result]
2627+
cancelled = thread.task.yield_to(thread, cancellable, other_thread)
2628+
return [cancelled]
26092629

26102630
### 🧵 `canon thread.yield`
26112631

26122632
def canon_thread_yield(cancellable, thread):
26132633
trap_if(not thread.task.inst.may_leave)
26142634
if not thread.task.may_block():
2615-
return [SuspendResult.NOT_CANCELLED]
2635+
return [Cancelled.FALSE]
26162636
event_code,_,_ = thread.task.yield_until(lambda: True, thread, cancellable)
26172637
match event_code:
26182638
case EventCode.NONE:
2619-
return [SuspendResult.NOT_CANCELLED]
2639+
return [Cancelled.FALSE]
26202640
case EventCode.TASK_CANCELLED:
2621-
return [SuspendResult.CANCELLED]
2641+
return [Cancelled.TRUE]
26222642

26232643
### 📝 `canon error-context.new`
26242644

0 commit comments

Comments
 (0)