Skip to content

Commit 47c997e

Browse files
committed
CABI: fix may_block to not use the current task
1 parent f35b72d commit 47c997e

File tree

1 file changed

+24
-14
lines changed

1 file changed

+24
-14
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ class ComponentInstance:
259259
handles: Table[ResourceHandle | Waitable | WaitableSet | ErrorContext]
260260
threads: Table[Thread]
261261
may_leave: bool
262+
may_block: bool
262263
backpressure: int
263264
exclusive: bool
264265
num_waiting_to_enter: int
@@ -270,6 +271,7 @@ def __init__(self, store, parent = None):
270271
self.handles = Table()
271272
self.threads = Table()
272273
self.may_leave = True
274+
self.may_block = True
273275
self.backpressure = 0
274276
self.exclusive = False
275277
self.num_waiting_to_enter = 0
@@ -489,7 +491,7 @@ def resume(self, cancelled):
489491
cancelled = Cancelled.FALSE
490492

491493
def suspend(self, cancellable) -> Cancelled:
492-
assert(self.running() and self.task.may_block())
494+
assert(self.running() and self.task.inst.may_block)
493495
if self.task.deliver_pending_cancel(cancellable):
494496
return Cancelled.TRUE
495497
self.cancellable = cancellable
@@ -502,7 +504,7 @@ def suspend(self, cancellable) -> Cancelled:
502504
return cancelled
503505

504506
def wait_until(self, ready_func, cancellable = False) -> Cancelled:
505-
assert(self.running() and self.task.may_block())
507+
assert(self.running() and self.task.inst.may_block)
506508
if self.task.deliver_pending_cancel(cancellable):
507509
return Cancelled.TRUE
508510
if ready_func() and not DETERMINISTIC_PROFILE and random.randint(0,1):
@@ -513,7 +515,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
513515

514516
def yield_until(self, ready_func, cancellable) -> Cancelled:
515517
assert(self.running())
516-
if self.task.may_block():
518+
if self.task.inst.may_block:
517519
return self.wait_until(ready_func, cancellable)
518520
else:
519521
assert(ready_func())
@@ -671,12 +673,14 @@ def thread_stop(self, thread):
671673
def needs_exclusive(self):
672674
return not self.opts.async_ or self.opts.callback
673675

674-
def may_block(self):
675-
return self.ft.async_ or self.state == Task.State.RESOLVED
676-
677676
def enter(self, thread):
678677
assert(thread in self.threads and thread.task is self)
679678
if not self.ft.async_:
679+
# TODO: what makes this true? where is the assert or trap_if? specifically
680+
# for sibling reentrance. maybe need to add back may_enter and only clear
681+
# at cooperative yield points (based on type)
682+
assert(self.inst.may_block)
683+
self.inst.may_block = False
680684
return True
681685
def has_backpressure():
682686
return self.inst.backpressure > 0 or (self.needs_exclusive() and self.inst.exclusive)
@@ -695,6 +699,7 @@ def has_backpressure():
695699
def exit(self):
696700
assert(len(self.threads) > 0)
697701
if not self.ft.async_:
702+
assert(self.inst.may_block)
698703
return
699704
if self.needs_exclusive():
700705
assert(self.inst.exclusive)
@@ -719,13 +724,17 @@ def deliver_pending_cancel(self, cancellable) -> bool:
719724
def return_(self, result):
720725
trap_if(self.state == Task.State.RESOLVED)
721726
trap_if(self.num_borrows > 0)
727+
if not self.ft.async_:
728+
assert(not self.inst.may_block)
729+
self.inst.may_block = True
722730
assert(result is not None)
723731
self.on_resolve(result)
724732
self.state = Task.State.RESOLVED
725733

726734
def cancel(self):
727735
trap_if(self.state != Task.State.CANCEL_DELIVERED)
728736
trap_if(self.num_borrows > 0)
737+
assert(self.ft.async_)
729738
self.on_resolve(None)
730739
self.state = Task.State.RESOLVED
731740

@@ -2052,7 +2061,7 @@ def thread_func(thread):
20522061
else:
20532062
event = (EventCode.NONE, 0, 0)
20542063
case CallbackCode.WAIT:
2055-
trap_if(not task.may_block())
2064+
trap_if(not inst.may_block)
20562065
wset = inst.handles.get(si)
20572066
trap_if(not isinstance(wset, WaitableSet))
20582067
event = wset.wait_until(lambda: not inst.exclusive, thread, cancellable = True)
@@ -2068,6 +2077,7 @@ def thread_func(thread):
20682077

20692078
thread = Thread(task, thread_func)
20702079
thread.resume(Cancelled.FALSE)
2080+
assert(ft.async_ or task.state == Task.State.RESOLVED)
20712081
return task
20722082

20732083
class CallbackCode(IntEnum):
@@ -2094,7 +2104,7 @@ def call_and_trap_on_throw(callee, thread, args):
20942104

20952105
def canon_lower(opts, ft, callee: FuncInst, thread, flat_args):
20962106
trap_if(not thread.task.inst.may_leave)
2097-
trap_if(not thread.task.may_block() and ft.async_ and not opts.async_)
2107+
trap_if(not thread.task.inst.may_block and ft.async_ and not opts.async_)
20982108

20992109
subtask = Subtask()
21002110
cx = LiftLowerContext(opts, thread.task.inst, subtask)
@@ -2271,7 +2281,7 @@ def canon_waitable_set_new(thread):
22712281

22722282
def canon_waitable_set_wait(cancellable, mem, thread, si, ptr):
22732283
trap_if(not thread.task.inst.may_leave)
2274-
trap_if(not thread.task.may_block())
2284+
trap_if(not thread.task.inst.may_block)
22752285
wset = thread.task.inst.handles.get(si)
22762286
trap_if(not isinstance(wset, WaitableSet))
22772287
event = wset.wait(thread, cancellable)
@@ -2322,7 +2332,7 @@ def canon_waitable_join(thread, wi, si):
23222332

23232333
def canon_subtask_cancel(async_, thread, i):
23242334
trap_if(not thread.task.inst.may_leave)
2325-
trap_if(not thread.task.may_block() and not async_)
2335+
trap_if(not thread.task.inst.may_block and not async_)
23262336
subtask = thread.task.inst.handles.get(i)
23272337
trap_if(not isinstance(subtask, Subtask))
23282338
trap_if(subtask.resolve_delivered())
@@ -2379,7 +2389,7 @@ def canon_stream_write(stream_t, opts, thread, i, ptr, n):
23792389

23802390
def stream_copy(EndT, BufferT, event_code, stream_t, opts, thread, i, ptr, n):
23812391
trap_if(not thread.task.inst.may_leave)
2382-
trap_if(not thread.task.may_block() and not opts.async_)
2392+
trap_if(not thread.task.inst.may_block and not opts.async_)
23832393

23842394
e = thread.task.inst.handles.get(i)
23852395
trap_if(not isinstance(e, EndT))
@@ -2433,7 +2443,7 @@ def canon_future_write(future_t, opts, thread, i, ptr):
24332443

24342444
def future_copy(EndT, BufferT, event_code, future_t, opts, thread, i, ptr):
24352445
trap_if(not thread.task.inst.may_leave)
2436-
trap_if(not thread.task.may_block() and not opts.async_)
2446+
trap_if(not thread.task.inst.may_block and not opts.async_)
24372447

24382448
e = thread.task.inst.handles.get(i)
24392449
trap_if(not isinstance(e, EndT))
@@ -2485,7 +2495,7 @@ def canon_future_cancel_write(future_t, async_, thread, i):
24852495

24862496
def cancel_copy(EndT, event_code, stream_or_future_t, async_, thread, i):
24872497
trap_if(not thread.task.inst.may_leave)
2488-
trap_if(not thread.task.may_block() and not async_)
2498+
trap_if(not thread.task.inst.may_block and not async_)
24892499
e = thread.task.inst.handles.get(i)
24902500
trap_if(not isinstance(e, EndT))
24912501
trap_if(e.shared.t != stream_or_future_t.t)
@@ -2562,7 +2572,7 @@ def canon_thread_resume_later(thread, i):
25622572

25632573
def canon_thread_suspend(cancellable, thread):
25642574
trap_if(not thread.task.inst.may_leave)
2565-
trap_if(not thread.task.may_block())
2575+
trap_if(not thread.task.inst.may_block)
25662576
cancelled = thread.suspend(cancellable)
25672577
return [cancelled]
25682578

0 commit comments

Comments
 (0)