Skip to content

Commit 0c983ce

Browse files
committed
CABI: fix may_block to not use the current task
1 parent 12cb1b4 commit 0c983ce

File tree

1 file changed

+38
-15
lines changed

1 file changed

+38
-15
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -185,21 +185,29 @@ class FutureType(ValType):
185185

186186
class Store:
187187
waiting: list[Thread]
188+
nesting_depth: int
188189

189190
def __init__(self):
190191
self.waiting = []
192+
self.nesting_depth = 0
191193

192194
def invoke(self, f: FuncInst, caller: Optional[Supertask], on_start, on_resolve) -> Call:
193195
host_caller = Supertask()
194196
host_caller.inst = None
195197
host_caller.supertask = caller
196-
return f(host_caller, on_start, on_resolve)
198+
self.nesting_depth += 1
199+
assert(self.nesting_depth == host_caller.num_host_callers())
200+
call = f(host_caller, on_start, on_resolve)
201+
self.nesting_depth -= 1
197202

198203
def tick(self):
204+
assert(self.nesting_depth == 0)
199205
random.shuffle(self.waiting)
200206
for thread in self.waiting:
201207
if thread.ready():
208+
self.nesting_depth = 1
202209
thread.resume(Cancelled.FALSE)
210+
self.nesting_depth = 0
203211
return
204212

205213
FuncInst: Callable[[Optional[Supertask], OnStart, OnResolve], Call]
@@ -211,6 +219,15 @@ class Supertask:
211219
inst: Optional[ComponentInstance]
212220
supertask: Optional[Supertask]
213221

222+
def num_host_callers(self):
223+
n = 0
224+
t = self
225+
while t is not None:
226+
if t.inst is None:
227+
n += 1
228+
t = t.supertask
229+
return n
230+
214231
class Call:
215232
request_cancellation: Callable[[], None]
216233

@@ -286,6 +303,7 @@ class ComponentInstance:
286303
handles: Table[ResourceHandle | Waitable | WaitableSet | ErrorContext]
287304
threads: Table[Thread]
288305
may_leave: bool
306+
may_block: bool
289307
backpressure: int
290308
exclusive: Optional[Task]
291309
num_waiting_to_enter: int
@@ -297,6 +315,7 @@ def __init__(self, store, parent = None):
297315
self.handles = Table()
298316
self.threads = Table()
299317
self.may_leave = True
318+
self.may_block = True
300319
self.backpressure = 0
301320
self.exclusive = None
302321
self.num_waiting_to_enter = 0
@@ -509,7 +528,7 @@ def resume(self, cancelled):
509528
cancelled = Cancelled.FALSE
510529

511530
def suspend(self, cancellable) -> Cancelled:
512-
assert(self.running() and self.task.may_block())
531+
assert(self.running() and self.task.inst.may_block)
513532
if self.task.deliver_pending_cancel(cancellable):
514533
return Cancelled.TRUE
515534
self.cancellable = cancellable
@@ -518,7 +537,7 @@ def suspend(self, cancellable) -> Cancelled:
518537
return cancelled
519538

520539
def wait_until(self, ready_func, cancellable = False) -> Cancelled:
521-
assert(self.running() and self.task.may_block())
540+
assert(self.running() and self.task.inst.may_block)
522541
if self.task.deliver_pending_cancel(cancellable):
523542
return Cancelled.TRUE
524543
if ready_func() and not DETERMINISTIC_PROFILE and random.randint(0,1):
@@ -529,7 +548,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
529548

530549
def yield_until(self, ready_func, cancellable) -> Cancelled:
531550
assert(self.running())
532-
if self.task.may_block():
551+
if self.task.inst.may_block:
533552
return self.wait_until(ready_func, cancellable)
534553
else:
535554
assert(ready_func())
@@ -684,13 +703,12 @@ def thread_stop(self, thread):
684703
def needs_exclusive(self):
685704
return not self.opts.async_ or self.opts.callback
686705

687-
def may_block(self):
688-
return self.ft.async_ or self.state == Task.State.RESOLVED
689-
690706
def enter(self):
691707
thread = current_thread()
692708
assert(thread in self.threads and thread.task is self)
693709
if not self.ft.async_:
710+
assert(self.inst.may_block)
711+
self.inst.may_block = False
694712
return True
695713
def has_backpressure():
696714
return self.inst.backpressure > 0 or (self.needs_exclusive() and bool(self.inst.exclusive))
@@ -741,13 +759,17 @@ def deliver_pending_cancel(self, cancellable) -> bool:
741759
def return_(self, result):
742760
trap_if(self.state == Task.State.RESOLVED)
743761
trap_if(self.num_borrows > 0)
762+
if not self.ft.async_:
763+
assert(not self.inst.may_block)
764+
self.inst.may_block = True
744765
assert(result is not None)
745766
self.on_resolve(result)
746767
self.state = Task.State.RESOLVED
747768

748769
def cancel(self):
749770
trap_if(self.state != Task.State.CANCEL_DELIVERED)
750771
trap_if(self.num_borrows > 0)
772+
assert(self.ft.async_)
751773
self.on_resolve(None)
752774
self.state = Task.State.RESOLVED
753775

@@ -2084,7 +2106,7 @@ def thread_func():
20842106
else:
20852107
event = (EventCode.NONE, 0, 0)
20862108
case CallbackCode.WAIT:
2087-
trap_if(not task.may_block())
2109+
trap_if(not inst.may_block)
20882110
wset = inst.handles.get(si)
20892111
trap_if(not isinstance(wset, WaitableSet))
20902112
event = wset.wait_until(lambda: not inst.exclusive, cancellable = True)
@@ -2100,6 +2122,7 @@ def thread_func():
21002122

21012123
thread = Thread(task, thread_func)
21022124
thread.resume(Cancelled.FALSE)
2125+
assert(ft.async_ or task.state == Task.State.RESOLVED)
21032126
return task
21042127

21052128
class CallbackCode(IntEnum):
@@ -2127,7 +2150,7 @@ def call_and_trap_on_throw(callee, args):
21272150
def canon_lower(opts, ft, callee: FuncInst, flat_args):
21282151
thread = current_thread()
21292152
trap_if(not thread.task.inst.may_leave)
2130-
trap_if(not thread.task.may_block() and ft.async_ and not opts.async_)
2153+
trap_if(not thread.task.inst.may_block and ft.async_ and not opts.async_)
21312154

21322155
subtask = Subtask()
21332156
cx = LiftLowerContext(opts, thread.task.inst, subtask)
@@ -2313,7 +2336,7 @@ def canon_waitable_set_new():
23132336
def canon_waitable_set_wait(cancellable, mem, si, ptr):
23142337
thread = current_thread()
23152338
trap_if(not thread.task.inst.may_leave)
2316-
trap_if(not thread.task.may_block())
2339+
trap_if(not thread.task.inst.may_block)
23172340
wset = thread.task.inst.handles.get(si)
23182341
trap_if(not isinstance(wset, WaitableSet))
23192342
event = wset.wait(cancellable)
@@ -2368,7 +2391,7 @@ def canon_waitable_join(wi, si):
23682391
def canon_subtask_cancel(async_, i):
23692392
thread = current_thread()
23702393
trap_if(not thread.task.inst.may_leave)
2371-
trap_if(not thread.task.may_block() and not async_)
2394+
trap_if(not thread.task.inst.may_block and not async_)
23722395
subtask = thread.task.inst.handles.get(i)
23732396
trap_if(not isinstance(subtask, Subtask))
23742397
trap_if(subtask.resolve_delivered())
@@ -2429,7 +2452,7 @@ def canon_stream_write(stream_t, opts, i, ptr, n):
24292452
def stream_copy(EndT, BufferT, event_code, stream_t, opts, i, ptr, n):
24302453
thread = current_thread()
24312454
trap_if(not thread.task.inst.may_leave)
2432-
trap_if(not thread.task.may_block() and not opts.async_)
2455+
trap_if(not thread.task.inst.may_block and not opts.async_)
24332456

24342457
e = thread.task.inst.handles.get(i)
24352458
trap_if(not isinstance(e, EndT))
@@ -2484,7 +2507,7 @@ def canon_future_write(future_t, opts, i, ptr):
24842507
def future_copy(EndT, BufferT, event_code, future_t, opts, i, ptr):
24852508
thread = current_thread()
24862509
trap_if(not thread.task.inst.may_leave)
2487-
trap_if(not thread.task.may_block() and not opts.async_)
2510+
trap_if(not thread.task.inst.may_block and not opts.async_)
24882511

24892512
e = thread.task.inst.handles.get(i)
24902513
trap_if(not isinstance(e, EndT))
@@ -2537,7 +2560,7 @@ def canon_future_cancel_write(future_t, async_, i):
25372560
def cancel_copy(EndT, event_code, stream_or_future_t, async_, i):
25382561
thread = current_thread()
25392562
trap_if(not thread.task.inst.may_leave)
2540-
trap_if(not thread.task.may_block() and not async_)
2563+
trap_if(not thread.task.inst.may_block and not async_)
25412564
e = thread.task.inst.handles.get(i)
25422565
trap_if(not isinstance(e, EndT))
25432566
trap_if(e.shared.t != stream_or_future_t.t)
@@ -2619,7 +2642,7 @@ def canon_thread_switch_to(cancellable, i):
26192642
def canon_thread_suspend(cancellable):
26202643
thread = current_thread()
26212644
trap_if(not thread.task.inst.may_leave)
2622-
trap_if(not thread.task.may_block())
2645+
trap_if(not thread.task.inst.may_block)
26232646
cancelled = thread.suspend(cancellable)
26242647
return [cancelled]
26252648

0 commit comments

Comments
 (0)