Skip to content

Commit 86352ee

Browse files
committed
Add cooperative threads
1 parent 3f51807 commit 86352ee

3 files changed

Lines changed: 160 additions & 88 deletions

File tree

design/mvp/Async.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ as a parameter.
240240

241241
### Context-Local Storage
242242

243+
TODO: update (also there are 2 now)
244+
243245
Each task contains a distinct mutable **context-local storage** array. The
244246
current task's context-local storage can be read and written from core wasm
245247
code by calling the [`context.get`] and [`context.set`] built-ins.

design/mvp/canonical-abi/definitions.py

Lines changed: 99 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,19 @@ def tick(self):
213213

214214
class Thread:
215215
task: Task
216+
index: int
217+
context: list[int]
216218
ready_func: Optional[Callable[[], bool]]
217219
run_lock: threading.Lock
218220
resume_lock: threading.Lock
219221
stack: threading.Thread
220222

223+
CONTEXT_LENGTH = 2
224+
221225
def __init__(self, task, thread_func):
222226
self.task = task
227+
self.index = task.inst.table.add(self)
228+
self.context = [0] * Thread.CONTEXT_LENGTH
223229
self.ready_func = None
224230
self.run_lock = threading.Lock()
225231
self.run_lock.acquire()
@@ -230,6 +236,7 @@ def thread_stack_base():
230236
assert(self.resume_lock.locked())
231237
thread_func(self)
232238
self.task.thread_stop(self)
239+
self.task.inst.table.remove(self.index)
233240
self.resume_lock.release()
234241
self.stack = threading.Thread(target = thread_stack_base)
235242
self.stack.start()
@@ -255,6 +262,22 @@ def suspend_until(self, ready_func):
255262
self.resume_lock.release()
256263
self.run_lock.acquire()
257264

265+
async def switch_to(self, cancellable, other: Thread):
266+
# deterministically switch to other, leave this blocked
267+
TODO
268+
269+
def yield_to(self, cancellable, other: Thread):
270+
# deterministically switch to other, but leave this thread unblocked
271+
TODO
272+
273+
def block(self, cancellable):
274+
# perform just the first half of switch
275+
TODO
276+
277+
def unblock(self, other: Thread):
278+
# unblock other, but deterministically keep running here
279+
TODO
280+
258281

259282
### Lifting and Lowering Context
260283

@@ -431,22 +454,6 @@ def write(self, vs):
431454
assert(all(v == () for v in vs))
432455
self.progress += len(vs)
433456

434-
#### Context-Local Storage
435-
436-
class ContextLocalStorage:
437-
LENGTH = 1
438-
array: list[int]
439-
440-
def __init__(self):
441-
self.array = [0] * ContextLocalStorage.LENGTH
442-
443-
def set(self, i, v):
444-
assert(types_match_values(['i32'], [v]))
445-
self.array[i] = v
446-
447-
def get(self, i):
448-
return self.array[i]
449-
450457
#### Waitable State
451458

452459
class EventCode(IntEnum):
@@ -457,6 +464,7 @@ class EventCode(IntEnum):
457464
FUTURE_READ = 4
458465
FUTURE_WRITE = 5
459466
TASK_CANCELLED = 6
467+
THREAD_RESUMED = 7
460468

461469
EventTuple = tuple[EventCode, int, int]
462470

@@ -529,11 +537,10 @@ class State(Enum):
529537
ft: FuncType
530538
supertask: Optional[Task]
531539
on_resolve: OnResolve
532-
thread: Optional[Thread]
540+
threads: list[Thread]
533541
cancellable: bool
534542
waiting_for_callback: bool
535543
num_borrows: int
536-
context: ContextLocalStorage
537544

538545
def __init__(self, opts, inst, ft, supertask, on_resolve):
539546
self.state = Task.State.INITIAL
@@ -542,11 +549,10 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
542549
self.ft = ft
543550
self.supertask = supertask
544551
self.on_resolve = on_resolve
545-
self.thread = None
552+
self.threads = []
546553
self.cancellable = False
547554
self.waiting_for_callback = False
548555
self.num_borrows = 0
549-
self.context = ContextLocalStorage()
550556

551557
def trap_if_on_the_stack(self, inst):
552558
c = self.supertask
@@ -558,7 +564,7 @@ def needs_exclusive(self):
558564
return self.opts.sync or self.opts.callback
559565

560566
def enter(self, thread):
561-
assert(thread is self.thread and thread.task is self)
567+
assert(thread in self.threads and thread.task is self)
562568
def has_backpressure():
563569
return self.inst.backpressure or (self.needs_exclusive() and self.inst.exclusive)
564570
if has_backpressure() or self.inst.pending_tasks > 0:
@@ -585,11 +591,13 @@ def deliver_cancel(self) -> bool:
585591
def request_cancellation(self):
586592
assert(self.state == Task.State.INITIAL)
587593
self.state = Task.State.PENDING_CANCEL
594+
# TODO: move cancellability to the Thread and then search
595+
# for a cancellable one here...
588596
if self.cancellable and not (self.waiting_for_callback and self.inst.exclusive):
589-
self.thread.resume()
597+
self.threads[0].resume()
590598

591599
def wait_until(self, ready_func, thread, cancellable, for_callback):
592-
assert(thread is self.thread and thread.task is self)
600+
assert(thread in self.threads and thread.task is self)
593601
if cancellable and self.deliver_cancel():
594602
return True
595603
assert(not self.cancellable)
@@ -612,13 +620,15 @@ def wait_for_event(self, thread, wset, cancellable, for_callback) -> EventTuple:
612620
wset.num_waiting += 1
613621
cancelled = self.wait_until(wset.has_pending_event, thread, cancellable, for_callback)
614622
wset.num_waiting -= 1
623+
# TODO: somehow get a THREAD_RESUME event...
615624
if cancelled:
616625
return (EventCode.TASK_CANCELLED, 0, 0)
617626
else:
618627
return wset.get_pending_event()
619628

620629
def yield_(self, thread, cancellable, for_callback) -> EventTuple:
621630
cancelled = self.wait_until(lambda: True, thread, cancellable, for_callback)
631+
# TODO: somehow get a THREAD_RESUME event...
622632
if cancelled:
623633
return (EventCode.TASK_CANCELLED, 0, 0)
624634
else:
@@ -628,6 +638,7 @@ def poll_for_event(self, thread, wset, cancellable, for_callback) -> Optional[Ev
628638
wset.num_waiting += 1
629639
cancelled = self.wait_until(lambda: True, thread, cancellable, for_callback)
630640
wset.num_waiting -= 1
641+
# TODO: somehow get a THREAD_RESUME event...
631642
if cancelled:
632643
return (EventCode.TASK_CANCELLED, 0, 0)
633644
elif wset.has_pending_event():
@@ -649,20 +660,21 @@ def cancel(self):
649660
self.state = Task.State.RESOLVED
650661

651662
def exit(self):
652-
assert(self.thread is not None)
663+
assert(len(self.threads) > 0)
653664
if self.needs_exclusive():
654665
assert(self.inst.exclusive)
655666
self.inst.exclusive = False
656667

657668
def thread_start(self, thread):
658-
assert(self.thread is None and thread.task is self)
659-
self.thread = thread
669+
assert(thread not in self.threads and thread.task is self)
670+
self.threads.append(thread)
660671

661672
def thread_stop(self, thread):
662-
assert(thread is self.thread and thread.task is self)
663-
self.thread = None
664-
trap_if(self.state != Task.State.RESOLVED)
665-
assert(self.num_borrows == 0)
673+
assert(thread in self.threads and thread.task is self)
674+
self.threads.remove(thread)
675+
if len(self.threads) == 0:
676+
trap_if(self.state != Task.State.RESOLVED)
677+
assert(self.num_borrows == 0)
666678

667679
#### Subtask State
668680

@@ -2078,25 +2090,77 @@ def canon_resource_rep(rt, thread, i):
20782090
trap_if(h.rt is not rt)
20792091
return [h.rep]
20802092

2093+
### 🧵 `canon thread.index`
2094+
2095+
def canon_thread_index(shared, thread):
2096+
assert(not shared)
2097+
return [thread.index]
2098+
2099+
### 🧵 `canon thread.new`
2100+
2101+
def canon_thread_new(ft, ftbl, thread, i, c):
2102+
task = thread.task
2103+
trap_if(not task.inst.may_leave)
2104+
f = task.inst.ftbl.get(i)
2105+
trap_if(f.type != ft)
2106+
new_thread = Thread(task, f(c))
2107+
return [new_thread.index]
2108+
2109+
### 🧵 `canon thread.switch-to`
2110+
2111+
def canon_thread_switch_to(thread, cancellable, i):
2112+
trap_if(not thread.task.inst.may_leave)
2113+
other = thread.task.inst.table.get(i)
2114+
trap_if(not isinstance(other, Thread))
2115+
cancelled = thread.switch_to(cancellable, other)
2116+
return [ 1 if cancelled else 0 ]
2117+
2118+
### 🧵 `canon thread.yield-to`
2119+
2120+
def canon_thread_yield_to(thread, cancellable, i):
2121+
trap_if(not thread.task.inst.may_leave)
2122+
other = thread.task.inst.table.get(i)
2123+
trap_if(not isinstance(other, Thread))
2124+
other.yield_to(cancellable, other)
2125+
return []
2126+
2127+
### 🧵 `canon thread.block`
2128+
2129+
def canon_thread_block(thread, cancellable, i):
2130+
trap_if(not thread.task.inst.may_leave)
2131+
other = thread.task.inst.table.get(i)
2132+
trap_if(not isinstance(other, Thread))
2133+
cancelled = thread.block(cancellable)
2134+
return [ 1 if cancelled else 0 ]
2135+
2136+
### 🧵 `canon thread.unblock`
2137+
2138+
def canon_thread_unblock(thread, i):
2139+
trap_if(not thread.task.inst.may_leave)
2140+
other = thread.task.inst.table.get(i)
2141+
trap_if(not isinstance(other, Thread))
2142+
thread.unblock()
2143+
return []
2144+
20812145
### 🔀 `canon context.get`
20822146

20832147
def canon_context_get(t, i, thread):
20842148
assert(t == 'i32')
2085-
assert(i < ContextLocalStorage.LENGTH)
2086-
return [thread.task.context.get(i)]
2149+
assert(i < Thread.CONTEXT_LENGTH)
2150+
return [thread.context[i]]
20872151

20882152
### 🔀 `canon context.set`
20892153

20902154
def canon_context_set(t, i, thread, v):
20912155
assert(t == 'i32')
2092-
assert(i < ContextLocalStorage.LENGTH)
2093-
thread.task.context.set(i, v)
2156+
assert(i < Thread.CONTEXT_LENGTH)
2157+
thread.context[i] = v
20942158
return []
20952159

20962160
### 🔀 `canon backpressure.set`
20972161

20982162
def canon_backpressure_set(thread, flat_args):
2099-
trap_if(thread.task.opts.sync)
2163+
# TODO: remove trap_if(thread.task.opts.sync)
21002164
assert(len(flat_args) == 1)
21012165
thread.task.inst.backpressure = bool(flat_args[0])
21022166
return []

0 commit comments

Comments
 (0)