Skip to content

Commit e28d546

Browse files
committed
Add cooperative threads
1 parent d5fd17c commit e28d546

3 files changed

Lines changed: 192 additions & 94 deletions

File tree

design/mvp/Async.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,8 @@ is always clear: it's the one passed to the current function as a parameter.
239239

240240
### Context-Local Storage
241241

242+
TODO: update (also there are 2 now)
243+
242244
Each task contains a distinct mutable **context-local storage** array. The
243245
current task's context-local storage can be read and written from core wasm
244246
code by calling the [`context.get`] and [`context.set`] built-ins.

design/mvp/canonical-abi/definitions.py

Lines changed: 131 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -384,22 +384,6 @@ def write(self, vs):
384384
assert(all(v == () for v in vs))
385385
self.progress += len(vs)
386386

387-
#### Context-Local Storage
388-
389-
class ContextLocalStorage:
390-
LENGTH = 1
391-
array: list[int]
392-
393-
def __init__(self):
394-
self.array = [0] * ContextLocalStorage.LENGTH
395-
396-
def set(self, i, v):
397-
assert(types_match_values(['i32'], [v]))
398-
self.array[i] = v
399-
400-
def get(self, i):
401-
return self.array[i]
402-
403387
#### Thread State
404388

405389
class Thread:
@@ -411,6 +395,10 @@ class Thread:
411395
cancellable: bool
412396
cancelled: bool
413397
waiting_for_callback: bool
398+
index: Optional[int]
399+
context: list[int]
400+
401+
CONTEXT_LENGTH = 2
414402

415403
def running(self):
416404
return self.parent_lock is not None
@@ -434,12 +422,17 @@ def __init__(self, task, thread_func):
434422
self.cancellable = False
435423
self.cancelled = False
436424
self.waiting_for_callback = False
425+
self.index = None
426+
self.context = [0] * Thread.CONTEXT_LENGTH
427+
437428
def fiber_func():
438429
self.fiber_lock.acquire()
439430
assert(self.running())
440431
thread_func(self)
441432
assert(self.running())
442433
self.task.thread_stop(self)
434+
if self.index is not None:
435+
self.task.inst.table.remove(self.index)
443436
self.parent_lock.release()
444437
self.fiber = threading.Thread(target = fiber_func)
445438
self.fiber.start()
@@ -482,6 +475,33 @@ def suspend_until(self, ready_func, cancellable = False) -> bool:
482475
self.task.inst.store.pending.append(self)
483476
return self.suspend(cancellable)
484477

478+
def switch_to(self, cancellable, other: Thread) -> bool:
479+
assert(self.running() and other.suspended())
480+
assert(not self.cancellable)
481+
self.cancellable = cancellable
482+
assert(self.parent_lock and not other.parent_lock)
483+
other.parent_lock = self.parent_lock
484+
self.parent_lock = None
485+
assert(self.suspended() and other.running())
486+
other.fiber_lock.release()
487+
self.fiber_lock.acquire()
488+
assert(self.running())
489+
self.cancellable = False
490+
completed = not self.cancelled
491+
self.cancelled = False
492+
return completed
493+
494+
def yield_to(self, cancellable, other: Thread) -> bool:
495+
assert(not self.ready_func)
496+
self.ready_func = lambda: True
497+
self.task.inst.store.pending.append(self)
498+
return self.switch_to(cancellable, other)
499+
500+
def resume_later(self, other: Thread):
501+
assert(self.running() and other.suspended())
502+
other.ready_func = lambda: True
503+
other.task.inst.store.pending.append(other)
504+
485505
#### Waitable State
486506

487507
class EventCode(IntEnum):
@@ -564,8 +584,7 @@ class State(Enum):
564584
supertask: Optional[Task]
565585
on_resolve: OnResolve
566586
num_borrows: int
567-
thread: Optional[Thread]
568-
context: ContextLocalStorage
587+
threads: list[Thread]
569588

570589
def __init__(self, opts, inst, ft, supertask, on_resolve):
571590
self.state = Task.State.INITIAL
@@ -575,8 +594,7 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
575594
self.supertask = supertask
576595
self.on_resolve = on_resolve
577596
self.num_borrows = 0
578-
self.thread = None
579-
self.context = ContextLocalStorage()
597+
self.threads = []
580598

581599
def trap_if_on_the_stack(self, inst):
582600
c = self.supertask
@@ -588,7 +606,7 @@ def needs_exclusive(self):
588606
return self.opts.sync or self.opts.callback
589607

590608
def enter(self, thread):
591-
assert(thread is self.thread and thread.task is self)
609+
assert(thread in self.threads and thread.task is self)
592610
def has_backpressure():
593611
return self.inst.backpressure or (self.needs_exclusive() and self.inst.exclusive)
594612
if has_backpressure() or self.inst.pending_tasks > 0:
@@ -605,28 +623,30 @@ def has_backpressure():
605623

606624
def request_cancellation(self):
607625
assert(self.state == Task.State.INITIAL)
608-
if self.thread.cancellable and not (self.thread.waiting_for_callback and self.inst.exclusive):
609-
self.state = Task.State.CANCEL_DELIVERED
610-
self.thread.resume(cancel = True)
611-
else:
612-
self.state = Task.State.PENDING_CANCEL
626+
random.shuffle(self.threads)
627+
for thread in self.threads:
628+
if thread.cancellable and not (thread.waiting_for_callback and self.inst.exclusive):
629+
self.state = Task.State.CANCEL_DELIVERED
630+
thread.resume(cancel = True)
631+
return
632+
self.state = Task.State.PENDING_CANCEL
613633

614634
def wait_until(self, ready_func, thread, cancellable, for_callback) -> bool:
615-
assert(thread is self.thread and thread.task is self)
635+
assert(thread in self.threads and thread.task is self)
616636
if cancellable and self.state == Task.State.PENDING_CANCEL:
617637
self.state = Task.State.CANCEL_DELIVERED
618638
return False
619639
if for_callback:
620640
assert(self.inst.exclusive)
621641
self.inst.exclusive = False
622-
self.thread.waiting_for_callback = True
642+
thread.waiting_for_callback = True
623643
def ready_and_uncontended():
624644
return ready_func() and not (for_callback and self.inst.exclusive)
625645
completed = thread.suspend_until(ready_and_uncontended, cancellable)
626646
if for_callback:
627647
assert(not self.inst.exclusive)
628648
self.inst.exclusive = True
629-
self.thread.waiting_for_callback = False
649+
thread.waiting_for_callback = False
630650
return completed
631651

632652
def yield_(self, thread, cancellable, for_callback) -> EventTuple:
@@ -669,20 +689,21 @@ def cancel(self):
669689
self.state = Task.State.RESOLVED
670690

671691
def exit(self):
672-
assert(self.thread is not None)
692+
assert(len(self.threads) > 0)
673693
if self.needs_exclusive():
674694
assert(self.inst.exclusive)
675695
self.inst.exclusive = False
676696

677697
def thread_start(self, thread):
678-
assert(self.thread is None and thread.task is self)
679-
self.thread = thread
698+
assert(thread not in self.threads and thread.task is self)
699+
self.threads.append(thread)
680700

681701
def thread_stop(self, thread):
682-
assert(thread is self.thread and thread.task is self)
683-
self.thread = None
684-
trap_if(self.state != Task.State.RESOLVED)
685-
assert(self.num_borrows == 0)
702+
assert(thread in self.threads and thread.task is self)
703+
self.threads.remove(thread)
704+
if len(self.threads) == 0:
705+
trap_if(self.state != Task.State.RESOLVED)
706+
assert(self.num_borrows == 0)
686707

687708
#### Subtask State
688709

@@ -1918,6 +1939,9 @@ def thread_func(thread):
19181939
if not task.enter(thread):
19191940
return
19201941

1942+
assert(thread.index is None)
1943+
thread.index = thread.task.inst.table.add(thread)
1944+
19211945
cx = LiftLowerContext(opts, inst, task)
19221946
args = on_start()
19231947
flat_args = lower_flat_values(cx, MAX_FLAT_PARAMS, args, ft.param_types())
@@ -2098,25 +2122,91 @@ def canon_resource_rep(rt, thread, i):
20982122
trap_if(h.rt is not rt)
20992123
return [h.rep]
21002124

2125+
### 🧵 `canon thread.index`
2126+
2127+
def canon_thread_index(shared, thread):
2128+
assert(not shared)
2129+
assert(thread.index is not None)
2130+
return [thread.index]
2131+
2132+
### 🧵 `canon thread.new`
2133+
2134+
def canon_thread_new(ft, ftbl, thread, i, c):
2135+
task = thread.task
2136+
trap_if(not task.inst.may_leave)
2137+
f = task.inst.ftbl.get(i)
2138+
trap_if(f.type != ft)
2139+
thread_func = partial(f, c)
2140+
new_thread = Thread(task, thread_func)
2141+
assert(new_thread.suspended())
2142+
new_thread.index = task.inst.table.add(thread)
2143+
return [new_thread.index]
2144+
2145+
### 🧵 `canon thread.resume-later`
2146+
2147+
def canon_thread_resume_later(thread, i):
2148+
trap_if(not thread.task.inst.may_leave)
2149+
other_thread = thread.task.inst.table.get(i)
2150+
trap_if(not isinstance(other_thread, Thread))
2151+
trap_if(not other_thread.suspended())
2152+
thread.resume_later(other_thread)
2153+
return []
2154+
2155+
### 🧵 `canon thread.switch-to`
2156+
2157+
def canon_thread_switch_to(thread, cancellable, i):
2158+
trap_if(not thread.task.inst.may_leave)
2159+
other_thread = thread.task.inst.table.get(i)
2160+
trap_if(not isinstance(other_thread, Thread))
2161+
trap_if(not other_thread.suspended())
2162+
if not thread.switch_to(cancellable, other_thread):
2163+
assert(cancellable)
2164+
return [0]
2165+
else:
2166+
return [1]
2167+
2168+
### 🧵 `canon thread.yield-to`
2169+
2170+
def canon_thread_yield_to(thread, cancellable, i):
2171+
trap_if(not thread.task.inst.may_leave)
2172+
other_thread = thread.task.inst.table.get(i)
2173+
trap_if(not isinstance(other_thread, Thread))
2174+
trap_if(not other_thread.suspended())
2175+
if not other_thread.yield_to(cancellable, other_thread):
2176+
assert(cancellable)
2177+
return [0]
2178+
else:
2179+
return [1]
2180+
2181+
### 🧵 `canon thread.suspend`
2182+
2183+
def canon_thread_suspend(thread, cancellable):
2184+
trap_if(not thread.task.inst.may_leave)
2185+
if not thread.suspend(cancellable):
2186+
assert(cancellable)
2187+
return [0]
2188+
else:
2189+
return [1]
2190+
21012191
### 🔀 `canon context.get`
21022192

21032193
def canon_context_get(t, i, thread):
21042194
assert(t == 'i32')
2105-
assert(i < ContextLocalStorage.LENGTH)
2106-
return [thread.task.context.get(i)]
2195+
assert(i < Thread.CONTEXT_LENGTH)
2196+
return [thread.context[i]]
21072197

21082198
### 🔀 `canon context.set`
21092199

21102200
def canon_context_set(t, i, thread, v):
21112201
assert(t == 'i32')
2112-
assert(i < ContextLocalStorage.LENGTH)
2113-
thread.task.context.set(i, v)
2202+
assert(i < Thread.CONTEXT_LENGTH)
2203+
thread.context[i] = v
21142204
return []
21152205

21162206
### 🔀 `canon backpressure.set`
21172207

21182208
def canon_backpressure_set(thread, flat_args):
2119-
trap_if(thread.task.opts.sync)
2209+
# TODO: remove trap_if(thread.task.opts.sync)
21202210
assert(len(flat_args) == 1)
21212211
thread.task.inst.backpressure = bool(flat_args[0])
21222212
return []

0 commit comments

Comments
 (0)