Skip to content

Commit 9452116

Browse files
committed
Rebase CABI onto explicit stack-switching interface (no behavior change)
1 parent 099bc80 commit 9452116

File tree

2 files changed

+129
-104
lines changed

2 files changed

+129
-104
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 114 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def tick(self):
199199
random.shuffle(self.pending)
200200
for thread in self.pending:
201201
if thread.ready():
202-
thread.resume()
202+
thread.resume(Cancelled.FALSE)
203203
return
204204

205205
FuncInst: Callable[[Optional[Supertask], OnStart, OnResolve], Call]
@@ -362,28 +362,74 @@ def __init__(self, impl, dtor = None, dtor_async = False, dtor_callback = None):
362362
self.dtor_async = dtor_async
363363
self.dtor_callback = dtor_callback
364364

365+
#### Stack Switching Support
366+
367+
class Continuation:
368+
lock: threading.Lock
369+
handler: Handler
370+
arg: any
371+
372+
class Handler:
373+
tls = threading.local()
374+
lock: threading.Lock
375+
result: Optional[tuple[Continuation, any]]
376+
377+
def cont_new(f: Callable[[], None]) -> Continuation:
378+
cont = Continuation()
379+
cont.lock = threading.Lock()
380+
cont.lock.acquire()
381+
def wrapper():
382+
cont.lock.acquire()
383+
Handler.tls.value = cont.handler
384+
f(cont.arg)
385+
handler = Handler.tls.value
386+
del Handler.tls.value
387+
handler.result = None
388+
handler.lock.release()
389+
threading.Thread(target = wrapper).start()
390+
return cont
391+
392+
def resume(cont: Continuation, v: any) -> Optional[tuple[Continuation, any]]:
393+
handler = Handler()
394+
handler.lock = threading.Lock()
395+
handler.lock.acquire()
396+
cont.handler = handler
397+
cont.arg = v
398+
cont.lock.release()
399+
handler.lock.acquire()
400+
return handler.result
401+
402+
def suspend(v: any) -> any:
403+
handler = Handler.tls.value
404+
del Handler.tls.value
405+
cont = Continuation()
406+
cont.lock = threading.Lock()
407+
cont.lock.acquire()
408+
handler.result = (cont, v)
409+
handler.lock.release()
410+
cont.lock.acquire()
411+
Handler.tls.value = cont.handler
412+
return cont.arg
413+
365414
#### Thread State
366415

367-
class SuspendResult(IntEnum):
368-
NOT_CANCELLED = 0
369-
CANCELLED = 1
416+
class Cancelled(IntEnum):
417+
FALSE = 0
418+
TRUE = 1
370419

371420
class Thread:
372-
task: Task
373-
fiber: threading.Thread
374-
fiber_lock: threading.Lock
375-
parent_lock: Optional[threading.Lock]
421+
cont: Optional[Continuation]
376422
ready_func: Optional[Callable[[], bool]]
377-
cancellable: bool
378-
suspend_result: Optional[SuspendResult]
379-
in_event_loop: bool
423+
task: Task
380424
index: Optional[int]
381425
context: list[int]
426+
cancellable: bool
427+
in_event_loop: bool
382428

383429
CONTEXT_LENGTH = 2
384430

385431
def running(self):
386-
return self.parent_lock is not None
432+
return self.cont is None
387433

388434
def suspended(self):
389435
return not self.running() and self.ready_func is None
@@ -396,59 +442,49 @@ def ready(self):
396442
return self.ready_func()
397443

398444
def __init__(self, task, thread_func):
399-
self.task = task
400-
self.fiber_lock = threading.Lock()
401-
self.fiber_lock.acquire()
402-
self.parent_lock = None
403-
self.ready_func = None
404-
self.cancellable = False
405-
self.suspend_result = None
406-
self.in_event_loop = False
407-
self.index = None
408-
self.context = [0] * Thread.CONTEXT_LENGTH
409-
def fiber_func():
410-
self.fiber_lock.acquire()
411-
assert(self.running() and self.suspend_result == SuspendResult.NOT_CANCELLED)
412-
self.suspend_result = None
445+
def wrapper(cancelled):
446+
assert(self.running() and not cancelled)
413447
thread_func(self)
414-
assert(self.running())
415448
self.task.thread_stop(self)
416449
if self.index is not None:
417450
self.task.inst.threads.remove(self.index)
418-
self.parent_lock.release()
419-
self.fiber = threading.Thread(target = fiber_func)
420-
self.fiber.start()
421-
self.task.thread_start(self)
451+
self.cont = cont_new(wrapper)
452+
self.ready_func = None
453+
self.task = task
454+
self.index = None
455+
self.context = [0] * Thread.CONTEXT_LENGTH
456+
self.cancellable = False
457+
self.in_event_loop = False
422458
assert(self.suspended())
459+
self.task.thread_start(self)
423460

424-
def resume(self, suspend_result = SuspendResult.NOT_CANCELLED):
425-
assert(not self.running() and self.suspend_result is None)
461+
def resume(self, cancelled):
462+
assert(not self.running())
426463
if self.ready_func:
427-
assert(suspend_result == SuspendResult.CANCELLED or self.ready_func())
464+
assert(cancelled or self.ready_func())
428465
self.ready_func = None
429466
self.task.inst.store.pending.remove(self)
430-
assert(self.cancellable or suspend_result == SuspendResult.NOT_CANCELLED)
431-
self.suspend_result = suspend_result
432-
self.parent_lock = threading.Lock()
433-
self.parent_lock.acquire()
434-
self.fiber_lock.release()
435-
self.parent_lock.acquire()
436-
self.parent_lock = None
437-
assert(not self.running())
467+
assert(self.cancellable or not cancelled)
468+
thread = self
469+
while True:
470+
assert(not thread.running())
471+
cont = thread.cont
472+
thread.cont = None
473+
if not (resume_result := resume(cont, cancelled)):
474+
return
475+
thread.cont,switch_to_thread = resume_result
476+
if switch_to_thread is None:
477+
return
478+
thread = switch_to_thread
479+
cancelled = Cancelled.FALSE
438480

439-
def suspend(self, cancellable) -> SuspendResult:
440-
assert(self.task.may_block())
441-
assert(self.running() and not self.cancellable and self.suspend_result is None)
481+
def suspend(self, cancellable) -> Cancelled:
482+
assert(self.running() and self.task.may_block())
442483
self.cancellable = cancellable
443-
self.parent_lock.release()
444-
self.fiber_lock.acquire()
484+
cancelled = suspend(None)
445485
assert(self.running())
446-
self.cancellable = False
447-
suspend_result = self.suspend_result
448-
self.suspend_result = None
449-
assert(suspend_result is not None)
450-
assert(cancellable or suspend_result == SuspendResult.NOT_CANCELLED)
451-
return suspend_result
486+
assert(cancellable or not cancelled)
487+
return cancelled
452488

453489
def resume_later(self):
454490
assert(self.suspended())
@@ -459,29 +495,18 @@ def suspend_until(self, ready_func, cancellable = False) -> SuspendResult:
459495
assert(self.task.may_block())
460496
assert(self.running())
461497
if ready_func() and not DETERMINISTIC_PROFILE and random.randint(0,1):
462-
return SuspendResult.NOT_CANCELLED
498+
return Cancelled.FALSE
463499
self.ready_func = ready_func
464500
self.task.inst.store.pending.append(self)
465501
return self.suspend(cancellable)
466502

467503
def switch_to(self, cancellable, other: Thread) -> SuspendResult:
468-
assert(self.running() and not self.cancellable and self.suspend_result is None)
469-
assert(other.suspended() and other.suspend_result is None)
504+
assert(self.running())
470505
self.cancellable = cancellable
471-
other.suspend_result = SuspendResult.NOT_CANCELLED
472-
assert(self.parent_lock and not other.parent_lock)
473-
other.parent_lock = self.parent_lock
474-
self.parent_lock = None
475-
assert(not self.running() and other.running())
476-
other.fiber_lock.release()
477-
self.fiber_lock.acquire()
506+
cancelled = suspend(other)
478507
assert(self.running())
479-
self.cancellable = False
480-
suspend_result = self.suspend_result
481-
self.suspend_result = None
482-
assert(suspend_result is not None)
483-
assert(cancellable or suspend_result == SuspendResult.NOT_CANCELLED)
484-
return suspend_result
508+
assert(cancellable or not cancelled)
509+
return cancelled
485510

486511
def yield_to(self, cancellable, other: Thread) -> SuspendResult:
487512
assert(not self.ready_func)
@@ -610,7 +635,7 @@ def has_backpressure():
610635
self.inst.num_waiting_to_enter += 1
611636
result = thread.suspend_until(lambda: not has_backpressure(), cancellable = True)
612637
self.inst.num_waiting_to_enter -= 1
613-
if result == SuspendResult.CANCELLED:
638+
if result == Cancelled.TRUE:
614639
self.cancel()
615640
return False
616641
if self.needs_exclusive():
@@ -632,7 +657,7 @@ def request_cancellation(self):
632657
for thread in self.threads:
633658
if thread.cancellable and not (thread.in_event_loop and self.inst.exclusive):
634659
self.state = Task.State.CANCEL_DELIVERED
635-
thread.resume(SuspendResult.CANCELLED)
660+
thread.resume(Cancelled.TRUE)
636661
return
637662
self.state = Task.State.PENDING_CANCEL
638663

@@ -645,25 +670,25 @@ def deliver_pending_cancel(self, cancellable) -> bool:
645670
def suspend(self, thread, cancellable) -> SuspendResult:
646671
assert(thread in self.threads and thread.task is self)
647672
if self.deliver_pending_cancel(cancellable):
648-
return SuspendResult.CANCELLED
673+
return Cancelled.TRUE
649674
return thread.suspend(cancellable)
650675

651676
def suspend_until(self, ready_func, thread, cancellable) -> SuspendResult:
652677
assert(thread in self.threads and thread.task is self)
653678
if self.deliver_pending_cancel(cancellable):
654-
return SuspendResult.CANCELLED
679+
return Cancelled.TRUE
655680
return thread.suspend_until(ready_func, cancellable)
656681

657682
def switch_to(self, thread, cancellable, other_thread) -> SuspendResult:
658683
assert(thread in self.threads and thread.task is self)
659684
if self.deliver_pending_cancel(cancellable):
660-
return SuspendResult.CANCELLED
685+
return Cancelled.TRUE
661686
return thread.switch_to(cancellable, other_thread)
662687

663688
def yield_to(self, thread, cancellable, other_thread) -> SuspendResult:
664689
assert(thread in self.threads and thread.task is self)
665690
if self.deliver_pending_cancel(cancellable):
666-
return SuspendResult.CANCELLED
691+
return Cancelled.TRUE
667692
return thread.yield_to(cancellable, other_thread)
668693

669694
def wait_until(self, ready_func, thread, wset, cancellable) -> EventTuple:
@@ -672,19 +697,19 @@ def wait_until(self, ready_func, thread, wset, cancellable) -> EventTuple:
672697
def ready_and_has_event():
673698
return ready_func() and wset.has_pending_event()
674699
match self.suspend_until(ready_and_has_event, thread, cancellable):
675-
case SuspendResult.CANCELLED:
700+
case Cancelled.TRUE:
676701
event = (EventCode.TASK_CANCELLED, 0, 0)
677-
case SuspendResult.NOT_CANCELLED:
702+
case Cancelled.FALSE:
678703
event = wset.get_pending_event()
679704
wset.num_waiting -= 1
680705
return event
681706

682707
def yield_until(self, ready_func, thread, cancellable) -> EventTuple:
683708
assert(thread in self.threads and thread.task is self)
684709
match self.suspend_until(ready_func, thread, cancellable):
685-
case SuspendResult.CANCELLED:
710+
case Cancelled.TRUE:
686711
return (EventCode.TASK_CANCELLED, 0, 0)
687-
case SuspendResult.NOT_CANCELLED:
712+
case Cancelled.FALSE:
688713
return (EventCode.NONE, 0, 0)
689714

690715
def return_(self, result):
@@ -2038,7 +2063,7 @@ def thread_func(thread):
20382063
return
20392064

20402065
thread = Thread(task, thread_func)
2041-
thread.resume()
2066+
thread.resume(Cancelled.FALSE)
20422067
return task
20432068

20442069
class CallbackCode(IntEnum):
@@ -2531,16 +2556,16 @@ def canon_thread_switch_to(cancellable, thread, i):
25312556
trap_if(not thread.task.inst.may_leave)
25322557
other_thread = thread.task.inst.threads.get(i)
25332558
trap_if(not other_thread.suspended())
2534-
suspend_result = thread.task.switch_to(thread, cancellable, other_thread)
2535-
return [suspend_result]
2559+
cancelled = thread.task.switch_to(thread, cancellable, other_thread)
2560+
return [cancelled]
25362561

25372562
### 🧵 `canon thread.suspend`
25382563

25392564
def canon_thread_suspend(cancellable, thread):
25402565
trap_if(not thread.task.inst.may_leave)
25412566
trap_if(not thread.task.may_block())
2542-
suspend_result = thread.task.suspend(thread, cancellable)
2543-
return [suspend_result]
2567+
cancelled = thread.task.suspend(thread, cancellable)
2568+
return [cancelled]
25442569

25452570
### 🧵 `canon thread.resume-later`
25462571

@@ -2557,21 +2582,21 @@ def canon_thread_yield_to(cancellable, thread, i):
25572582
trap_if(not thread.task.inst.may_leave)
25582583
other_thread = thread.task.inst.threads.get(i)
25592584
trap_if(not other_thread.suspended())
2560-
suspend_result = thread.task.yield_to(thread, cancellable, other_thread)
2561-
return [suspend_result]
2585+
cancelled = thread.task.yield_to(thread, cancellable, other_thread)
2586+
return [cancelled]
25622587

25632588
### 🧵 `canon thread.yield`
25642589

25652590
def canon_thread_yield(cancellable, thread):
25662591
trap_if(not thread.task.inst.may_leave)
25672592
if not thread.task.may_block():
2568-
return [SuspendResult.NOT_CANCELLED]
2593+
return [Cancelled.FALSE]
25692594
event_code,_,_ = thread.task.yield_until(lambda: True, thread, cancellable)
25702595
match event_code:
25712596
case EventCode.NONE:
2572-
return [SuspendResult.NOT_CANCELLED]
2597+
return [Cancelled.FALSE]
25732598
case EventCode.TASK_CANCELLED:
2574-
return [SuspendResult.CANCELLED]
2599+
return [Cancelled.TRUE]
25752600

25762601
### 📝 `canon error-context.new`
25772602

0 commit comments

Comments
 (0)