Skip to content

Commit c85c138

Browse files
committed
CABI: improve and add cooperative thread built-ins
1 parent 5cdd14f commit c85c138

File tree

2 files changed

+245
-51
lines changed

2 files changed

+245
-51
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 100 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
class Trap(BaseException): pass
1818
class CoreWebAssemblyException(BaseException): pass
19+
class ThreadExit(BaseException): pass
1920

2021
def trap():
2122
raise Trap()
@@ -303,7 +304,7 @@ class ComponentInstance:
303304
handles: Table[ResourceHandle | Waitable | WaitableSet | ErrorContext]
304305
threads: Table[Thread]
305306
may_leave: bool
306-
may_block: bool
307+
sync_before_return: bool
307308
backpressure: int
308309
exclusive: Optional[Task]
309310
num_waiting_to_enter: int
@@ -315,11 +316,17 @@ def __init__(self, store, parent = None):
315316
self.handles = Table()
316317
self.threads = Table()
317318
self.may_leave = True
318-
self.may_block = True
319+
self.sync_before_return = False
319320
self.backpressure = 0
320321
self.exclusive = None
321322
self.num_waiting_to_enter = 0
322323

324+
def ready_threads(self) -> list[Thread]:
325+
return [t for t in self.threads.array if t and t.waiting() and t.ready()]
326+
327+
def may_block(self):
328+
return not self.sync_before_return or len(self.ready_threads()) > 0
329+
323330
def reflexive_ancestors(self) -> set[ComponentInstance]:
324331
s = set()
325332
inst = self
@@ -490,10 +497,12 @@ def ready(self):
490497
def __init__(self, task, thread_func):
491498
def wrapper(cancelled):
492499
assert(self.running() and not cancelled)
493-
thread_func(self)
494-
self.task.thread_stop(self)
495-
if self.index is not None:
496-
self.task.inst.threads.remove(self.index)
500+
try:
501+
thread_func(self)
502+
self.exit()
503+
except ThreadExit:
504+
return
505+
assert(False)
497506
self.cont = cont_new(wrapper)
498507
self.ready_func = None
499508
self.task = task
@@ -503,7 +512,14 @@ def wrapper(cancelled):
503512
assert(self.suspended())
504513
self.task.thread_start(self)
505514

506-
def resume_later(self):
515+
def exit(self):
516+
assert(self.task.inst.may_block())
517+
self.task.thread_stop(self)
518+
if self.index is not None:
519+
self.task.inst.threads.remove(self.index)
520+
raise ThreadExit()
521+
522+
def unsuspend(self):
507523
assert(self.suspended())
508524
self.ready_func = lambda: True
509525
self.task.inst.store.waiting.append(self)
@@ -513,17 +529,24 @@ def resume(self, cancelled):
513529
assert(not self.running() and (self.cancellable or not cancelled))
514530
if self.waiting():
515531
assert(cancelled or self.ready())
516-
self.ready_func = None
517-
self.task.inst.store.waiting.remove(self)
532+
self.stop_waiting()
518533
thread = self
519534
while thread is not None:
520535
cont = thread.cont
521536
thread.cont = None
522537
thread.cont, thread = resume(cont, cancelled)
538+
if thread is None and self.task.inst.sync_before_return:
539+
thread = random.choice(self.task.inst.ready_threads())
540+
thread.stop_waiting()
523541
cancelled = Cancelled.FALSE
524542

543+
def stop_waiting(self):
544+
assert(self.waiting())
545+
self.ready_func = None
546+
self.task.inst.store.waiting.remove(self)
547+
525548
def suspend(self, cancellable) -> Cancelled:
526-
assert(self.running() and self.task.inst.may_block)
549+
assert(self.running() and self.task.inst.may_block())
527550
if self.task.deliver_pending_cancel(cancellable):
528551
return Cancelled.TRUE
529552
self.cancellable = cancellable
@@ -532,7 +555,7 @@ def suspend(self, cancellable) -> Cancelled:
532555
return cancelled
533556

534557
def wait_until(self, ready_func, cancellable = False) -> Cancelled:
535-
assert(self.running() and self.task.inst.may_block)
558+
assert(self.running() and self.task.inst.may_block())
536559
if self.task.deliver_pending_cancel(cancellable):
537560
return Cancelled.TRUE
538561
if ready_func() and not DETERMINISTIC_PROFILE and random.randint(0,1):
@@ -543,7 +566,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
543566

544567
def yield_until(self, ready_func, cancellable) -> Cancelled:
545568
assert(self.running())
546-
if self.task.inst.may_block:
569+
if self.task.inst.may_block():
547570
return self.wait_until(ready_func, cancellable)
548571
else:
549572
assert(ready_func())
@@ -552,7 +575,7 @@ def yield_until(self, ready_func, cancellable) -> Cancelled:
552575
def yield_(self, cancellable) -> Cancelled:
553576
return self.yield_until(lambda: True, cancellable)
554577

555-
def switch_to(self, cancellable, other: Thread) -> Cancelled:
578+
def suspend_to_suspended(self, cancellable, other: Thread) -> ResumeArg:
556579
assert(self.running() and other.suspended())
557580
if self.task.deliver_pending_cancel(cancellable):
558581
return Cancelled.TRUE
@@ -561,11 +584,27 @@ def switch_to(self, cancellable, other: Thread) -> Cancelled:
561584
assert(self.running() and (cancellable or not cancelled))
562585
return cancelled
563586

564-
def yield_to(self, cancellable, other: Thread) -> Cancelled:
587+
def yield_to_suspended(self, cancellable, other: Thread) -> ResumeArg:
565588
assert(self.running() and other.suspended())
566589
self.ready_func = lambda: True
567590
self.task.inst.store.waiting.append(self)
568-
return self.switch_to(cancellable, other)
591+
return self.suspend_to_suspended(cancellable, other)
592+
593+
def suspend_then_promote(self, cancellable, other: Thread) -> ResumeArg:
594+
assert(self.running())
595+
if other.waiting() and other.ready():
596+
other.stop_waiting()
597+
return self.suspend_to_suspended(cancellable, other)
598+
else:
599+
return self.suspend(cancellable)
600+
601+
def yield_then_promote(self, cancellable, other: Thread) -> ResumeArg:
602+
assert(self.running())
603+
if other.waiting() and other.ready():
604+
other.stop_waiting()
605+
return self.yield_to_suspended(cancellable, other)
606+
else:
607+
return self.yield_(cancellable)
569608

570609
#### Waitable State
571610

@@ -701,8 +740,8 @@ def needs_exclusive(self):
701740
def enter(self, thread):
702741
assert(thread in self.threads and thread.task is self)
703742
if not self.ft.async_:
704-
assert(self.inst.may_block)
705-
self.inst.may_block = False
743+
assert(not self.inst.sync_before_return)
744+
self.inst.sync_before_return = True
706745
return True
707746
def has_backpressure():
708747
return self.inst.backpressure > 0 or (self.needs_exclusive() and bool(self.inst.exclusive))
@@ -754,8 +793,8 @@ def return_(self, result):
754793
trap_if(self.state == Task.State.RESOLVED)
755794
trap_if(self.num_borrows > 0)
756795
if not self.ft.async_:
757-
assert(not self.inst.may_block)
758-
self.inst.may_block = True
796+
assert(self.inst.sync_before_return)
797+
self.inst.sync_before_return = False
759798
assert(result is not None)
760799
self.on_resolve(result)
761800
self.state = Task.State.RESOLVED
@@ -2100,7 +2139,7 @@ def thread_func(thread):
21002139
else:
21012140
event = (EventCode.NONE, 0, 0)
21022141
case CallbackCode.WAIT:
2103-
trap_if(not inst.may_block)
2142+
trap_if(not inst.may_block())
21042143
wset = inst.handles.get(si)
21052144
trap_if(not isinstance(wset, WaitableSet))
21062145
event = wset.wait_until(lambda: not inst.exclusive, thread, cancellable = True)
@@ -2143,7 +2182,7 @@ def call_and_trap_on_throw(callee, thread, args):
21432182

21442183
def canon_lower(opts, ft, callee: FuncInst, thread, flat_args):
21452184
trap_if(not thread.task.inst.may_leave)
2146-
trap_if(not thread.task.inst.may_block and ft.async_ and not opts.async_)
2185+
trap_if(not thread.task.inst.may_block() and ft.async_ and not opts.async_)
21472186

21482187
subtask = Subtask()
21492188
cx = LiftLowerContext(opts, thread.task.inst, subtask)
@@ -2323,7 +2362,7 @@ def canon_waitable_set_new(thread):
23232362

23242363
def canon_waitable_set_wait(cancellable, mem, thread, si, ptr):
23252364
trap_if(not thread.task.inst.may_leave)
2326-
trap_if(not thread.task.inst.may_block)
2365+
trap_if(not thread.task.inst.may_block())
23272366
wset = thread.task.inst.handles.get(si)
23282367
trap_if(not isinstance(wset, WaitableSet))
23292368
event = wset.wait(thread, cancellable)
@@ -2374,7 +2413,7 @@ def canon_waitable_join(thread, wi, si):
23742413

23752414
def canon_subtask_cancel(async_, thread, i):
23762415
trap_if(not thread.task.inst.may_leave)
2377-
trap_if(not thread.task.inst.may_block and not async_)
2416+
trap_if(not thread.task.inst.may_block() and not async_)
23782417
subtask = thread.task.inst.handles.get(i)
23792418
trap_if(not isinstance(subtask, Subtask))
23802419
trap_if(subtask.resolve_delivered())
@@ -2431,7 +2470,7 @@ def canon_stream_write(stream_t, opts, thread, i, ptr, n):
24312470

24322471
def stream_copy(EndT, BufferT, event_code, stream_t, opts, thread, i, ptr, n):
24332472
trap_if(not thread.task.inst.may_leave)
2434-
trap_if(not thread.task.inst.may_block and not opts.async_)
2473+
trap_if(not thread.task.inst.may_block() and not opts.async_)
24352474

24362475
e = thread.task.inst.handles.get(i)
24372476
trap_if(not isinstance(e, EndT))
@@ -2485,7 +2524,7 @@ def canon_future_write(future_t, opts, thread, i, ptr):
24852524

24862525
def future_copy(EndT, BufferT, event_code, future_t, opts, thread, i, ptr):
24872526
trap_if(not thread.task.inst.may_leave)
2488-
trap_if(not thread.task.inst.may_block and not opts.async_)
2527+
trap_if(not thread.task.inst.may_block() and not opts.async_)
24892528

24902529
e = thread.task.inst.handles.get(i)
24912530
trap_if(not isinstance(e, EndT))
@@ -2537,7 +2576,7 @@ def canon_future_cancel_write(future_t, async_, thread, i):
25372576

25382577
def cancel_copy(EndT, event_code, stream_or_future_t, async_, thread, i):
25392578
trap_if(not thread.task.inst.may_leave)
2540-
trap_if(not thread.task.inst.may_block and not async_)
2579+
trap_if(not thread.task.inst.may_block() and not async_)
25412580
e = thread.task.inst.handles.get(i)
25422581
trap_if(not isinstance(e, EndT))
25432582
trap_if(e.shared.t != stream_or_future_t.t)
@@ -2601,20 +2640,28 @@ def thread_func(thread):
26012640
new_thread.index = thread.task.inst.threads.add(new_thread)
26022641
return [new_thread.index]
26032642

2604-
### 🧵 `canon thread.resume-later`
2643+
### 🧵 `canon thread.unsuspend`
26052644

2606-
def canon_thread_resume_later(thread, i):
2645+
def canon_thread_unsuspend(thread, i):
26072646
trap_if(not thread.task.inst.may_leave)
26082647
other_thread = thread.task.inst.threads.get(i)
26092648
trap_if(not other_thread.suspended())
2610-
other_thread.resume_later()
2649+
other_thread.unsuspend()
26112650
return []
26122651

2652+
### 🧵 `canon thread.exit`
2653+
2654+
def canon_thread_exit(thread):
2655+
trap_if(not thread.task.inst.may_leave)
2656+
trap_if(not thread.task.inst.may_block())
2657+
thread.exit()
2658+
assert(False)
2659+
26132660
### 🧵 `canon thread.suspend`
26142661

26152662
def canon_thread_suspend(cancellable, thread):
26162663
trap_if(not thread.task.inst.may_leave)
2617-
trap_if(not thread.task.inst.may_block)
2664+
trap_if(not thread.task.inst.may_block())
26182665
cancelled = thread.suspend(cancellable)
26192666
return [cancelled]
26202667

@@ -2625,22 +2672,39 @@ def canon_thread_yield(cancellable, thread):
26252672
cancelled = thread.yield_(cancellable)
26262673
return [cancelled]
26272674

2628-
### 🧵 `canon thread.switch-to`
2675+
### 🧵 `canon thread.suspend-to-suspended`
26292676

2630-
def canon_thread_switch_to(cancellable, thread, i):
2677+
def canon_thread_suspend_to_suspended(cancellable, thread, i):
26312678
trap_if(not thread.task.inst.may_leave)
26322679
other_thread = thread.task.inst.threads.get(i)
26332680
trap_if(not other_thread.suspended())
2634-
cancelled = thread.switch_to(cancellable, other_thread)
2681+
cancelled = thread.suspend_to_suspended(cancellable, other_thread)
26352682
return [cancelled]
26362683

2637-
### 🧵 `canon thread.yield-to`
2684+
### 🧵 `canon thread.yield-to-suspended`
26382685

2639-
def canon_thread_yield_to(cancellable, thread, i):
2686+
def canon_thread_yield_to_suspended(cancellable, thread, i):
26402687
trap_if(not thread.task.inst.may_leave)
26412688
other_thread = thread.task.inst.threads.get(i)
26422689
trap_if(not other_thread.suspended())
2643-
cancelled = thread.yield_to(cancellable, other_thread)
2690+
cancelled = thread.yield_to_suspended(cancellable, other_thread)
2691+
return [cancelled]
2692+
2693+
### 🧵 `canon thread.suspend-then-promote`
2694+
2695+
def canon_thread_suspend_then_promote(cancellable, thread, i):
2696+
trap_if(not thread.task.inst.may_leave)
2697+
trap_if(not thread.task.inst.may_block())
2698+
other_thread = thread.task.inst.threads.get(i)
2699+
cancelled = thread.suspend_then_promote(cancellable, other_thread)
2700+
return [cancelled]
2701+
2702+
### 🧵 `canon thread.yield-then-promote`
2703+
2704+
def canon_thread_yield_then_promote(cancellable, thread, i):
2705+
trap_if(not thread.task.inst.may_leave)
2706+
other_thread = thread.task.inst.threads.get(i)
2707+
cancelled = thread.yield_then_promote(cancellable, other_thread)
26442708
return [cancelled]
26452709

26462710
### 📝 `canon error-context.new`

0 commit comments

Comments
 (0)