Skip to content

Commit cf5e473

Browse files
committed
CABI: fix may_block to not use the current task
1 parent aec7316 commit cf5e473

File tree

1 file changed

+46
-16
lines changed

1 file changed

+46
-16
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -185,21 +185,29 @@ class FutureType(ValType):
185185

186186
class Store:
187187
waiting: list[Thread]
188+
nesting_depth: int
188189

189190
def __init__(self):
190191
self.waiting = []
192+
self.nesting_depth = 0
191193

192194
def invoke(self, f: FuncInst, caller: Optional[Supertask], on_start, on_resolve) -> Call:
193195
host_caller = Supertask()
194196
host_caller.inst = None
195197
host_caller.supertask = caller
196-
return f(host_caller, on_start, on_resolve)
198+
self.nesting_depth += 1
199+
assert(self.nesting_depth == host_caller.num_host_callers())
200+
call = f(host_caller, on_start, on_resolve)
201+
self.nesting_depth -= 1
197202

198203
def tick(self):
204+
assert(self.nesting_depth == 0)
199205
random.shuffle(self.waiting)
200206
for thread in self.waiting:
201207
if thread.ready():
208+
self.nesting_depth = 1
202209
thread.resume(Cancelled.FALSE)
210+
self.nesting_depth = 0
203211
return
204212

205213
FuncInst: Callable[[Optional[Supertask], OnStart, OnResolve], Call]
@@ -211,6 +219,15 @@ class Supertask:
211219
inst: Optional[ComponentInstance]
212220
supertask: Optional[Supertask]
213221

222+
def num_host_callers(self):
223+
n = 0
224+
t = self
225+
while t is not None:
226+
if t.inst is None:
227+
n += 1
228+
t = t.supertask
229+
return n
230+
214231
class Call:
215232
request_cancellation: Callable[[], None]
216233

@@ -285,7 +302,9 @@ class ComponentInstance:
285302
parent: Optional[ComponentInstance]
286303
handles: Table[ResourceHandle | Waitable | WaitableSet | ErrorContext]
287304
threads: Table[Thread]
305+
may_enter: bool
288306
may_leave: bool
307+
may_block: bool
289308
backpressure: int
290309
exclusive: Optional[Task]
291310
num_waiting_to_enter: int
@@ -296,7 +315,9 @@ def __init__(self, store, parent = None):
296315
self.parent = parent
297316
self.handles = Table()
298317
self.threads = Table()
318+
self.may_enter = True
299319
self.may_leave = True
320+
self.may_block = True
300321
self.backpressure = 0
301322
self.exclusive = None
302323
self.num_waiting_to_enter = 0
@@ -489,7 +510,8 @@ def resume_later(self):
489510

490511
def resume(self, cancelled):
491512
assert(self.cancellable or not cancelled)
492-
assert(not self.running())
513+
assert(not self.running() and self.task.inst.may_enter)
514+
self.task.inst.may_enter = False
493515
if self.waiting():
494516
assert(cancelled or self.ready())
495517
self.ready_func = None
@@ -506,9 +528,11 @@ def resume(self, cancelled):
506528
break
507529
thread = switch_to_thread
508530
cancelled = Cancelled.FALSE
531+
assert(not self.task.inst.may_enter)
532+
self.task.inst.may_enter = True
509533

510534
def suspend(self, cancellable) -> Cancelled:
511-
assert(self.running() and self.task.may_block())
535+
assert(self.running() and self.task.inst.may_block)
512536
if self.task.deliver_pending_cancel(cancellable):
513537
return Cancelled.TRUE
514538
self.cancellable = cancellable
@@ -517,7 +541,7 @@ def suspend(self, cancellable) -> Cancelled:
517541
return cancelled
518542

519543
def wait_until(self, ready_func, cancellable = False) -> Cancelled:
520-
assert(self.running() and self.task.may_block())
544+
assert(self.running() and self.task.inst.may_block)
521545
if self.task.deliver_pending_cancel(cancellable):
522546
return Cancelled.TRUE
523547
if ready_func() and not DETERMINISTIC_PROFILE and random.randint(0,1):
@@ -528,7 +552,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
528552

529553
def yield_until(self, ready_func, cancellable) -> Cancelled:
530554
assert(self.running())
531-
if self.task.may_block():
555+
if self.task.inst.may_block:
532556
return self.wait_until(ready_func, cancellable)
533557
else:
534558
assert(ready_func())
@@ -683,12 +707,11 @@ def thread_stop(self, thread):
683707
def needs_exclusive(self):
684708
return not self.opts.async_ or self.opts.callback
685709

686-
def may_block(self):
687-
return self.ft.async_ or self.state == Task.State.RESOLVED
688-
689710
def enter(self, thread):
690711
assert(thread in self.threads and thread.task is self)
691712
if not self.ft.async_:
713+
assert(self.inst.may_block)
714+
self.inst.may_block = False
692715
return True
693716
def has_backpressure():
694717
return self.inst.backpressure > 0 or (self.needs_exclusive() and bool(self.inst.exclusive))
@@ -739,13 +762,17 @@ def deliver_pending_cancel(self, cancellable) -> bool:
739762
def return_(self, result):
740763
trap_if(self.state == Task.State.RESOLVED)
741764
trap_if(self.num_borrows > 0)
765+
if not self.ft.async_:
766+
assert(not self.inst.may_block)
767+
self.inst.may_block = True
742768
assert(result is not None)
743769
self.on_resolve(result)
744770
self.state = Task.State.RESOLVED
745771

746772
def cancel(self):
747773
trap_if(self.state != Task.State.CANCEL_DELIVERED)
748774
trap_if(self.num_borrows > 0)
775+
assert(self.ft.async_)
749776
self.on_resolve(None)
750777
self.state = Task.State.RESOLVED
751778

@@ -2038,6 +2065,8 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
20382065

20392066
def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve) -> Call:
20402067
trap_if(call_might_be_recursive(caller, inst))
2068+
assert(inst.may_enter) # is it actually guaranteed by `call_might_be_recursive`?
2069+
20412070
task = Task(opts, inst, ft, caller, on_resolve)
20422071
def thread_func(thread):
20432072
if not task.enter(thread):
@@ -2082,7 +2111,7 @@ def thread_func(thread):
20822111
else:
20832112
event = (EventCode.NONE, 0, 0)
20842113
case CallbackCode.WAIT:
2085-
trap_if(not task.may_block())
2114+
trap_if(not inst.may_block)
20862115
wset = inst.handles.get(si)
20872116
trap_if(not isinstance(wset, WaitableSet))
20882117
event = wset.wait_until(lambda: not inst.exclusive, thread, cancellable = True)
@@ -2098,6 +2127,7 @@ def thread_func(thread):
20982127

20992128
thread = Thread(task, thread_func)
21002129
thread.resume(Cancelled.FALSE)
2130+
assert(ft.async_ or task.state == Task.State.RESOLVED)
21012131
return task
21022132

21032133
class CallbackCode(IntEnum):
@@ -2124,7 +2154,7 @@ def call_and_trap_on_throw(callee, thread, args):
21242154

21252155
def canon_lower(opts, ft, callee: FuncInst, thread, flat_args):
21262156
trap_if(not thread.task.inst.may_leave)
2127-
trap_if(not thread.task.may_block() and ft.async_ and not opts.async_)
2157+
trap_if(not thread.task.inst.may_block and ft.async_ and not opts.async_)
21282158

21292159
subtask = Subtask()
21302160
cx = LiftLowerContext(opts, thread.task.inst, subtask)
@@ -2304,7 +2334,7 @@ def canon_waitable_set_new(thread):
23042334

23052335
def canon_waitable_set_wait(cancellable, mem, thread, si, ptr):
23062336
trap_if(not thread.task.inst.may_leave)
2307-
trap_if(not thread.task.may_block())
2337+
trap_if(not thread.task.inst.may_block)
23082338
wset = thread.task.inst.handles.get(si)
23092339
trap_if(not isinstance(wset, WaitableSet))
23102340
event = wset.wait(thread, cancellable)
@@ -2355,7 +2385,7 @@ def canon_waitable_join(thread, wi, si):
23552385

23562386
def canon_subtask_cancel(async_, thread, i):
23572387
trap_if(not thread.task.inst.may_leave)
2358-
trap_if(not thread.task.may_block() and not async_)
2388+
trap_if(not thread.task.inst.may_block and not async_)
23592389
subtask = thread.task.inst.handles.get(i)
23602390
trap_if(not isinstance(subtask, Subtask))
23612391
trap_if(subtask.resolve_delivered())
@@ -2412,7 +2442,7 @@ def canon_stream_write(stream_t, opts, thread, i, ptr, n):
24122442

24132443
def stream_copy(EndT, BufferT, event_code, stream_t, opts, thread, i, ptr, n):
24142444
trap_if(not thread.task.inst.may_leave)
2415-
trap_if(not thread.task.may_block() and not opts.async_)
2445+
trap_if(not thread.task.inst.may_block and not opts.async_)
24162446

24172447
e = thread.task.inst.handles.get(i)
24182448
trap_if(not isinstance(e, EndT))
@@ -2466,7 +2496,7 @@ def canon_future_write(future_t, opts, thread, i, ptr):
24662496

24672497
def future_copy(EndT, BufferT, event_code, future_t, opts, thread, i, ptr):
24682498
trap_if(not thread.task.inst.may_leave)
2469-
trap_if(not thread.task.may_block() and not opts.async_)
2499+
trap_if(not thread.task.inst.may_block and not opts.async_)
24702500

24712501
e = thread.task.inst.handles.get(i)
24722502
trap_if(not isinstance(e, EndT))
@@ -2518,7 +2548,7 @@ def canon_future_cancel_write(future_t, async_, thread, i):
25182548

25192549
def cancel_copy(EndT, event_code, stream_or_future_t, async_, thread, i):
25202550
trap_if(not thread.task.inst.may_leave)
2521-
trap_if(not thread.task.may_block() and not async_)
2551+
trap_if(not thread.task.inst.may_block and not async_)
25222552
e = thread.task.inst.handles.get(i)
25232553
trap_if(not isinstance(e, EndT))
25242554
trap_if(e.shared.t != stream_or_future_t.t)
@@ -2595,7 +2625,7 @@ def canon_thread_switch_to(cancellable, thread, i):
25952625

25962626
def canon_thread_suspend(cancellable, thread):
25972627
trap_if(not thread.task.inst.may_leave)
2598-
trap_if(not thread.task.may_block())
2628+
trap_if(not thread.task.inst.may_block)
25992629
cancelled = thread.suspend(cancellable)
26002630
return [cancelled]
26012631

0 commit comments

Comments
 (0)