Skip to content

Commit b6a9e96

Browse files
committed
Add cooperative threads
1 parent 7ea8ea8 commit b6a9e96

3 files changed

Lines changed: 197 additions & 104 deletions

File tree

design/mvp/Async.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ as a parameter.
240240

241241
### Context-Local Storage
242242

243+
TODO: update (also there are 2 now)
244+
243245
Each task contains a distinct mutable **context-local storage** array. The
244246
current task's context-local storage can be read and written from core wasm
245247
code by calling the [`context.get`] and [`context.set`] built-ins.

design/mvp/canonical-abi/definitions.py

Lines changed: 136 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -213,20 +213,32 @@ def tick(self):
213213

214214
class Thread:
215215
task: Task
216+
index: Optional[int]
217+
context: list[int]
216218
ready_func: Optional[Callable[[], bool]]
217219
run_lock: threading.Lock
218220
resume_lock: Optional[threading.Lock]
219221
stack: threading.Thread
222+
cancellable: bool
223+
waiting_for_callback: bool
224+
225+
CONTEXT_LENGTH = 2
220226

221227
def __init__(self, task, thread_func):
222228
self.task = task
229+
self.index = None
230+
self.context = [0] * Thread.CONTEXT_LENGTH
223231
self.ready_func = None
224232
self.run_lock = threading.Lock()
225233
self.resume_lock = None
234+
self.cancellable = False
235+
self.waiting_for_callback = False
226236
def thread_stack_base():
227237
self.run_lock.acquire()
228238
thread_func(self)
229239
self.task.thread_stop(self)
240+
if self.index is not None:
241+
self.task.inst.table.remove(self.index)
230242
self.resume_lock.release()
231243
self.stack = threading.Thread(target = thread_stack_base)
232244
self.run_lock.acquire()
@@ -247,14 +259,45 @@ def resume(self):
247259
self.resume_lock.acquire()
248260
self.resume_lock = None
249261

250-
def suspend_until(self, ready_func):
262+
def block(self, cancellable):
263+
assert(not self.cancellable)
264+
self.cancellable = cancellable
265+
self.resume_lock.release()
266+
self.run_lock.acquire()
267+
self.cancellable = False
268+
269+
def suspend_until(self, ready_func, cancellable = False):
251270
assert(not self.ready_func)
252271
if not DETERMINISTIC_PROFILE and ready_func():
253272
return
254273
self.ready_func = ready_func
255274
self.task.inst.store.waiting.append(self)
256-
self.resume_lock.release()
275+
self.block(cancellable)
276+
277+
async def switch_to(self, cancellable, other: Thread):
278+
assert(self.task.inst is other.task.inst)
279+
if other.ready_func:
280+
other.ready_func = None
281+
other.task.inst.store.waiting.remove(other)
282+
assert(not self.cancellable)
283+
self.cancellable = cancellable
284+
assert(self.resume_lock and not other.resume_lock)
285+
other.resume_lock = self.resume_lock
286+
self.resume_lock = None
287+
other.run_lock.release()
257288
self.run_lock.acquire()
289+
self.cancellable = False
290+
291+
def yield_to(self, cancellable, other: Thread):
292+
assert(not self.ready_func)
293+
self.ready_func = lambda: True
294+
self.task.inst.store.waiting.append(self)
295+
self.switch_to(cancellable, other)
296+
297+
def unblock(self, other: Thread):
298+
if not other.ready_func:
299+
other.task.inst.store.waiting.append(other)
300+
other.ready_func = lambda: True
258301

259302

260303
### Lifting and Lowering Context
@@ -432,22 +475,6 @@ def write(self, vs):
432475
assert(all(v == () for v in vs))
433476
self.progress += len(vs)
434477

435-
#### Context-Local Storage
436-
437-
class ContextLocalStorage:
438-
LENGTH = 1
439-
array: list[int]
440-
441-
def __init__(self):
442-
self.array = [0] * ContextLocalStorage.LENGTH
443-
444-
def set(self, i, v):
445-
assert(types_match_values(['i32'], [v]))
446-
self.array[i] = v
447-
448-
def get(self, i):
449-
return self.array[i]
450-
451478
#### Waitable State
452479

453480
class EventCode(IntEnum):
@@ -458,6 +485,7 @@ class EventCode(IntEnum):
458485
FUTURE_READ = 4
459486
FUTURE_WRITE = 5
460487
TASK_CANCELLED = 6
488+
THREAD_RESUMED = 7
461489

462490
EventTuple = tuple[EventCode, int, int]
463491

@@ -530,11 +558,8 @@ class State(Enum):
530558
ft: FuncType
531559
supertask: Optional[Task]
532560
on_resolve: OnResolve
533-
thread: Optional[Thread]
534-
cancellable: bool
535-
waiting_for_callback: bool
561+
threads: list[Thread]
536562
num_borrows: int
537-
context: ContextLocalStorage
538563

539564
def __init__(self, opts, inst, ft, supertask, on_resolve):
540565
self.state = Task.State.INITIAL
@@ -543,11 +568,8 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
543568
self.ft = ft
544569
self.supertask = supertask
545570
self.on_resolve = on_resolve
546-
self.thread = None
547-
self.cancellable = False
548-
self.waiting_for_callback = False
571+
self.threads = []
549572
self.num_borrows = 0
550-
self.context = ContextLocalStorage()
551573

552574
def trap_if_on_the_stack(self, inst):
553575
c = self.supertask
@@ -559,15 +581,13 @@ def needs_exclusive(self):
559581
return self.opts.sync or self.opts.callback
560582

561583
def enter(self, thread):
562-
assert(thread is self.thread and thread.task is self)
584+
assert(thread in self.threads and thread.task is self)
563585
def has_backpressure():
564586
return self.inst.backpressure or (self.needs_exclusive() and self.inst.exclusive)
565587
if has_backpressure() or self.inst.pending_tasks > 0:
566588
self.inst.pending_tasks += 1
567-
self.cancellable = True
568-
thread.suspend_until(lambda: not has_backpressure())
589+
thread.suspend_until(lambda: not has_backpressure(), cancellable = True)
569590
self.inst.pending_tasks -= 1
570-
self.cancellable = False
571591
if self.deliver_cancel():
572592
self.cancel()
573593
return False
@@ -586,27 +606,28 @@ def deliver_cancel(self) -> bool:
586606
def request_cancellation(self):
587607
assert(self.state == Task.State.INITIAL)
588608
self.state = Task.State.PENDING_CANCEL
589-
if self.cancellable and not (self.waiting_for_callback and self.inst.exclusive):
590-
self.thread.resume()
609+
if not DETERMINISTIC_PROFILE:
610+
random.shuffle(self.threads)
611+
for thread in self.threads:
612+
if thread.cancellable and not (thread.waiting_for_callback and self.inst.exclusive):
613+
thread.resume()
614+
break
591615

592616
def wait_until(self, ready_func, thread, cancellable, for_callback):
593-
assert(thread is self.thread and thread.task is self)
617+
assert(thread in self.threads and thread.task is self)
594618
if cancellable and self.deliver_cancel():
595619
return True
596-
assert(not self.cancellable)
597-
self.cancellable = cancellable
598620
if for_callback:
599621
assert(self.inst.exclusive)
600622
self.inst.exclusive = False
601-
self.waiting_for_callback = True
623+
thread.waiting_for_callback = True
602624
def ready_and_allowed():
603625
return ready_func() and not (for_callback and self.inst.exclusive)
604-
thread.suspend_until(ready_and_allowed)
626+
thread.suspend_until(ready_and_allowed, cancellable)
605627
if for_callback:
606628
assert(not self.inst.exclusive)
607629
self.inst.exclusive = True
608-
self.waiting_for_callback = False
609-
self.cancellable = False
630+
thread.waiting_for_callback = False
610631
if cancellable and self.deliver_cancel():
611632
return True
612633
return False
@@ -615,13 +636,15 @@ def wait_for_event(self, thread, wset, cancellable, for_callback) -> EventTuple:
615636
wset.num_waiting += 1
616637
cancelled = self.wait_until(wset.has_pending_event, thread, cancellable, for_callback)
617638
wset.num_waiting -= 1
639+
# TODO: somehow get a THREAD_RESUME event...
618640
if cancelled:
619641
return (EventCode.TASK_CANCELLED, 0, 0)
620642
else:
621643
return wset.get_pending_event()
622644

623645
def yield_(self, thread, cancellable, for_callback) -> EventTuple:
624646
cancelled = self.wait_until(lambda: True, thread, cancellable, for_callback)
647+
# TODO: somehow get a THREAD_RESUME event...
625648
if cancelled:
626649
return (EventCode.TASK_CANCELLED, 0, 0)
627650
else:
@@ -631,6 +654,7 @@ def poll_for_event(self, thread, wset, cancellable, for_callback) -> Optional[Ev
631654
wset.num_waiting += 1
632655
cancelled = self.wait_until(lambda: True, thread, cancellable, for_callback)
633656
wset.num_waiting -= 1
657+
# TODO: somehow get a THREAD_RESUME event...
634658
if cancelled:
635659
return (EventCode.TASK_CANCELLED, 0, 0)
636660
elif wset.has_pending_event():
@@ -652,20 +676,21 @@ def cancel(self):
652676
self.state = Task.State.RESOLVED
653677

654678
def exit(self):
655-
assert(self.thread is not None)
679+
assert(len(self.threads) > 0)
656680
if self.needs_exclusive():
657681
assert(self.inst.exclusive)
658682
self.inst.exclusive = False
659683

660684
def thread_start(self, thread):
661-
assert(self.thread is None and thread.task is self)
662-
self.thread = thread
685+
assert(thread not in self.threads and thread.task is self)
686+
self.threads.append(thread)
663687

664688
def thread_stop(self, thread):
665-
assert(thread is self.thread and thread.task is self)
666-
self.thread = None
667-
trap_if(self.state != Task.State.RESOLVED)
668-
assert(self.num_borrows == 0)
689+
assert(thread in self.threads and thread.task is self)
690+
self.threads.remove(thread)
691+
if len(self.threads) == 0:
692+
trap_if(self.state != Task.State.RESOLVED)
693+
assert(self.num_borrows == 0)
669694

670695
#### Subtask State
671696

@@ -1901,6 +1926,9 @@ def thread_func(thread):
19011926
if not task.enter(thread):
19021927
return
19031928

1929+
assert(thread.index is None)
1930+
thread.index = thread.task.inst.table.add(thread)
1931+
19041932
cx = LiftLowerContext(opts, inst, task)
19051933
args = on_start()
19061934
flat_args = lower_flat_values(cx, MAX_FLAT_PARAMS, args, ft.param_types())
@@ -2081,25 +2109,82 @@ def canon_resource_rep(rt, thread, i):
20812109
trap_if(h.rt is not rt)
20822110
return [h.rep]
20832111

2112+
### 🧵 `canon thread.index`
2113+
2114+
def canon_thread_index(shared, thread):
2115+
assert(not shared)
2116+
assert(thread.index is not None)
2117+
return [thread.index]
2118+
2119+
### 🧵 `canon thread.new`
2120+
2121+
def canon_thread_new(ft, ftbl, thread, i, c):
2122+
task = thread.task
2123+
trap_if(not task.inst.may_leave)
2124+
f = task.inst.ftbl.get(i)
2125+
trap_if(f.type != ft)
2126+
thread_func = partial(f, c)
2127+
i = task.inst.table.add(Thread(task, thread_func))
2128+
return [i]
2129+
2130+
### 🧵 `canon thread.switch-to`
2131+
2132+
def canon_thread_switch_to(thread, cancellable, i):
2133+
trap_if(not thread.task.inst.may_leave)
2134+
other = thread.task.inst.table.get(i)
2135+
trap_if(not isinstance(other, Thread))
2136+
trap_if(not other.cancellable) # TODO: what about waiting_for_callback and exclusive
2137+
cancelled = thread.switch_to(cancellable, other)
2138+
return [ 1 if cancelled else 0 ]
2139+
2140+
### 🧵 `canon thread.yield-to`
2141+
2142+
def canon_thread_yield_to(thread, cancellable, i):
2143+
trap_if(not thread.task.inst.may_leave)
2144+
other = thread.task.inst.table.get(i)
2145+
trap_if(not isinstance(other, Thread))
2146+
trap_if(not other.cancellable) # TODO: what about waiting_for_callback and exclusive
2147+
other.yield_to(cancellable, other)
2148+
return []
2149+
2150+
### 🧵 `canon thread.unblock`
2151+
2152+
def canon_thread_unblock(thread, i):
2153+
trap_if(not thread.task.inst.may_leave)
2154+
other = thread.task.inst.table.get(i)
2155+
trap_if(not isinstance(other, Thread))
2156+
trap_if(not other.cancellable) # TODO: what about waiting_for_callback and exclusive
2157+
thread.unblock()
2158+
return []
2159+
2160+
### 🧵 `canon thread.block`
2161+
2162+
def canon_thread_block(thread, cancellable, i):
2163+
trap_if(not thread.task.inst.may_leave)
2164+
other = thread.task.inst.table.get(i)
2165+
trap_if(not isinstance(other, Thread))
2166+
cancelled = thread.block(cancellable)
2167+
return [ 1 if cancelled else 0 ]
2168+
20842169
### 🔀 `canon context.get`
20852170

20862171
def canon_context_get(t, i, thread):
20872172
assert(t == 'i32')
2088-
assert(i < ContextLocalStorage.LENGTH)
2089-
return [thread.task.context.get(i)]
2173+
assert(i < Thread.CONTEXT_LENGTH)
2174+
return [thread.context[i]]
20902175

20912176
### 🔀 `canon context.set`
20922177

20932178
def canon_context_set(t, i, thread, v):
20942179
assert(t == 'i32')
2095-
assert(i < ContextLocalStorage.LENGTH)
2096-
thread.task.context.set(i, v)
2180+
assert(i < Thread.CONTEXT_LENGTH)
2181+
thread.context[i] = v
20972182
return []
20982183

20992184
### 🔀 `canon backpressure.set`
21002185

21012186
def canon_backpressure_set(thread, flat_args):
2102-
trap_if(thread.task.opts.sync)
2187+
# TODO: remove trap_if(thread.task.opts.sync)
21032188
assert(len(flat_args) == 1)
21042189
thread.task.inst.backpressure = bool(flat_args[0])
21052190
return []

0 commit comments

Comments
 (0)