Skip to content

Commit 32c0f2e

Browse files
committed
CABI: fix may_block to not use the current task
1 parent a920c0c commit 32c0f2e

File tree

1 file changed

+28
-14
lines changed

1 file changed

+28
-14
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,9 @@ class ComponentInstance:
258258
parent: Optional[ComponentInstance]
259259
handles: Table[ResourceHandle | Waitable | WaitableSet | ErrorContext]
260260
threads: Table[Thread]
261+
may_enter: bool
261262
may_leave: bool
263+
may_block: bool
262264
backpressure: int
263265
exclusive: Optional[Task]
264266
num_waiting_to_enter: int
@@ -269,7 +271,9 @@ def __init__(self, store, parent = None):
269271
self.parent = parent
270272
self.handles = Table()
271273
self.threads = Table()
274+
self.may_enter = True
272275
self.may_leave = True
276+
self.may_block = True
273277
self.backpressure = 0
274278
self.exclusive = None
275279
self.num_waiting_to_enter = 0
@@ -467,6 +471,8 @@ def resume(self, cancelled):
467471
self.ready_func = None
468472
self.task.inst.store.waiting.remove(self)
469473
assert(self.cancellable or not cancelled)
474+
assert(self.task.inst.may_enter)
475+
self.task.inst.may_enter = False
470476
thread = self
471477
while True:
472478
cont = thread.cont
@@ -479,9 +485,11 @@ def resume(self, cancelled):
479485
break
480486
thread = switch_to_thread
481487
cancelled = Cancelled.FALSE
488+
assert(not self.task.inst.may_enter)
489+
self.task.inst.may_enter = True
482490

483491
def suspend(self, cancellable) -> Cancelled:
484-
assert(self.running() and self.task.may_block())
492+
assert(self.running() and self.task.inst.may_block)
485493
if self.task.deliver_pending_cancel(cancellable):
486494
return Cancelled.TRUE
487495
self.cancellable = cancellable
@@ -490,7 +498,7 @@ def suspend(self, cancellable) -> Cancelled:
490498
return cancelled
491499

492500
def wait_until(self, ready_func, cancellable = False) -> Cancelled:
493-
assert(self.running() and self.task.may_block())
501+
assert(self.running() and self.task.inst.may_block)
494502
if self.task.deliver_pending_cancel(cancellable):
495503
return Cancelled.TRUE
496504
if ready_func() and not DETERMINISTIC_PROFILE and random.randint(0,1):
@@ -501,7 +509,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
501509

502510
def yield_until(self, ready_func, cancellable) -> Cancelled:
503511
assert(self.running())
504-
if self.task.may_block():
512+
if self.task.inst.may_block:
505513
return self.wait_until(ready_func, cancellable)
506514
else:
507515
assert(ready_func())
@@ -656,12 +664,11 @@ def thread_stop(self, thread):
656664
def needs_exclusive(self):
657665
return not self.opts.async_ or self.opts.callback
658666

659-
def may_block(self):
660-
return self.ft.async_ or self.state == Task.State.RESOLVED
661-
662667
def enter(self, thread):
663668
assert(thread in self.threads and thread.task is self)
664669
if not self.ft.async_:
670+
assert(self.inst.may_block) # TODO: why
671+
self.inst.may_block = False
665672
return True
666673
def has_backpressure():
667674
return self.inst.backpressure > 0 or (self.needs_exclusive() and bool(self.inst.exclusive))
@@ -712,13 +719,17 @@ def deliver_pending_cancel(self, cancellable) -> bool:
712719
def return_(self, result):
713720
trap_if(self.state == Task.State.RESOLVED)
714721
trap_if(self.num_borrows > 0)
722+
if not self.ft.async_:
723+
assert(not self.inst.may_block)
724+
self.inst.may_block = True
715725
assert(result is not None)
716726
self.on_resolve(result)
717727
self.state = Task.State.RESOLVED
718728

719729
def cancel(self):
720730
trap_if(self.state != Task.State.CANCEL_DELIVERED)
721731
trap_if(self.num_borrows > 0)
732+
assert(self.ft.async_)
722733
self.on_resolve(None)
723734
self.state = Task.State.RESOLVED
724735

@@ -2000,7 +2011,9 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
20002011
### `canon lift`
20012012

20022013
def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve) -> Call:
2014+
trap_if(not inst.may_enter)
20032015
trap_if(call_might_be_recursive(caller, inst))
2016+
20042017
task = Task(opts, inst, ft, caller, on_resolve)
20052018
def thread_func(thread):
20062019
if not task.enter(thread):
@@ -2045,7 +2058,7 @@ def thread_func(thread):
20452058
else:
20462059
event = (EventCode.NONE, 0, 0)
20472060
case CallbackCode.WAIT:
2048-
trap_if(not task.may_block())
2061+
trap_if(not inst.may_block)
20492062
wset = inst.handles.get(si)
20502063
trap_if(not isinstance(wset, WaitableSet))
20512064
event = wset.wait_until(lambda: not inst.exclusive, thread, cancellable = True)
@@ -2061,6 +2074,7 @@ def thread_func(thread):
20612074

20622075
thread = Thread(task, thread_func)
20632076
thread.resume(Cancelled.FALSE)
2077+
assert(ft.async_ or task.state == Task.State.RESOLVED)
20642078
return task
20652079

20662080
class CallbackCode(IntEnum):
@@ -2087,7 +2101,7 @@ def call_and_trap_on_throw(callee, thread, args):
20872101

20882102
def canon_lower(opts, ft, callee: FuncInst, thread, flat_args):
20892103
trap_if(not thread.task.inst.may_leave)
2090-
trap_if(not thread.task.may_block() and ft.async_ and not opts.async_)
2104+
trap_if(not thread.task.inst.may_block and ft.async_ and not opts.async_)
20912105

20922106
subtask = Subtask()
20932107
cx = LiftLowerContext(opts, thread.task.inst, subtask)
@@ -2262,7 +2276,7 @@ def canon_waitable_set_new(thread):
22622276

22632277
def canon_waitable_set_wait(cancellable, mem, thread, si, ptr):
22642278
trap_if(not thread.task.inst.may_leave)
2265-
trap_if(not thread.task.may_block())
2279+
trap_if(not thread.task.inst.may_block)
22662280
wset = thread.task.inst.handles.get(si)
22672281
trap_if(not isinstance(wset, WaitableSet))
22682282
event = wset.wait(thread, cancellable)
@@ -2313,7 +2327,7 @@ def canon_waitable_join(thread, wi, si):
23132327

23142328
def canon_subtask_cancel(async_, thread, i):
23152329
trap_if(not thread.task.inst.may_leave)
2316-
trap_if(not thread.task.may_block() and not async_)
2330+
trap_if(not thread.task.inst.may_block and not async_)
23172331
subtask = thread.task.inst.handles.get(i)
23182332
trap_if(not isinstance(subtask, Subtask))
23192333
trap_if(subtask.resolve_delivered())
@@ -2370,7 +2384,7 @@ def canon_stream_write(stream_t, opts, thread, i, ptr, n):
23702384

23712385
def stream_copy(EndT, BufferT, event_code, stream_t, opts, thread, i, ptr, n):
23722386
trap_if(not thread.task.inst.may_leave)
2373-
trap_if(not thread.task.may_block() and not opts.async_)
2387+
trap_if(not thread.task.inst.may_block and not opts.async_)
23742388

23752389
e = thread.task.inst.handles.get(i)
23762390
trap_if(not isinstance(e, EndT))
@@ -2424,7 +2438,7 @@ def canon_future_write(future_t, opts, thread, i, ptr):
24242438

24252439
def future_copy(EndT, BufferT, event_code, future_t, opts, thread, i, ptr):
24262440
trap_if(not thread.task.inst.may_leave)
2427-
trap_if(not thread.task.may_block() and not opts.async_)
2441+
trap_if(not thread.task.inst.may_block and not opts.async_)
24282442

24292443
e = thread.task.inst.handles.get(i)
24302444
trap_if(not isinstance(e, EndT))
@@ -2476,7 +2490,7 @@ def canon_future_cancel_write(future_t, async_, thread, i):
24762490

24772491
def cancel_copy(EndT, event_code, stream_or_future_t, async_, thread, i):
24782492
trap_if(not thread.task.inst.may_leave)
2479-
trap_if(not thread.task.may_block() and not async_)
2493+
trap_if(not thread.task.inst.may_block and not async_)
24802494
e = thread.task.inst.handles.get(i)
24812495
trap_if(not isinstance(e, EndT))
24822496
trap_if(e.shared.t != stream_or_future_t.t)
@@ -2553,7 +2567,7 @@ def canon_thread_resume_later(thread, i):
25532567

25542568
def canon_thread_suspend(cancellable, thread):
25552569
trap_if(not thread.task.inst.may_leave)
2556-
trap_if(not thread.task.may_block())
2570+
trap_if(not thread.task.inst.may_block)
25572571
cancelled = thread.suspend(cancellable)
25582572
return [cancelled]
25592573

0 commit comments

Comments
 (0)