Skip to content

Commit 4714524

Browse files
committed
Add cooperative threads
1 parent e3e5b93 commit 4714524

3 files changed

Lines changed: 193 additions & 94 deletions

File tree

design/mvp/Async.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,8 @@ is always clear: it's the one passed to the current function as a parameter.
239239

240240
### Context-Local Storage
241241

242+
TODO: update (also there are 2 now)
243+
242244
Each task contains a distinct mutable **context-local storage** array. The
243245
current task's context-local storage can be read and written from core wasm
244246
code by calling the [`context.get`] and [`context.set`] built-ins.

design/mvp/canonical-abi/definitions.py

Lines changed: 132 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -384,22 +384,6 @@ def write(self, vs):
384384
assert(all(v == () for v in vs))
385385
self.progress += len(vs)
386386

387-
#### Context-Local Storage
388-
389-
class ContextLocalStorage:
390-
LENGTH = 1
391-
array: list[int]
392-
393-
def __init__(self):
394-
self.array = [0] * ContextLocalStorage.LENGTH
395-
396-
def set(self, i, v):
397-
assert(types_match_values(['i32'], [v]))
398-
self.array[i] = v
399-
400-
def get(self, i):
401-
return self.array[i]
402-
403387
#### Waitable State
404388

405389
class EventCode(IntEnum):
@@ -477,6 +461,10 @@ class Thread:
477461
cancellable: bool
478462
cancelled: bool
479463
waiting_for_callback: bool
464+
index: Optional[int]
465+
context: list[int]
466+
467+
CONTEXT_LENGTH = 2
480468

481469
def running(self):
482470
return self.parent_lock is not None
@@ -500,12 +488,17 @@ def __init__(self, task, thread_func):
500488
self.cancellable = False
501489
self.cancelled = False
502490
self.waiting_for_callback = False
491+
self.index = None
492+
self.context = [0] * Thread.CONTEXT_LENGTH
493+
503494
def fiber_func():
504495
self.fiber_lock.acquire()
505496
assert(self.running())
506497
thread_func(self)
507498
assert(self.running())
508499
self.task.thread_stop(self)
500+
if self.index is not None:
501+
self.task.inst.table.remove(self.index)
509502
self.parent_lock.release()
510503
self.fiber = threading.Thread(target = fiber_func)
511504
self.fiber.start()
@@ -548,6 +541,34 @@ def suspend_until(self, ready_func, cancellable = False) -> bool:
548541
self.task.inst.store.pending.append(self)
549542
return self.suspend(cancellable)
550543

544+
def switch_to(self, cancellable, other: Thread) -> bool:
545+
assert(self.running() and other.suspended())
546+
assert(not self.cancellable)
547+
self.cancellable = cancellable
548+
assert(self.parent_lock and not other.parent_lock)
549+
other.parent_lock = self.parent_lock
550+
self.parent_lock = None
551+
assert(self.suspended() and other.running())
552+
other.fiber_lock.release()
553+
self.fiber_lock.acquire()
554+
assert(self.running())
555+
self.cancellable = False
556+
completed = not self.cancelled
557+
self.cancelled = False
558+
return completed
559+
560+
def yield_to(self, cancellable, other: Thread) -> bool:
561+
assert(not self.ready_func)
562+
self.ready_func = lambda: True
563+
self.task.inst.store.pending.append(self)
564+
return self.switch_to(cancellable, other)
565+
566+
def resume_later(self, other: Thread):
567+
assert(self.running() and other.suspended())
568+
other.ready_func = lambda: True
569+
other.task.inst.store.pending.append(other)
570+
571+
551572
#### Task State
552573

553574
class Task(Call, Supertask):
@@ -564,8 +585,7 @@ class State(Enum):
564585
supertask: Optional[Task]
565586
on_resolve: OnResolve
566587
num_borrows: int
567-
thread: Optional[Thread]
568-
context: ContextLocalStorage
588+
threads: list[Thread]
569589

570590
def __init__(self, opts, inst, ft, supertask, on_resolve):
571591
self.state = Task.State.INITIAL
@@ -575,8 +595,7 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
575595
self.supertask = supertask
576596
self.on_resolve = on_resolve
577597
self.num_borrows = 0
578-
self.thread = None
579-
self.context = ContextLocalStorage()
598+
self.threads = []
580599

581600
def trap_if_on_the_stack(self, inst):
582601
c = self.supertask
@@ -588,7 +607,7 @@ def needs_exclusive(self):
588607
return self.opts.sync or self.opts.callback
589608

590609
def enter(self, thread):
591-
assert(thread is self.thread and thread.task is self)
610+
assert(thread in self.threads and thread.task is self)
592611
def has_backpressure():
593612
return self.inst.backpressure or (self.needs_exclusive() and self.inst.exclusive)
594613
if has_backpressure() or self.inst.pending_tasks > 0:
@@ -605,28 +624,30 @@ def has_backpressure():
605624

606625
def request_cancellation(self):
607626
assert(self.state == Task.State.INITIAL)
608-
if self.thread.cancellable and not (self.thread.waiting_for_callback and self.inst.exclusive):
609-
self.state = Task.State.CANCEL_DELIVERED
610-
self.thread.resume(cancel = True)
611-
else:
612-
self.state = Task.State.PENDING_CANCEL
627+
random.shuffle(self.threads)
628+
for thread in self.threads:
629+
if thread.cancellable and not (thread.waiting_for_callback and self.inst.exclusive):
630+
self.state = Task.State.CANCEL_DELIVERED
631+
thread.resume(cancel = True)
632+
return
633+
self.state = Task.State.PENDING_CANCEL
613634

614635
def wait_until(self, ready_func, thread, cancellable, for_callback) -> bool:
615-
assert(thread is self.thread and thread.task is self)
636+
assert(thread in self.threads and thread.task is self)
616637
if cancellable and self.state == Task.State.PENDING_CANCEL:
617638
self.state = Task.State.CANCEL_DELIVERED
618639
return False
619640
if for_callback:
620641
assert(self.inst.exclusive)
621642
self.inst.exclusive = False
622-
self.thread.waiting_for_callback = True
643+
thread.waiting_for_callback = True
623644
def ready_and_uncontended():
624645
return ready_func() and not (for_callback and self.inst.exclusive)
625646
completed = thread.suspend_until(ready_and_uncontended, cancellable)
626647
if for_callback:
627648
assert(not self.inst.exclusive)
628649
self.inst.exclusive = True
629-
self.thread.waiting_for_callback = False
650+
thread.waiting_for_callback = False
630651
return completed
631652

632653
def yield_(self, thread, cancellable, for_callback) -> EventTuple:
@@ -669,20 +690,21 @@ def cancel(self):
669690
self.state = Task.State.RESOLVED
670691

671692
def exit(self):
672-
assert(self.thread is not None)
693+
assert(len(self.threads) > 0)
673694
if self.needs_exclusive():
674695
assert(self.inst.exclusive)
675696
self.inst.exclusive = False
676697

677698
def thread_start(self, thread):
678-
assert(self.thread is None and thread.task is self)
679-
self.thread = thread
699+
assert(thread not in self.threads and thread.task is self)
700+
self.threads.append(thread)
680701

681702
def thread_stop(self, thread):
682-
assert(thread is self.thread and thread.task is self)
683-
self.thread = None
684-
trap_if(self.state != Task.State.RESOLVED)
685-
assert(self.num_borrows == 0)
703+
assert(thread in self.threads and thread.task is self)
704+
self.threads.remove(thread)
705+
if len(self.threads) == 0:
706+
trap_if(self.state != Task.State.RESOLVED)
707+
assert(self.num_borrows == 0)
686708

687709
#### Subtask State
688710

@@ -1918,6 +1940,9 @@ def thread_func(thread):
19181940
if not task.enter(thread):
19191941
return
19201942

1943+
assert(thread.index is None)
1944+
thread.index = thread.task.inst.table.add(thread)
1945+
19211946
cx = LiftLowerContext(opts, inst, task)
19221947
args = on_start()
19231948
flat_args = lower_flat_values(cx, MAX_FLAT_PARAMS, args, ft.param_types())
@@ -2098,25 +2123,91 @@ def canon_resource_rep(rt, thread, i):
20982123
trap_if(h.rt is not rt)
20992124
return [h.rep]
21002125

2126+
### 🧵 `canon thread.index`
2127+
2128+
def canon_thread_index(shared, thread):
2129+
assert(not shared)
2130+
assert(thread.index is not None)
2131+
return [thread.index]
2132+
2133+
### 🧵 `canon thread.new`
2134+
2135+
def canon_thread_new(ft, ftbl, thread, i, c):
2136+
task = thread.task
2137+
trap_if(not task.inst.may_leave)
2138+
f = task.inst.ftbl.get(i)
2139+
trap_if(f.type != ft)
2140+
thread_func = partial(f, c)
2141+
new_thread = Thread(task, thread_func)
2142+
assert(new_thread.suspended())
2143+
new_thread.index = task.inst.table.add(thread)
2144+
return [new_thread.index]
2145+
2146+
### 🧵 `canon thread.resume-later`
2147+
2148+
def canon_thread_resume_later(thread, i):
2149+
trap_if(not thread.task.inst.may_leave)
2150+
other_thread = thread.task.inst.table.get(i)
2151+
trap_if(not isinstance(other_thread, Thread))
2152+
trap_if(not other_thread.suspended())
2153+
thread.resume_later(other_thread)
2154+
return []
2155+
2156+
### 🧵 `canon thread.switch-to`
2157+
2158+
def canon_thread_switch_to(thread, cancellable, i):
2159+
trap_if(not thread.task.inst.may_leave)
2160+
other_thread = thread.task.inst.table.get(i)
2161+
trap_if(not isinstance(other_thread, Thread))
2162+
trap_if(not other_thread.suspended())
2163+
if not thread.switch_to(cancellable, other_thread):
2164+
assert(cancellable)
2165+
return [0]
2166+
else:
2167+
return [1]
2168+
2169+
### 🧵 `canon thread.yield-to`
2170+
2171+
def canon_thread_yield_to(thread, cancellable, i):
2172+
trap_if(not thread.task.inst.may_leave)
2173+
other_thread = thread.task.inst.table.get(i)
2174+
trap_if(not isinstance(other_thread, Thread))
2175+
trap_if(not other_thread.suspended())
2176+
if not other_thread.yield_to(cancellable, other_thread):
2177+
assert(cancellable)
2178+
return [0]
2179+
else:
2180+
return [1]
2181+
2182+
### 🧵 `canon thread.suspend`
2183+
2184+
def canon_thread_suspend(thread, cancellable):
2185+
trap_if(not thread.task.inst.may_leave)
2186+
if not thread.suspend(cancellable):
2187+
assert(cancellable)
2188+
return [0]
2189+
else:
2190+
return [1]
2191+
21012192
### 🔀 `canon context.get`
21022193

21032194
def canon_context_get(t, i, thread):
21042195
assert(t == 'i32')
2105-
assert(i < ContextLocalStorage.LENGTH)
2106-
return [thread.task.context.get(i)]
2196+
assert(i < Thread.CONTEXT_LENGTH)
2197+
return [thread.context[i]]
21072198

21082199
### 🔀 `canon context.set`
21092200

21102201
def canon_context_set(t, i, thread, v):
21112202
assert(t == 'i32')
2112-
assert(i < ContextLocalStorage.LENGTH)
2113-
thread.task.context.set(i, v)
2203+
assert(i < Thread.CONTEXT_LENGTH)
2204+
thread.context[i] = v
21142205
return []
21152206

21162207
### 🔀 `canon backpressure.set`
21172208

21182209
def canon_backpressure_set(thread, flat_args):
2119-
trap_if(thread.task.opts.sync)
2210+
# TODO: remove trap_if(thread.task.opts.sync)
21202211
assert(len(flat_args) == 1)
21212212
thread.task.inst.backpressure = bool(flat_args[0])
21222213
return []

0 commit comments

Comments
 (0)