Skip to content

Commit 2742c71

Browse files
committed
Add cooperative threads
1 parent f243c74 commit 2742c71

3 files changed

Lines changed: 210 additions & 102 deletions

File tree

design/mvp/Async.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ as a parameter.
240240

241241
### Context-Local Storage
242242

243+
TODO: update (also there are 2 now)
244+
243245
Each task contains a distinct mutable **context-local storage** array. The
244246
current task's context-local storage can be read and written from core wasm
245247
code by calling the [`context.get`] and [`context.set`] built-ins.

design/mvp/canonical-abi/definitions.py

Lines changed: 149 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -213,39 +213,57 @@ def tick(self):
213213

214214
class Thread:
215215
task: Task
216+
index: Optional[int]
217+
context: list[int]
216218
ready_func: Optional[Callable[[], bool]]
217219
run_lock: threading.Lock
218220
resume_lock: Optional[threading.Lock]
219221
stack: threading.Thread
220222
cancellable: bool
221223
cancelled: bool
224+
waiting_for_callback: bool
225+
226+
CONTEXT_LENGTH = 2
222227

223228
def __init__(self, task, thread_func):
224229
self.task = task
230+
self.index = None
231+
self.context = [0] * Thread.CONTEXT_LENGTH
225232
self.ready_func = None
226233
self.run_lock = threading.Lock()
227234
self.run_lock.acquire()
228235
self.resume_lock = None
229236
self.cancellable = False
230237
self.cancelled = False
238+
self.waiting_for_callback = False
231239
def thread_stack_base():
232240
self.run_lock.acquire()
233241
thread_func(self)
234242
self.task.thread_stop(self)
243+
if self.index is not None:
244+
self.task.inst.table.remove(self.index)
235245
self.resume_lock.release()
236246
self.stack = threading.Thread(target = thread_stack_base)
237247
self.stack.start()
238248
self.task.thread_start(self)
249+
assert(self.suspended())
250+
251+
def suspended(self):
252+
return self.ready_func is None
253+
254+
def pending(self):
255+
return self.ready_func is not None
239256

240257
def ready(self):
258+
assert(self.pending())
241259
return self.ready_func()
242260

243261
def resume(self, cancel = False):
244262
if cancel:
245263
assert(self.cancellable and not self.cancelled)
246264
self.cancelled = True
247-
if self.ready_func:
248-
assert(cancel or self.ready_func())
265+
if self.pending():
266+
assert(cancel or self.ready())
249267
self.ready_func = None
250268
self.task.inst.store.pending.remove(self)
251269
assert(not self.resume_lock)
@@ -255,22 +273,51 @@ def resume(self, cancel = False):
255273
self.resume_lock.acquire()
256274
self.resume_lock = None
257275

276+
def suspend(self, cancellable) -> bool:
277+
assert(not self.cancellable and not self.cancelled)
278+
self.cancellable = cancellable
279+
self.resume_lock.release()
280+
self.run_lock.acquire()
281+
assert(self.cancellable or not self.cancelled)
282+
self.cancellable = False
283+
completed = not self.cancelled
284+
self.cancelled = False
285+
return completed
286+
258287
def suspend_until(self, ready_func, cancellable = False) -> bool:
259-
assert(not self.ready_func)
288+
assert(not self.pending())
260289
if not DETERMINISTIC_PROFILE and ready_func():
261290
return True
262291
self.ready_func = ready_func
263292
self.task.inst.store.pending.append(self)
264-
assert(not self.cancellable and not self.cancelled)
293+
return self.suspend(cancellable)
294+
295+
def switch_to(self, cancellable, other: Thread) -> bool:
296+
assert(other.suspended())
297+
assert(not self.cancellable)
265298
self.cancellable = cancellable
266-
self.resume_lock.release()
299+
assert(self.resume_lock and not other.resume_lock)
300+
other.resume_lock = self.resume_lock
301+
self.resume_lock = None
302+
other.run_lock.release()
267303
self.run_lock.acquire()
268-
assert(self.cancellable or not self.cancelled)
269304
self.cancellable = False
270305
completed = not self.cancelled
271306
self.cancelled = False
272307
return completed
273308

309+
def yield_to(self, cancellable, other: Thread) -> bool:
310+
assert(other.suspended())
311+
assert(not self.ready_func)
312+
self.ready_func = lambda: True
313+
self.task.inst.store.pending.append(self)
314+
return self.switch_to(cancellable, other)
315+
316+
def resume_later(self, other: Thread):
317+
assert(other.suspended())
318+
other.ready_func = lambda: True
319+
other.task.inst.store.pending.append(other)
320+
274321

275322
### Lifting and Lowering Context
276323

@@ -447,22 +494,6 @@ def write(self, vs):
447494
assert(all(v == () for v in vs))
448495
self.progress += len(vs)
449496

450-
#### Context-Local Storage
451-
452-
class ContextLocalStorage:
453-
LENGTH = 1
454-
array: list[int]
455-
456-
def __init__(self):
457-
self.array = [0] * ContextLocalStorage.LENGTH
458-
459-
def set(self, i, v):
460-
assert(types_match_values(['i32'], [v]))
461-
self.array[i] = v
462-
463-
def get(self, i):
464-
return self.array[i]
465-
466497
#### Waitable State
467498

468499
class EventCode(IntEnum):
@@ -545,10 +576,8 @@ class State(Enum):
545576
ft: FuncType
546577
supertask: Optional[Task]
547578
on_resolve: OnResolve
548-
thread: Optional[Thread]
579+
threads: list[Thread]
549580
num_borrows: int
550-
waiting_for_callback: bool
551-
context: ContextLocalStorage
552581

553582
def __init__(self, opts, inst, ft, supertask, on_resolve):
554583
self.state = Task.State.INITIAL
@@ -557,10 +586,8 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
557586
self.ft = ft
558587
self.supertask = supertask
559588
self.on_resolve = on_resolve
560-
self.thread = None
589+
self.threads = []
561590
self.num_borrows = 0
562-
self.waiting_for_callback = False
563-
self.context = ContextLocalStorage()
564591

565592
def trap_if_on_the_stack(self, inst):
566593
c = self.supertask
@@ -572,7 +599,7 @@ def needs_exclusive(self):
572599
return self.opts.sync or self.opts.callback
573600

574601
def enter(self, thread):
575-
assert(thread is self.thread and thread.task is self)
602+
assert(thread in self.threads and thread.task is self)
576603
def has_backpressure():
577604
return self.inst.backpressure or (self.needs_exclusive() and self.inst.exclusive)
578605
if has_backpressure() or self.inst.pending_tasks > 0:
@@ -589,28 +616,31 @@ def has_backpressure():
589616

590617
def request_cancellation(self):
591618
assert(self.state == Task.State.INITIAL)
592-
if self.thread.cancellable and not (self.waiting_for_callback and self.inst.exclusive):
593-
self.state = Task.State.CANCEL_DELIVERED
594-
self.thread.resume(cancel = True)
595-
else:
596-
self.state = Task.State.PENDING_CANCEL
619+
if not DETERMINISTIC_PROFILE:
620+
random.shuffle(self.threads)
621+
for thread in self.threads:
622+
if thread.cancellable and not (thread.waiting_for_callback and self.inst.exclusive):
623+
self.state = Task.State.CANCEL_DELIVERED
624+
thread.resume(cancel = True)
625+
return
626+
self.state = Task.State.PENDING_CANCEL
597627

598628
def wait_until(self, ready_func, thread, cancellable, for_callback) -> bool:
599-
assert(thread is self.thread and thread.task is self)
629+
assert(thread in self.threads and thread.task is self)
600630
if cancellable and self.state == Task.State.PENDING_CANCEL:
601631
self.state = Task.State.CANCEL_DELIVERED
602632
return False
603633
if for_callback:
604634
assert(self.inst.exclusive)
605635
self.inst.exclusive = False
606-
self.waiting_for_callback = True
636+
thread.waiting_for_callback = True
607637
def ready_and_uncontended():
608638
return ready_func() and not (for_callback and self.inst.exclusive)
609639
completed = thread.suspend_until(ready_and_uncontended, cancellable)
610640
if for_callback:
611641
assert(not self.inst.exclusive)
612642
self.inst.exclusive = True
613-
self.waiting_for_callback = False
643+
thread.waiting_for_callback = False
614644
return completed
615645

616646
def yield_(self, thread, cancellable, for_callback) -> EventTuple:
@@ -653,20 +683,21 @@ def cancel(self):
653683
self.state = Task.State.RESOLVED
654684

655685
def exit(self):
656-
assert(self.thread is not None)
686+
assert(len(self.threads) > 0)
657687
if self.needs_exclusive():
658688
assert(self.inst.exclusive)
659689
self.inst.exclusive = False
660690

661691
def thread_start(self, thread):
662-
assert(self.thread is None and thread.task is self)
663-
self.thread = thread
692+
assert(thread not in self.threads and thread.task is self)
693+
self.threads.append(thread)
664694

665695
def thread_stop(self, thread):
666-
assert(thread is self.thread and thread.task is self)
667-
self.thread = None
668-
trap_if(self.state != Task.State.RESOLVED)
669-
assert(self.num_borrows == 0)
696+
assert(thread in self.threads and thread.task is self)
697+
self.threads.remove(thread)
698+
if len(self.threads) == 0:
699+
trap_if(self.state != Task.State.RESOLVED)
700+
assert(self.num_borrows == 0)
670701

671702
#### Subtask State
672703

@@ -1902,6 +1933,9 @@ def thread_func(thread):
19021933
if not task.enter(thread):
19031934
return
19041935

1936+
assert(thread.index is None)
1937+
thread.index = thread.task.inst.table.add(thread)
1938+
19051939
cx = LiftLowerContext(opts, inst, task)
19061940
args = on_start()
19071941
flat_args = lower_flat_values(cx, MAX_FLAT_PARAMS, args, ft.param_types())
@@ -2082,25 +2116,91 @@ def canon_resource_rep(rt, thread, i):
20822116
trap_if(h.rt is not rt)
20832117
return [h.rep]
20842118

2119+
### 🧵 `canon thread.index`
2120+
2121+
def canon_thread_index(shared, thread):
2122+
assert(not shared)
2123+
assert(thread.index is not None)
2124+
return [thread.index]
2125+
2126+
### 🧵 `canon thread.new`
2127+
2128+
def canon_thread_new(ft, ftbl, thread, i, c):
2129+
task = thread.task
2130+
trap_if(not task.inst.may_leave)
2131+
f = task.inst.ftbl.get(i)
2132+
trap_if(f.type != ft)
2133+
thread_func = partial(f, c)
2134+
new_thread = Thread(task, thread_func)
2135+
assert(new_thread.suspended())
2136+
new_thread.index = task.inst.table.add(thread)
2137+
return [new_thread.index]
2138+
2139+
### 🧵 `canon thread.resume-later`
2140+
2141+
def canon_thread_resume_later(thread, i):
2142+
trap_if(not thread.task.inst.may_leave)
2143+
other_thread = thread.task.inst.table.get(i)
2144+
trap_if(not isinstance(other_thread, Thread))
2145+
trap_if(not other_thread.suspended())
2146+
thread.resume_later(other_thread)
2147+
return []
2148+
2149+
### 🧵 `canon thread.switch-to`
2150+
2151+
def canon_thread_switch_to(thread, cancellable, i):
2152+
trap_if(not thread.task.inst.may_leave)
2153+
other_thread = thread.task.inst.table.get(i)
2154+
trap_if(not isinstance(other_thread, Thread))
2155+
trap_if(not other_thread.suspended())
2156+
if not thread.switch_to(cancellable, other_thread):
2157+
assert(cancellable)
2158+
return [0]
2159+
else:
2160+
return [1]
2161+
2162+
### 🧵 `canon thread.yield-to`
2163+
2164+
def canon_thread_yield_to(thread, cancellable, i):
2165+
trap_if(not thread.task.inst.may_leave)
2166+
other_thread = thread.task.inst.table.get(i)
2167+
trap_if(not isinstance(other_thread, Thread))
2168+
trap_if(not other_thread.suspended())
2169+
if not other_thread.yield_to(cancellable, other_thread):
2170+
assert(cancellable)
2171+
return [0]
2172+
else:
2173+
return [1]
2174+
2175+
### 🧵 `canon thread.suspend`
2176+
2177+
def canon_thread_suspend(thread, cancellable):
2178+
trap_if(not thread.task.inst.may_leave)
2179+
if not thread.suspend(cancellable):
2180+
assert(cancellable)
2181+
return [0]
2182+
else:
2183+
return [1]
2184+
20852185
### 🔀 `canon context.get`
20862186

20872187
def canon_context_get(t, i, thread):
20882188
assert(t == 'i32')
2089-
assert(i < ContextLocalStorage.LENGTH)
2090-
return [thread.task.context.get(i)]
2189+
assert(i < Thread.CONTEXT_LENGTH)
2190+
return [thread.context[i]]
20912191

20922192
### 🔀 `canon context.set`
20932193

20942194
def canon_context_set(t, i, thread, v):
20952195
assert(t == 'i32')
2096-
assert(i < ContextLocalStorage.LENGTH)
2097-
thread.task.context.set(i, v)
2196+
assert(i < Thread.CONTEXT_LENGTH)
2197+
thread.context[i] = v
20982198
return []
20992199

21002200
### 🔀 `canon backpressure.set`
21012201

21022202
def canon_backpressure_set(thread, flat_args):
2103-
trap_if(thread.task.opts.sync)
2203+
# TODO: remove trap_if(thread.task.opts.sync)
21042204
assert(len(flat_args) == 1)
21052205
thread.task.inst.backpressure = bool(flat_args[0])
21062206
return []

0 commit comments

Comments
 (0)