Skip to content

Commit cff91f9

Browse files
committed
CABI: fix may_block to not use the current task
1 parent 299ae3a commit cff91f9

File tree

1 file changed

+41
-16
lines changed

1 file changed

+41
-16
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -185,22 +185,31 @@ class FutureType(ValType):
185185

186186
class Store:
187187
waiting: list[Thread]
188+
nesting_depth: int
188189

189190
def __init__(self):
190191
self.waiting = []
192+
self.nesting_depth = 0
191193

192194
def invoke(self, f: FuncInst, caller: Optional[Supertask], on_start, on_resolve) -> Call:
193195
host_caller = Supertask()
194196
host_caller.inst = None
195197
host_caller.supertask = caller
196-
return f(host_caller, on_start, on_resolve)
198+
self.nesting_depth += 1
199+
assert(self.nesting_depth == host_caller.num_host_callers())
200+
call = f(host_caller, on_start, on_resolve)
201+
self.nesting_depth -= 1
202+
return call
197203

198204
def tick(self):
205+
assert(self.nesting_depth == 0)
206+
self.nesting_depth = 1
199207
random.shuffle(self.waiting)
200208
for thread in self.waiting:
201209
if thread.ready():
202210
thread.resume(Cancelled.FALSE)
203-
return
211+
break
212+
self.nesting_depth = 0
204213

205214
FuncInst: Callable[[Optional[Supertask], OnStart, OnResolve], Call]
206215

@@ -211,6 +220,15 @@ class Supertask:
211220
inst: Optional[ComponentInstance]
212221
supertask: Optional[Supertask]
213222

223+
def num_host_callers(self):
224+
n = 0
225+
t = self
226+
while t is not None:
227+
if t.inst is None:
228+
n += 1
229+
t = t.supertask
230+
return n
231+
214232
class Call:
215233
request_cancellation: Callable[[], None]
216234

@@ -286,6 +304,7 @@ class ComponentInstance:
286304
handles: Table[ResourceHandle | Waitable | WaitableSet | ErrorContext]
287305
threads: Table[Thread]
288306
may_leave: bool
307+
may_block: bool
289308
backpressure: int
290309
exclusive: Optional[Task]
291310
num_waiting_to_enter: int
@@ -297,6 +316,7 @@ def __init__(self, store, parent = None):
297316
self.handles = Table()
298317
self.threads = Table()
299318
self.may_leave = True
319+
self.may_block = True
300320
self.backpressure = 0
301321
self.exclusive = None
302322
self.num_waiting_to_enter = 0
@@ -504,7 +524,7 @@ def resume(self, cancelled):
504524
cancelled = Cancelled.FALSE
505525

506526
def suspend(self, cancellable) -> Cancelled:
507-
assert(self.running() and self.task.may_block())
527+
assert(self.running() and self.task.inst.may_block)
508528
if self.task.deliver_pending_cancel(cancellable):
509529
return Cancelled.TRUE
510530
self.cancellable = cancellable
@@ -513,7 +533,7 @@ def suspend(self, cancellable) -> Cancelled:
513533
return cancelled
514534

515535
def wait_until(self, ready_func, cancellable = False) -> Cancelled:
516-
assert(self.running() and self.task.may_block())
536+
assert(self.running() and self.task.inst.may_block)
517537
if self.task.deliver_pending_cancel(cancellable):
518538
return Cancelled.TRUE
519539
if ready_func() and not DETERMINISTIC_PROFILE and random.randint(0,1):
@@ -524,7 +544,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
524544

525545
def yield_until(self, ready_func, cancellable) -> Cancelled:
526546
assert(self.running())
527-
if self.task.may_block():
547+
if self.task.inst.may_block:
528548
return self.wait_until(ready_func, cancellable)
529549
else:
530550
assert(ready_func())
@@ -669,9 +689,6 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
669689
def needs_exclusive(self):
670690
return not self.opts.async_ or self.opts.callback
671691

672-
def may_block(self):
673-
return self.ft.async_ or self.state == Task.State.RESOLVED
674-
675692
def enter(self):
676693
thread = current_thread()
677694
if self.ft.async_:
@@ -689,6 +706,9 @@ def has_backpressure():
689706
if self.needs_exclusive():
690707
assert(self.inst.exclusive is None)
691708
self.inst.exclusive = self
709+
else:
710+
assert(self.inst.may_block)
711+
self.inst.may_block = False
692712
self.register_thread(thread)
693713
return True
694714

@@ -738,13 +758,17 @@ def deliver_pending_cancel(self, cancellable) -> bool:
738758
def return_(self, result):
739759
trap_if(self.state == Task.State.RESOLVED)
740760
trap_if(self.num_borrows > 0)
761+
if not self.ft.async_:
762+
assert(not self.inst.may_block)
763+
self.inst.may_block = True
741764
assert(result is not None)
742765
self.on_resolve(result)
743766
self.state = Task.State.RESOLVED
744767

745768
def cancel(self):
746769
trap_if(self.state != Task.State.CANCEL_DELIVERED)
747770
trap_if(self.num_borrows > 0)
771+
assert(self.ft.async_)
748772
self.on_resolve(None)
749773
self.state = Task.State.RESOLVED
750774

@@ -2078,7 +2102,7 @@ def thread_func():
20782102
else:
20792103
event = (EventCode.NONE, 0, 0)
20802104
case CallbackCode.WAIT:
2081-
trap_if(not task.may_block())
2105+
trap_if(not inst.may_block)
20822106
wset = inst.handles.get(si)
20832107
trap_if(not isinstance(wset, WaitableSet))
20842108
event = wset.wait_until(lambda: not inst.exclusive, cancellable = True)
@@ -2094,6 +2118,7 @@ def thread_func():
20942118

20952119
thread = Thread(task, thread_func)
20962120
thread.resume(Cancelled.FALSE)
2121+
assert(ft.async_ or task.state == Task.State.RESOLVED)
20972122
return task
20982123

20992124
class CallbackCode(IntEnum):
@@ -2121,7 +2146,7 @@ def call_and_trap_on_throw(callee, args):
21212146
def canon_lower(opts, ft, callee: FuncInst, flat_args):
21222147
thread = current_thread()
21232148
trap_if(not thread.task.inst.may_leave)
2124-
trap_if(not thread.task.may_block() and ft.async_ and not opts.async_)
2149+
trap_if(not thread.task.inst.may_block and ft.async_ and not opts.async_)
21252150

21262151
subtask = Subtask()
21272152
cx = LiftLowerContext(opts, thread.task.inst, subtask)
@@ -2309,7 +2334,7 @@ def canon_waitable_set_new():
23092334
def canon_waitable_set_wait(cancellable, mem, si, ptr):
23102335
thread = current_thread()
23112336
trap_if(not thread.task.inst.may_leave)
2312-
trap_if(not thread.task.may_block())
2337+
trap_if(not thread.task.inst.may_block)
23132338
wset = thread.task.inst.handles.get(si)
23142339
trap_if(not isinstance(wset, WaitableSet))
23152340
event = wset.wait(cancellable)
@@ -2364,7 +2389,7 @@ def canon_waitable_join(wi, si):
23642389
def canon_subtask_cancel(async_, i):
23652390
thread = current_thread()
23662391
trap_if(not thread.task.inst.may_leave)
2367-
trap_if(not thread.task.may_block() and not async_)
2392+
trap_if(not thread.task.inst.may_block and not async_)
23682393
subtask = thread.task.inst.handles.get(i)
23692394
trap_if(not isinstance(subtask, Subtask))
23702395
trap_if(subtask.resolve_delivered())
@@ -2425,7 +2450,7 @@ def canon_stream_write(stream_t, opts, i, ptr, n):
24252450
def stream_copy(EndT, BufferT, event_code, stream_t, opts, i, ptr, n):
24262451
thread = current_thread()
24272452
trap_if(not thread.task.inst.may_leave)
2428-
trap_if(not thread.task.may_block() and not opts.async_)
2453+
trap_if(not thread.task.inst.may_block and not opts.async_)
24292454

24302455
e = thread.task.inst.handles.get(i)
24312456
trap_if(not isinstance(e, EndT))
@@ -2480,7 +2505,7 @@ def canon_future_write(future_t, opts, i, ptr):
24802505
def future_copy(EndT, BufferT, event_code, future_t, opts, i, ptr):
24812506
thread = current_thread()
24822507
trap_if(not thread.task.inst.may_leave)
2483-
trap_if(not thread.task.may_block() and not opts.async_)
2508+
trap_if(not thread.task.inst.may_block and not opts.async_)
24842509

24852510
e = thread.task.inst.handles.get(i)
24862511
trap_if(not isinstance(e, EndT))
@@ -2533,7 +2558,7 @@ def canon_future_cancel_write(future_t, async_, i):
25332558
def cancel_copy(EndT, event_code, stream_or_future_t, async_, i):
25342559
thread = current_thread()
25352560
trap_if(not thread.task.inst.may_leave)
2536-
trap_if(not thread.task.may_block() and not async_)
2561+
trap_if(not thread.task.inst.may_block and not async_)
25372562
e = thread.task.inst.handles.get(i)
25382563
trap_if(not isinstance(e, EndT))
25392564
trap_if(e.shared.t != stream_or_future_t.t)
@@ -2616,7 +2641,7 @@ def canon_thread_switch_to(cancellable, i):
26162641
def canon_thread_suspend(cancellable):
26172642
thread = current_thread()
26182643
trap_if(not thread.task.inst.may_leave)
2619-
trap_if(not thread.task.may_block())
2644+
trap_if(not thread.task.inst.may_block)
26202645
cancelled = thread.suspend(cancellable)
26212646
return [cancelled]
26222647

0 commit comments

Comments
 (0)