Skip to content

Commit fa9f0ab

Browse files
committed
Rebase CABI onto explicit stack-switching interface (no behavior change)
1 parent 4068dbc commit fa9f0ab

2 files changed

Lines changed: 132 additions & 116 deletions

File tree

design/mvp/canonical-abi/definitions.py

Lines changed: 117 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def tick(self):
199199
random.shuffle(self.pending)
200200
for thread in self.pending:
201201
if thread.ready():
202-
thread.resume()
202+
thread.resume(Cancelled.FALSE)
203203
return
204204

205205
FuncInst: Callable[[Optional[Supertask], OnStart, OnResolve], Call]
@@ -389,27 +389,74 @@ def __init__(self, impl, dtor = None, dtor_async = False, dtor_callback = None):
389389
self.dtor_async = dtor_async
390390
self.dtor_callback = dtor_callback
391391

392+
#### Stack Switching Support
393+
394+
class Continuation:
395+
lock: threading.Lock
396+
handler: Handler
397+
cancelled: Cancelled
398+
399+
class Handler:
400+
tls = threading.local()
401+
lock: threading.Lock
402+
cont: Optional[Continuation]
403+
switch_to: Optional[Thread]
404+
405+
def cont_new(f: Callable[[Cancelled], Optional[Thread]]) -> Continuation:
406+
cont = Continuation()
407+
cont.lock = threading.Lock()
408+
cont.lock.acquire()
409+
def wrapper():
410+
cont.lock.acquire()
411+
Handler.tls.value = cont.handler
412+
switch_to = f(cont.cancelled)
413+
handler = Handler.tls.value
414+
handler.cont = None
415+
handler.switch_to = switch_to
416+
handler.lock.release()
417+
threading.Thread(target = wrapper).start()
418+
return cont
419+
420+
def resume(cont: Continuation, cancelled: Cancelled) -> tuple[Optional[Continuation], Optional[Thread]]:
421+
handler = Handler()
422+
handler.lock = threading.Lock()
423+
handler.lock.acquire()
424+
cont.handler = handler
425+
cont.cancelled = cancelled
426+
cont.lock.release()
427+
handler.lock.acquire()
428+
return (handler.cont, handler.switch_to)
429+
430+
def suspend(switch_to: Optional[Thread]) -> Cancelled:
431+
cont = Continuation()
432+
cont.lock = threading.Lock()
433+
cont.lock.acquire()
434+
handler = Handler.tls.value
435+
handler.cont = cont
436+
handler.switch_to = switch_to
437+
handler.lock.release()
438+
cont.lock.acquire()
439+
Handler.tls.value = cont.handler
440+
return cont.cancelled
441+
392442
#### Thread State
393443

394-
class SuspendResult(IntEnum):
395-
NOT_CANCELLED = 0
396-
CANCELLED = 1
444+
class Cancelled(IntEnum):
445+
FALSE = 0
446+
TRUE = 1
397447

398448
class Thread:
399-
task: Task
400-
fiber: threading.Thread
401-
fiber_lock: threading.Lock
402-
parent_lock: Optional[threading.Lock]
449+
cont: Optional[Continuation]
403450
ready_func: Optional[Callable[[], bool]]
404-
cancellable: bool
405-
suspend_result: Optional[SuspendResult]
451+
task: Task
406452
index: Optional[int]
407453
context: list[int]
454+
cancellable: bool
408455

409456
CONTEXT_LENGTH = 2
410457

411458
def running(self):
412-
return self.parent_lock is not None
459+
return self.cont is None
413460

414461
def suspended(self):
415462
return not self.running() and self.ready_func is None
@@ -422,94 +469,63 @@ def ready(self):
422469
return self.ready_func()
423470

424471
def __init__(self, task, thread_func):
425-
self.task = task
426-
self.fiber_lock = threading.Lock()
427-
self.fiber_lock.acquire()
428-
self.parent_lock = None
429-
self.ready_func = None
430-
self.cancellable = False
431-
self.suspend_result = None
432-
self.index = None
433-
self.context = [0] * Thread.CONTEXT_LENGTH
434-
def fiber_func():
435-
self.fiber_lock.acquire()
436-
assert(self.running() and self.suspend_result == SuspendResult.NOT_CANCELLED)
437-
self.suspend_result = None
472+
def wrapper(cancelled):
473+
assert(self.running() and not cancelled)
438474
thread_func(self)
439-
assert(self.running())
440475
self.task.thread_stop(self)
441476
if self.index is not None:
442477
self.task.inst.threads.remove(self.index)
443-
self.parent_lock.release()
444-
self.fiber = threading.Thread(target = fiber_func)
445-
self.fiber.start()
446-
self.task.thread_start(self)
478+
self.cont = cont_new(wrapper)
479+
self.ready_func = None
480+
self.task = task
481+
self.index = None
482+
self.context = [0] * Thread.CONTEXT_LENGTH
483+
self.cancellable = False
447484
assert(self.suspended())
485+
self.task.thread_start(self)
448486

449-
def resume(self, suspend_result = SuspendResult.NOT_CANCELLED):
450-
assert(not self.running() and self.suspend_result is None)
487+
def resume(self, cancelled):
488+
assert(not self.running() and (self.cancellable or not cancelled))
451489
if self.ready_func:
452-
assert(suspend_result == SuspendResult.CANCELLED or self.ready_func())
490+
assert(cancelled or self.ready_func())
453491
self.ready_func = None
454492
self.task.inst.store.pending.remove(self)
455-
assert(self.cancellable or suspend_result == SuspendResult.NOT_CANCELLED)
456-
self.suspend_result = suspend_result
457-
self.parent_lock = threading.Lock()
458-
self.parent_lock.acquire()
459-
self.fiber_lock.release()
460-
self.parent_lock.acquire()
461-
self.parent_lock = None
462-
assert(not self.running())
463-
464-
def suspend(self, cancellable) -> SuspendResult:
465-
assert(self.task.may_block())
466-
assert(self.running() and not self.cancellable and self.suspend_result is None)
493+
thread = self
494+
while thread is not None:
495+
cont = thread.cont
496+
thread.cont = None
497+
thread.cont, thread = resume(cont, cancelled)
498+
cancelled = Cancelled.FALSE
499+
500+
def suspend(self, cancellable) -> Cancelled:
501+
assert(self.running() and self.task.may_block())
467502
self.cancellable = cancellable
468-
self.parent_lock.release()
469-
self.fiber_lock.acquire()
470-
assert(self.running())
471-
self.cancellable = False
472-
suspend_result = self.suspend_result
473-
self.suspend_result = None
474-
assert(suspend_result is not None)
475-
assert(cancellable or suspend_result == SuspendResult.NOT_CANCELLED)
476-
return suspend_result
503+
cancelled = suspend(None)
504+
assert(self.running() and (cancellable or not cancelled))
505+
return cancelled
477506

478507
def resume_later(self):
479508
assert(self.suspended())
480509
self.ready_func = lambda: True
481510
self.task.inst.store.pending.append(self)
482511

483-
def suspend_until(self, ready_func, cancellable = False) -> SuspendResult:
484-
assert(self.task.may_block())
485-
assert(self.running())
512+
def suspend_until(self, ready_func, cancellable = False) -> Cancelled:
513+
assert(self.running() and self.task.may_block())
486514
if ready_func() and not DETERMINISTIC_PROFILE and random.randint(0,1):
487-
return SuspendResult.NOT_CANCELLED
515+
return Cancelled.FALSE
488516
self.ready_func = ready_func
489517
self.task.inst.store.pending.append(self)
490518
return self.suspend(cancellable)
491519

492-
def switch_to(self, cancellable, other: Thread) -> SuspendResult:
493-
assert(self.running() and not self.cancellable and self.suspend_result is None)
494-
assert(other.suspended() and other.suspend_result is None)
520+
def switch_to(self, cancellable, other: Thread) -> Cancelled:
521+
assert(self.running())
495522
self.cancellable = cancellable
496-
other.suspend_result = SuspendResult.NOT_CANCELLED
497-
assert(self.parent_lock and not other.parent_lock)
498-
other.parent_lock = self.parent_lock
499-
self.parent_lock = None
500-
assert(not self.running() and other.running())
501-
other.fiber_lock.release()
502-
self.fiber_lock.acquire()
523+
cancelled = suspend(other)
524+
assert(self.running() and (cancellable or not cancelled))
525+
return cancelled
526+
527+
def yield_to(self, cancellable, other: Thread) -> Cancelled:
503528
assert(self.running())
504-
self.cancellable = False
505-
suspend_result = self.suspend_result
506-
self.suspend_result = None
507-
assert(suspend_result is not None)
508-
assert(cancellable or suspend_result == SuspendResult.NOT_CANCELLED)
509-
return suspend_result
510-
511-
def yield_to(self, cancellable, other: Thread) -> SuspendResult:
512-
assert(not self.ready_func)
513529
self.ready_func = lambda: True
514530
self.task.inst.store.pending.append(self)
515531
return self.switch_to(cancellable, other)
@@ -637,7 +653,7 @@ def has_backpressure():
637653
self.inst.num_waiting_to_enter += 1
638654
result = thread.suspend_until(lambda: not has_backpressure(), cancellable = True)
639655
self.inst.num_waiting_to_enter -= 1
640-
if result == SuspendResult.CANCELLED:
656+
if result == Cancelled.TRUE:
641657
self.cancel()
642658
return False
643659
self.state = Task.State.UNRESOLVED
@@ -658,15 +674,15 @@ def request_cancellation(self):
658674
if self.state == Task.State.BACKPRESSURE:
659675
assert(len(self.threads) == 1)
660676
self.state = Task.State.CANCEL_DELIVERED
661-
self.threads[0].resume(SuspendResult.CANCELLED)
677+
self.threads[0].resume(Cancelled.TRUE)
662678
return
663679
assert(self.state == Task.State.UNRESOLVED)
664680
if not self.needs_exclusive() or not self.inst.exclusive or self.inst.exclusive is self:
665681
random.shuffle(self.threads)
666682
for thread in self.threads:
667683
if thread.cancellable:
668684
self.state = Task.State.CANCEL_DELIVERED
669-
thread.resume(SuspendResult.CANCELLED)
685+
thread.resume(Cancelled.TRUE)
670686
return
671687
self.state = Task.State.PENDING_CANCEL
672688

@@ -676,28 +692,28 @@ def deliver_pending_cancel(self, cancellable) -> bool:
676692
return True
677693
return False
678694

679-
def suspend(self, thread, cancellable) -> SuspendResult:
695+
def suspend(self, thread, cancellable) -> Cancelled:
680696
assert(thread in self.threads and thread.task is self)
681697
if self.deliver_pending_cancel(cancellable):
682-
return SuspendResult.CANCELLED
698+
return Cancelled.TRUE
683699
return thread.suspend(cancellable)
684700

685-
def suspend_until(self, ready_func, thread, cancellable) -> SuspendResult:
701+
def suspend_until(self, ready_func, thread, cancellable) -> Cancelled:
686702
assert(thread in self.threads and thread.task is self)
687703
if self.deliver_pending_cancel(cancellable):
688-
return SuspendResult.CANCELLED
704+
return Cancelled.TRUE
689705
return thread.suspend_until(ready_func, cancellable)
690706

691-
def switch_to(self, thread, cancellable, other_thread) -> SuspendResult:
707+
def switch_to(self, thread, cancellable, other_thread) -> Cancelled:
692708
assert(thread in self.threads and thread.task is self)
693709
if self.deliver_pending_cancel(cancellable):
694-
return SuspendResult.CANCELLED
710+
return Cancelled.TRUE
695711
return thread.switch_to(cancellable, other_thread)
696712

697-
def yield_to(self, thread, cancellable, other_thread) -> SuspendResult:
713+
def yield_to(self, thread, cancellable, other_thread) -> Cancelled:
698714
assert(thread in self.threads and thread.task is self)
699715
if self.deliver_pending_cancel(cancellable):
700-
return SuspendResult.CANCELLED
716+
return Cancelled.TRUE
701717
return thread.yield_to(cancellable, other_thread)
702718

703719
def wait_until(self, ready_func, thread, wset, cancellable) -> EventTuple:
@@ -706,19 +722,19 @@ def wait_until(self, ready_func, thread, wset, cancellable) -> EventTuple:
706722
def ready_and_has_event():
707723
return ready_func() and wset.has_pending_event()
708724
match self.suspend_until(ready_and_has_event, thread, cancellable):
709-
case SuspendResult.CANCELLED:
725+
case Cancelled.TRUE:
710726
event = (EventCode.TASK_CANCELLED, 0, 0)
711-
case SuspendResult.NOT_CANCELLED:
727+
case Cancelled.FALSE:
712728
event = wset.get_pending_event()
713729
wset.num_waiting -= 1
714730
return event
715731

716732
def yield_until(self, ready_func, thread, cancellable) -> EventTuple:
717733
assert(thread in self.threads and thread.task is self)
718734
match self.suspend_until(ready_func, thread, cancellable):
719-
case SuspendResult.CANCELLED:
735+
case Cancelled.TRUE:
720736
return (EventCode.TASK_CANCELLED, 0, 0)
721-
case SuspendResult.NOT_CANCELLED:
737+
case Cancelled.FALSE:
722738
return (EventCode.NONE, 0, 0)
723739

724740
def return_(self, result):
@@ -2082,7 +2098,7 @@ def thread_func(thread):
20822098
return
20832099

20842100
thread = Thread(task, thread_func)
2085-
thread.resume()
2101+
thread.resume(Cancelled.FALSE)
20862102
return task
20872103

20882104
class CallbackCode(IntEnum):
@@ -2578,16 +2594,16 @@ def canon_thread_switch_to(cancellable, thread, i):
25782594
trap_if(not thread.task.inst.may_leave)
25792595
other_thread = thread.task.inst.threads.get(i)
25802596
trap_if(not other_thread.suspended())
2581-
suspend_result = thread.task.switch_to(thread, cancellable, other_thread)
2582-
return [suspend_result]
2597+
cancelled = thread.task.switch_to(thread, cancellable, other_thread)
2598+
return [cancelled]
25832599

25842600
### 🧵 `canon thread.suspend`
25852601

25862602
def canon_thread_suspend(cancellable, thread):
25872603
trap_if(not thread.task.inst.may_leave)
25882604
trap_if(not thread.task.may_block())
2589-
suspend_result = thread.task.suspend(thread, cancellable)
2590-
return [suspend_result]
2605+
cancelled = thread.task.suspend(thread, cancellable)
2606+
return [cancelled]
25912607

25922608
### 🧵 `canon thread.resume-later`
25932609

@@ -2604,21 +2620,21 @@ def canon_thread_yield_to(cancellable, thread, i):
26042620
trap_if(not thread.task.inst.may_leave)
26052621
other_thread = thread.task.inst.threads.get(i)
26062622
trap_if(not other_thread.suspended())
2607-
suspend_result = thread.task.yield_to(thread, cancellable, other_thread)
2608-
return [suspend_result]
2623+
cancelled = thread.task.yield_to(thread, cancellable, other_thread)
2624+
return [cancelled]
26092625

26102626
### 🧵 `canon thread.yield`
26112627

26122628
def canon_thread_yield(cancellable, thread):
26132629
trap_if(not thread.task.inst.may_leave)
26142630
if not thread.task.may_block():
2615-
return [SuspendResult.NOT_CANCELLED]
2631+
return [Cancelled.FALSE]
26162632
event_code,_,_ = thread.task.yield_until(lambda: True, thread, cancellable)
26172633
match event_code:
26182634
case EventCode.NONE:
2619-
return [SuspendResult.NOT_CANCELLED]
2635+
return [Cancelled.FALSE]
26202636
case EventCode.TASK_CANCELLED:
2621-
return [SuspendResult.CANCELLED]
2637+
return [Cancelled.TRUE]
26222638

26232639
### 📝 `canon error-context.new`
26242640

0 commit comments

Comments
 (0)