Skip to content

Commit f1630c8

Browse files
committed
Rebase CABI onto explicit stack-switching interface (no behavior change)
1 parent 71c2e05 commit f1630c8

File tree

2 files changed

+138
-116
lines changed

2 files changed

+138
-116
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 123 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def tick(self):
199199
random.shuffle(self.pending)
200200
for thread in self.pending:
201201
if thread.ready():
202-
thread.resume()
202+
thread.resume(Cancelled.FALSE)
203203
return
204204

205205
FuncInst: Callable[[Optional[Supertask], OnStart, OnResolve], Call]
@@ -362,27 +362,73 @@ def __init__(self, impl, dtor = None, dtor_async = False, dtor_callback = None):
362362
self.dtor_async = dtor_async
363363
self.dtor_callback = dtor_callback
364364

365+
#### Stack Switching Support
366+
367+
class Continuation:
368+
lock: threading.Lock
369+
handler: Handler
370+
arg: any
371+
372+
class Handler:
373+
tls = threading.local()
374+
lock: threading.Lock
375+
result: Optional[tuple[Continuation, any]]
376+
377+
def cont_new(f: Callable[[any], None]) -> Continuation:
378+
cont = Continuation()
379+
cont.lock = threading.Lock()
380+
cont.lock.acquire()
381+
def wrapper():
382+
cont.lock.acquire()
383+
Handler.tls.value = cont.handler
384+
f(cont.arg)
385+
handler = Handler.tls.value
386+
del Handler.tls.value
387+
handler.result = None
388+
handler.lock.release()
389+
threading.Thread(target = wrapper).start()
390+
return cont
391+
392+
def resume(cont: Continuation, v: any) -> Optional[tuple[Continuation, any]]:
393+
handler = Handler()
394+
handler.lock = threading.Lock()
395+
handler.lock.acquire()
396+
cont.handler = handler
397+
cont.arg = v
398+
cont.lock.release()
399+
handler.lock.acquire()
400+
return handler.result
401+
402+
def suspend(v: any) -> any:
403+
handler = Handler.tls.value
404+
del Handler.tls.value
405+
cont = Continuation()
406+
cont.lock = threading.Lock()
407+
cont.lock.acquire()
408+
handler.result = (cont, v)
409+
handler.lock.release()
410+
cont.lock.acquire()
411+
Handler.tls.value = cont.handler
412+
return cont.arg
413+
365414
#### Thread State
366415

367-
class SuspendResult(IntEnum):
368-
NOT_CANCELLED = 0
369-
CANCELLED = 1
416+
class Cancelled(IntEnum):
417+
FALSE = 0
418+
TRUE = 1
370419

371420
class Thread:
372-
task: Task
373-
fiber: threading.Thread
374-
fiber_lock: threading.Lock
375-
parent_lock: Optional[threading.Lock]
421+
cont: Optional[Continuation]
376422
ready_func: Optional[Callable[[], bool]]
377-
cancellable: bool
378-
suspend_result: Optional[SuspendResult]
423+
task: Task
379424
index: Optional[int]
380425
context: list[int]
426+
cancellable: bool
381427

382428
CONTEXT_LENGTH = 2
383429

384430
def running(self):
385-
return self.parent_lock is not None
431+
return self.cont is None
386432

387433
def suspended(self):
388434
return not self.running() and self.ready_func is None
@@ -395,94 +441,70 @@ def ready(self):
395441
return self.ready_func()
396442

397443
def __init__(self, task, thread_func):
398-
self.task = task
399-
self.fiber_lock = threading.Lock()
400-
self.fiber_lock.acquire()
401-
self.parent_lock = None
402-
self.ready_func = None
403-
self.cancellable = False
404-
self.suspend_result = None
405-
self.index = None
406-
self.context = [0] * Thread.CONTEXT_LENGTH
407-
def fiber_func():
408-
self.fiber_lock.acquire()
409-
assert(self.running() and self.suspend_result == SuspendResult.NOT_CANCELLED)
410-
self.suspend_result = None
444+
def wrapper(cancelled):
445+
assert(self.running() and not cancelled)
411446
thread_func(self)
412-
assert(self.running())
413447
self.task.thread_stop(self)
414448
if self.index is not None:
415449
self.task.inst.threads.remove(self.index)
416-
self.parent_lock.release()
417-
self.fiber = threading.Thread(target = fiber_func)
418-
self.fiber.start()
419-
self.task.thread_start(self)
450+
self.cont = cont_new(wrapper)
451+
self.ready_func = None
452+
self.task = task
453+
self.index = None
454+
self.context = [0] * Thread.CONTEXT_LENGTH
455+
self.cancellable = False
420456
assert(self.suspended())
457+
self.task.thread_start(self)
421458

422-
def resume(self, suspend_result = SuspendResult.NOT_CANCELLED):
423-
assert(not self.running() and self.suspend_result is None)
459+
def resume(self, cancelled):
460+
assert(not self.running())
424461
if self.ready_func:
425-
assert(suspend_result == SuspendResult.CANCELLED or self.ready_func())
462+
assert(cancelled or self.ready_func())
426463
self.ready_func = None
427464
self.task.inst.store.pending.remove(self)
428-
assert(self.cancellable or suspend_result == SuspendResult.NOT_CANCELLED)
429-
self.suspend_result = suspend_result
430-
self.parent_lock = threading.Lock()
431-
self.parent_lock.acquire()
432-
self.fiber_lock.release()
433-
self.parent_lock.acquire()
434-
self.parent_lock = None
435-
assert(not self.running())
436-
437-
def suspend(self, cancellable) -> SuspendResult:
438-
assert(self.task.may_block())
439-
assert(self.running() and not self.cancellable and self.suspend_result is None)
465+
assert(self.cancellable or not cancelled)
466+
thread = self
467+
while True:
468+
cont = thread.cont
469+
thread.cont = None
470+
resume_result = resume(cont, cancelled)
471+
if resume_result is None:
472+
break
473+
(thread.cont, switch_to_thread) = resume_result
474+
if switch_to_thread is None:
475+
break
476+
thread = switch_to_thread
477+
cancelled = Cancelled.FALSE
478+
479+
def suspend(self, cancellable) -> Cancelled:
480+
assert(self.running() and self.task.may_block())
440481
self.cancellable = cancellable
441-
self.parent_lock.release()
442-
self.fiber_lock.acquire()
443-
assert(self.running())
444-
self.cancellable = False
445-
suspend_result = self.suspend_result
446-
self.suspend_result = None
447-
assert(suspend_result is not None)
448-
assert(cancellable or suspend_result == SuspendResult.NOT_CANCELLED)
449-
return suspend_result
482+
cancelled = suspend(None)
483+
assert(self.running() and (cancellable or not cancelled))
484+
return cancelled
450485

451486
def resume_later(self):
452487
assert(self.suspended())
453488
self.ready_func = lambda: True
454489
self.task.inst.store.pending.append(self)
455490

456-
def suspend_until(self, ready_func, cancellable = False) -> SuspendResult:
457-
assert(self.task.may_block())
458-
assert(self.running())
491+
def suspend_until(self, ready_func, cancellable = False) -> Cancelled:
492+
assert(self.running() and self.task.may_block())
459493
if ready_func() and not DETERMINISTIC_PROFILE and random.randint(0,1):
460-
return SuspendResult.NOT_CANCELLED
494+
return Cancelled.FALSE
461495
self.ready_func = ready_func
462496
self.task.inst.store.pending.append(self)
463497
return self.suspend(cancellable)
464498

465-
def switch_to(self, cancellable, other: Thread) -> SuspendResult:
466-
assert(self.running() and not self.cancellable and self.suspend_result is None)
467-
assert(other.suspended() and other.suspend_result is None)
499+
def switch_to(self, cancellable, other: Thread) -> Cancelled:
500+
assert(self.running())
468501
self.cancellable = cancellable
469-
other.suspend_result = SuspendResult.NOT_CANCELLED
470-
assert(self.parent_lock and not other.parent_lock)
471-
other.parent_lock = self.parent_lock
472-
self.parent_lock = None
473-
assert(not self.running() and other.running())
474-
other.fiber_lock.release()
475-
self.fiber_lock.acquire()
502+
cancelled = suspend(other)
503+
assert(self.running() and (cancellable or not cancelled))
504+
return cancelled
505+
506+
def yield_to(self, cancellable, other: Thread) -> Cancelled:
476507
assert(self.running())
477-
self.cancellable = False
478-
suspend_result = self.suspend_result
479-
self.suspend_result = None
480-
assert(suspend_result is not None)
481-
assert(cancellable or suspend_result == SuspendResult.NOT_CANCELLED)
482-
return suspend_result
483-
484-
def yield_to(self, cancellable, other: Thread) -> SuspendResult:
485-
assert(not self.ready_func)
486508
self.ready_func = lambda: True
487509
self.task.inst.store.pending.append(self)
488510
return self.switch_to(cancellable, other)
@@ -610,7 +632,7 @@ def has_backpressure():
610632
self.inst.num_waiting_to_enter += 1
611633
result = thread.suspend_until(lambda: not has_backpressure(), cancellable = True)
612634
self.inst.num_waiting_to_enter -= 1
613-
if result == SuspendResult.CANCELLED:
635+
if result == Cancelled.TRUE:
614636
self.cancel()
615637
return False
616638
self.state = Task.State.UNRESOLVED
@@ -631,15 +653,15 @@ def request_cancellation(self):
631653
if self.state == Task.State.BACKPRESSURE:
632654
assert(len(self.threads) == 1)
633655
self.state = Task.State.CANCEL_DELIVERED
634-
self.threads[0].resume(SuspendResult.CANCELLED)
656+
self.threads[0].resume(Cancelled.TRUE)
635657
return
636658
assert(self.state == Task.State.UNRESOLVED)
637659
if not self.needs_exclusive() or not self.inst.exclusive or self.inst.exclusive is self:
638660
random.shuffle(self.threads)
639661
for thread in self.threads:
640662
if thread.cancellable:
641663
self.state = Task.State.CANCEL_DELIVERED
642-
thread.resume(SuspendResult.CANCELLED)
664+
thread.resume(Cancelled.TRUE)
643665
return
644666
self.state = Task.State.PENDING_CANCEL
645667

@@ -649,28 +671,28 @@ def deliver_pending_cancel(self, cancellable) -> bool:
649671
return True
650672
return False
651673

652-
def suspend(self, thread, cancellable) -> SuspendResult:
674+
def suspend(self, thread, cancellable) -> Cancelled:
653675
assert(thread in self.threads and thread.task is self)
654676
if self.deliver_pending_cancel(cancellable):
655-
return SuspendResult.CANCELLED
677+
return Cancelled.TRUE
656678
return thread.suspend(cancellable)
657679

658-
def suspend_until(self, ready_func, thread, cancellable) -> SuspendResult:
680+
def suspend_until(self, ready_func, thread, cancellable) -> Cancelled:
659681
assert(thread in self.threads and thread.task is self)
660682
if self.deliver_pending_cancel(cancellable):
661-
return SuspendResult.CANCELLED
683+
return Cancelled.TRUE
662684
return thread.suspend_until(ready_func, cancellable)
663685

664-
def switch_to(self, thread, cancellable, other_thread) -> SuspendResult:
686+
def switch_to(self, thread, cancellable, other_thread) -> Cancelled:
665687
assert(thread in self.threads and thread.task is self)
666688
if self.deliver_pending_cancel(cancellable):
667-
return SuspendResult.CANCELLED
689+
return Cancelled.TRUE
668690
return thread.switch_to(cancellable, other_thread)
669691

670-
def yield_to(self, thread, cancellable, other_thread) -> SuspendResult:
692+
def yield_to(self, thread, cancellable, other_thread) -> Cancelled:
671693
assert(thread in self.threads and thread.task is self)
672694
if self.deliver_pending_cancel(cancellable):
673-
return SuspendResult.CANCELLED
695+
return Cancelled.TRUE
674696
return thread.yield_to(cancellable, other_thread)
675697

676698
def wait_until(self, ready_func, thread, wset, cancellable) -> EventTuple:
@@ -679,19 +701,19 @@ def wait_until(self, ready_func, thread, wset, cancellable) -> EventTuple:
679701
def ready_and_has_event():
680702
return ready_func() and wset.has_pending_event()
681703
match self.suspend_until(ready_and_has_event, thread, cancellable):
682-
case SuspendResult.CANCELLED:
704+
case Cancelled.TRUE:
683705
event = (EventCode.TASK_CANCELLED, 0, 0)
684-
case SuspendResult.NOT_CANCELLED:
706+
case Cancelled.FALSE:
685707
event = wset.get_pending_event()
686708
wset.num_waiting -= 1
687709
return event
688710

689711
def yield_until(self, ready_func, thread, cancellable) -> EventTuple:
690712
assert(thread in self.threads and thread.task is self)
691713
match self.suspend_until(ready_func, thread, cancellable):
692-
case SuspendResult.CANCELLED:
714+
case Cancelled.TRUE:
693715
return (EventCode.TASK_CANCELLED, 0, 0)
694-
case SuspendResult.NOT_CANCELLED:
716+
case Cancelled.FALSE:
695717
return (EventCode.NONE, 0, 0)
696718

697719
def return_(self, result):
@@ -2045,7 +2067,7 @@ def thread_func(thread):
20452067
return
20462068

20472069
thread = Thread(task, thread_func)
2048-
thread.resume()
2070+
thread.resume(Cancelled.FALSE)
20492071
return task
20502072

20512073
class CallbackCode(IntEnum):
@@ -2536,16 +2558,16 @@ def canon_thread_switch_to(cancellable, thread, i):
25362558
trap_if(not thread.task.inst.may_leave)
25372559
other_thread = thread.task.inst.threads.get(i)
25382560
trap_if(not other_thread.suspended())
2539-
suspend_result = thread.task.switch_to(thread, cancellable, other_thread)
2540-
return [suspend_result]
2561+
cancelled = thread.task.switch_to(thread, cancellable, other_thread)
2562+
return [cancelled]
25412563

25422564
### 🧵 `canon thread.suspend`
25432565

25442566
def canon_thread_suspend(cancellable, thread):
25452567
trap_if(not thread.task.inst.may_leave)
25462568
trap_if(not thread.task.may_block())
2547-
suspend_result = thread.task.suspend(thread, cancellable)
2548-
return [suspend_result]
2569+
cancelled = thread.task.suspend(thread, cancellable)
2570+
return [cancelled]
25492571

25502572
### 🧵 `canon thread.resume-later`
25512573

@@ -2562,21 +2584,21 @@ def canon_thread_yield_to(cancellable, thread, i):
25622584
trap_if(not thread.task.inst.may_leave)
25632585
other_thread = thread.task.inst.threads.get(i)
25642586
trap_if(not other_thread.suspended())
2565-
suspend_result = thread.task.yield_to(thread, cancellable, other_thread)
2566-
return [suspend_result]
2587+
cancelled = thread.task.yield_to(thread, cancellable, other_thread)
2588+
return [cancelled]
25672589

25682590
### 🧵 `canon thread.yield`
25692591

25702592
def canon_thread_yield(cancellable, thread):
25712593
trap_if(not thread.task.inst.may_leave)
25722594
if not thread.task.may_block():
2573-
return [SuspendResult.NOT_CANCELLED]
2595+
return [Cancelled.FALSE]
25742596
event_code,_,_ = thread.task.yield_until(lambda: True, thread, cancellable)
25752597
match event_code:
25762598
case EventCode.NONE:
2577-
return [SuspendResult.NOT_CANCELLED]
2599+
return [Cancelled.FALSE]
25782600
case EventCode.TASK_CANCELLED:
2579-
return [SuspendResult.CANCELLED]
2601+
return [Cancelled.TRUE]
25802602

25812603
### 📝 `canon error-context.new`
25822604

0 commit comments

Comments
 (0)