Skip to content

Commit c7e4043

Browse files
committed
Add cooperative threads
1 parent 7ea8ea8 commit c7e4043

3 files changed

Lines changed: 191 additions & 104 deletions

File tree

design/mvp/Async.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ as a parameter.
240240

241241
### Context-Local Storage
242242

243+
TODO: update (also there are 2 now)
244+
243245
Each task contains a distinct mutable **context-local storage** array. The
244246
current task's context-local storage can be read and written from core wasm
245247
code by calling the [`context.get`] and [`context.set`] built-ins.

design/mvp/canonical-abi/definitions.py

Lines changed: 130 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -213,20 +213,31 @@ def tick(self):
213213

214214
class Thread:
215215
task: Task
216+
index: int
217+
context: list[int]
216218
ready_func: Optional[Callable[[], bool]]
217219
run_lock: threading.Lock
218220
resume_lock: Optional[threading.Lock]
219221
stack: threading.Thread
222+
cancellable: bool
223+
waiting_for_callback: bool
224+
225+
CONTEXT_LENGTH = 2
220226

221227
def __init__(self, task, thread_func):
222228
self.task = task
229+
self.index = task.inst.table.add(self)
230+
self.context = [0] * Thread.CONTEXT_LENGTH
223231
self.ready_func = None
224232
self.run_lock = threading.Lock()
225233
self.resume_lock = None
234+
self.cancellable = False
235+
self.waiting_for_callback = False
226236
def thread_stack_base():
227237
self.run_lock.acquire()
228238
thread_func(self)
229239
self.task.thread_stop(self)
240+
self.task.inst.table.remove(self.index)
230241
self.resume_lock.release()
231242
self.stack = threading.Thread(target = thread_stack_base)
232243
self.run_lock.acquire()
@@ -247,14 +258,45 @@ def resume(self):
247258
self.resume_lock.acquire()
248259
self.resume_lock = None
249260

250-
def suspend_until(self, ready_func):
261+
def block(self, cancellable):
262+
assert(not self.cancellable)
263+
self.cancellable = cancellable
264+
self.resume_lock.release()
265+
self.run_lock.acquire()
266+
self.cancellable = False
267+
268+
def suspend_until(self, ready_func, cancellable = False):
251269
assert(not self.ready_func)
252270
if not DETERMINISTIC_PROFILE and ready_func():
253271
return
254272
self.ready_func = ready_func
255273
self.task.inst.store.waiting.append(self)
256-
self.resume_lock.release()
274+
self.block(cancellable)
275+
276+
async def switch_to(self, cancellable, other: Thread):
277+
assert(self.task.inst is other.task.inst)
278+
if other.ready_func:
279+
other.ready_func = None
280+
other.task.inst.store.waiting.remove(other)
281+
assert(not self.cancellable)
282+
self.cancellable = cancellable
283+
assert(self.resume_lock and not other.resume_lock)
284+
other.resume_lock = self.resume_lock
285+
self.resume_lock = None
286+
other.run_lock.release()
257287
self.run_lock.acquire()
288+
self.cancellable = False
289+
290+
def yield_to(self, cancellable, other: Thread):
291+
assert(not self.ready_func)
292+
self.ready_func = lambda: True
293+
self.task.inst.store.waiting.append(self)
294+
self.switch_to(cancellable, other)
295+
296+
def unblock(self, other: Thread):
297+
if not other.ready_func:
298+
other.task.inst.store.waiting.append(other)
299+
other.ready_func = lambda: True
258300

259301

260302
### Lifting and Lowering Context
@@ -432,22 +474,6 @@ def write(self, vs):
432474
assert(all(v == () for v in vs))
433475
self.progress += len(vs)
434476

435-
#### Context-Local Storage
436-
437-
class ContextLocalStorage:
438-
LENGTH = 1
439-
array: list[int]
440-
441-
def __init__(self):
442-
self.array = [0] * ContextLocalStorage.LENGTH
443-
444-
def set(self, i, v):
445-
assert(types_match_values(['i32'], [v]))
446-
self.array[i] = v
447-
448-
def get(self, i):
449-
return self.array[i]
450-
451477
#### Waitable State
452478

453479
class EventCode(IntEnum):
@@ -458,6 +484,7 @@ class EventCode(IntEnum):
458484
FUTURE_READ = 4
459485
FUTURE_WRITE = 5
460486
TASK_CANCELLED = 6
487+
THREAD_RESUMED = 7
461488

462489
EventTuple = tuple[EventCode, int, int]
463490

@@ -530,11 +557,8 @@ class State(Enum):
530557
ft: FuncType
531558
supertask: Optional[Task]
532559
on_resolve: OnResolve
533-
thread: Optional[Thread]
534-
cancellable: bool
535-
waiting_for_callback: bool
560+
threads: list[Thread]
536561
num_borrows: int
537-
context: ContextLocalStorage
538562

539563
def __init__(self, opts, inst, ft, supertask, on_resolve):
540564
self.state = Task.State.INITIAL
@@ -543,11 +567,8 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
543567
self.ft = ft
544568
self.supertask = supertask
545569
self.on_resolve = on_resolve
546-
self.thread = None
547-
self.cancellable = False
548-
self.waiting_for_callback = False
570+
self.threads = []
549571
self.num_borrows = 0
550-
self.context = ContextLocalStorage()
551572

552573
def trap_if_on_the_stack(self, inst):
553574
c = self.supertask
@@ -559,15 +580,13 @@ def needs_exclusive(self):
559580
return self.opts.sync or self.opts.callback
560581

561582
def enter(self, thread):
562-
assert(thread is self.thread and thread.task is self)
583+
assert(thread in self.threads and thread.task is self)
563584
def has_backpressure():
564585
return self.inst.backpressure or (self.needs_exclusive() and self.inst.exclusive)
565586
if has_backpressure() or self.inst.pending_tasks > 0:
566587
self.inst.pending_tasks += 1
567-
self.cancellable = True
568-
thread.suspend_until(lambda: not has_backpressure())
588+
thread.suspend_until(lambda: not has_backpressure(), cancellable = True)
569589
self.inst.pending_tasks -= 1
570-
self.cancellable = False
571590
if self.deliver_cancel():
572591
self.cancel()
573592
return False
@@ -586,27 +605,28 @@ def deliver_cancel(self) -> bool:
586605
def request_cancellation(self):
587606
assert(self.state == Task.State.INITIAL)
588607
self.state = Task.State.PENDING_CANCEL
589-
if self.cancellable and not (self.waiting_for_callback and self.inst.exclusive):
590-
self.thread.resume()
608+
if not DETERMINISTIC_PROFILE:
609+
random.shuffle(self.threads)
610+
for thread in self.threads:
611+
if thread.cancellable and not (thread.waiting_for_callback and self.inst.exclusive):
612+
thread.resume()
613+
break
591614

592615
def wait_until(self, ready_func, thread, cancellable, for_callback):
593-
assert(thread is self.thread and thread.task is self)
616+
assert(thread in self.threads and thread.task is self)
594617
if cancellable and self.deliver_cancel():
595618
return True
596-
assert(not self.cancellable)
597-
self.cancellable = cancellable
598619
if for_callback:
599620
assert(self.inst.exclusive)
600621
self.inst.exclusive = False
601-
self.waiting_for_callback = True
622+
thread.waiting_for_callback = True
602623
def ready_and_allowed():
603624
return ready_func() and not (for_callback and self.inst.exclusive)
604-
thread.suspend_until(ready_and_allowed)
625+
thread.suspend_until(ready_and_allowed, cancellable)
605626
if for_callback:
606627
assert(not self.inst.exclusive)
607628
self.inst.exclusive = True
608-
self.waiting_for_callback = False
609-
self.cancellable = False
629+
thread.waiting_for_callback = False
610630
if cancellable and self.deliver_cancel():
611631
return True
612632
return False
@@ -615,13 +635,15 @@ def wait_for_event(self, thread, wset, cancellable, for_callback) -> EventTuple:
615635
wset.num_waiting += 1
616636
cancelled = self.wait_until(wset.has_pending_event, thread, cancellable, for_callback)
617637
wset.num_waiting -= 1
638+
# TODO: somehow get a THREAD_RESUME event...
618639
if cancelled:
619640
return (EventCode.TASK_CANCELLED, 0, 0)
620641
else:
621642
return wset.get_pending_event()
622643

623644
def yield_(self, thread, cancellable, for_callback) -> EventTuple:
624645
cancelled = self.wait_until(lambda: True, thread, cancellable, for_callback)
646+
# TODO: somehow get a THREAD_RESUME event...
625647
if cancelled:
626648
return (EventCode.TASK_CANCELLED, 0, 0)
627649
else:
@@ -631,6 +653,7 @@ def poll_for_event(self, thread, wset, cancellable, for_callback) -> Optional[Ev
631653
wset.num_waiting += 1
632654
cancelled = self.wait_until(lambda: True, thread, cancellable, for_callback)
633655
wset.num_waiting -= 1
656+
# TODO: somehow get a THREAD_RESUME event...
634657
if cancelled:
635658
return (EventCode.TASK_CANCELLED, 0, 0)
636659
elif wset.has_pending_event():
@@ -652,20 +675,21 @@ def cancel(self):
652675
self.state = Task.State.RESOLVED
653676

654677
def exit(self):
655-
assert(self.thread is not None)
678+
assert(len(self.threads) > 0)
656679
if self.needs_exclusive():
657680
assert(self.inst.exclusive)
658681
self.inst.exclusive = False
659682

660683
def thread_start(self, thread):
661-
assert(self.thread is None and thread.task is self)
662-
self.thread = thread
684+
assert(thread not in self.threads and thread.task is self)
685+
self.threads.append(thread)
663686

664687
def thread_stop(self, thread):
665-
assert(thread is self.thread and thread.task is self)
666-
self.thread = None
667-
trap_if(self.state != Task.State.RESOLVED)
668-
assert(self.num_borrows == 0)
688+
assert(thread in self.threads and thread.task is self)
689+
self.threads.remove(thread)
690+
if len(self.threads) == 0:
691+
trap_if(self.state != Task.State.RESOLVED)
692+
assert(self.num_borrows == 0)
669693

670694
#### Subtask State
671695

@@ -2081,25 +2105,80 @@ def canon_resource_rep(rt, thread, i):
20812105
trap_if(h.rt is not rt)
20822106
return [h.rep]
20832107

2108+
### 🧵 `canon thread.index`
2109+
2110+
def canon_thread_index(shared, thread):
2111+
assert(not shared)
2112+
return [thread.index]
2113+
2114+
### 🧵 `canon thread.new`
2115+
2116+
def canon_thread_new(ft, ftbl, thread, i, c):
2117+
task = thread.task
2118+
trap_if(not task.inst.may_leave)
2119+
f = task.inst.ftbl.get(i)
2120+
trap_if(f.type != ft)
2121+
new_thread = Thread(task, f(c))
2122+
return [new_thread.index]
2123+
2124+
### 🧵 `canon thread.switch-to`
2125+
2126+
def canon_thread_switch_to(thread, cancellable, i):
2127+
trap_if(not thread.task.inst.may_leave)
2128+
other = thread.task.inst.table.get(i)
2129+
trap_if(not isinstance(other, Thread))
2130+
trap_if(not other.cancellable) # TODO: what about waiting_for_callback and exclusive
2131+
cancelled = thread.switch_to(cancellable, other)
2132+
return [ 1 if cancelled else 0 ]
2133+
2134+
### 🧵 `canon thread.yield-to`
2135+
2136+
def canon_thread_yield_to(thread, cancellable, i):
2137+
trap_if(not thread.task.inst.may_leave)
2138+
other = thread.task.inst.table.get(i)
2139+
trap_if(not isinstance(other, Thread))
2140+
trap_if(not other.cancellable) # TODO: what about waiting_for_callback and exclusive
2141+
other.yield_to(cancellable, other)
2142+
return []
2143+
2144+
### 🧵 `canon thread.unblock`
2145+
2146+
def canon_thread_unblock(thread, i):
2147+
trap_if(not thread.task.inst.may_leave)
2148+
other = thread.task.inst.table.get(i)
2149+
trap_if(not isinstance(other, Thread))
2150+
trap_if(not other.cancellable) # TODO: what about waiting_for_callback and exclusive
2151+
thread.unblock()
2152+
return []
2153+
2154+
### 🧵 `canon thread.block`
2155+
2156+
def canon_thread_block(thread, cancellable, i):
2157+
trap_if(not thread.task.inst.may_leave)
2158+
other = thread.task.inst.table.get(i)
2159+
trap_if(not isinstance(other, Thread))
2160+
cancelled = thread.block(cancellable)
2161+
return [ 1 if cancelled else 0 ]
2162+
20842163
### 🔀 `canon context.get`
20852164

20862165
def canon_context_get(t, i, thread):
20872166
assert(t == 'i32')
2088-
assert(i < ContextLocalStorage.LENGTH)
2089-
return [thread.task.context.get(i)]
2167+
assert(i < Thread.CONTEXT_LENGTH)
2168+
return [thread.context[i]]
20902169

20912170
### 🔀 `canon context.set`
20922171

20932172
def canon_context_set(t, i, thread, v):
20942173
assert(t == 'i32')
2095-
assert(i < ContextLocalStorage.LENGTH)
2096-
thread.task.context.set(i, v)
2174+
assert(i < Thread.CONTEXT_LENGTH)
2175+
thread.context[i] = v
20972176
return []
20982177

20992178
### 🔀 `canon backpressure.set`
21002179

21012180
def canon_backpressure_set(thread, flat_args):
2102-
trap_if(thread.task.opts.sync)
2181+
# TODO: remove trap_if(thread.task.opts.sync)
21032182
assert(len(flat_args) == 1)
21042183
thread.task.inst.backpressure = bool(flat_args[0])
21052184
return []

0 commit comments

Comments
 (0)