Skip to content

Commit c1e03c0

Browse files
committed
Add cooperative threads
1 parent 446f043 commit c1e03c0

3 files changed

Lines changed: 160 additions & 88 deletions

File tree

design/mvp/Async.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ as a parameter.
240240

241241
### Context-Local Storage
242242

243+
TODO: update (also there are 2 now)
244+
243245
Each task contains a distinct mutable **context-local storage** array. The
244246
current task's context-local storage can be read and written from core wasm
245247
code by calling the [`context.get`] and [`context.set`] built-ins.

design/mvp/canonical-abi/definitions.py

Lines changed: 99 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,19 @@ def tick(self):
213213

214214
class Thread:
215215
task: Task
216+
index: int
217+
context: list[int]
216218
ready_func: Optional[Callable[[], bool]]
217219
run_lock: threading.Lock
218220
resume_lock: threading.Lock
219221
stack: threading.Thread
220222

223+
CONTEXT_LENGTH = 2
224+
221225
def __init__(self, task, thread_func):
222226
self.task = task
227+
self.index = task.inst.table.add(self)
228+
self.context = [0] * Thread.CONTEXT_LENGTH
223229
self.ready_func = None
224230
self.run_lock = threading.Lock()
225231
self.run_lock.acquire()
@@ -229,6 +235,7 @@ def thread_stack_base():
229235
self.run_lock.acquire()
230236
thread_func(self)
231237
self.task.thread_stop(self)
238+
self.task.inst.table.remove(self.index)
232239
self.resume_lock.release()
233240
self.stack = threading.Thread(target = thread_stack_base)
234241
self.stack.start()
@@ -252,6 +259,22 @@ def suspend_until(self, ready_func):
252259
self.resume_lock.release()
253260
self.run_lock.acquire()
254261

262+
async def switch_to(self, cancellable, other: Thread):
263+
# deterministically switch to other, leave this blocked
264+
TODO
265+
266+
def yield_to(self, cancellable, other: Thread):
267+
# deterministically switch to other, but leave this thread unblocked
268+
TODO
269+
270+
def block(self, cancellable):
271+
# perform just the first half of switch
272+
TODO
273+
274+
def unblock(self, other: Thread):
275+
# unblock other, but deterministically keep running here
276+
TODO
277+
255278

256279
### Lifting and Lowering Context
257280

@@ -428,22 +451,6 @@ def write(self, vs):
428451
assert(all(v == () for v in vs))
429452
self.progress += len(vs)
430453

431-
#### Context-Local Storage
432-
433-
class ContextLocalStorage:
434-
LENGTH = 1
435-
array: list[int]
436-
437-
def __init__(self):
438-
self.array = [0] * ContextLocalStorage.LENGTH
439-
440-
def set(self, i, v):
441-
assert(types_match_values(['i32'], [v]))
442-
self.array[i] = v
443-
444-
def get(self, i):
445-
return self.array[i]
446-
447454
#### Waitable State
448455

449456
class EventCode(IntEnum):
@@ -454,6 +461,7 @@ class EventCode(IntEnum):
454461
FUTURE_READ = 4
455462
FUTURE_WRITE = 5
456463
TASK_CANCELLED = 6
464+
THREAD_RESUMED = 7
457465

458466
EventTuple = tuple[EventCode, int, int]
459467

@@ -526,11 +534,10 @@ class State(Enum):
526534
ft: FuncType
527535
supertask: Optional[Task]
528536
on_resolve: OnResolve
529-
thread: Optional[Thread]
537+
threads: list[Thread]
530538
cancellable: bool
531539
waiting_for_callback: bool
532540
num_borrows: int
533-
context: ContextLocalStorage
534541

535542
def __init__(self, opts, inst, ft, supertask, on_resolve):
536543
self.state = Task.State.INITIAL
@@ -539,11 +546,10 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
539546
self.ft = ft
540547
self.supertask = supertask
541548
self.on_resolve = on_resolve
542-
self.thread = None
549+
self.threads = []
543550
self.cancellable = False
544551
self.waiting_for_callback = False
545552
self.num_borrows = 0
546-
self.context = ContextLocalStorage()
547553

548554
def trap_if_on_the_stack(self, inst):
549555
c = self.supertask
@@ -555,7 +561,7 @@ def needs_exclusive(self):
555561
return self.opts.sync or self.opts.callback
556562

557563
def enter(self, thread):
558-
assert(thread is self.thread and thread.task is self)
564+
assert(thread in self.threads and thread.task is self)
559565
def has_backpressure():
560566
return self.inst.backpressure or (self.needs_exclusive() and self.inst.exclusive)
561567
if has_backpressure() or self.inst.pending_tasks > 0:
@@ -582,11 +588,13 @@ def deliver_cancel(self) -> bool:
582588
def request_cancellation(self):
583589
assert(self.state == Task.State.INITIAL)
584590
self.state = Task.State.PENDING_CANCEL
591+
# TODO: move cancellability to the Thread and then search
592+
# for a cancellable one here...
585593
if self.cancellable and not (self.waiting_for_callback and self.inst.exclusive):
586-
self.thread.resume()
594+
self.threads[0].resume()
587595

588596
def wait_until(self, ready_func, thread, cancellable, for_callback):
589-
assert(thread is self.thread and thread.task is self)
597+
assert(thread in self.threads and thread.task is self)
590598
if cancellable and self.deliver_cancel():
591599
return True
592600
assert(not self.cancellable)
@@ -609,13 +617,15 @@ def wait_for_event(self, thread, wset, cancellable, for_callback) -> EventTuple:
609617
wset.num_waiting += 1
610618
cancelled = self.wait_until(wset.has_pending_event, thread, cancellable, for_callback)
611619
wset.num_waiting -= 1
620+
# TODO: somehow get a THREAD_RESUME event...
612621
if cancelled:
613622
return (EventCode.TASK_CANCELLED, 0, 0)
614623
else:
615624
return wset.get_pending_event()
616625

617626
def yield_(self, thread, cancellable, for_callback) -> EventTuple:
618627
cancelled = self.wait_until(lambda: True, thread, cancellable, for_callback)
628+
# TODO: somehow get a THREAD_RESUME event...
619629
if cancelled:
620630
return (EventCode.TASK_CANCELLED, 0, 0)
621631
else:
@@ -625,6 +635,7 @@ def poll_for_event(self, thread, wset, cancellable, for_callback) -> Optional[Ev
625635
wset.num_waiting += 1
626636
cancelled = self.wait_until(lambda: True, thread, cancellable, for_callback)
627637
wset.num_waiting -= 1
638+
# TODO: somehow get a THREAD_RESUME event...
628639
if cancelled:
629640
return (EventCode.TASK_CANCELLED, 0, 0)
630641
elif wset.has_pending_event():
@@ -646,20 +657,21 @@ def cancel(self):
646657
self.state = Task.State.RESOLVED
647658

648659
def exit(self):
649-
assert(self.thread is not None)
660+
assert(len(self.threads) > 0)
650661
if self.needs_exclusive():
651662
assert(self.inst.exclusive)
652663
self.inst.exclusive = False
653664

654665
def thread_start(self, thread):
655-
assert(self.thread is None and thread.task is self)
656-
self.thread = thread
666+
assert(thread not in self.threads and thread.task is self)
667+
self.threads.append(thread)
657668

658669
def thread_stop(self, thread):
659-
assert(thread is self.thread and thread.task is self)
660-
self.thread = None
661-
trap_if(self.state != Task.State.RESOLVED)
662-
assert(self.num_borrows == 0)
670+
assert(thread in self.threads and thread.task is self)
671+
self.threads.remove(thread)
672+
if len(self.threads) == 0:
673+
trap_if(self.state != Task.State.RESOLVED)
674+
assert(self.num_borrows == 0)
663675

664676
#### Subtask State
665677

@@ -2075,25 +2087,77 @@ def canon_resource_rep(rt, thread, i):
20752087
trap_if(h.rt is not rt)
20762088
return [h.rep]
20772089

2090+
### 🧵 `canon thread.index`
2091+
2092+
def canon_thread_index(shared, thread):
2093+
assert(not shared)
2094+
return [thread.index]
2095+
2096+
### 🧵 `canon thread.new`
2097+
2098+
def canon_thread_new(ft, ftbl, thread, i, c):
2099+
task = thread.task
2100+
trap_if(not task.inst.may_leave)
2101+
f = task.inst.ftbl.get(i)
2102+
trap_if(f.type != ft)
2103+
new_thread = Thread(task, f(c))
2104+
return [new_thread.index]
2105+
2106+
### 🧵 `canon thread.switch-to`
2107+
2108+
def canon_thread_switch_to(thread, cancellable, i):
2109+
trap_if(not thread.task.inst.may_leave)
2110+
other = thread.task.inst.table.get(i)
2111+
trap_if(not isinstance(other, Thread))
2112+
cancelled = thread.switch_to(cancellable, other)
2113+
return [ 1 if cancelled else 0 ]
2114+
2115+
### 🧵 `canon thread.yield-to`
2116+
2117+
def canon_thread_yield_to(thread, cancellable, i):
2118+
trap_if(not thread.task.inst.may_leave)
2119+
other = thread.task.inst.table.get(i)
2120+
trap_if(not isinstance(other, Thread))
2121+
other.yield_to(cancellable, other)
2122+
return []
2123+
2124+
### 🧵 `canon thread.block`
2125+
2126+
def canon_thread_block(thread, cancellable, i):
2127+
trap_if(not thread.task.inst.may_leave)
2128+
other = thread.task.inst.table.get(i)
2129+
trap_if(not isinstance(other, Thread))
2130+
cancelled = thread.block(cancellable)
2131+
return [ 1 if cancelled else 0 ]
2132+
2133+
### 🧵 `canon thread.unblock`
2134+
2135+
def canon_thread_unblock(thread, i):
2136+
trap_if(not thread.task.inst.may_leave)
2137+
other = thread.task.inst.table.get(i)
2138+
trap_if(not isinstance(other, Thread))
2139+
thread.unblock()
2140+
return []
2141+
20782142
### 🔀 `canon context.get`
20792143

20802144
def canon_context_get(t, i, thread):
20812145
assert(t == 'i32')
2082-
assert(i < ContextLocalStorage.LENGTH)
2083-
return [thread.task.context.get(i)]
2146+
assert(i < Thread.CONTEXT_LENGTH)
2147+
return [thread.context[i]]
20842148

20852149
### 🔀 `canon context.set`
20862150

20872151
def canon_context_set(t, i, thread, v):
20882152
assert(t == 'i32')
2089-
assert(i < ContextLocalStorage.LENGTH)
2090-
thread.task.context.set(i, v)
2153+
assert(i < Thread.CONTEXT_LENGTH)
2154+
thread.context[i] = v
20912155
return []
20922156

20932157
### 🔀 `canon backpressure.set`
20942158

20952159
def canon_backpressure_set(thread, flat_args):
2096-
trap_if(thread.task.opts.sync)
2160+
# TODO: remove trap_if(thread.task.opts.sync)
20972161
assert(len(flat_args) == 1)
20982162
thread.task.inst.backpressure = bool(flat_args[0])
20992163
return []

0 commit comments

Comments
 (0)