Skip to content

Commit 3bdbcfc

Browse files
committed
CABI: fix may_block to not use the current task
1 parent e1d1138 commit 3bdbcfc

File tree

1 file changed

+24
-14
lines changed

1 file changed

+24
-14
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ class ComponentInstance:
259259
handles: Table[ResourceHandle | Waitable | WaitableSet | ErrorContext]
260260
threads: Table[Thread]
261261
may_leave: bool
262+
may_block: bool
262263
backpressure: int
263264
exclusive: bool
264265
num_waiting_to_enter: int
@@ -270,6 +271,7 @@ def __init__(self, store, parent = None):
270271
self.handles = Table()
271272
self.threads = Table()
272273
self.may_leave = True
274+
self.may_block = True
273275
self.backpressure = 0
274276
self.exclusive = False
275277
self.num_waiting_to_enter = 0
@@ -490,7 +492,7 @@ def resume(self, cancelled):
490492
cancelled = Cancelled.FALSE
491493

492494
def suspend(self, cancellable) -> Cancelled:
493-
assert(self.running() and self.task.may_block())
495+
assert(self.running() and self.task.inst.may_block)
494496
if self.task.deliver_pending_cancel(cancellable):
495497
return Cancelled.TRUE
496498
self.cancellable = cancellable
@@ -503,7 +505,7 @@ def suspend(self, cancellable) -> Cancelled:
503505
return cancelled
504506

505507
def wait_until(self, ready_func, cancellable = False) -> Cancelled:
506-
assert(self.running() and self.task.may_block())
508+
assert(self.running() and self.task.inst.may_block)
507509
if self.task.deliver_pending_cancel(cancellable):
508510
return Cancelled.TRUE
509511
if ready_func() and not DETERMINISTIC_PROFILE and random.randint(0,1):
@@ -514,7 +516,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
514516

515517
def yield_until(self, ready_func, cancellable) -> Cancelled:
516518
assert(self.running())
517-
if self.task.may_block():
519+
if self.task.inst.may_block:
518520
return self.wait_until(ready_func, cancellable)
519521
else:
520522
assert(ready_func())
@@ -672,12 +674,14 @@ def thread_stop(self, thread):
672674
def needs_exclusive(self):
673675
return not self.opts.async_ or self.opts.callback
674676

675-
def may_block(self):
676-
return self.ft.async_ or self.state == Task.State.RESOLVED
677-
678677
def enter(self, thread):
679678
assert(thread in self.threads and thread.task is self)
680679
if not self.ft.async_:
680+
# TODO: what makes this true? where is the assert or trap_if? specifically
681+
# for sibling reentrance. maybe need to add back may_enter and only clear
682+
# at cooperative yield points (based on type)
683+
assert(self.inst.may_block)
684+
self.inst.may_block = False
681685
return True
682686
def has_backpressure():
683687
return self.inst.backpressure > 0 or (self.needs_exclusive() and self.inst.exclusive)
@@ -696,6 +700,7 @@ def has_backpressure():
696700
def exit(self):
697701
assert(len(self.threads) > 0)
698702
if not self.ft.async_:
703+
assert(self.inst.may_block)
699704
return
700705
if self.needs_exclusive():
701706
assert(self.inst.exclusive)
@@ -720,13 +725,17 @@ def deliver_pending_cancel(self, cancellable) -> bool:
720725
def return_(self, result):
721726
trap_if(self.state == Task.State.RESOLVED)
722727
trap_if(self.num_borrows > 0)
728+
if not self.ft.async_:
729+
assert(not self.inst.may_block)
730+
self.inst.may_block = True
723731
assert(result is not None)
724732
self.on_resolve(result)
725733
self.state = Task.State.RESOLVED
726734

727735
def cancel(self):
728736
trap_if(self.state != Task.State.CANCEL_DELIVERED)
729737
trap_if(self.num_borrows > 0)
738+
assert(self.ft.async_)
730739
self.on_resolve(None)
731740
self.state = Task.State.RESOLVED
732741

@@ -2053,7 +2062,7 @@ def thread_func(thread):
20532062
else:
20542063
event = (EventCode.NONE, 0, 0)
20552064
case CallbackCode.WAIT:
2056-
trap_if(not task.may_block())
2065+
trap_if(not inst.may_block)
20572066
wset = inst.handles.get(si)
20582067
trap_if(not isinstance(wset, WaitableSet))
20592068
event = wset.wait_until(lambda: not inst.exclusive, thread, cancellable = True)
@@ -2069,6 +2078,7 @@ def thread_func(thread):
20692078

20702079
thread = Thread(task, thread_func)
20712080
thread.resume(Cancelled.FALSE)
2081+
assert(ft.async_ or task.state == Task.State.RESOLVED)
20722082
return task
20732083

20742084
class CallbackCode(IntEnum):
@@ -2095,7 +2105,7 @@ def call_and_trap_on_throw(callee, thread, args):
20952105

20962106
def canon_lower(opts, ft, callee: FuncInst, thread, flat_args):
20972107
trap_if(not thread.task.inst.may_leave)
2098-
trap_if(not thread.task.may_block() and ft.async_ and not opts.async_)
2108+
trap_if(not thread.task.inst.may_block and ft.async_ and not opts.async_)
20992109

21002110
subtask = Subtask()
21012111
cx = LiftLowerContext(opts, thread.task.inst, subtask)
@@ -2272,7 +2282,7 @@ def canon_waitable_set_new(thread):
22722282

22732283
def canon_waitable_set_wait(cancellable, mem, thread, si, ptr):
22742284
trap_if(not thread.task.inst.may_leave)
2275-
trap_if(not thread.task.may_block())
2285+
trap_if(not thread.task.inst.may_block)
22762286
wset = thread.task.inst.handles.get(si)
22772287
trap_if(not isinstance(wset, WaitableSet))
22782288
event = wset.wait(thread, cancellable)
@@ -2323,7 +2333,7 @@ def canon_waitable_join(thread, wi, si):
23232333

23242334
def canon_subtask_cancel(async_, thread, i):
23252335
trap_if(not thread.task.inst.may_leave)
2326-
trap_if(not thread.task.may_block() and not async_)
2336+
trap_if(not thread.task.inst.may_block and not async_)
23272337
subtask = thread.task.inst.handles.get(i)
23282338
trap_if(not isinstance(subtask, Subtask))
23292339
trap_if(subtask.resolve_delivered())
@@ -2380,7 +2390,7 @@ def canon_stream_write(stream_t, opts, thread, i, ptr, n):
23802390

23812391
def stream_copy(EndT, BufferT, event_code, stream_t, opts, thread, i, ptr, n):
23822392
trap_if(not thread.task.inst.may_leave)
2383-
trap_if(not thread.task.may_block() and not opts.async_)
2393+
trap_if(not thread.task.inst.may_block and not opts.async_)
23842394

23852395
e = thread.task.inst.handles.get(i)
23862396
trap_if(not isinstance(e, EndT))
@@ -2434,7 +2444,7 @@ def canon_future_write(future_t, opts, thread, i, ptr):
24342444

24352445
def future_copy(EndT, BufferT, event_code, future_t, opts, thread, i, ptr):
24362446
trap_if(not thread.task.inst.may_leave)
2437-
trap_if(not thread.task.may_block() and not opts.async_)
2447+
trap_if(not thread.task.inst.may_block and not opts.async_)
24382448

24392449
e = thread.task.inst.handles.get(i)
24402450
trap_if(not isinstance(e, EndT))
@@ -2486,7 +2496,7 @@ def canon_future_cancel_write(future_t, async_, thread, i):
24862496

24872497
def cancel_copy(EndT, event_code, stream_or_future_t, async_, thread, i):
24882498
trap_if(not thread.task.inst.may_leave)
2489-
trap_if(not thread.task.may_block() and not async_)
2499+
trap_if(not thread.task.inst.may_block and not async_)
24902500
e = thread.task.inst.handles.get(i)
24912501
trap_if(not isinstance(e, EndT))
24922502
trap_if(e.shared.t != stream_or_future_t.t)
@@ -2563,7 +2573,7 @@ def canon_thread_resume_later(thread, i):
25632573

25642574
def canon_thread_suspend(cancellable, thread):
25652575
trap_if(not thread.task.inst.may_leave)
2566-
trap_if(not thread.task.may_block())
2576+
trap_if(not thread.task.inst.may_block)
25672577
cancelled = thread.suspend(cancellable)
25682578
return [cancelled]
25692579

0 commit comments

Comments
 (0)