Skip to content

Commit 9d26f4c

Browse files
committed
Add cooperative threads
1 parent e574cae commit 9d26f4c

2 files changed

Lines changed: 149 additions & 76 deletions

File tree

design/mvp/canonical-abi/definitions.py

Lines changed: 91 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -215,21 +215,28 @@ def tick(self):
215215

216216
class Thread:
217217
task: Task
218+
index: int
218219
future: Optional[asyncio.Future]
219220
on_resume: Optional[asyncio.Future]
220221
on_suspend_or_exit: Optional[asyncio.Future]
222+
context: list[int]
223+
224+
CONTEXT_LENGTH = 1
221225

222226
def __init__(self, task, thread_func):
223227
self.task = task
228+
self.index = task.inst.table.add(self)
224229
self.future = None
225230
self.on_resume = asyncio.Future()
226231
self.on_suspend_or_exit = None
232+
self.context = [0] * Thread.CONTEXT_LENGTH
227233
async def thread_start():
228234
await self.on_resume
229235
self.on_resume = None
230236
await thread_func(task, self)
231237
self.on_suspend_or_exit.set_result(None)
232238
self.task.thread = None
239+
self.task.inst.table.remove(self.index)
233240
asyncio.create_task(thread_start())
234241

235242
async def resume(self):
@@ -254,6 +261,30 @@ async def suspend(self, future):
254261
await self.on_resume
255262
self.on_resume = None
256263

264+
async def switch(self, other: Thread) -> Cancelled:
265+
assert(not self.future and not other.future)
266+
assert(self.on_suspend_or_exit and not other.on_suspend_or_exit)
267+
other.on_suspend_or_exit = self.on_suspend_or_exit
268+
self.on_suspend_or_exit = None
269+
other.on_resume.set_result(Cancelled.FALSE)
270+
assert(not self.on_resume)
271+
self.on_resume = asyncio.Future()
272+
cancelled = await self.on_resume
273+
self.on_resume = None
274+
return cancelled
275+
276+
def yield_to(self, other: Thread) -> Cancelled:
277+
# deterministically switch to other, but leave this thread unblocked
278+
TODO
279+
280+
def block(self) -> Cancelled:
281+
# perform just the first half of switch
282+
TODO
283+
284+
def unblock(self, other: Thread):
285+
# unblock other, but deterministically keep running here
286+
TODO
287+
257288

258289
### Lifting and Lowering Context
259290

@@ -431,22 +462,6 @@ def write(self, vs):
431462
assert(all(v == () for v in vs))
432463
self.progress += len(vs)
433464

434-
#### Context-Local Storage
435-
436-
class ContextLocalStorage:
437-
LENGTH = 1
438-
array: list[int]
439-
440-
def __init__(self):
441-
self.array = [0] * ContextLocalStorage.LENGTH
442-
443-
def set(self, i, v):
444-
assert(types_match_values(['i32'], [v]))
445-
self.array[i] = v
446-
447-
def get(self, i):
448-
return self.array[i]
449-
450465
#### Waitable State
451466

452467
class EventCode(IntEnum):
@@ -457,6 +472,7 @@ class EventCode(IntEnum):
457472
FUTURE_READ = 4
458473
FUTURE_WRITE = 5
459474
TASK_CANCELLED = 6
475+
THREAD_RESUMED = 7
460476

461477
EventTuple = tuple[EventCode, int, int]
462478

@@ -546,7 +562,6 @@ class State(Enum):
546562
thread: Thread
547563
cancellable: bool
548564
num_borrows: int
549-
context: ContextLocalStorage
550565

551566
def __init__(self, opts, inst, ft, supertask, on_resolve, thread_func):
552567
self.state = Task.State.INITIAL
@@ -558,7 +573,6 @@ def __init__(self, opts, inst, ft, supertask, on_resolve, thread_func):
558573
self.thread = Thread(self, thread_func)
559574
self.cancellable = False
560575
self.num_borrows = 0
561-
self.context = ContextLocalStorage()
562576

563577
def trap_if_on_the_stack(self, inst):
564578
c = self.supertask
@@ -638,6 +652,7 @@ async def wait_for_event(self, thread, waitable_set, cancellable, unlock) -> Eve
638652
waitable_set.num_waiting += 1
639653
e = None
640654
while not e:
655+
# TODO: somehow get a THREAD_RESUME event...
641656
maybe_event = waitable_set.maybe_has_pending_event.wait()
642657
await self.block_on(thread, maybe_event, cancellable, unlock)
643658
if self.deliver_cancel():
@@ -650,6 +665,7 @@ async def yield_(self, thread, cancellable, unlock) -> EventTuple:
650665
assert(self.thread is thread and self is thread.task)
651666
if cancellable and self.deliver_cancel():
652667
return (EventCode.TASK_CANCELLED, 0, 0)
668+
# TODO: somehow get a THREAD_RESUME event...
653669
await self.block_on(thread, asyncio.sleep(0), cancellable, unlock)
654670
if cancellable and self.deliver_cancel():
655671
return (EventCode.TASK_CANCELLED, 0, 0)
@@ -681,7 +697,7 @@ def cancel(self):
681697
self.state = Task.State.RESOLVED
682698

683699
def exit(self):
684-
trap_if(self.state != Task.State.RESOLVED)
700+
trap_if(self.state != Task.State.RESOLVED) # TODO: move this to empty-threads case
685701
assert(self.num_borrows == 0)
686702
if self.needs_lock():
687703
self.inst.lock.release()
@@ -2102,25 +2118,76 @@ async def canon_resource_rep(rt, thread, i):
21022118
trap_if(h.rt is not rt)
21032119
return [h.rep]
21042120

2121+
### 🧵 `canon thread.index`
2122+
2123+
async def canon_thread_index(shared, thread):
2124+
assert(not shared)
2125+
return [thread.index]
2126+
2127+
### 🧵 `canon thread.new_indirect`
2128+
2129+
async def canon_thread_new_indirect(ft, ftbl, thread, i, c):
2130+
trap_if(not thread.task.inst.may_leave)
2131+
f = thread.task.inst.ftbl.get(i)
2132+
trap_if(f.type != ft)
2133+
thread = Thread(thread.task, f(c))
2134+
return [thread.index]
2135+
2136+
### 🧵 `canon thread.switch`
2137+
2138+
async def canon_thread_switch(thread, i):
2139+
trap_if(not thread.task.inst.may_leave)
2140+
other = thread.task.inst.table.get(i)
2141+
trap_if(not isinstance(other, Thread))
2142+
cancelled = await thread.switch(other)
2143+
return [ 1 if cancelled else 0 ]
2144+
2145+
### 🧵 `canon thread.yield-to`
2146+
2147+
async def canon_thread_yield_to(thread, i):
2148+
trap_if(not thread.task.inst.may_leave)
2149+
other = thread.task.inst.table.get(i)
2150+
trap_if(not isinstance(other, Thread))
2151+
other.yield_to(other)
2152+
return []
2153+
2154+
### 🧵 `canon thread.block`
2155+
2156+
async def canon_thread_block(thread, i):
2157+
trap_if(not thread.task.inst.may_leave)
2158+
other = thread.task.inst.table.get(i)
2159+
trap_if(not isinstance(other, Thread))
2160+
cancelled = await thread.block()
2161+
return [ 1 if cancelled else 0 ]
2162+
2163+
### 🧵 `canon thread.unblock`
2164+
2165+
async def canon_thread_unblock(thread, i):
2166+
trap_if(not thread.task.inst.may_leave)
2167+
other = thread.task.inst.table.get(i)
2168+
trap_if(not isinstance(other, Thread))
2169+
thread.unblock()
2170+
return []
2171+
21052172
### 🔀 `canon context.get`
21062173

21072174
async def canon_context_get(t, i, thread):
21082175
assert(t == 'i32')
2109-
assert(i < ContextLocalStorage.LENGTH)
2110-
return [thread.task.context.get(i)]
2176+
assert(i < Thread.CONTEXT_LENGTH)
2177+
return [thread.context[i]]
21112178

21122179
### 🔀 `canon context.set`
21132180

21142181
async def canon_context_set(t, i, thread, v):
21152182
assert(t == 'i32')
2116-
assert(i < ContextLocalStorage.LENGTH)
2117-
thread.task.context.set(i, v)
2183+
assert(i < Thread.CONTEXT_LENGTH)
2184+
thread.context[i] = v
21182185
return []
21192186

21202187
### 🔀 `canon backpressure.set`
21212188

21222189
async def canon_backpressure_set(thread, flat_args):
2123-
trap_if(thread.task.opts.sync)
2190+
# TODO: remove trap_if(thread.task.opts.sync)
21242191
assert(len(flat_args) == 1)
21252192
if flat_args[0] == 0:
21262193
thread.task.inst.no_backpressure.set()

0 commit comments

Comments
 (0)