Skip to content

Commit ee685af

Browse files
committed
Add cooperative threads
1 parent 08e4335 commit ee685af

2 files changed

Lines changed: 152 additions & 76 deletions

File tree

design/mvp/canonical-abi/definitions.py

Lines changed: 94 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -354,22 +354,6 @@ def write(self, vs):
354354
assert(all(v == () for v in vs))
355355
self.progress += len(vs)
356356

357-
#### Context-Local Storage
358-
359-
class ContextLocalStorage:
360-
LENGTH = 1
361-
array: list[int]
362-
363-
def __init__(self):
364-
self.array = [0] * ContextLocalStorage.LENGTH
365-
366-
def set(self, i, v):
367-
assert(types_match_values(['i32'], [v]))
368-
self.array[i] = v
369-
370-
def get(self, i):
371-
return self.array[i]
372-
373357
#### Waitable State
374358

375359
class EventCode(IntEnum):
@@ -380,6 +364,7 @@ class EventCode(IntEnum):
380364
FUTURE_READ = 4
381365
FUTURE_WRITE = 5
382366
TASK_CANCELLED = 6
367+
THREAD_RESUMED = 7
383368

384369
EventTuple = tuple[EventCode, int, int]
385370

@@ -475,7 +460,6 @@ class State(Enum):
475460
on_resolve: OnResolve
476461
thread: Thread
477462
num_borrows: int
478-
context: ContextLocalStorage
479463

480464
def __init__(self, opts, inst, ft, supertask, on_resolve, thread_func):
481465
self.state = Task.State.INITIAL
@@ -486,7 +470,6 @@ def __init__(self, opts, inst, ft, supertask, on_resolve, thread_func):
486470
self.on_resolve = on_resolve
487471
self.thread = Thread(self, thread_func)
488472
self.num_borrows = 0
489-
self.context = ContextLocalStorage()
490473

491474
def trap_if_on_the_stack(self, inst):
492475
c = self.supertask
@@ -526,13 +509,16 @@ async def block_on(self, thread, awaitable, cancellable = False, unlock = False)
526509
return Cancelled.FALSE
527510

528511
if unlock and (self.opts.sync or self.opts.callback):
512+
# assert(thread is thread.task.main_thread)
529513
self.inst.lock.release()
530514

515+
# TODO: maybe pass 'cancellable' into 'suspend'
531516
cancelled = await thread.suspend(f)
532517
if cancelled and not cancellable:
533518
assert(await thread.suspend(f) == Cancelled.FALSE)
534519

535520
if unlock and (self.opts.sync or self.opts.callback):
521+
# assert(thread is thread.task.main_thread)
536522
acquired = asyncio.create_task(self.inst.lock.acquire())
537523
if await thread.suspend(acquired) == Cancelled.TRUE:
538524
assert(thread.suspend(acquired) == Cancelled.FALSE)
@@ -559,6 +545,7 @@ async def wait_for_event(self, thread, waitable_set, cancellable, unlock) -> Eve
559545
e = None
560546
while not e:
561547
maybe_event = waitable_set.maybe_has_pending_event.wait()
548+
# TODO: return EventCode.THREAD_RESUME
562549
if await self.block_on(thread, maybe_event, cancellable, unlock) == Cancelled.TRUE:
563550
return (EventCode.TASK_CANCELLED, 0, 0)
564551
e = waitable_set.poll()
@@ -570,6 +557,7 @@ async def yield_(self, thread, cancellable, unlock) -> EventTuple:
570557
if self.state == Task.State.PENDING_CANCEL and cancellable:
571558
self.state = Task.State.CANCEL_DELIVERED
572559
return (EventCode.TASK_CANCELLED, 0, 0)
560+
# TODO: return EventCode.THREAD_RESUME
573561
elif await self.block_on(thread, asyncio.sleep(0), cancellable, unlock) == Cancelled.TRUE:
574562
return (EventCode.TASK_CANCELLED, 0, 0)
575563
else:
@@ -601,7 +589,7 @@ def cancel(self):
601589
self.state = Task.State.RESOLVED
602590

603591
def exit(self):
604-
trap_if(self.state != Task.State.RESOLVED)
592+
trap_if(self.state != Task.State.RESOLVED) # TODO: move this to empty-threads case
605593
assert(self.num_borrows == 0)
606594
if self.opts.sync or self.opts.callback:
607595
self.inst.lock.release()
@@ -882,21 +870,28 @@ def drop(self):
882870

883871
class Thread:
884872
task: Task
873+
index: int
885874
future: Optional[asyncio.Future]
886875
on_resume: Optional[asyncio.Future]
887876
on_suspend_or_exit: Optional[asyncio.Future]
877+
context: list[int]
878+
879+
CONTEXT_LENGTH = 1
888880

889881
def __init__(self, task, thread_func):
890882
self.task = task
883+
self.index = task.inst.table.add(self)
891884
self.future = None
892885
self.on_resume = asyncio.Future()
893886
self.on_suspend_or_exit = None
887+
self.context = [0] * Thread.CONTEXT_LENGTH
894888
async def thread_start():
895889
assert(await self.on_resume == Cancelled.FALSE)
896890
self.on_resume = None
897891
await thread_func(task, self)
898892
self.on_suspend_or_exit.set_result(None)
899893
self.task.thread = None
894+
self.task.inst.table.remove(self.index)
900895
asyncio.create_task(thread_start())
901896

902897
async def resume(self, cancelled = Cancelled.FALSE):
@@ -923,6 +918,30 @@ async def suspend(self, future) -> Cancelled:
923918
self.on_resume = None
924919
return cancelled
925920

921+
async def switch(self, other: Thread) -> Cancelled:
922+
assert(not self.future and not other.future)
923+
assert(self.on_suspend_or_exit and not other.on_suspend_or_exit)
924+
other.on_suspend_or_exit = self.on_suspend_or_exit
925+
self.on_suspend_or_exit = None
926+
other.on_resume.set_result(Cancelled.FALSE)
927+
assert(not self.on_resume)
928+
self.on_resume = asyncio.Future()
929+
cancelled = await self.on_resume
930+
self.on_resume = None
931+
return cancelled
932+
933+
def yield_to(self, other: Thread) -> Cancelled:
934+
# deterministically switch to other, but leave this thread unblocked
935+
TODO
936+
937+
def block(self) -> Cancelled:
938+
# perform just the first half of switch
939+
TODO
940+
941+
def unblock(self, other: Thread):
942+
# unblock other, but deterministically keep running here
943+
TODO
944+
926945
#### Store State / Embedding API
927946

928947
class Store:
@@ -2095,25 +2114,76 @@ async def canon_resource_rep(rt, thread, i):
20952114
trap_if(h.rt is not rt)
20962115
return [h.rep]
20972116

2117+
### 🧵 `canon thread.index`
2118+
2119+
async def canon_thread_index(shared, thread):
2120+
assert(not shared)
2121+
return [thread.index]
2122+
2123+
### 🧵 `canon thread.new_indirect`
2124+
2125+
async def canon_thread_new_indirect(ft, ftbl, thread, i, c):
2126+
trap_if(not thread.task.inst.may_leave)
2127+
f = thread.task.inst.ftbl.get(i)
2128+
trap_if(f.type != ft)
2129+
thread = Thread(thread.task, f(c))
2130+
return [thread.index]
2131+
2132+
### 🧵 `canon thread.switch`
2133+
2134+
async def canon_thread_switch(thread, i):
2135+
trap_if(not thread.task.inst.may_leave)
2136+
other = thread.task.inst.table.get(i)
2137+
trap_if(not isinstance(other, Thread))
2138+
cancelled = await thread.switch(other)
2139+
return [ 1 if cancelled else 0 ]
2140+
2141+
### 🧵 `canon thread.yield-to`
2142+
2143+
async def canon_thread_yield_to(thread, i):
2144+
trap_if(not thread.task.inst.may_leave)
2145+
other = thread.task.inst.table.get(i)
2146+
trap_if(not isinstance(other, Thread))
2147+
other.yield_to(other)
2148+
return []
2149+
2150+
### 🧵 `canon thread.block`
2151+
2152+
async def canon_thread_block(thread, i):
2153+
trap_if(not thread.task.inst.may_leave)
2154+
other = thread.task.inst.table.get(i)
2155+
trap_if(not isinstance(other, Thread))
2156+
cancelled = await thread.block()
2157+
return [ 1 if cancelled else 0 ]
2158+
2159+
### 🧵 `canon thread.unblock`
2160+
2161+
async def canon_thread_unblock(thread, i):
2162+
trap_if(not thread.task.inst.may_leave)
2163+
other = thread.task.inst.table.get(i)
2164+
trap_if(not isinstance(other, Thread))
2165+
thread.unblock()
2166+
return []
2167+
20982168
### 🔀 `canon context.get`
20992169

21002170
async def canon_context_get(t, i, thread):
21012171
assert(t == 'i32')
2102-
assert(i < ContextLocalStorage.LENGTH)
2103-
return [thread.task.context.get(i)]
2172+
assert(i < Thread.CONTEXT_LENGTH)
2173+
return [thread.context[i]]
21042174

21052175
### 🔀 `canon context.set`
21062176

21072177
async def canon_context_set(t, i, thread, v):
21082178
assert(t == 'i32')
2109-
assert(i < ContextLocalStorage.LENGTH)
2110-
thread.task.context.set(i, v)
2179+
assert(i < Thread.CONTEXT_LENGTH)
2180+
thread.context[i] = v
21112181
return []
21122182

21132183
### 🔀 `canon backpressure.set`
21142184

21152185
async def canon_backpressure_set(thread, flat_args):
2116-
trap_if(thread.task.opts.sync)
2186+
# TODO: remove trap_if(thread.task.opts.sync)
21172187
assert(len(flat_args) == 1)
21182188
if flat_args[0] == 0:
21192189
thread.task.inst.no_backpressure.set()

0 commit comments

Comments
 (0)