Skip to content

Commit e7a73b4

Browse files
committed
CABI: fix may_block to not use the current task
1 parent 53b74a3 commit e7a73b4

File tree

1 file changed

+28
-14
lines changed

1 file changed

+28
-14
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,9 @@ class ComponentInstance:
258258
parent: Optional[ComponentInstance]
259259
handles: Table[ResourceHandle | Waitable | WaitableSet | ErrorContext]
260260
threads: Table[Thread]
261+
may_enter: bool
261262
may_leave: bool
263+
may_block: bool
262264
backpressure: int
263265
exclusive: Optional[Task]
264266
num_waiting_to_enter: int
@@ -269,7 +271,9 @@ def __init__(self, store, parent = None):
269271
self.parent = parent
270272
self.handles = Table()
271273
self.threads = Table()
274+
self.may_enter = True
272275
self.may_leave = True
276+
self.may_block = True
273277
self.backpressure = 0
274278
self.exclusive = None
275279
self.num_waiting_to_enter = 0
@@ -469,6 +473,8 @@ def resume(self, cancelled):
469473
self.ready_func = None
470474
self.task.inst.store.waiting.remove(self)
471475
assert(self.cancellable or not cancelled)
476+
assert(self.task.inst.may_enter)
477+
self.task.inst.may_enter = False
472478
thread = self
473479
while True:
474480
cont = thread.cont
@@ -481,9 +487,11 @@ def resume(self, cancelled):
481487
break
482488
thread = switch_to_thread
483489
cancelled = Cancelled.FALSE
490+
assert(not self.task.inst.may_enter)
491+
self.task.inst.may_enter = True
484492

485493
def suspend(self, cancellable) -> Cancelled:
486-
assert(self.running() and self.task.may_block())
494+
assert(self.running() and self.task.inst.may_block)
487495
if self.task.deliver_pending_cancel(cancellable):
488496
return Cancelled.TRUE
489497
self.cancellable = cancellable
@@ -492,7 +500,7 @@ def suspend(self, cancellable) -> Cancelled:
492500
return cancelled
493501

494502
def wait_until(self, ready_func, cancellable = False) -> Cancelled:
495-
assert(self.running() and self.task.may_block())
503+
assert(self.running() and self.task.inst.may_block)
496504
if self.task.deliver_pending_cancel(cancellable):
497505
return Cancelled.TRUE
498506
if ready_func() and not DETERMINISTIC_PROFILE and random.randint(0,1):
@@ -503,7 +511,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
503511

504512
def yield_until(self, ready_func, cancellable) -> Cancelled:
505513
assert(self.running())
506-
if self.task.may_block():
514+
if self.task.inst.may_block:
507515
return self.wait_until(ready_func, cancellable)
508516
else:
509517
assert(ready_func())
@@ -658,12 +666,11 @@ def thread_stop(self, thread):
658666
def needs_exclusive(self):
659667
return not self.opts.async_ or self.opts.callback
660668

661-
def may_block(self):
662-
return self.ft.async_ or self.state == Task.State.RESOLVED
663-
664669
def enter(self, thread):
665670
assert(thread in self.threads and thread.task is self)
666671
if not self.ft.async_:
672+
assert(self.inst.may_block) # TODO: why
673+
self.inst.may_block = False
667674
return True
668675
def has_backpressure():
669676
return self.inst.backpressure > 0 or (self.needs_exclusive() and bool(self.inst.exclusive))
@@ -714,13 +721,17 @@ def deliver_pending_cancel(self, cancellable) -> bool:
714721
def return_(self, result):
715722
trap_if(self.state == Task.State.RESOLVED)
716723
trap_if(self.num_borrows > 0)
724+
if not self.ft.async_:
725+
assert(not self.inst.may_block)
726+
self.inst.may_block = True
717727
assert(result is not None)
718728
self.on_resolve(result)
719729
self.state = Task.State.RESOLVED
720730

721731
def cancel(self):
722732
trap_if(self.state != Task.State.CANCEL_DELIVERED)
723733
trap_if(self.num_borrows > 0)
734+
assert(self.ft.async_)
724735
self.on_resolve(None)
725736
self.state = Task.State.RESOLVED
726737

@@ -2002,7 +2013,9 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
20022013
### `canon lift`
20032014

20042015
def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve) -> Call:
2016+
trap_if(not inst.may_enter)
20052017
trap_if(call_might_be_recursive(caller, inst))
2018+
20062019
task = Task(opts, inst, ft, caller, on_resolve)
20072020
def thread_func(thread):
20082021
if not task.enter(thread):
@@ -2047,7 +2060,7 @@ def thread_func(thread):
20472060
else:
20482061
event = (EventCode.NONE, 0, 0)
20492062
case CallbackCode.WAIT:
2050-
trap_if(not task.may_block())
2063+
trap_if(not inst.may_block)
20512064
wset = inst.handles.get(si)
20522065
trap_if(not isinstance(wset, WaitableSet))
20532066
event = wset.wait_until(lambda: not inst.exclusive, thread, cancellable = True)
@@ -2063,6 +2076,7 @@ def thread_func(thread):
20632076

20642077
thread = Thread(task, thread_func)
20652078
thread.resume(Cancelled.FALSE)
2079+
assert(ft.async_ or task.state == Task.State.RESOLVED)
20662080
return task
20672081

20682082
class CallbackCode(IntEnum):
@@ -2089,7 +2103,7 @@ def call_and_trap_on_throw(callee, thread, args):
20892103

20902104
def canon_lower(opts, ft, callee: FuncInst, thread, flat_args):
20912105
trap_if(not thread.task.inst.may_leave)
2092-
trap_if(not thread.task.may_block() and ft.async_ and not opts.async_)
2106+
trap_if(not thread.task.inst.may_block and ft.async_ and not opts.async_)
20932107

20942108
subtask = Subtask()
20952109
cx = LiftLowerContext(opts, thread.task.inst, subtask)
@@ -2264,7 +2278,7 @@ def canon_waitable_set_new(thread):
22642278

22652279
def canon_waitable_set_wait(cancellable, mem, thread, si, ptr):
22662280
trap_if(not thread.task.inst.may_leave)
2267-
trap_if(not thread.task.may_block())
2281+
trap_if(not thread.task.inst.may_block)
22682282
wset = thread.task.inst.handles.get(si)
22692283
trap_if(not isinstance(wset, WaitableSet))
22702284
event = wset.wait(thread, cancellable)
@@ -2315,7 +2329,7 @@ def canon_waitable_join(thread, wi, si):
23152329

23162330
def canon_subtask_cancel(async_, thread, i):
23172331
trap_if(not thread.task.inst.may_leave)
2318-
trap_if(not thread.task.may_block() and not async_)
2332+
trap_if(not thread.task.inst.may_block and not async_)
23192333
subtask = thread.task.inst.handles.get(i)
23202334
trap_if(not isinstance(subtask, Subtask))
23212335
trap_if(subtask.resolve_delivered())
@@ -2372,7 +2386,7 @@ def canon_stream_write(stream_t, opts, thread, i, ptr, n):
23722386

23732387
def stream_copy(EndT, BufferT, event_code, stream_t, opts, thread, i, ptr, n):
23742388
trap_if(not thread.task.inst.may_leave)
2375-
trap_if(not thread.task.may_block() and not opts.async_)
2389+
trap_if(not thread.task.inst.may_block and not opts.async_)
23762390

23772391
e = thread.task.inst.handles.get(i)
23782392
trap_if(not isinstance(e, EndT))
@@ -2426,7 +2440,7 @@ def canon_future_write(future_t, opts, thread, i, ptr):
24262440

24272441
def future_copy(EndT, BufferT, event_code, future_t, opts, thread, i, ptr):
24282442
trap_if(not thread.task.inst.may_leave)
2429-
trap_if(not thread.task.may_block() and not opts.async_)
2443+
trap_if(not thread.task.inst.may_block and not opts.async_)
24302444

24312445
e = thread.task.inst.handles.get(i)
24322446
trap_if(not isinstance(e, EndT))
@@ -2478,7 +2492,7 @@ def canon_future_cancel_write(future_t, async_, thread, i):
24782492

24792493
def cancel_copy(EndT, event_code, stream_or_future_t, async_, thread, i):
24802494
trap_if(not thread.task.inst.may_leave)
2481-
trap_if(not thread.task.may_block() and not async_)
2495+
trap_if(not thread.task.inst.may_block and not async_)
24822496
e = thread.task.inst.handles.get(i)
24832497
trap_if(not isinstance(e, EndT))
24842498
trap_if(e.shared.t != stream_or_future_t.t)
@@ -2555,7 +2569,7 @@ def canon_thread_resume_later(thread, i):
25552569

25562570
def canon_thread_suspend(cancellable, thread):
25572571
trap_if(not thread.task.inst.may_leave)
2558-
trap_if(not thread.task.may_block())
2572+
trap_if(not thread.task.inst.may_block)
25592573
cancelled = thread.suspend(cancellable)
25602574
return [cancelled]
25612575

0 commit comments

Comments
 (0)