Skip to content

Commit 1f9efb5

Browse files
committed
Add cooperative threads
1 parent f243c74 commit 1f9efb5

3 files changed

Lines changed: 212 additions & 102 deletions

File tree

design/mvp/Async.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ as a parameter.
240240

241241
### Context-Local Storage
242242

243+
TODO: update (also there are 2 now)
244+
243245
Each task contains a distinct mutable **context-local storage** array. The
244246
current task's context-local storage can be read and written from core wasm
245247
code by calling the [`context.get`] and [`context.set`] built-ins.

design/mvp/canonical-abi/definitions.py

Lines changed: 151 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -213,39 +213,58 @@ def tick(self):
213213

214214
class Thread:
215215
task: Task
216+
index: Optional[int]
217+
context: list[int]
216218
ready_func: Optional[Callable[[], bool]]
217219
run_lock: threading.Lock
218220
resume_lock: Optional[threading.Lock]
219221
stack: threading.Thread
220222
cancellable: bool
221223
cancelled: bool
224+
waiting_for_callback: bool
225+
226+
CONTEXT_LENGTH = 2
222227

223228
def __init__(self, task, thread_func):
224229
self.task = task
230+
self.index = None
231+
self.context = [0] * Thread.CONTEXT_LENGTH
225232
self.ready_func = None
226233
self.run_lock = threading.Lock()
227234
self.run_lock.acquire()
228235
self.resume_lock = None
229236
self.cancellable = False
230237
self.cancelled = False
238+
self.waiting_for_callback = False
231239
def thread_stack_base():
232240
self.run_lock.acquire()
233241
thread_func(self)
234242
self.task.thread_stop(self)
243+
if self.index is not None:
244+
self.task.inst.table.remove(self.index)
235245
self.resume_lock.release()
236246
self.stack = threading.Thread(target = thread_stack_base)
237247
self.stack.start()
238248
self.task.thread_start(self)
249+
assert(self.suspended())
250+
251+
def suspended(self):
252+
return self.ready_func is None and self.resume_lock is None
253+
254+
def pending(self):
255+
return self.ready_func is not None and self.resume_lock is None
239256

240257
def ready(self):
258+
assert(self.pending())
241259
return self.ready_func()
242260

243261
def resume(self, cancel = False):
262+
assert(self.suspended() or self.pending())
244263
if cancel:
245264
assert(self.cancellable and not self.cancelled)
246265
self.cancelled = True
247-
if self.ready_func:
248-
assert(cancel or self.ready_func())
266+
if self.pending():
267+
assert(cancel or self.ready())
249268
self.ready_func = None
250269
self.task.inst.store.pending.remove(self)
251270
assert(not self.resume_lock)
@@ -255,22 +274,52 @@ def resume(self, cancel = False):
255274
self.resume_lock.acquire()
256275
self.resume_lock = None
257276

277+
def suspend(self, cancellable) -> bool:
278+
assert(not self.cancellable and not self.cancelled)
279+
self.cancellable = cancellable
280+
self.resume_lock.release()
281+
self.run_lock.acquire()
282+
assert(self.cancellable or not self.cancelled)
283+
self.cancellable = False
284+
completed = not self.cancelled
285+
self.cancelled = False
286+
return completed
287+
258288
def suspend_until(self, ready_func, cancellable = False) -> bool:
259-
assert(not self.ready_func)
289+
assert(not self.pending())
260290
if not DETERMINISTIC_PROFILE and ready_func():
261291
return True
262292
self.ready_func = ready_func
263293
self.task.inst.store.pending.append(self)
264-
assert(not self.cancellable and not self.cancelled)
294+
return self.suspend(cancellable)
295+
296+
def switch_to(self, cancellable, other: Thread) -> bool:
297+
assert(other.suspended())
298+
assert(not self.cancellable)
265299
self.cancellable = cancellable
266-
self.resume_lock.release()
300+
assert(self.resume_lock and not other.resume_lock)
301+
other.resume_lock = self.resume_lock
302+
self.resume_lock = None
303+
assert(self.suspended())
304+
other.run_lock.release()
267305
self.run_lock.acquire()
268-
assert(self.cancellable or not self.cancelled)
269306
self.cancellable = False
270307
completed = not self.cancelled
271308
self.cancelled = False
272309
return completed
273310

311+
def yield_to(self, cancellable, other: Thread) -> bool:
312+
assert(other.suspended())
313+
assert(not self.ready_func)
314+
self.ready_func = lambda: True
315+
self.task.inst.store.pending.append(self)
316+
return self.switch_to(cancellable, other)
317+
318+
def resume_later(self, other: Thread):
319+
assert(other.suspended())
320+
other.ready_func = lambda: True
321+
other.task.inst.store.pending.append(other)
322+
274323

275324
### Lifting and Lowering Context
276325

@@ -447,22 +496,6 @@ def write(self, vs):
447496
assert(all(v == () for v in vs))
448497
self.progress += len(vs)
449498

450-
#### Context-Local Storage
451-
452-
class ContextLocalStorage:
453-
LENGTH = 1
454-
array: list[int]
455-
456-
def __init__(self):
457-
self.array = [0] * ContextLocalStorage.LENGTH
458-
459-
def set(self, i, v):
460-
assert(types_match_values(['i32'], [v]))
461-
self.array[i] = v
462-
463-
def get(self, i):
464-
return self.array[i]
465-
466499
#### Waitable State
467500

468501
class EventCode(IntEnum):
@@ -545,10 +578,8 @@ class State(Enum):
545578
ft: FuncType
546579
supertask: Optional[Task]
547580
on_resolve: OnResolve
548-
thread: Optional[Thread]
581+
threads: list[Thread]
549582
num_borrows: int
550-
waiting_for_callback: bool
551-
context: ContextLocalStorage
552583

553584
def __init__(self, opts, inst, ft, supertask, on_resolve):
554585
self.state = Task.State.INITIAL
@@ -557,10 +588,8 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
557588
self.ft = ft
558589
self.supertask = supertask
559590
self.on_resolve = on_resolve
560-
self.thread = None
591+
self.threads = []
561592
self.num_borrows = 0
562-
self.waiting_for_callback = False
563-
self.context = ContextLocalStorage()
564593

565594
def trap_if_on_the_stack(self, inst):
566595
c = self.supertask
@@ -572,7 +601,7 @@ def needs_exclusive(self):
572601
return self.opts.sync or self.opts.callback
573602

574603
def enter(self, thread):
575-
assert(thread is self.thread and thread.task is self)
604+
assert(thread in self.threads and thread.task is self)
576605
def has_backpressure():
577606
return self.inst.backpressure or (self.needs_exclusive() and self.inst.exclusive)
578607
if has_backpressure() or self.inst.pending_tasks > 0:
@@ -589,28 +618,31 @@ def has_backpressure():
589618

590619
def request_cancellation(self):
591620
assert(self.state == Task.State.INITIAL)
592-
if self.thread.cancellable and not (self.waiting_for_callback and self.inst.exclusive):
593-
self.state = Task.State.CANCEL_DELIVERED
594-
self.thread.resume(cancel = True)
595-
else:
596-
self.state = Task.State.PENDING_CANCEL
621+
if not DETERMINISTIC_PROFILE:
622+
random.shuffle(self.threads)
623+
for thread in self.threads:
624+
if thread.cancellable and not (thread.waiting_for_callback and self.inst.exclusive):
625+
self.state = Task.State.CANCEL_DELIVERED
626+
thread.resume(cancel = True)
627+
return
628+
self.state = Task.State.PENDING_CANCEL
597629

598630
def wait_until(self, ready_func, thread, cancellable, for_callback) -> bool:
599-
assert(thread is self.thread and thread.task is self)
631+
assert(thread in self.threads and thread.task is self)
600632
if cancellable and self.state == Task.State.PENDING_CANCEL:
601633
self.state = Task.State.CANCEL_DELIVERED
602634
return False
603635
if for_callback:
604636
assert(self.inst.exclusive)
605637
self.inst.exclusive = False
606-
self.waiting_for_callback = True
638+
thread.waiting_for_callback = True
607639
def ready_and_uncontended():
608640
return ready_func() and not (for_callback and self.inst.exclusive)
609641
completed = thread.suspend_until(ready_and_uncontended, cancellable)
610642
if for_callback:
611643
assert(not self.inst.exclusive)
612644
self.inst.exclusive = True
613-
self.waiting_for_callback = False
645+
thread.waiting_for_callback = False
614646
return completed
615647

616648
def yield_(self, thread, cancellable, for_callback) -> EventTuple:
@@ -653,20 +685,21 @@ def cancel(self):
653685
self.state = Task.State.RESOLVED
654686

655687
def exit(self):
656-
assert(self.thread is not None)
688+
assert(len(self.threads) > 0)
657689
if self.needs_exclusive():
658690
assert(self.inst.exclusive)
659691
self.inst.exclusive = False
660692

661693
def thread_start(self, thread):
662-
assert(self.thread is None and thread.task is self)
663-
self.thread = thread
694+
assert(thread not in self.threads and thread.task is self)
695+
self.threads.append(thread)
664696

665697
def thread_stop(self, thread):
666-
assert(thread is self.thread and thread.task is self)
667-
self.thread = None
668-
trap_if(self.state != Task.State.RESOLVED)
669-
assert(self.num_borrows == 0)
698+
assert(thread in self.threads and thread.task is self)
699+
self.threads.remove(thread)
700+
if len(self.threads) == 0:
701+
trap_if(self.state != Task.State.RESOLVED)
702+
assert(self.num_borrows == 0)
670703

671704
#### Subtask State
672705

@@ -1902,6 +1935,9 @@ def thread_func(thread):
19021935
if not task.enter(thread):
19031936
return
19041937

1938+
assert(thread.index is None)
1939+
thread.index = thread.task.inst.table.add(thread)
1940+
19051941
cx = LiftLowerContext(opts, inst, task)
19061942
args = on_start()
19071943
flat_args = lower_flat_values(cx, MAX_FLAT_PARAMS, args, ft.param_types())
@@ -2082,25 +2118,91 @@ def canon_resource_rep(rt, thread, i):
20822118
trap_if(h.rt is not rt)
20832119
return [h.rep]
20842120

2121+
### 🧵 `canon thread.index`
2122+
2123+
def canon_thread_index(shared, thread):
2124+
assert(not shared)
2125+
assert(thread.index is not None)
2126+
return [thread.index]
2127+
2128+
### 🧵 `canon thread.new`
2129+
2130+
def canon_thread_new(ft, ftbl, thread, i, c):
2131+
task = thread.task
2132+
trap_if(not task.inst.may_leave)
2133+
f = task.inst.ftbl.get(i)
2134+
trap_if(f.type != ft)
2135+
thread_func = partial(f, c)
2136+
new_thread = Thread(task, thread_func)
2137+
assert(new_thread.suspended())
2138+
new_thread.index = task.inst.table.add(thread)
2139+
return [new_thread.index]
2140+
2141+
### 🧵 `canon thread.resume-later`
2142+
2143+
def canon_thread_resume_later(thread, i):
2144+
trap_if(not thread.task.inst.may_leave)
2145+
other_thread = thread.task.inst.table.get(i)
2146+
trap_if(not isinstance(other_thread, Thread))
2147+
trap_if(not other_thread.suspended())
2148+
thread.resume_later(other_thread)
2149+
return []
2150+
2151+
### 🧵 `canon thread.switch-to`
2152+
2153+
def canon_thread_switch_to(thread, cancellable, i):
2154+
trap_if(not thread.task.inst.may_leave)
2155+
other_thread = thread.task.inst.table.get(i)
2156+
trap_if(not isinstance(other_thread, Thread))
2157+
trap_if(not other_thread.suspended())
2158+
if not thread.switch_to(cancellable, other_thread):
2159+
assert(cancellable)
2160+
return [0]
2161+
else:
2162+
return [1]
2163+
2164+
### 🧵 `canon thread.yield-to`
2165+
2166+
def canon_thread_yield_to(thread, cancellable, i):
2167+
trap_if(not thread.task.inst.may_leave)
2168+
other_thread = thread.task.inst.table.get(i)
2169+
trap_if(not isinstance(other_thread, Thread))
2170+
trap_if(not other_thread.suspended())
2171+
if not other_thread.yield_to(cancellable, other_thread):
2172+
assert(cancellable)
2173+
return [0]
2174+
else:
2175+
return [1]
2176+
2177+
### 🧵 `canon thread.suspend`
2178+
2179+
def canon_thread_suspend(thread, cancellable):
2180+
trap_if(not thread.task.inst.may_leave)
2181+
if not thread.suspend(cancellable):
2182+
assert(cancellable)
2183+
return [0]
2184+
else:
2185+
return [1]
2186+
20852187
### 🔀 `canon context.get`
20862188

20872189
def canon_context_get(t, i, thread):
20882190
assert(t == 'i32')
2089-
assert(i < ContextLocalStorage.LENGTH)
2090-
return [thread.task.context.get(i)]
2191+
assert(i < Thread.CONTEXT_LENGTH)
2192+
return [thread.context[i]]
20912193

20922194
### 🔀 `canon context.set`
20932195

20942196
def canon_context_set(t, i, thread, v):
20952197
assert(t == 'i32')
2096-
assert(i < ContextLocalStorage.LENGTH)
2097-
thread.task.context.set(i, v)
2198+
assert(i < Thread.CONTEXT_LENGTH)
2199+
thread.context[i] = v
20982200
return []
20992201

21002202
### 🔀 `canon backpressure.set`
21012203

21022204
def canon_backpressure_set(thread, flat_args):
2103-
trap_if(thread.task.opts.sync)
2205+
# TODO: remove trap_if(thread.task.opts.sync)
21042206
assert(len(flat_args) == 1)
21052207
thread.task.inst.backpressure = bool(flat_args[0])
21062208
return []

0 commit comments

Comments
 (0)