Skip to content

Commit 09b13d9

Browse files
committed
Add cooperative threads
1 parent 04e68a4 commit 09b13d9

2 files changed

Lines changed: 151 additions & 74 deletions

File tree

design/mvp/canonical-abi/definitions.py

Lines changed: 93 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -358,22 +358,6 @@ def write(self, vs):
358358
assert(all(v == () for v in vs))
359359
self.progress += len(vs)
360360

361-
#### Context-Local Storage
362-
363-
class ContextLocalStorage:
364-
LENGTH = 1
365-
array: list[int]
366-
367-
def __init__(self):
368-
self.array = [0] * ContextLocalStorage.LENGTH
369-
370-
def set(self, i, v):
371-
assert(types_match_values(['i32'], [v]))
372-
self.array[i] = v
373-
374-
def get(self, i):
375-
return self.array[i]
376-
377361
#### Waitable State
378362

379363
class EventCode(IntEnum):
@@ -475,7 +459,6 @@ class State(Enum):
475459
supertask: Optional[Task]
476460
on_resolve: Callable[[Optional[list[any]]], None]
477461
num_borrows: int
478-
context: ContextLocalStorage
479462

480463
def __init__(self, opts, inst, ft, supertask, on_resolve):
481464
self.state = Task.State.INITIAL
@@ -485,7 +468,6 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
485468
self.supertask = supertask
486469
self.on_resolve = on_resolve
487470
self.num_borrows = 0
488-
self.context = ContextLocalStorage()
489471

490472
async def enter(self, thread):
491473
self.trap_if_on_the_stack(self.inst)
@@ -529,6 +511,7 @@ async def wait_sync(self, thread, awaitable) -> None:
529511
if awaitable.done() and not DETERMINISTIC_PROFILE and random.randint(0,1):
530512
return
531513
assert(self.inst.interruptible.is_set())
514+
# TODO: only clear interruptible if thread == task.main_thread
532515
self.inst.interruptible.clear()
533516
if await thread.suspend(awaitable) == Cancelled.TRUE:
534517
assert(self.state == Task.State.INITIAL)
@@ -889,13 +872,19 @@ def drop(self):
889872

890873
class Thread:
891874
task: Task
875+
index: int
876+
context: list[int]
892877
awaitable: Optional[Awaitable]
893878
on_resume: Optional[asyncio.Future]
894879
on_suspend_or_exit: Optional[asyncio.Future]
895880
returned: bool
896881

882+
CONTEXT_LENGTH = 1
883+
897884
def __init__(self, task, coro):
898885
self.task = task
886+
self.index = task.inst.table.add(self)
887+
self.context = [0] * Thread.CONTEXT_LENGTH
899888
self.awaitable = None
900889
self.on_resume = asyncio.Future()
901890
self.on_suspend_or_exit = None
@@ -905,6 +894,7 @@ async def async_impl():
905894
self.on_resume = None
906895
await coro
907896
self.on_suspend_or_exit.set_result(None)
897+
self.task.inst.table.remove(self.index)
908898
self.returned = True
909899
asyncio.create_task(async_impl())
910900

@@ -932,6 +922,30 @@ async def suspend(self, awaitable) -> Cancelled:
932922
self.on_resume = None
933923
return cancelled
934924

925+
async def switch(self, other: Thread) -> Cancelled:
926+
assert(not self.awaitable and not other.awaitable)
927+
assert(self.on_suspend_or_exit and not other.on_suspend_or_exit)
928+
other.on_suspend_or_exit = self.on_suspend_or_exit
929+
self.on_suspend_or_exit = None
930+
other.on_resume.set_result(Cancelled.FALSE)
931+
assert(not self.on_resume)
932+
self.on_resume = asyncio.Future()
933+
cancelled = await self.on_resume
934+
self.on_resume = None
935+
return cancelled
936+
937+
def yield_(self, other: Thread):
938+
# deterministically switch to other, but leave this thread unblocked
939+
TODO
940+
941+
def unblock(self, other: Thread):
942+
# unblock other, but deterministically keep running here
943+
TODO
944+
945+
def wait(self) -> Cancelled:
946+
# perform just the first half of switch
947+
TODO
948+
935949
#### Store State / Embedding API
936950

937951
class Store:
@@ -2110,19 +2124,76 @@ async def canon_resource_rep(rt, thread, i):
21102124
trap_if(h.rt is not rt)
21112125
return [h.rep]
21122126

2127+
### 🧵 `canon thread.index`
2128+
2129+
async def canon_thread_index(shared, thread):
2130+
assert(not shared)
2131+
return [thread.index]
2132+
2133+
### 🧵 `canon thread.new_indirect`
2134+
2135+
async def canon_thread_new_indirect(shared, ft, ftbl, thread, i, c):
2136+
assert(not shared)
2137+
inst = thread.task.inst
2138+
trap_if(not inst.may_leave)
2139+
f = ftbl.get(i)
2140+
trap_if(f is None)
2141+
trap_if(f.type != ft)
2142+
thread = Thread(thread.task, f(c))
2143+
return [thread.index]
2144+
2145+
### 🧵 `canon thread.switch`
2146+
2147+
async def canon_thread_switch(shared, thread, i):
2148+
assert(not shared)
2149+
trap_if(not thread.task.inst.may_leave)
2150+
other = thread.task.inst.table.get(i)
2151+
trap_if(not isinstance(other, Thread))
2152+
cancelled = await thread.switch(other)
2153+
return [ 1 if cancelled else 0 ]
2154+
2155+
### 🧵 `canon thread.yield`
2156+
2157+
async def canon_thread_yield(shared, thread, i):
2158+
assert(not shared)
2159+
trap_if(not thread.task.inst.may_leave)
2160+
other = thread.task.inst.table.get(i)
2161+
trap_if(not isinstance(other, Thread))
2162+
other.yield_(other)
2163+
return []
2164+
2165+
### 🧵 `canon thread.unblock`
2166+
2167+
async def canon_thread_unblock(shared, thread, i):
2168+
trap_if(not thread.task.inst.may_leave)
2169+
other = thread.task.inst.table.get(i)
2170+
trap_if(not isinstance(other, Thread))
2171+
thread.unblock()
2172+
return []
2173+
2174+
### 🧵 `canon thread.wait`
2175+
2176+
async def canon_thread_wait(shared, thread, i):
2177+
assert(not shared)
2178+
trap_if(not thread.task.inst.may_leave)
2179+
other = thread.task.inst.table.get(i)
2180+
trap_if(not isinstance(other, Thread))
2181+
cancelled = await thread.suspend()
2182+
return [ 1 if cancelled else 0 ]
2183+
21132184
### 🔀 `canon context.get`
21142185

21152186
async def canon_context_get(t, i, thread):
21162187
assert(t == 'i32')
2117-
assert(i < ContextLocalStorage.LENGTH)
2118-
return [thread.task.context.get(i)]
2188+
assert(i < Thread.CONTEXT_LENGTH)
2189+
return [thread.context[i]]
21192190

21202191
### 🔀 `canon context.set`
21212192

21222193
async def canon_context_set(t, i, thread, v):
21232194
assert(t == 'i32')
2124-
assert(i < ContextLocalStorage.LENGTH)
2125-
thread.task.context.set(i, v)
2195+
assert(i < Thread.CONTEXT_LENGTH)
2196+
thread.context[i] = v
21262197
return []
21272198

21282199
### 🔀 `canon backpressure.set`

0 commit comments

Comments
 (0)