Skip to content

Commit 4cdba5a

Browse files
committed
CABI: fix may_block to not use the current task
1 parent aec7316 commit 4cdba5a

File tree

1 file changed

+29
-16
lines changed

1 file changed

+29
-16
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def invoke(self, f: FuncInst, caller: Optional[Supertask], on_start, on_resolve)
198198
def tick(self):
199199
random.shuffle(self.waiting)
200200
for thread in self.waiting:
201-
if thread.ready():
201+
if thread.ready() and thread.task.inst.may_enter:
202202
thread.resume(Cancelled.FALSE)
203203
return
204204

@@ -285,7 +285,9 @@ class ComponentInstance:
285285
parent: Optional[ComponentInstance]
286286
handles: Table[ResourceHandle | Waitable | WaitableSet | ErrorContext]
287287
threads: Table[Thread]
288+
may_enter: bool
288289
may_leave: bool
290+
may_block: bool
289291
backpressure: int
290292
exclusive: Optional[Task]
291293
num_waiting_to_enter: int
@@ -296,7 +298,9 @@ def __init__(self, store, parent = None):
296298
self.parent = parent
297299
self.handles = Table()
298300
self.threads = Table()
301+
self.may_enter = True
299302
self.may_leave = True
303+
self.may_block = True
300304
self.backpressure = 0
301305
self.exclusive = None
302306
self.num_waiting_to_enter = 0
@@ -489,7 +493,8 @@ def resume_later(self):
489493

490494
def resume(self, cancelled):
491495
assert(self.cancellable or not cancelled)
492-
assert(not self.running())
496+
assert(not self.running() and self.task.inst.may_enter)
497+
self.task.inst.may_enter = False
493498
if self.waiting():
494499
assert(cancelled or self.ready())
495500
self.ready_func = None
@@ -506,9 +511,11 @@ def resume(self, cancelled):
506511
break
507512
thread = switch_to_thread
508513
cancelled = Cancelled.FALSE
514+
assert(not self.task.inst.may_enter)
515+
self.task.inst.may_enter = True
509516

510517
def suspend(self, cancellable) -> Cancelled:
511-
assert(self.running() and self.task.may_block())
518+
assert(self.running() and self.task.inst.may_block)
512519
if self.task.deliver_pending_cancel(cancellable):
513520
return Cancelled.TRUE
514521
self.cancellable = cancellable
@@ -517,7 +524,7 @@ def suspend(self, cancellable) -> Cancelled:
517524
return cancelled
518525

519526
def wait_until(self, ready_func, cancellable = False) -> Cancelled:
520-
assert(self.running() and self.task.may_block())
527+
assert(self.running() and self.task.inst.may_block)
521528
if self.task.deliver_pending_cancel(cancellable):
522529
return Cancelled.TRUE
523530
if ready_func() and not DETERMINISTIC_PROFILE and random.randint(0,1):
@@ -528,7 +535,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
528535

529536
def yield_until(self, ready_func, cancellable) -> Cancelled:
530537
assert(self.running())
531-
if self.task.may_block():
538+
if self.task.inst.may_block:
532539
return self.wait_until(ready_func, cancellable)
533540
else:
534541
assert(ready_func())
@@ -683,12 +690,11 @@ def thread_stop(self, thread):
683690
def needs_exclusive(self):
684691
return not self.opts.async_ or self.opts.callback
685692

686-
def may_block(self):
687-
return self.ft.async_ or self.state == Task.State.RESOLVED
688-
689693
def enter(self, thread):
690694
assert(thread in self.threads and thread.task is self)
691695
if not self.ft.async_:
696+
assert(self.inst.may_block)
697+
self.inst.may_block = False
692698
return True
693699
def has_backpressure():
694700
return self.inst.backpressure > 0 or (self.needs_exclusive() and bool(self.inst.exclusive))
@@ -739,13 +745,17 @@ def deliver_pending_cancel(self, cancellable) -> bool:
739745
def return_(self, result):
740746
trap_if(self.state == Task.State.RESOLVED)
741747
trap_if(self.num_borrows > 0)
748+
if not self.ft.async_:
749+
assert(not self.inst.may_block)
750+
self.inst.may_block = True
742751
assert(result is not None)
743752
self.on_resolve(result)
744753
self.state = Task.State.RESOLVED
745754

746755
def cancel(self):
747756
trap_if(self.state != Task.State.CANCEL_DELIVERED)
748757
trap_if(self.num_borrows > 0)
758+
assert(self.ft.async_)
749759
self.on_resolve(None)
750760
self.state = Task.State.RESOLVED
751761

@@ -2038,6 +2048,8 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
20382048

20392049
def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve) -> Call:
20402050
trap_if(call_might_be_recursive(caller, inst))
2051+
assert(inst.may_enter)
2052+
20412053
task = Task(opts, inst, ft, caller, on_resolve)
20422054
def thread_func(thread):
20432055
if not task.enter(thread):
@@ -2082,7 +2094,7 @@ def thread_func(thread):
20822094
else:
20832095
event = (EventCode.NONE, 0, 0)
20842096
case CallbackCode.WAIT:
2085-
trap_if(not task.may_block())
2097+
trap_if(not inst.may_block)
20862098
wset = inst.handles.get(si)
20872099
trap_if(not isinstance(wset, WaitableSet))
20882100
event = wset.wait_until(lambda: not inst.exclusive, thread, cancellable = True)
@@ -2098,6 +2110,7 @@ def thread_func(thread):
20982110

20992111
thread = Thread(task, thread_func)
21002112
thread.resume(Cancelled.FALSE)
2113+
assert(ft.async_ or task.state == Task.State.RESOLVED)
21012114
return task
21022115

21032116
class CallbackCode(IntEnum):
@@ -2124,7 +2137,7 @@ def call_and_trap_on_throw(callee, thread, args):
21242137

21252138
def canon_lower(opts, ft, callee: FuncInst, thread, flat_args):
21262139
trap_if(not thread.task.inst.may_leave)
2127-
trap_if(not thread.task.may_block() and ft.async_ and not opts.async_)
2140+
trap_if(not thread.task.inst.may_block and ft.async_ and not opts.async_)
21282141

21292142
subtask = Subtask()
21302143
cx = LiftLowerContext(opts, thread.task.inst, subtask)
@@ -2304,7 +2317,7 @@ def canon_waitable_set_new(thread):
23042317

23052318
def canon_waitable_set_wait(cancellable, mem, thread, si, ptr):
23062319
trap_if(not thread.task.inst.may_leave)
2307-
trap_if(not thread.task.may_block())
2320+
trap_if(not thread.task.inst.may_block)
23082321
wset = thread.task.inst.handles.get(si)
23092322
trap_if(not isinstance(wset, WaitableSet))
23102323
event = wset.wait(thread, cancellable)
@@ -2355,7 +2368,7 @@ def canon_waitable_join(thread, wi, si):
23552368

23562369
def canon_subtask_cancel(async_, thread, i):
23572370
trap_if(not thread.task.inst.may_leave)
2358-
trap_if(not thread.task.may_block() and not async_)
2371+
trap_if(not thread.task.inst.may_block and not async_)
23592372
subtask = thread.task.inst.handles.get(i)
23602373
trap_if(not isinstance(subtask, Subtask))
23612374
trap_if(subtask.resolve_delivered())
@@ -2412,7 +2425,7 @@ def canon_stream_write(stream_t, opts, thread, i, ptr, n):
24122425

24132426
def stream_copy(EndT, BufferT, event_code, stream_t, opts, thread, i, ptr, n):
24142427
trap_if(not thread.task.inst.may_leave)
2415-
trap_if(not thread.task.may_block() and not opts.async_)
2428+
trap_if(not thread.task.inst.may_block and not opts.async_)
24162429

24172430
e = thread.task.inst.handles.get(i)
24182431
trap_if(not isinstance(e, EndT))
@@ -2466,7 +2479,7 @@ def canon_future_write(future_t, opts, thread, i, ptr):
24662479

24672480
def future_copy(EndT, BufferT, event_code, future_t, opts, thread, i, ptr):
24682481
trap_if(not thread.task.inst.may_leave)
2469-
trap_if(not thread.task.may_block() and not opts.async_)
2482+
trap_if(not thread.task.inst.may_block and not opts.async_)
24702483

24712484
e = thread.task.inst.handles.get(i)
24722485
trap_if(not isinstance(e, EndT))
@@ -2518,7 +2531,7 @@ def canon_future_cancel_write(future_t, async_, thread, i):
25182531

25192532
def cancel_copy(EndT, event_code, stream_or_future_t, async_, thread, i):
25202533
trap_if(not thread.task.inst.may_leave)
2521-
trap_if(not thread.task.may_block() and not async_)
2534+
trap_if(not thread.task.inst.may_block and not async_)
25222535
e = thread.task.inst.handles.get(i)
25232536
trap_if(not isinstance(e, EndT))
25242537
trap_if(e.shared.t != stream_or_future_t.t)
@@ -2595,7 +2608,7 @@ def canon_thread_switch_to(cancellable, thread, i):
25952608

25962609
def canon_thread_suspend(cancellable, thread):
25972610
trap_if(not thread.task.inst.may_leave)
2598-
trap_if(not thread.task.may_block())
2611+
trap_if(not thread.task.inst.may_block)
25992612
cancelled = thread.suspend(cancellable)
26002613
return [cancelled]
26012614

0 commit comments

Comments
 (0)