Skip to content

Commit 38bd747

Browse files
committed
CABI: fix may_block to not use the current task
1 parent 41fd17b commit 38bd747

File tree

1 file changed

+24
-14
lines changed

1 file changed

+24
-14
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ class ComponentInstance:
259259
handles: Table[ResourceHandle | Waitable | WaitableSet | ErrorContext]
260260
threads: Table[Thread]
261261
may_leave: bool
262+
may_block: bool
262263
backpressure: int
263264
exclusive: bool
264265
num_waiting_to_enter: int
@@ -270,6 +271,7 @@ def __init__(self, store, parent = None):
270271
self.handles = Table()
271272
self.threads = Table()
272273
self.may_leave = True
274+
self.may_block = True
273275
self.backpressure = 0
274276
self.exclusive = False
275277
self.num_waiting_to_enter = 0
@@ -485,7 +487,7 @@ def resume(self, cancelled):
485487
cancelled = Cancelled.FALSE
486488

487489
def suspend(self, cancellable) -> Cancelled:
488-
assert(self.running() and self.task.may_block())
490+
assert(self.running() and self.task.inst.may_block)
489491
if self.task.deliver_pending_cancel(cancellable):
490492
return Cancelled.TRUE
491493
self.cancellable = cancellable
@@ -495,7 +497,7 @@ def suspend(self, cancellable) -> Cancelled:
495497
return cancelled
496498

497499
def wait_until(self, ready_func, cancellable = False) -> Cancelled:
498-
assert(self.running() and self.task.may_block())
500+
assert(self.running() and self.task.inst.may_block)
499501
if self.task.deliver_pending_cancel(cancellable):
500502
return Cancelled.TRUE
501503
if ready_func() and not DETERMINISTIC_PROFILE and random.randint(0,1):
@@ -506,7 +508,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
506508

507509
def yield_until(self, ready_func, cancellable) -> Cancelled:
508510
assert(self.running())
509-
if self.task.may_block():
511+
if self.task.inst.may_block:
510512
return self.wait_until(ready_func, cancellable)
511513
else:
512514
assert(ready_func())
@@ -661,12 +663,14 @@ def thread_stop(self, thread):
661663
def needs_exclusive(self):
662664
return not self.opts.async_ or self.opts.callback
663665

664-
def may_block(self):
665-
return self.ft.async_ or self.state == Task.State.RESOLVED
666-
667666
def enter(self, thread):
668667
assert(thread in self.threads and thread.task is self)
669668
if not self.ft.async_:
669+
# TODO: what makes this true? where is the assert or trap_if? specifically
670+
# for sibling reentrance. maybe need to add back may_enter and only clear
671+
# at cooperative yield points (based on type)
672+
assert(self.inst.may_block)
673+
self.inst.may_block = False
670674
return True
671675
def has_backpressure():
672676
return self.inst.backpressure > 0 or (self.needs_exclusive() and self.inst.exclusive)
@@ -685,6 +689,7 @@ def has_backpressure():
685689
def exit(self):
686690
assert(len(self.threads) > 0)
687691
if not self.ft.async_:
692+
assert(self.inst.may_block)
688693
return
689694
if self.needs_exclusive():
690695
assert(self.inst.exclusive)
@@ -709,13 +714,17 @@ def deliver_pending_cancel(self, cancellable) -> bool:
709714
def return_(self, result):
710715
trap_if(self.state == Task.State.RESOLVED)
711716
trap_if(self.num_borrows > 0)
717+
if not self.ft.async_:
718+
assert(not self.inst.may_block)
719+
self.inst.may_block = True
712720
assert(result is not None)
713721
self.on_resolve(result)
714722
self.state = Task.State.RESOLVED
715723

716724
def cancel(self):
717725
trap_if(self.state != Task.State.CANCEL_DELIVERED)
718726
trap_if(self.num_borrows > 0)
727+
assert(self.ft.async_)
719728
self.on_resolve(None)
720729
self.state = Task.State.RESOLVED
721730

@@ -2042,7 +2051,7 @@ def thread_func(thread):
20422051
else:
20432052
event = (EventCode.NONE, 0, 0)
20442053
case CallbackCode.WAIT:
2045-
trap_if(not task.may_block())
2054+
trap_if(not inst.may_block)
20462055
wset = inst.handles.get(si)
20472056
trap_if(not isinstance(wset, WaitableSet))
20482057
event = wset.wait_until(lambda: not inst.exclusive, thread, cancellable = True)
@@ -2058,6 +2067,7 @@ def thread_func(thread):
20582067

20592068
thread = Thread(task, thread_func)
20602069
thread.resume(Cancelled.FALSE)
2070+
assert(ft.async_ or task.state == Task.State.RESOLVED)
20612071
return task
20622072

20632073
class CallbackCode(IntEnum):
@@ -2084,7 +2094,7 @@ def call_and_trap_on_throw(callee, thread, args):
20842094

20852095
def canon_lower(opts, ft, callee: FuncInst, thread, flat_args):
20862096
trap_if(not thread.task.inst.may_leave)
2087-
trap_if(not thread.task.may_block() and ft.async_ and not opts.async_)
2097+
trap_if(not thread.task.inst.may_block and ft.async_ and not opts.async_)
20882098

20892099
subtask = Subtask()
20902100
cx = LiftLowerContext(opts, thread.task.inst, subtask)
@@ -2261,7 +2271,7 @@ def canon_waitable_set_new(thread):
22612271

22622272
def canon_waitable_set_wait(cancellable, mem, thread, si, ptr):
22632273
trap_if(not thread.task.inst.may_leave)
2264-
trap_if(not thread.task.may_block())
2274+
trap_if(not thread.task.inst.may_block)
22652275
wset = thread.task.inst.handles.get(si)
22662276
trap_if(not isinstance(wset, WaitableSet))
22672277
event = wset.wait(thread, cancellable)
@@ -2312,7 +2322,7 @@ def canon_waitable_join(thread, wi, si):
23122322

23132323
def canon_subtask_cancel(async_, thread, i):
23142324
trap_if(not thread.task.inst.may_leave)
2315-
trap_if(not thread.task.may_block() and not async_)
2325+
trap_if(not thread.task.inst.may_block and not async_)
23162326
subtask = thread.task.inst.handles.get(i)
23172327
trap_if(not isinstance(subtask, Subtask))
23182328
trap_if(subtask.resolve_delivered())
@@ -2369,7 +2379,7 @@ def canon_stream_write(stream_t, opts, thread, i, ptr, n):
23692379

23702380
def stream_copy(EndT, BufferT, event_code, stream_t, opts, thread, i, ptr, n):
23712381
trap_if(not thread.task.inst.may_leave)
2372-
trap_if(not thread.task.may_block() and not opts.async_)
2382+
trap_if(not thread.task.inst.may_block and not opts.async_)
23732383

23742384
e = thread.task.inst.handles.get(i)
23752385
trap_if(not isinstance(e, EndT))
@@ -2423,7 +2433,7 @@ def canon_future_write(future_t, opts, thread, i, ptr):
24232433

24242434
def future_copy(EndT, BufferT, event_code, future_t, opts, thread, i, ptr):
24252435
trap_if(not thread.task.inst.may_leave)
2426-
trap_if(not thread.task.may_block() and not opts.async_)
2436+
trap_if(not thread.task.inst.may_block and not opts.async_)
24272437

24282438
e = thread.task.inst.handles.get(i)
24292439
trap_if(not isinstance(e, EndT))
@@ -2475,7 +2485,7 @@ def canon_future_cancel_write(future_t, async_, thread, i):
24752485

24762486
def cancel_copy(EndT, event_code, stream_or_future_t, async_, thread, i):
24772487
trap_if(not thread.task.inst.may_leave)
2478-
trap_if(not thread.task.may_block() and not async_)
2488+
trap_if(not thread.task.inst.may_block and not async_)
24792489
e = thread.task.inst.handles.get(i)
24802490
trap_if(not isinstance(e, EndT))
24812491
trap_if(e.shared.t != stream_or_future_t.t)
@@ -2552,7 +2562,7 @@ def canon_thread_resume_later(thread, i):
25522562

25532563
def canon_thread_suspend(cancellable, thread):
25542564
trap_if(not thread.task.inst.may_leave)
2555-
trap_if(not thread.task.may_block())
2565+
trap_if(not thread.task.inst.may_block)
25562566
cancelled = thread.suspend(cancellable)
25572567
return [cancelled]
25582568

0 commit comments

Comments
 (0)